393 lines
13 KiB
Python
393 lines
13 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
数据库迁移和完整性检查模块
|
|
|
|
在程序启动时自动检查数据库结构,如果缺少字段则自动添加
|
|
"""
|
|
import sqlite3
|
|
from pathlib import Path
|
|
from typing import List, Dict, Tuple
|
|
from logger import get_logger
|
|
|
|
logger = get_logger()
|
|
|
|
# 数据库版本号
|
|
CURRENT_DB_VERSION = 2 # 每次修改表结构时递增
|
|
|
|
# 数据库文件路径
|
|
DB_PATH = Path(__file__).parent / "experiments.db"
|
|
|
|
|
|
class DatabaseMigration:
|
|
"""数据库迁移管理器"""
|
|
|
|
def __init__(self, db_path: Path = DB_PATH):
|
|
self.db_path = db_path
|
|
self.conn = None
|
|
self.cursor = None
|
|
|
|
def connect(self):
|
|
"""连接数据库"""
|
|
self.conn = sqlite3.connect(str(self.db_path))
|
|
self.cursor = self.conn.cursor()
|
|
|
|
def close(self):
|
|
"""关闭数据库连接"""
|
|
if self.conn:
|
|
self.conn.close()
|
|
|
|
def get_db_version(self) -> int:
|
|
"""获取当前数据库版本"""
|
|
try:
|
|
# 检查版本表是否存在
|
|
self.cursor.execute(
|
|
"SELECT name FROM sqlite_master WHERE type='table' AND name='db_version'"
|
|
)
|
|
if not self.cursor.fetchone():
|
|
# 版本表不存在,创建它
|
|
self.cursor.execute("""
|
|
CREATE TABLE db_version (
|
|
version INTEGER PRIMARY KEY,
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
""")
|
|
self.cursor.execute("INSERT INTO db_version (version) VALUES (0)")
|
|
self.conn.commit()
|
|
return 0
|
|
|
|
# 读取版本号
|
|
self.cursor.execute("SELECT version FROM db_version ORDER BY version DESC LIMIT 1")
|
|
result = self.cursor.fetchone()
|
|
return result[0] if result else 0
|
|
|
|
except Exception as e:
|
|
logger.error(f"获取数据库版本失败: {e}", exc_info=True)
|
|
return 0
|
|
|
|
def set_db_version(self, version: int):
|
|
"""设置数据库版本"""
|
|
try:
|
|
self.cursor.execute(
|
|
"INSERT INTO db_version (version) VALUES (?)",
|
|
(version,)
|
|
)
|
|
self.conn.commit()
|
|
logger.info(f"数据库版本已更新到: {version}")
|
|
except Exception as e:
|
|
logger.error(f"设置数据库版本失败: {e}", exc_info=True)
|
|
|
|
def get_table_columns(self, table_name: str) -> List[str]:
|
|
"""获取表的所有列名"""
|
|
try:
|
|
self.cursor.execute(f"PRAGMA table_info({table_name})")
|
|
columns = [row[1] for row in self.cursor.fetchall()]
|
|
return columns
|
|
except Exception as e:
|
|
logger.error(f"获取表 {table_name} 的列信息失败: {e}", exc_info=True)
|
|
return []
|
|
|
|
def column_exists(self, table_name: str, column_name: str) -> bool:
|
|
"""检查列是否存在"""
|
|
columns = self.get_table_columns(table_name)
|
|
return column_name in columns
|
|
|
|
def add_column(self, table_name: str, column_name: str, column_type: str, default_value=None):
|
|
"""添加列到表"""
|
|
try:
|
|
if self.column_exists(table_name, column_name):
|
|
logger.debug(f"列 {table_name}.{column_name} 已存在,跳过")
|
|
return True
|
|
|
|
# 构建 ALTER TABLE 语句
|
|
sql = f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type}"
|
|
if default_value is not None:
|
|
if isinstance(default_value, str):
|
|
sql += f" DEFAULT '{default_value}'"
|
|
else:
|
|
sql += f" DEFAULT {default_value}"
|
|
|
|
self.cursor.execute(sql)
|
|
self.conn.commit()
|
|
logger.info(f"✅ 已添加列: {table_name}.{column_name} ({column_type})")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"添加列 {table_name}.{column_name} 失败: {e}", exc_info=True)
|
|
return False
|
|
|
|
def table_exists(self, table_name: str) -> bool:
|
|
"""检查表是否存在"""
|
|
try:
|
|
self.cursor.execute(
|
|
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
|
|
(table_name,)
|
|
)
|
|
return self.cursor.fetchone() is not None
|
|
except Exception as e:
|
|
logger.error(f"检查表 {table_name} 是否存在失败: {e}", exc_info=True)
|
|
return False
|
|
|
|
def create_table(self, table_name: str, schema: str):
|
|
"""创建表"""
|
|
try:
|
|
if self.table_exists(table_name):
|
|
logger.debug(f"表 {table_name} 已存在,跳过创建")
|
|
return True
|
|
|
|
self.cursor.execute(schema)
|
|
self.conn.commit()
|
|
logger.info(f"✅ 已创建表: {table_name}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"创建表 {table_name} 失败: {e}", exc_info=True)
|
|
return False
|
|
|
|
def migrate_to_version_1(self):
|
|
"""迁移到版本 1: 添加 script_data 字段"""
|
|
logger.info("开始迁移到版本 1...")
|
|
|
|
# 添加 script_data 字段
|
|
success = self.add_column(
|
|
table_name="experiments",
|
|
column_name="script_data",
|
|
column_type="TEXT",
|
|
default_value=None
|
|
)
|
|
|
|
if success:
|
|
logger.info("✅ 版本 1 迁移完成")
|
|
else:
|
|
logger.warning("⚠️ 版本 1 迁移部分失败")
|
|
|
|
return success
|
|
|
|
def migrate_to_version_2(self):
|
|
"""迁移到版本 2: 确保所有必要字段存在"""
|
|
logger.info("开始迁移到版本 2...")
|
|
|
|
# 定义 experiments 表应该有的所有字段
|
|
required_columns = [
|
|
("id", "INTEGER PRIMARY KEY AUTOINCREMENT"),
|
|
("work_order_no", "TEXT NOT NULL"),
|
|
("config_json", "TEXT NOT NULL"),
|
|
("start_ts", "TEXT"),
|
|
("end_ts", "TEXT"),
|
|
("is_paused", "INTEGER DEFAULT 0"),
|
|
("created_at", "TIMESTAMP DEFAULT CURRENT_TIMESTAMP"),
|
|
("script_data", "TEXT"),
|
|
]
|
|
|
|
success_count = 0
|
|
for column_name, column_type in required_columns:
|
|
# 跳过主键和自增字段(无法通过 ALTER TABLE 添加)
|
|
if "PRIMARY KEY" in column_type or "AUTOINCREMENT" in column_type:
|
|
continue
|
|
|
|
# 提取基本类型(去掉 DEFAULT 等)
|
|
base_type = column_type.split()[0]
|
|
|
|
# 提取默认值
|
|
default_value = None
|
|
if "DEFAULT" in column_type:
|
|
parts = column_type.split("DEFAULT")
|
|
if len(parts) > 1:
|
|
default_str = parts[1].strip()
|
|
if default_str == "CURRENT_TIMESTAMP":
|
|
default_value = None # SQLite 会自动处理
|
|
elif default_str.isdigit():
|
|
default_value = int(default_str)
|
|
else:
|
|
default_value = default_str.strip("'\"")
|
|
|
|
if self.add_column("experiments", column_name, base_type, default_value):
|
|
success_count += 1
|
|
|
|
logger.info(f"✅ 版本 2 迁移完成 (检查了 {len(required_columns)} 个字段)")
|
|
return True
|
|
|
|
def run_migrations(self):
|
|
"""运行所有必要的迁移"""
|
|
try:
|
|
self.connect()
|
|
|
|
# 获取当前版本
|
|
current_version = self.get_db_version()
|
|
logger.info(f"当前数据库版本: {current_version}")
|
|
logger.info(f"目标数据库版本: {CURRENT_DB_VERSION}")
|
|
|
|
if current_version >= CURRENT_DB_VERSION:
|
|
logger.info("✅ 数据库已是最新版本,无需迁移")
|
|
return True
|
|
|
|
# 执行迁移
|
|
migrations = [
|
|
(1, self.migrate_to_version_1),
|
|
(2, self.migrate_to_version_2),
|
|
]
|
|
|
|
for version, migration_func in migrations:
|
|
if current_version < version:
|
|
logger.info(f"执行迁移: 版本 {version}")
|
|
try:
|
|
if migration_func():
|
|
self.set_db_version(version)
|
|
current_version = version
|
|
else:
|
|
logger.error(f"迁移到版本 {version} 失败")
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"迁移到版本 {version} 时发生异常: {e}", exc_info=True)
|
|
return False
|
|
|
|
logger.info(f"✅ 数据库迁移完成,当前版本: {current_version}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"运行数据库迁移失败: {e}", exc_info=True)
|
|
return False
|
|
|
|
finally:
|
|
self.close()
|
|
|
|
def check_and_repair(self):
|
|
"""检查并修复数据库完整性"""
|
|
try:
|
|
self.connect()
|
|
|
|
logger.info("开始检查数据库完整性...")
|
|
|
|
# 1. 检查 experiments 表是否存在
|
|
if not self.table_exists("experiments"):
|
|
logger.warning("experiments 表不存在,需要创建")
|
|
# 这里应该创建完整的表结构
|
|
# 但通常这种情况不应该发生,因为程序首次运行时会创建表
|
|
return False
|
|
|
|
# 2. 检查所有必要的列
|
|
required_columns = {
|
|
"id": "INTEGER",
|
|
"work_order_no": "TEXT",
|
|
"config_json": "TEXT",
|
|
"start_ts": "TEXT",
|
|
"end_ts": "TEXT",
|
|
"is_paused": "INTEGER",
|
|
"created_at": "TIMESTAMP",
|
|
"script_data": "TEXT",
|
|
}
|
|
|
|
existing_columns = self.get_table_columns("experiments")
|
|
missing_columns = []
|
|
|
|
for col_name, col_type in required_columns.items():
|
|
if col_name not in existing_columns:
|
|
missing_columns.append((col_name, col_type))
|
|
|
|
if missing_columns:
|
|
logger.warning(f"发现缺失的列: {[col[0] for col in missing_columns]}")
|
|
for col_name, col_type in missing_columns:
|
|
self.add_column("experiments", col_name, col_type)
|
|
else:
|
|
logger.info("✅ 所有必要的列都存在")
|
|
|
|
# 3. 检查索引(可选)
|
|
# TODO: 如果需要,可以在这里检查和创建索引
|
|
|
|
logger.info("✅ 数据库完整性检查完成")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"检查数据库完整性失败: {e}", exc_info=True)
|
|
return False
|
|
|
|
finally:
|
|
self.close()
|
|
|
|
|
|
def initialize_database():
|
|
"""
|
|
初始化数据库
|
|
在程序启动时调用,确保数据库结构是最新的
|
|
"""
|
|
logger.info("=" * 80)
|
|
logger.info("开始数据库初始化...")
|
|
logger.info("=" * 80)
|
|
|
|
migration = DatabaseMigration()
|
|
|
|
# 1. 运行迁移
|
|
if not migration.run_migrations():
|
|
logger.error("❌ 数据库迁移失败")
|
|
return False
|
|
|
|
# 2. 检查并修复
|
|
if not migration.check_and_repair():
|
|
logger.error("❌ 数据库完整性检查失败")
|
|
return False
|
|
|
|
logger.info("=" * 80)
|
|
logger.info("✅ 数据库初始化完成")
|
|
logger.info("=" * 80)
|
|
return True
|
|
|
|
|
|
def get_database_info():
|
|
"""获取数据库信息(用于调试)"""
|
|
try:
|
|
migration = DatabaseMigration()
|
|
migration.connect()
|
|
|
|
info = {
|
|
"version": migration.get_db_version(),
|
|
"tables": [],
|
|
}
|
|
|
|
# 获取所有表
|
|
migration.cursor.execute(
|
|
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
|
|
)
|
|
tables = [row[0] for row in migration.cursor.fetchall()]
|
|
|
|
for table in tables:
|
|
columns = migration.get_table_columns(table)
|
|
info["tables"].append({
|
|
"name": table,
|
|
"columns": columns,
|
|
"column_count": len(columns)
|
|
})
|
|
|
|
migration.close()
|
|
return info
|
|
|
|
except Exception as e:
|
|
logger.error(f"获取数据库信息失败: {e}", exc_info=True)
|
|
return None
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# 测试迁移
|
|
print("=" * 80)
|
|
print("数据库迁移测试")
|
|
print("=" * 80)
|
|
|
|
# 运行初始化
|
|
success = initialize_database()
|
|
|
|
if success:
|
|
print("\n" + "=" * 80)
|
|
print("数据库信息:")
|
|
print("=" * 80)
|
|
|
|
info = get_database_info()
|
|
if info:
|
|
print(f"\n数据库版本: {info['version']}")
|
|
print(f"表数量: {len(info['tables'])}")
|
|
|
|
for table in info['tables']:
|
|
print(f"\n表: {table['name']}")
|
|
print(f" 列数: {table['column_count']}")
|
|
print(f" 列名: {', '.join(table['columns'])}")
|
|
|
|
print("\n" + "=" * 80)
|