package migrations import ( "database/sql" "fmt" "log" "sort" "sync" "time" ) // Migration represents a database migration type Migration struct { Version int Description string Up func(*sql.Tx) error Down func(*sql.Tx) error } var ( migrations = []Migration{} mu sync.RWMutex ) // RegisterMigration adds a migration to the list of available migrations func RegisterMigration(version int, description string, up, down func(*sql.Tx) error) { mu.Lock() defer mu.Unlock() migrations = append(migrations, Migration{ Version: version, Description: description, Up: up, Down: down, }) // Sort migrations by version sort.Slice(migrations, func(i, j int) bool { return migrations[i].Version < migrations[j].Version }) } // getMigrations returns a copy of the migrations list (thread-safe) func getMigrations() []Migration { mu.RLock() defer mu.RUnlock() result := make([]Migration, len(migrations)) copy(result, migrations) return result } // InitMigrationTable creates the schema_migrations table if it doesn't exist func InitMigrationTable(db *sql.DB) error { _, err := db.Exec(` CREATE TABLE IF NOT EXISTS schema_migrations ( version INTEGER PRIMARY KEY, applied_at TIMESTAMP NOT NULL ) `) return err } // GetCurrentVersion returns the current database schema version func GetCurrentVersion(db *sql.DB) (int, error) { var version int err := db.QueryRow(` SELECT COALESCE(MAX(version), 0) FROM schema_migrations `).Scan(&version) return version, err } // MigrateUp applies all pending migrations func MigrateUp(db *sql.DB) error { // Get current version currentVersion, err := GetCurrentVersion(db) if err != nil { return err } // Apply pending migrations allMigrations := getMigrations() for _, migration := range allMigrations { if migration.Version > currentVersion { log.Printf("Applying migration %d: %s", migration.Version, migration.Description) // Begin transaction tx, err := db.Begin() if err != nil { return err } // Apply migration if err := migration.Up(tx); err != nil { tx.Rollback() return fmt.Errorf("error applying migration %d: %w", migration.Version, err) } // Record migration _, err = tx.Exec( "INSERT INTO schema_migrations (version, applied_at) VALUES (?, ?)", migration.Version, time.Now(), ) if err != nil { tx.Rollback() return fmt.Errorf("error recording migration %d: %w", migration.Version, err) } // Commit transaction if err := tx.Commit(); err != nil { return fmt.Errorf("error committing migration %d: %w", migration.Version, err) } log.Printf("Successfully applied migration %d", migration.Version) } } return nil } // MigrateDown rolls back migrations to the specified version // If targetVersion is 0, rolls back all migrations func MigrateDown(db *sql.DB, targetVersion int) error { // Get current version currentVersion, err := GetCurrentVersion(db) if err != nil { return err } // Get a copy of migrations and sort in reverse order allMigrations := getMigrations() sort.Slice(allMigrations, func(i, j int) bool { return allMigrations[i].Version > allMigrations[j].Version }) // Roll back migrations for _, migration := range allMigrations { if migration.Version <= currentVersion && migration.Version > targetVersion { log.Printf("Rolling back migration %d: %s", migration.Version, migration.Description) // Begin transaction tx, err := db.Begin() if err != nil { return err } // Apply down migration if err := migration.Down(tx); err != nil { tx.Rollback() return fmt.Errorf("error rolling back migration %d: %w", migration.Version, err) } // Remove migration record _, err = tx.Exec( "DELETE FROM schema_migrations WHERE version = ?", migration.Version, ) if err != nil { tx.Rollback() return fmt.Errorf("error removing migration record %d: %w", migration.Version, err) } // Commit transaction if err := tx.Commit(); err != nil { return fmt.Errorf("error committing rollback of migration %d: %w", migration.Version, err) } log.Printf("Successfully rolled back migration %d", migration.Version) } } return nil } // ListMigrations returns a list of all available migrations and their status func ListMigrations(db *sql.DB) ([]map[string]interface{}, error) { // Get current version (unused but kept for reference) _, err := GetCurrentVersion(db) if err != nil { return nil, err } // Get applied migrations with timestamps rows, err := db.Query("SELECT version, applied_at FROM schema_migrations") if err != nil { return nil, err } defer rows.Close() appliedMigrations := make(map[int]time.Time) for rows.Next() { var version int var appliedAt time.Time if err := rows.Scan(&version, &appliedAt); err != nil { return nil, err } appliedMigrations[version] = appliedAt } // Create result var result []map[string]interface{} allMigrations := getMigrations() for _, migration := range allMigrations { appliedAt, applied := appliedMigrations[migration.Version] appliedAtStr := "" if applied { appliedAtStr = appliedAt.Format("2006-01-02 15:04:05") } result = append(result, map[string]interface{}{ "version": migration.Version, "description": migration.Description, "applied": applied, "applied_at": appliedAtStr, }) } return result, nil }