219 lines
5.3 KiB
Go
219 lines
5.3 KiB
Go
package migrations
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"log"
|
|
"sort"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// Migration represents a database migration
|
|
type Migration struct {
|
|
Version int
|
|
Description string
|
|
Up func(*sql.Tx) error
|
|
Down func(*sql.Tx) error
|
|
}
|
|
|
|
var (
|
|
migrations = []Migration{}
|
|
mu sync.RWMutex
|
|
)
|
|
|
|
// RegisterMigration adds a migration to the list of available migrations
|
|
func RegisterMigration(version int, description string, up, down func(*sql.Tx) error) {
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
migrations = append(migrations, Migration{
|
|
Version: version,
|
|
Description: description,
|
|
Up: up,
|
|
Down: down,
|
|
})
|
|
|
|
// Sort migrations by version
|
|
sort.Slice(migrations, func(i, j int) bool {
|
|
return migrations[i].Version < migrations[j].Version
|
|
})
|
|
}
|
|
|
|
// getMigrations returns a copy of the migrations list (thread-safe)
|
|
func getMigrations() []Migration {
|
|
mu.RLock()
|
|
defer mu.RUnlock()
|
|
|
|
result := make([]Migration, len(migrations))
|
|
copy(result, migrations)
|
|
return result
|
|
}
|
|
|
|
// InitMigrationTable creates the schema_migrations table if it doesn't exist
|
|
func InitMigrationTable(db *sql.DB) error {
|
|
_, err := db.Exec(`
|
|
CREATE TABLE IF NOT EXISTS schema_migrations (
|
|
version INTEGER PRIMARY KEY,
|
|
applied_at TIMESTAMP NOT NULL
|
|
)
|
|
`)
|
|
return err
|
|
}
|
|
|
|
// GetCurrentVersion returns the current database schema version
|
|
func GetCurrentVersion(db *sql.DB) (int, error) {
|
|
var version int
|
|
err := db.QueryRow(`
|
|
SELECT COALESCE(MAX(version), 0) FROM schema_migrations
|
|
`).Scan(&version)
|
|
return version, err
|
|
}
|
|
|
|
// MigrateUp applies all pending migrations
|
|
func MigrateUp(db *sql.DB) error {
|
|
// Get current version
|
|
currentVersion, err := GetCurrentVersion(db)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Apply pending migrations
|
|
allMigrations := getMigrations()
|
|
for _, migration := range allMigrations {
|
|
if migration.Version > currentVersion {
|
|
log.Printf("Applying migration %d: %s", migration.Version, migration.Description)
|
|
|
|
// Begin transaction
|
|
tx, err := db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Apply migration
|
|
if err := migration.Up(tx); err != nil {
|
|
tx.Rollback()
|
|
return fmt.Errorf("error applying migration %d: %w", migration.Version, err)
|
|
}
|
|
|
|
// Record migration
|
|
_, err = tx.Exec(
|
|
"INSERT INTO schema_migrations (version, applied_at) VALUES (?, ?)",
|
|
migration.Version, time.Now(),
|
|
)
|
|
if err != nil {
|
|
tx.Rollback()
|
|
return fmt.Errorf("error recording migration %d: %w", migration.Version, err)
|
|
}
|
|
|
|
// Commit transaction
|
|
if err := tx.Commit(); err != nil {
|
|
return fmt.Errorf("error committing migration %d: %w", migration.Version, err)
|
|
}
|
|
|
|
log.Printf("Successfully applied migration %d", migration.Version)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// MigrateDown rolls back migrations to the specified version
|
|
// If targetVersion is 0, rolls back all migrations
|
|
func MigrateDown(db *sql.DB, targetVersion int) error {
|
|
// Get current version
|
|
currentVersion, err := GetCurrentVersion(db)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Get a copy of migrations and sort in reverse order
|
|
allMigrations := getMigrations()
|
|
sort.Slice(allMigrations, func(i, j int) bool {
|
|
return allMigrations[i].Version > allMigrations[j].Version
|
|
})
|
|
|
|
// Roll back migrations
|
|
for _, migration := range allMigrations {
|
|
if migration.Version <= currentVersion && migration.Version > targetVersion {
|
|
log.Printf("Rolling back migration %d: %s", migration.Version, migration.Description)
|
|
|
|
// Begin transaction
|
|
tx, err := db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Apply down migration
|
|
if err := migration.Down(tx); err != nil {
|
|
tx.Rollback()
|
|
return fmt.Errorf("error rolling back migration %d: %w", migration.Version, err)
|
|
}
|
|
|
|
// Remove migration record
|
|
_, err = tx.Exec(
|
|
"DELETE FROM schema_migrations WHERE version = ?",
|
|
migration.Version,
|
|
)
|
|
if err != nil {
|
|
tx.Rollback()
|
|
return fmt.Errorf("error removing migration record %d: %w", migration.Version, err)
|
|
}
|
|
|
|
// Commit transaction
|
|
if err := tx.Commit(); err != nil {
|
|
return fmt.Errorf("error committing rollback of migration %d: %w", migration.Version, err)
|
|
}
|
|
|
|
log.Printf("Successfully rolled back migration %d", migration.Version)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ListMigrations returns a list of all available migrations and their status
|
|
func ListMigrations(db *sql.DB) ([]map[string]interface{}, error) {
|
|
// Get current version (unused but kept for reference)
|
|
_, err := GetCurrentVersion(db)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Get applied migrations with timestamps
|
|
rows, err := db.Query("SELECT version, applied_at FROM schema_migrations")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
appliedMigrations := make(map[int]time.Time)
|
|
for rows.Next() {
|
|
var version int
|
|
var appliedAt time.Time
|
|
if err := rows.Scan(&version, &appliedAt); err != nil {
|
|
return nil, err
|
|
}
|
|
appliedMigrations[version] = appliedAt
|
|
}
|
|
|
|
// Create result
|
|
var result []map[string]interface{}
|
|
allMigrations := getMigrations()
|
|
for _, migration := range allMigrations {
|
|
appliedAt, applied := appliedMigrations[migration.Version]
|
|
appliedAtStr := ""
|
|
if applied {
|
|
appliedAtStr = appliedAt.Format("2006-01-02 15:04:05")
|
|
}
|
|
result = append(result, map[string]interface{}{
|
|
"version": migration.Version,
|
|
"description": migration.Description,
|
|
"applied": applied,
|
|
"applied_at": appliedAtStr,
|
|
})
|
|
}
|
|
|
|
return result, nil
|
|
}
|