ETest-Vue-FastAPI/ruoyi-fastapi-backend/config/migration.py

200 lines
6.6 KiB
Python

"""
数据库迁移管理模块
启动时自动检查并执行数据库迁移
"""
import os
import re
import asyncio
from datetime import datetime
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from config.database import async_engine, AsyncSessionLocal
import logging
logger = logging.getLogger(__name__)
# 迁移文件目录
MIGRATION_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'sql', 'migration')
class DatabaseMigration:
"""
数据库迁移管理类
"""
@classmethod
async def init_migration_table(cls, db: AsyncSession):
"""
初始化迁移记录表
"""
create_table_sql = """
CREATE TABLE IF NOT EXISTS db_migration (
id INT AUTO_INCREMENT PRIMARY KEY,
version VARCHAR(50) NOT NULL COMMENT '迁移版本号',
filename VARCHAR(255) NOT NULL COMMENT '迁移文件名',
executed_at DATETIME NOT NULL COMMENT '执行时间',
checksum VARCHAR(64) COMMENT '文件校验和',
UNIQUE KEY uk_version (version)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='数据库迁移记录表';
"""
await db.execute(text(create_table_sql))
await db.commit()
logger.info("迁移记录表初始化完成")
@classmethod
async def get_executed_migrations(cls, db: AsyncSession) -> list:
"""
获取已执行的迁移记录
"""
result = await db.execute(
text("SELECT version FROM db_migration ORDER BY version")
)
return [row[0] for row in result.fetchall()]
@classmethod
async def record_migration(cls, db: AsyncSession, version: str, filename: str, checksum: str = None):
"""
记录迁移执行
"""
await db.execute(
text("""
INSERT INTO db_migration (version, filename, executed_at, checksum)
VALUES (:version, :filename, :executed_at, :checksum)
"""),
{
'version': version,
'filename': filename,
'executed_at': datetime.now(),
'checksum': checksum
}
)
await db.commit()
@classmethod
def get_migration_files(cls) -> list:
"""
获取所有迁移文件,按版本号排序
"""
if not os.path.exists(MIGRATION_DIR):
logger.warning(f"迁移目录不存在: {MIGRATION_DIR}")
return []
migration_files = []
pattern = re.compile(r'^(\d{8})_.*\.sql$')
for filename in os.listdir(MIGRATION_DIR):
match = pattern.match(filename)
if match:
version = match.group(1)
migration_files.append({
'version': version,
'filename': filename,
'filepath': os.path.join(MIGRATION_DIR, filename)
})
# 按版本号排序
migration_files.sort(key=lambda x: x['version'])
return migration_files
@classmethod
async def execute_migration_file(cls, db: AsyncSession, filepath: str):
"""
执行单个迁移文件
"""
with open(filepath, 'r', encoding='utf-8') as f:
sql_content = f.read()
# 分割多条SQL语句
statements = [stmt.strip() for stmt in sql_content.split(';') if stmt.strip()]
for statement in statements:
if statement:
try:
await db.execute(text(statement))
logger.info(f"执行SQL: {statement[:100]}...")
except Exception as e:
logger.error(f"执行SQL失败: {statement[:100]}... 错误: {e}")
raise
await db.commit()
@classmethod
async def run_migrations(cls):
"""
运行所有待执行的迁移
"""
async with AsyncSessionLocal() as db:
try:
# 1. 初始化迁移表
await cls.init_migration_table(db)
# 2. 获取已执行的迁移
executed_versions = await cls.get_executed_migrations(db)
logger.info(f"已执行的迁移: {executed_versions}")
# 3. 获取所有迁移文件
migration_files = cls.get_migration_files()
if not migration_files:
logger.info("没有待执行的迁移文件")
return
# 4. 执行未执行的迁移
for migration in migration_files:
version = migration['version']
filename = migration['filename']
filepath = migration['filepath']
if version in executed_versions:
logger.info(f"迁移 {version} 已执行,跳过")
continue
logger.info(f"开始执行迁移: {filename}")
try:
# 执行迁移文件
await cls.execute_migration_file(db, filepath)
# 记录迁移
await cls.record_migration(db, version, filename)
logger.info(f"迁移 {version} 执行成功")
except Exception as e:
logger.error(f"迁移 {version} 执行失败: {e}")
raise
logger.info("数据库迁移完成")
except Exception as e:
await db.rollback()
logger.error(f"数据库迁移失败: {e}")
raise
async def run_database_migrations():
"""
入口函数:运行数据库迁移
"""
logger.info("=" * 50)
logger.info("开始数据库迁移检查...")
logger.info("=" * 50)
try:
await DatabaseMigration.run_migrations()
logger.info("数据库迁移检查完成")
except Exception as e:
logger.error(f"数据库迁移检查失败: {e}")
# 不抛出异常,让应用继续启动(可以根据需求调整)
# raise
# 兼容旧版本的同步调用方式
def run_migrations_sync():
"""
同步方式运行迁移(用于非异步上下文)
"""
try:
asyncio.run(run_database_migrations())
except Exception as e:
logger.error(f"同步迁移执行失败: {e}")