200 lines
6.6 KiB
Python
200 lines
6.6 KiB
Python
"""
|
|
数据库迁移管理模块
|
|
启动时自动检查并执行数据库迁移
|
|
"""
|
|
import os
|
|
import re
|
|
import asyncio
|
|
from datetime import datetime
|
|
from sqlalchemy import text
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from config.database import async_engine, AsyncSessionLocal
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# 迁移文件目录
|
|
MIGRATION_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'sql', 'migration')
|
|
|
|
|
|
class DatabaseMigration:
|
|
"""
|
|
数据库迁移管理类
|
|
"""
|
|
|
|
@classmethod
|
|
async def init_migration_table(cls, db: AsyncSession):
|
|
"""
|
|
初始化迁移记录表
|
|
"""
|
|
create_table_sql = """
|
|
CREATE TABLE IF NOT EXISTS db_migration (
|
|
id INT AUTO_INCREMENT PRIMARY KEY,
|
|
version VARCHAR(50) NOT NULL COMMENT '迁移版本号',
|
|
filename VARCHAR(255) NOT NULL COMMENT '迁移文件名',
|
|
executed_at DATETIME NOT NULL COMMENT '执行时间',
|
|
checksum VARCHAR(64) COMMENT '文件校验和',
|
|
UNIQUE KEY uk_version (version)
|
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='数据库迁移记录表';
|
|
"""
|
|
await db.execute(text(create_table_sql))
|
|
await db.commit()
|
|
logger.info("迁移记录表初始化完成")
|
|
|
|
@classmethod
|
|
async def get_executed_migrations(cls, db: AsyncSession) -> list:
|
|
"""
|
|
获取已执行的迁移记录
|
|
"""
|
|
result = await db.execute(
|
|
text("SELECT version FROM db_migration ORDER BY version")
|
|
)
|
|
return [row[0] for row in result.fetchall()]
|
|
|
|
@classmethod
|
|
async def record_migration(cls, db: AsyncSession, version: str, filename: str, checksum: str = None):
|
|
"""
|
|
记录迁移执行
|
|
"""
|
|
await db.execute(
|
|
text("""
|
|
INSERT INTO db_migration (version, filename, executed_at, checksum)
|
|
VALUES (:version, :filename, :executed_at, :checksum)
|
|
"""),
|
|
{
|
|
'version': version,
|
|
'filename': filename,
|
|
'executed_at': datetime.now(),
|
|
'checksum': checksum
|
|
}
|
|
)
|
|
await db.commit()
|
|
|
|
@classmethod
|
|
def get_migration_files(cls) -> list:
|
|
"""
|
|
获取所有迁移文件,按版本号排序
|
|
"""
|
|
if not os.path.exists(MIGRATION_DIR):
|
|
logger.warning(f"迁移目录不存在: {MIGRATION_DIR}")
|
|
return []
|
|
|
|
migration_files = []
|
|
pattern = re.compile(r'^(\d{8})_.*\.sql$')
|
|
|
|
for filename in os.listdir(MIGRATION_DIR):
|
|
match = pattern.match(filename)
|
|
if match:
|
|
version = match.group(1)
|
|
migration_files.append({
|
|
'version': version,
|
|
'filename': filename,
|
|
'filepath': os.path.join(MIGRATION_DIR, filename)
|
|
})
|
|
|
|
# 按版本号排序
|
|
migration_files.sort(key=lambda x: x['version'])
|
|
return migration_files
|
|
|
|
@classmethod
|
|
async def execute_migration_file(cls, db: AsyncSession, filepath: str):
|
|
"""
|
|
执行单个迁移文件
|
|
"""
|
|
with open(filepath, 'r', encoding='utf-8') as f:
|
|
sql_content = f.read()
|
|
|
|
# 分割多条SQL语句
|
|
statements = [stmt.strip() for stmt in sql_content.split(';') if stmt.strip()]
|
|
|
|
for statement in statements:
|
|
if statement:
|
|
try:
|
|
await db.execute(text(statement))
|
|
logger.info(f"执行SQL: {statement[:100]}...")
|
|
except Exception as e:
|
|
logger.error(f"执行SQL失败: {statement[:100]}... 错误: {e}")
|
|
raise
|
|
|
|
await db.commit()
|
|
|
|
@classmethod
|
|
async def run_migrations(cls):
|
|
"""
|
|
运行所有待执行的迁移
|
|
"""
|
|
async with AsyncSessionLocal() as db:
|
|
try:
|
|
# 1. 初始化迁移表
|
|
await cls.init_migration_table(db)
|
|
|
|
# 2. 获取已执行的迁移
|
|
executed_versions = await cls.get_executed_migrations(db)
|
|
logger.info(f"已执行的迁移: {executed_versions}")
|
|
|
|
# 3. 获取所有迁移文件
|
|
migration_files = cls.get_migration_files()
|
|
|
|
if not migration_files:
|
|
logger.info("没有待执行的迁移文件")
|
|
return
|
|
|
|
# 4. 执行未执行的迁移
|
|
for migration in migration_files:
|
|
version = migration['version']
|
|
filename = migration['filename']
|
|
filepath = migration['filepath']
|
|
|
|
if version in executed_versions:
|
|
logger.info(f"迁移 {version} 已执行,跳过")
|
|
continue
|
|
|
|
logger.info(f"开始执行迁移: {filename}")
|
|
|
|
try:
|
|
# 执行迁移文件
|
|
await cls.execute_migration_file(db, filepath)
|
|
|
|
# 记录迁移
|
|
await cls.record_migration(db, version, filename)
|
|
|
|
logger.info(f"迁移 {version} 执行成功")
|
|
except Exception as e:
|
|
logger.error(f"迁移 {version} 执行失败: {e}")
|
|
raise
|
|
|
|
logger.info("数据库迁移完成")
|
|
|
|
except Exception as e:
|
|
await db.rollback()
|
|
logger.error(f"数据库迁移失败: {e}")
|
|
raise
|
|
|
|
|
|
async def run_database_migrations():
|
|
"""
|
|
入口函数:运行数据库迁移
|
|
"""
|
|
logger.info("=" * 50)
|
|
logger.info("开始数据库迁移检查...")
|
|
logger.info("=" * 50)
|
|
|
|
try:
|
|
await DatabaseMigration.run_migrations()
|
|
logger.info("数据库迁移检查完成")
|
|
except Exception as e:
|
|
logger.error(f"数据库迁移检查失败: {e}")
|
|
# 不抛出异常,让应用继续启动(可以根据需求调整)
|
|
# raise
|
|
|
|
|
|
# 兼容旧版本的同步调用方式
|
|
def run_migrations_sync():
|
|
"""
|
|
同步方式运行迁移(用于非异步上下文)
|
|
"""
|
|
try:
|
|
asyncio.run(run_database_migrations())
|
|
except Exception as e:
|
|
logger.error(f"同步迁移执行失败: {e}")
|