Files
butterfliu/migrations/migrations.go
2026-01-08 20:48:55 +08:00

219 lines
5.3 KiB
Go

package migrations
import (
"database/sql"
"fmt"
"log"
"sort"
"sync"
"time"
)
// Migration represents a database migration
type Migration struct {
Version int
Description string
Up func(*sql.Tx) error
Down func(*sql.Tx) error
}
var (
migrations = []Migration{}
mu sync.RWMutex
)
// RegisterMigration adds a migration to the list of available migrations
func RegisterMigration(version int, description string, up, down func(*sql.Tx) error) {
mu.Lock()
defer mu.Unlock()
migrations = append(migrations, Migration{
Version: version,
Description: description,
Up: up,
Down: down,
})
// Sort migrations by version
sort.Slice(migrations, func(i, j int) bool {
return migrations[i].Version < migrations[j].Version
})
}
// getMigrations returns a copy of the migrations list (thread-safe)
func getMigrations() []Migration {
mu.RLock()
defer mu.RUnlock()
result := make([]Migration, len(migrations))
copy(result, migrations)
return result
}
// InitMigrationTable creates the schema_migrations table if it doesn't exist
func InitMigrationTable(db *sql.DB) error {
_, err := db.Exec(`
CREATE TABLE IF NOT EXISTS schema_migrations (
version INTEGER PRIMARY KEY,
applied_at TIMESTAMP NOT NULL
)
`)
return err
}
// GetCurrentVersion returns the current database schema version
func GetCurrentVersion(db *sql.DB) (int, error) {
var version int
err := db.QueryRow(`
SELECT COALESCE(MAX(version), 0) FROM schema_migrations
`).Scan(&version)
return version, err
}
// MigrateUp applies all pending migrations
func MigrateUp(db *sql.DB) error {
// Get current version
currentVersion, err := GetCurrentVersion(db)
if err != nil {
return err
}
// Apply pending migrations
allMigrations := getMigrations()
for _, migration := range allMigrations {
if migration.Version > currentVersion {
log.Printf("Applying migration %d: %s", migration.Version, migration.Description)
// Begin transaction
tx, err := db.Begin()
if err != nil {
return err
}
// Apply migration
if err := migration.Up(tx); err != nil {
tx.Rollback()
return fmt.Errorf("error applying migration %d: %w", migration.Version, err)
}
// Record migration
_, err = tx.Exec(
"INSERT INTO schema_migrations (version, applied_at) VALUES (?, ?)",
migration.Version, time.Now(),
)
if err != nil {
tx.Rollback()
return fmt.Errorf("error recording migration %d: %w", migration.Version, err)
}
// Commit transaction
if err := tx.Commit(); err != nil {
return fmt.Errorf("error committing migration %d: %w", migration.Version, err)
}
log.Printf("Successfully applied migration %d", migration.Version)
}
}
return nil
}
// MigrateDown rolls back migrations to the specified version
// If targetVersion is 0, rolls back all migrations
func MigrateDown(db *sql.DB, targetVersion int) error {
// Get current version
currentVersion, err := GetCurrentVersion(db)
if err != nil {
return err
}
// Get a copy of migrations and sort in reverse order
allMigrations := getMigrations()
sort.Slice(allMigrations, func(i, j int) bool {
return allMigrations[i].Version > allMigrations[j].Version
})
// Roll back migrations
for _, migration := range allMigrations {
if migration.Version <= currentVersion && migration.Version > targetVersion {
log.Printf("Rolling back migration %d: %s", migration.Version, migration.Description)
// Begin transaction
tx, err := db.Begin()
if err != nil {
return err
}
// Apply down migration
if err := migration.Down(tx); err != nil {
tx.Rollback()
return fmt.Errorf("error rolling back migration %d: %w", migration.Version, err)
}
// Remove migration record
_, err = tx.Exec(
"DELETE FROM schema_migrations WHERE version = ?",
migration.Version,
)
if err != nil {
tx.Rollback()
return fmt.Errorf("error removing migration record %d: %w", migration.Version, err)
}
// Commit transaction
if err := tx.Commit(); err != nil {
return fmt.Errorf("error committing rollback of migration %d: %w", migration.Version, err)
}
log.Printf("Successfully rolled back migration %d", migration.Version)
}
}
return nil
}
// ListMigrations returns a list of all available migrations and their status
func ListMigrations(db *sql.DB) ([]map[string]interface{}, error) {
// Get current version (unused but kept for reference)
_, err := GetCurrentVersion(db)
if err != nil {
return nil, err
}
// Get applied migrations with timestamps
rows, err := db.Query("SELECT version, applied_at FROM schema_migrations")
if err != nil {
return nil, err
}
defer rows.Close()
appliedMigrations := make(map[int]time.Time)
for rows.Next() {
var version int
var appliedAt time.Time
if err := rows.Scan(&version, &appliedAt); err != nil {
return nil, err
}
appliedMigrations[version] = appliedAt
}
// Create result
var result []map[string]interface{}
allMigrations := getMigrations()
for _, migration := range allMigrations {
appliedAt, applied := appliedMigrations[migration.Version]
appliedAtStr := ""
if applied {
appliedAtStr = appliedAt.Format("2006-01-02 15:04:05")
}
result = append(result, map[string]interface{}{
"version": migration.Version,
"description": migration.Description,
"applied": applied,
"applied_at": appliedAtStr,
})
}
return result, nil
}