Files
butterfliu/migrations/cli.go
2026-01-08 20:48:55 +08:00

206 lines
5.0 KiB
Go

package migrations
import (
"butterfliu/config"
"database/sql"
"fmt"
"log"
"os"
"path/filepath"
"strconv"
_ "github.com/mattn/go-sqlite3"
)
// RunMigrationCLI runs the migration command-line interface
func RunMigrationCLI() {
if len(os.Args) < 2 {
printUsage()
os.Exit(1)
}
command := os.Args[1]
// Load configuration
cfg := config.LoadConfig()
// Database file path - use configuration
dbPath := cfg.DatabasePath
log.Printf("Using database: %s", dbPath)
// Open database connection
db, err := sql.Open("sqlite3", dbPath)
if err != nil {
log.Fatalf("Error opening database %s: %v", dbPath, err)
}
defer db.Close()
// Test connection
if err = db.Ping(); err != nil {
log.Fatalf("Error connecting to database %s: %v", dbPath, err)
}
// Initialize migration table
if err = InitMigrationTable(db); err != nil {
log.Fatalf("Error initializing migration table: %v", err)
}
switch command {
case "up":
// Run all pending migrations
if err := MigrateUp(db); err != nil {
log.Fatalf("Error applying migrations: %v", err)
}
fmt.Println("All migrations applied successfully")
case "down":
if len(os.Args) < 3 {
fmt.Println("Error: Missing target version")
printUsage()
os.Exit(1)
}
targetVersion, err := strconv.Atoi(os.Args[2])
if err != nil {
log.Fatalf("Error parsing target version: %v", err)
}
// Roll back migrations to target version
if err := MigrateDown(db, targetVersion); err != nil {
log.Fatalf("Error rolling back migrations: %v", err)
}
fmt.Printf("Migrations rolled back to version %d successfully\n", targetVersion)
case "reset":
// Roll back all migrations
if err := MigrateDown(db, 0); err != nil {
log.Fatalf("Error rolling back migrations: %v", err)
}
fmt.Println("All migrations rolled back successfully")
case "refresh":
// Roll back all migrations and then apply them again
if err := MigrateDown(db, 0); err != nil {
log.Fatalf("Error rolling back migrations: %v", err)
}
if err := MigrateUp(db); err != nil {
log.Fatalf("Error applying migrations: %v", err)
}
fmt.Println("All migrations refreshed successfully")
case "status":
// Show migration status
migrations, err := ListMigrations(db)
if err != nil {
log.Fatalf("Error listing migrations: %v", err)
}
fmt.Println("Migration Status:")
fmt.Println("=================")
for _, migration := range migrations {
status := "Pending"
appliedAt := ""
if applied, ok := migration["applied"].(bool); ok && applied {
status = "Applied"
appliedAt = migration["applied_at"].(string)
}
fmt.Printf("%d: %s - %s %s\n", migration["version"], migration["description"], status, appliedAt)
}
case "create":
if len(os.Args) < 3 {
fmt.Println("Error: Missing migration name")
printUsage()
os.Exit(1)
}
name := os.Args[2]
if err := CreateMigration(name); err != nil {
log.Fatalf("Error creating migration: %v", err)
}
default:
fmt.Printf("Error: Unknown command '%s'\n", command)
printUsage()
os.Exit(1)
}
}
func printUsage() {
fmt.Println("Usage: go run cmd/migrate/main.go <command> [args]")
fmt.Println("")
fmt.Println("Commands:")
fmt.Println(" up Apply all pending migrations")
fmt.Println(" down <version> Roll back migrations to specified version")
fmt.Println(" reset Roll back all migrations")
fmt.Println(" refresh Roll back all migrations and apply them again")
fmt.Println(" status Show migration status")
fmt.Println(" create <name> Create a new migration")
}
func CreateMigration(name string) error {
// Get next version number
allMigrations := getMigrations()
nextVersion := 1
for _, migration := range allMigrations {
if migration.Version >= nextVersion {
nextVersion = migration.Version + 1
}
}
// Create migration file - use absolute path from current working directory
filename := fmt.Sprintf("%03d_%s.go", nextVersion, name)
// Get the absolute path to the migrations directory
migrationDir, err := filepath.Abs("migrations")
if err != nil {
return fmt.Errorf("failed to get migrations directory path: %v", err)
}
migrationPath := filepath.Join(migrationDir, filename)
// Check if file already exists
if _, err := os.Stat(migrationPath); err == nil {
return fmt.Errorf("migration file already exists: %s", migrationPath)
}
// Create migration file content
content := fmt.Sprintf(`package migrations
import (
"database/sql"
)
func init() {
RegisterMigration(
%d,
"%s",
migrate%sUp,
migrate%sDown,
)
}
func migrate%sUp(tx *sql.Tx) error {
// TODO: Implement migration up
_, err := tx.Exec(`+"`"+`
-- Add your SQL here
`+"`"+`)
return err
}
func migrate%sDown(tx *sql.Tx) error {
// TODO: Implement migration down
_, err := tx.Exec(`+"`"+`
-- Add your SQL here
`+"`"+`)
return err
}
`, nextVersion, name, name, name, name, name)
// Write file
if err := os.WriteFile(migrationPath, []byte(content), 0644); err != nil {
return err
}
fmt.Printf("Created migration file: %s\n", migrationPath)
return nil
}