206 lines
5.0 KiB
Go
206 lines
5.0 KiB
Go
package migrations
|
|
|
|
import (
|
|
"butterfliu/config"
|
|
"database/sql"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"path/filepath"
|
|
"strconv"
|
|
|
|
_ "github.com/mattn/go-sqlite3"
|
|
)
|
|
|
|
// RunMigrationCLI runs the migration command-line interface
|
|
func RunMigrationCLI() {
|
|
if len(os.Args) < 2 {
|
|
printUsage()
|
|
os.Exit(1)
|
|
}
|
|
|
|
command := os.Args[1]
|
|
|
|
// Load configuration
|
|
cfg := config.LoadConfig()
|
|
|
|
// Database file path - use configuration
|
|
dbPath := cfg.DatabasePath
|
|
log.Printf("Using database: %s", dbPath)
|
|
|
|
// Open database connection
|
|
db, err := sql.Open("sqlite3", dbPath)
|
|
if err != nil {
|
|
log.Fatalf("Error opening database %s: %v", dbPath, err)
|
|
}
|
|
defer db.Close()
|
|
|
|
// Test connection
|
|
if err = db.Ping(); err != nil {
|
|
log.Fatalf("Error connecting to database %s: %v", dbPath, err)
|
|
}
|
|
|
|
// Initialize migration table
|
|
if err = InitMigrationTable(db); err != nil {
|
|
log.Fatalf("Error initializing migration table: %v", err)
|
|
}
|
|
|
|
switch command {
|
|
case "up":
|
|
// Run all pending migrations
|
|
if err := MigrateUp(db); err != nil {
|
|
log.Fatalf("Error applying migrations: %v", err)
|
|
}
|
|
fmt.Println("All migrations applied successfully")
|
|
|
|
case "down":
|
|
if len(os.Args) < 3 {
|
|
fmt.Println("Error: Missing target version")
|
|
printUsage()
|
|
os.Exit(1)
|
|
}
|
|
|
|
targetVersion, err := strconv.Atoi(os.Args[2])
|
|
if err != nil {
|
|
log.Fatalf("Error parsing target version: %v", err)
|
|
}
|
|
|
|
// Roll back migrations to target version
|
|
if err := MigrateDown(db, targetVersion); err != nil {
|
|
log.Fatalf("Error rolling back migrations: %v", err)
|
|
}
|
|
fmt.Printf("Migrations rolled back to version %d successfully\n", targetVersion)
|
|
|
|
case "reset":
|
|
// Roll back all migrations
|
|
if err := MigrateDown(db, 0); err != nil {
|
|
log.Fatalf("Error rolling back migrations: %v", err)
|
|
}
|
|
fmt.Println("All migrations rolled back successfully")
|
|
|
|
case "refresh":
|
|
// Roll back all migrations and then apply them again
|
|
if err := MigrateDown(db, 0); err != nil {
|
|
log.Fatalf("Error rolling back migrations: %v", err)
|
|
}
|
|
if err := MigrateUp(db); err != nil {
|
|
log.Fatalf("Error applying migrations: %v", err)
|
|
}
|
|
fmt.Println("All migrations refreshed successfully")
|
|
|
|
case "status":
|
|
// Show migration status
|
|
migrations, err := ListMigrations(db)
|
|
if err != nil {
|
|
log.Fatalf("Error listing migrations: %v", err)
|
|
}
|
|
|
|
fmt.Println("Migration Status:")
|
|
fmt.Println("=================")
|
|
for _, migration := range migrations {
|
|
status := "Pending"
|
|
appliedAt := ""
|
|
if applied, ok := migration["applied"].(bool); ok && applied {
|
|
status = "Applied"
|
|
appliedAt = migration["applied_at"].(string)
|
|
}
|
|
fmt.Printf("%d: %s - %s %s\n", migration["version"], migration["description"], status, appliedAt)
|
|
}
|
|
|
|
case "create":
|
|
if len(os.Args) < 3 {
|
|
fmt.Println("Error: Missing migration name")
|
|
printUsage()
|
|
os.Exit(1)
|
|
}
|
|
|
|
name := os.Args[2]
|
|
if err := CreateMigration(name); err != nil {
|
|
log.Fatalf("Error creating migration: %v", err)
|
|
}
|
|
|
|
default:
|
|
fmt.Printf("Error: Unknown command '%s'\n", command)
|
|
printUsage()
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
func printUsage() {
|
|
fmt.Println("Usage: go run cmd/migrate/main.go <command> [args]")
|
|
fmt.Println("")
|
|
fmt.Println("Commands:")
|
|
fmt.Println(" up Apply all pending migrations")
|
|
fmt.Println(" down <version> Roll back migrations to specified version")
|
|
fmt.Println(" reset Roll back all migrations")
|
|
fmt.Println(" refresh Roll back all migrations and apply them again")
|
|
fmt.Println(" status Show migration status")
|
|
fmt.Println(" create <name> Create a new migration")
|
|
}
|
|
|
|
func CreateMigration(name string) error {
|
|
// Get next version number
|
|
allMigrations := getMigrations()
|
|
nextVersion := 1
|
|
for _, migration := range allMigrations {
|
|
if migration.Version >= nextVersion {
|
|
nextVersion = migration.Version + 1
|
|
}
|
|
}
|
|
|
|
// Create migration file - use absolute path from current working directory
|
|
filename := fmt.Sprintf("%03d_%s.go", nextVersion, name)
|
|
// Get the absolute path to the migrations directory
|
|
migrationDir, err := filepath.Abs("migrations")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get migrations directory path: %v", err)
|
|
}
|
|
migrationPath := filepath.Join(migrationDir, filename)
|
|
|
|
// Check if file already exists
|
|
if _, err := os.Stat(migrationPath); err == nil {
|
|
return fmt.Errorf("migration file already exists: %s", migrationPath)
|
|
}
|
|
|
|
// Create migration file content
|
|
content := fmt.Sprintf(`package migrations
|
|
|
|
import (
|
|
"database/sql"
|
|
)
|
|
|
|
func init() {
|
|
RegisterMigration(
|
|
%d,
|
|
"%s",
|
|
migrate%sUp,
|
|
migrate%sDown,
|
|
)
|
|
}
|
|
|
|
func migrate%sUp(tx *sql.Tx) error {
|
|
// TODO: Implement migration up
|
|
_, err := tx.Exec(`+"`"+`
|
|
-- Add your SQL here
|
|
`+"`"+`)
|
|
return err
|
|
}
|
|
|
|
func migrate%sDown(tx *sql.Tx) error {
|
|
// TODO: Implement migration down
|
|
_, err := tx.Exec(`+"`"+`
|
|
-- Add your SQL here
|
|
`+"`"+`)
|
|
return err
|
|
}
|
|
`, nextVersion, name, name, name, name, name)
|
|
|
|
// Write file
|
|
if err := os.WriteFile(migrationPath, []byte(content), 0644); err != nil {
|
|
return err
|
|
}
|
|
|
|
fmt.Printf("Created migration file: %s\n", migrationPath)
|
|
return nil
|
|
}
|