""" 数据库迁移管理模块 启动时自动检查并执行数据库迁移 """ import os import re import asyncio from datetime import datetime from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncSession from config.database import async_engine, AsyncSessionLocal import logging logger = logging.getLogger(__name__) # 迁移文件目录 MIGRATION_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'sql', 'migration') class DatabaseMigration: """ 数据库迁移管理类 """ @classmethod async def init_migration_table(cls, db: AsyncSession): """ 初始化迁移记录表 """ create_table_sql = """ CREATE TABLE IF NOT EXISTS db_migration ( id INT AUTO_INCREMENT PRIMARY KEY, version VARCHAR(50) NOT NULL COMMENT '迁移版本号', filename VARCHAR(255) NOT NULL COMMENT '迁移文件名', executed_at DATETIME NOT NULL COMMENT '执行时间', checksum VARCHAR(64) COMMENT '文件校验和', UNIQUE KEY uk_version (version) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='数据库迁移记录表'; """ await db.execute(text(create_table_sql)) await db.commit() logger.info("迁移记录表初始化完成") @classmethod async def get_executed_migrations(cls, db: AsyncSession) -> list: """ 获取已执行的迁移记录 """ result = await db.execute( text("SELECT version FROM db_migration ORDER BY version") ) return [row[0] for row in result.fetchall()] @classmethod async def record_migration(cls, db: AsyncSession, version: str, filename: str, checksum: str = None): """ 记录迁移执行 """ await db.execute( text(""" INSERT INTO db_migration (version, filename, executed_at, checksum) VALUES (:version, :filename, :executed_at, :checksum) """), { 'version': version, 'filename': filename, 'executed_at': datetime.now(), 'checksum': checksum } ) await db.commit() @classmethod def get_migration_files(cls) -> list: """ 获取所有迁移文件,按版本号排序 """ if not os.path.exists(MIGRATION_DIR): logger.warning(f"迁移目录不存在: {MIGRATION_DIR}") return [] migration_files = [] pattern = re.compile(r'^(\d{8})_.*\.sql$') for filename in os.listdir(MIGRATION_DIR): match = pattern.match(filename) if match: version = match.group(1) migration_files.append({ 'version': version, 'filename': filename, 'filepath': os.path.join(MIGRATION_DIR, filename) }) # 按版本号排序 migration_files.sort(key=lambda x: x['version']) return migration_files @classmethod async def execute_migration_file(cls, db: AsyncSession, filepath: str): """ 执行单个迁移文件 """ with open(filepath, 'r', encoding='utf-8') as f: sql_content = f.read() # 分割多条SQL语句 statements = [stmt.strip() for stmt in sql_content.split(';') if stmt.strip()] for statement in statements: if statement: try: await db.execute(text(statement)) logger.info(f"执行SQL: {statement[:100]}...") except Exception as e: logger.error(f"执行SQL失败: {statement[:100]}... 错误: {e}") raise await db.commit() @classmethod async def run_migrations(cls): """ 运行所有待执行的迁移 """ async with AsyncSessionLocal() as db: try: # 1. 初始化迁移表 await cls.init_migration_table(db) # 2. 获取已执行的迁移 executed_versions = await cls.get_executed_migrations(db) logger.info(f"已执行的迁移: {executed_versions}") # 3. 获取所有迁移文件 migration_files = cls.get_migration_files() if not migration_files: logger.info("没有待执行的迁移文件") return # 4. 执行未执行的迁移 for migration in migration_files: version = migration['version'] filename = migration['filename'] filepath = migration['filepath'] if version in executed_versions: logger.info(f"迁移 {version} 已执行,跳过") continue logger.info(f"开始执行迁移: {filename}") try: # 执行迁移文件 await cls.execute_migration_file(db, filepath) # 记录迁移 await cls.record_migration(db, version, filename) logger.info(f"迁移 {version} 执行成功") except Exception as e: logger.error(f"迁移 {version} 执行失败: {e}") raise logger.info("数据库迁移完成") except Exception as e: await db.rollback() logger.error(f"数据库迁移失败: {e}") raise async def run_database_migrations(): """ 入口函数:运行数据库迁移 """ logger.info("=" * 50) logger.info("开始数据库迁移检查...") logger.info("=" * 50) try: await DatabaseMigration.run_migrations() logger.info("数据库迁移检查完成") except Exception as e: logger.error(f"数据库迁移检查失败: {e}") # 不抛出异常,让应用继续启动(可以根据需求调整) # raise # 兼容旧版本的同步调用方式 def run_migrations_sync(): """ 同步方式运行迁移(用于非异步上下文) """ try: asyncio.run(run_database_migrations()) except Exception as e: logger.error(f"同步迁移执行失败: {e}")