PCM_Report/db_migration.py

393 lines
13 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
数据库迁移和完整性检查模块
在程序启动时自动检查数据库结构,如果缺少字段则自动添加
"""
import sqlite3
from pathlib import Path
from typing import List, Dict, Tuple
from logger import get_logger
logger = get_logger()
# 数据库版本号
CURRENT_DB_VERSION = 2 # 每次修改表结构时递增
# 数据库文件路径
DB_PATH = Path(__file__).parent / "experiments.db"
class DatabaseMigration:
"""数据库迁移管理器"""
def __init__(self, db_path: Path = DB_PATH):
self.db_path = db_path
self.conn = None
self.cursor = None
def connect(self):
"""连接数据库"""
self.conn = sqlite3.connect(str(self.db_path))
self.cursor = self.conn.cursor()
def close(self):
"""关闭数据库连接"""
if self.conn:
self.conn.close()
def get_db_version(self) -> int:
"""获取当前数据库版本"""
try:
# 检查版本表是否存在
self.cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='db_version'"
)
if not self.cursor.fetchone():
# 版本表不存在,创建它
self.cursor.execute("""
CREATE TABLE db_version (
version INTEGER PRIMARY KEY,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
self.cursor.execute("INSERT INTO db_version (version) VALUES (0)")
self.conn.commit()
return 0
# 读取版本号
self.cursor.execute("SELECT version FROM db_version ORDER BY version DESC LIMIT 1")
result = self.cursor.fetchone()
return result[0] if result else 0
except Exception as e:
logger.error(f"获取数据库版本失败: {e}", exc_info=True)
return 0
def set_db_version(self, version: int):
"""设置数据库版本"""
try:
self.cursor.execute(
"INSERT INTO db_version (version) VALUES (?)",
(version,)
)
self.conn.commit()
logger.info(f"数据库版本已更新到: {version}")
except Exception as e:
logger.error(f"设置数据库版本失败: {e}", exc_info=True)
def get_table_columns(self, table_name: str) -> List[str]:
"""获取表的所有列名"""
try:
self.cursor.execute(f"PRAGMA table_info({table_name})")
columns = [row[1] for row in self.cursor.fetchall()]
return columns
except Exception as e:
logger.error(f"获取表 {table_name} 的列信息失败: {e}", exc_info=True)
return []
def column_exists(self, table_name: str, column_name: str) -> bool:
"""检查列是否存在"""
columns = self.get_table_columns(table_name)
return column_name in columns
def add_column(self, table_name: str, column_name: str, column_type: str, default_value=None):
"""添加列到表"""
try:
if self.column_exists(table_name, column_name):
logger.debug(f"{table_name}.{column_name} 已存在,跳过")
return True
# 构建 ALTER TABLE 语句
sql = f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type}"
if default_value is not None:
if isinstance(default_value, str):
sql += f" DEFAULT '{default_value}'"
else:
sql += f" DEFAULT {default_value}"
self.cursor.execute(sql)
self.conn.commit()
logger.info(f"✅ 已添加列: {table_name}.{column_name} ({column_type})")
return True
except Exception as e:
logger.error(f"添加列 {table_name}.{column_name} 失败: {e}", exc_info=True)
return False
def table_exists(self, table_name: str) -> bool:
"""检查表是否存在"""
try:
self.cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
(table_name,)
)
return self.cursor.fetchone() is not None
except Exception as e:
logger.error(f"检查表 {table_name} 是否存在失败: {e}", exc_info=True)
return False
def create_table(self, table_name: str, schema: str):
"""创建表"""
try:
if self.table_exists(table_name):
logger.debug(f"{table_name} 已存在,跳过创建")
return True
self.cursor.execute(schema)
self.conn.commit()
logger.info(f"✅ 已创建表: {table_name}")
return True
except Exception as e:
logger.error(f"创建表 {table_name} 失败: {e}", exc_info=True)
return False
def migrate_to_version_1(self):
"""迁移到版本 1: 添加 script_data 字段"""
logger.info("开始迁移到版本 1...")
# 添加 script_data 字段
success = self.add_column(
table_name="experiments",
column_name="script_data",
column_type="TEXT",
default_value=None
)
if success:
logger.info("✅ 版本 1 迁移完成")
else:
logger.warning("⚠️ 版本 1 迁移部分失败")
return success
def migrate_to_version_2(self):
"""迁移到版本 2: 确保所有必要字段存在"""
logger.info("开始迁移到版本 2...")
# 定义 experiments 表应该有的所有字段
required_columns = [
("id", "INTEGER PRIMARY KEY AUTOINCREMENT"),
("work_order_no", "TEXT NOT NULL"),
("config_json", "TEXT NOT NULL"),
("start_ts", "TEXT"),
("end_ts", "TEXT"),
("is_paused", "INTEGER DEFAULT 0"),
("created_at", "TIMESTAMP DEFAULT CURRENT_TIMESTAMP"),
("script_data", "TEXT"),
]
success_count = 0
for column_name, column_type in required_columns:
# 跳过主键和自增字段(无法通过 ALTER TABLE 添加)
if "PRIMARY KEY" in column_type or "AUTOINCREMENT" in column_type:
continue
# 提取基本类型(去掉 DEFAULT 等)
base_type = column_type.split()[0]
# 提取默认值
default_value = None
if "DEFAULT" in column_type:
parts = column_type.split("DEFAULT")
if len(parts) > 1:
default_str = parts[1].strip()
if default_str == "CURRENT_TIMESTAMP":
default_value = None # SQLite 会自动处理
elif default_str.isdigit():
default_value = int(default_str)
else:
default_value = default_str.strip("'\"")
if self.add_column("experiments", column_name, base_type, default_value):
success_count += 1
logger.info(f"✅ 版本 2 迁移完成 (检查了 {len(required_columns)} 个字段)")
return True
def run_migrations(self):
"""运行所有必要的迁移"""
try:
self.connect()
# 获取当前版本
current_version = self.get_db_version()
logger.info(f"当前数据库版本: {current_version}")
logger.info(f"目标数据库版本: {CURRENT_DB_VERSION}")
if current_version >= CURRENT_DB_VERSION:
logger.info("✅ 数据库已是最新版本,无需迁移")
return True
# 执行迁移
migrations = [
(1, self.migrate_to_version_1),
(2, self.migrate_to_version_2),
]
for version, migration_func in migrations:
if current_version < version:
logger.info(f"执行迁移: 版本 {version}")
try:
if migration_func():
self.set_db_version(version)
current_version = version
else:
logger.error(f"迁移到版本 {version} 失败")
return False
except Exception as e:
logger.error(f"迁移到版本 {version} 时发生异常: {e}", exc_info=True)
return False
logger.info(f"✅ 数据库迁移完成,当前版本: {current_version}")
return True
except Exception as e:
logger.error(f"运行数据库迁移失败: {e}", exc_info=True)
return False
finally:
self.close()
def check_and_repair(self):
"""检查并修复数据库完整性"""
try:
self.connect()
logger.info("开始检查数据库完整性...")
# 1. 检查 experiments 表是否存在
if not self.table_exists("experiments"):
logger.warning("experiments 表不存在,需要创建")
# 这里应该创建完整的表结构
# 但通常这种情况不应该发生,因为程序首次运行时会创建表
return False
# 2. 检查所有必要的列
required_columns = {
"id": "INTEGER",
"work_order_no": "TEXT",
"config_json": "TEXT",
"start_ts": "TEXT",
"end_ts": "TEXT",
"is_paused": "INTEGER",
"created_at": "TIMESTAMP",
"script_data": "TEXT",
}
existing_columns = self.get_table_columns("experiments")
missing_columns = []
for col_name, col_type in required_columns.items():
if col_name not in existing_columns:
missing_columns.append((col_name, col_type))
if missing_columns:
logger.warning(f"发现缺失的列: {[col[0] for col in missing_columns]}")
for col_name, col_type in missing_columns:
self.add_column("experiments", col_name, col_type)
else:
logger.info("✅ 所有必要的列都存在")
# 3. 检查索引(可选)
# TODO: 如果需要,可以在这里检查和创建索引
logger.info("✅ 数据库完整性检查完成")
return True
except Exception as e:
logger.error(f"检查数据库完整性失败: {e}", exc_info=True)
return False
finally:
self.close()
def initialize_database():
"""
初始化数据库
在程序启动时调用,确保数据库结构是最新的
"""
logger.info("=" * 80)
logger.info("开始数据库初始化...")
logger.info("=" * 80)
migration = DatabaseMigration()
# 1. 运行迁移
if not migration.run_migrations():
logger.error("❌ 数据库迁移失败")
return False
# 2. 检查并修复
if not migration.check_and_repair():
logger.error("❌ 数据库完整性检查失败")
return False
logger.info("=" * 80)
logger.info("✅ 数据库初始化完成")
logger.info("=" * 80)
return True
def get_database_info():
"""获取数据库信息(用于调试)"""
try:
migration = DatabaseMigration()
migration.connect()
info = {
"version": migration.get_db_version(),
"tables": [],
}
# 获取所有表
migration.cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
)
tables = [row[0] for row in migration.cursor.fetchall()]
for table in tables:
columns = migration.get_table_columns(table)
info["tables"].append({
"name": table,
"columns": columns,
"column_count": len(columns)
})
migration.close()
return info
except Exception as e:
logger.error(f"获取数据库信息失败: {e}", exc_info=True)
return None
if __name__ == "__main__":
# 测试迁移
print("=" * 80)
print("数据库迁移测试")
print("=" * 80)
# 运行初始化
success = initialize_database()
if success:
print("\n" + "=" * 80)
print("数据库信息:")
print("=" * 80)
info = get_database_info()
if info:
print(f"\n数据库版本: {info['version']}")
print(f"表数量: {len(info['tables'])}")
for table in info['tables']:
print(f"\n表: {table['name']}")
print(f" 列数: {table['column_count']}")
print(f" 列名: {', '.join(table['columns'])}")
print("\n" + "=" * 80)