ETest-Vue-FastAPI/ruoyi-fastapi-backend/config/migration.py

200 lines
6.6 KiB
Python
Raw Permalink Normal View History

2026-04-14 10:53:22 +08:00
"""
数据库迁移管理模块
启动时自动检查并执行数据库迁移
"""
import os
import re
import asyncio
from datetime import datetime
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from config.database import async_engine, AsyncSessionLocal
import logging
logger = logging.getLogger(__name__)
# 迁移文件目录
MIGRATION_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'sql', 'migration')
class DatabaseMigration:
"""
数据库迁移管理类
"""
@classmethod
async def init_migration_table(cls, db: AsyncSession):
"""
初始化迁移记录表
"""
create_table_sql = """
CREATE TABLE IF NOT EXISTS db_migration (
id INT AUTO_INCREMENT PRIMARY KEY,
version VARCHAR(50) NOT NULL COMMENT '迁移版本号',
filename VARCHAR(255) NOT NULL COMMENT '迁移文件名',
executed_at DATETIME NOT NULL COMMENT '执行时间',
checksum VARCHAR(64) COMMENT '文件校验和',
UNIQUE KEY uk_version (version)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='数据库迁移记录表';
"""
await db.execute(text(create_table_sql))
await db.commit()
logger.info("迁移记录表初始化完成")
@classmethod
async def get_executed_migrations(cls, db: AsyncSession) -> list:
"""
获取已执行的迁移记录
"""
result = await db.execute(
text("SELECT version FROM db_migration ORDER BY version")
)
return [row[0] for row in result.fetchall()]
@classmethod
async def record_migration(cls, db: AsyncSession, version: str, filename: str, checksum: str = None):
"""
记录迁移执行
"""
await db.execute(
text("""
INSERT INTO db_migration (version, filename, executed_at, checksum)
VALUES (:version, :filename, :executed_at, :checksum)
"""),
{
'version': version,
'filename': filename,
'executed_at': datetime.now(),
'checksum': checksum
}
)
await db.commit()
@classmethod
def get_migration_files(cls) -> list:
"""
获取所有迁移文件按版本号排序
"""
if not os.path.exists(MIGRATION_DIR):
logger.warning(f"迁移目录不存在: {MIGRATION_DIR}")
return []
migration_files = []
pattern = re.compile(r'^(\d{8})_.*\.sql$')
for filename in os.listdir(MIGRATION_DIR):
match = pattern.match(filename)
if match:
version = match.group(1)
migration_files.append({
'version': version,
'filename': filename,
'filepath': os.path.join(MIGRATION_DIR, filename)
})
# 按版本号排序
migration_files.sort(key=lambda x: x['version'])
return migration_files
@classmethod
async def execute_migration_file(cls, db: AsyncSession, filepath: str):
"""
执行单个迁移文件
"""
with open(filepath, 'r', encoding='utf-8') as f:
sql_content = f.read()
# 分割多条SQL语句
statements = [stmt.strip() for stmt in sql_content.split(';') if stmt.strip()]
for statement in statements:
if statement:
try:
await db.execute(text(statement))
logger.info(f"执行SQL: {statement[:100]}...")
except Exception as e:
logger.error(f"执行SQL失败: {statement[:100]}... 错误: {e}")
raise
await db.commit()
@classmethod
async def run_migrations(cls):
"""
运行所有待执行的迁移
"""
async with AsyncSessionLocal() as db:
try:
# 1. 初始化迁移表
await cls.init_migration_table(db)
# 2. 获取已执行的迁移
executed_versions = await cls.get_executed_migrations(db)
logger.info(f"已执行的迁移: {executed_versions}")
# 3. 获取所有迁移文件
migration_files = cls.get_migration_files()
if not migration_files:
logger.info("没有待执行的迁移文件")
return
# 4. 执行未执行的迁移
for migration in migration_files:
version = migration['version']
filename = migration['filename']
filepath = migration['filepath']
if version in executed_versions:
logger.info(f"迁移 {version} 已执行,跳过")
continue
logger.info(f"开始执行迁移: {filename}")
try:
# 执行迁移文件
await cls.execute_migration_file(db, filepath)
# 记录迁移
await cls.record_migration(db, version, filename)
logger.info(f"迁移 {version} 执行成功")
except Exception as e:
logger.error(f"迁移 {version} 执行失败: {e}")
raise
logger.info("数据库迁移完成")
except Exception as e:
await db.rollback()
logger.error(f"数据库迁移失败: {e}")
raise
async def run_database_migrations():
"""
入口函数运行数据库迁移
"""
logger.info("=" * 50)
logger.info("开始数据库迁移检查...")
logger.info("=" * 50)
try:
await DatabaseMigration.run_migrations()
logger.info("数据库迁移检查完成")
except Exception as e:
logger.error(f"数据库迁移检查失败: {e}")
# 不抛出异常,让应用继续启动(可以根据需求调整)
# raise
# 兼容旧版本的同步调用方式
def run_migrations_sync():
"""
同步方式运行迁移用于非异步上下文
"""
try:
asyncio.run(run_database_migrations())
except Exception as e:
logger.error(f"同步迁移执行失败: {e}")