2077 lines
80 KiB
Python
2077 lines
80 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
探测器信号模拟器 V4.2 - 优化版
|
||
增加打包方式选择:单独打包 vs 整体打包
|
||
"""
|
||
|
||
import yaml
|
||
import numpy as np
|
||
import pandas as pd
|
||
from pathlib import Path
|
||
from typing import Dict, List, Tuple, Any, Optional
|
||
import logging
|
||
import warnings
|
||
import struct
|
||
from datetime import datetime
|
||
|
||
warnings.filterwarnings('ignore')
|
||
|
||
# ==================== 中文显示设置 ====================
|
||
import matplotlib.pyplot as plt
|
||
import matplotlib
|
||
|
||
# 设置中文字体
|
||
try:
|
||
import matplotlib.font_manager as fm
|
||
font_list = [f.name for f in fm.fontManager.ttflist]
|
||
chinese_fonts = ['Microsoft YaHei', 'SimHei', 'SimSun', 'KaiTi']
|
||
|
||
available_font = None
|
||
for font_name in chinese_fonts:
|
||
for font in font_list:
|
||
if font_name.lower() in font.lower():
|
||
available_font = font_name
|
||
break
|
||
if available_font:
|
||
break
|
||
|
||
if available_font:
|
||
plt.rcParams['font.sans-serif'] = [available_font]
|
||
else:
|
||
plt.rcParams['font.sans-serif'] = ['SimHei']
|
||
|
||
plt.rcParams['axes.unicode_minus'] = False
|
||
|
||
except Exception as e:
|
||
print(f"字体设置警告: {e}")
|
||
|
||
# 设置日志
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ==================== 指令配置类(V4版本) ====================
|
||
class InstructionConfigV4:
|
||
"""指令配置类 V4"""
|
||
|
||
# 指令部分各字段定义
|
||
INSTRUCTION_FORMAT = {
|
||
'header': {
|
||
'bytes': 2,
|
||
'value': [0xAA, 0xAA],
|
||
'description': '包头'
|
||
},
|
||
'signal_type': {
|
||
'bytes': 2,
|
||
'value': [0x00, 0x07],
|
||
'description': '信号类型'
|
||
},
|
||
'ddr_start_addr': {
|
||
'bytes': 4,
|
||
'value': [0x00, 0x00, 0x00, 0x00],
|
||
'description': 'DDR起始地址'
|
||
},
|
||
'ddr_end_addr': {
|
||
'bytes': 4,
|
||
'value': None, # 根据数据长度计算
|
||
'description': 'DDR结束地址'
|
||
},
|
||
'sync_pulse1_count': {
|
||
'bytes': 4,
|
||
'value': [0x00, 0x00, 0x01, 0xDB], # 475
|
||
'description': '同步脉冲1个数'
|
||
},
|
||
'sync_pulse1_period': {
|
||
'bytes': 4,
|
||
'value': [0x00, 0x00, 0x3E, 0x80], # 16000
|
||
'description': '同步脉冲1周期'
|
||
},
|
||
'sync_pulse2_count': {
|
||
'bytes': 4,
|
||
'value': [0x00, 0x00, 0x00, 0x14], # 20
|
||
'description': '同步脉冲2个数'
|
||
},
|
||
'sync_pulse2_period': {
|
||
'bytes': 4,
|
||
'value': [0x00, 0x01, 0x86, 0xA0], # 100000
|
||
'description': '同步脉冲2周期'
|
||
},
|
||
'sync_pulse3_count': {
|
||
'bytes': 4,
|
||
'value': [0x00, 0x00, 0x00, 0x01], # 1
|
||
'description': '同步脉冲3个数'
|
||
},
|
||
'sync_pulse3_period': {
|
||
'bytes': 4,
|
||
'value': [0x00, 0x06, 0x1A, 0x80], # 400000
|
||
'description': '同步脉冲3周期'
|
||
},
|
||
'sync_pulse4_count': {
|
||
'bytes': 4,
|
||
'value': [0x00, 0x00, 0x00, 0x00], # 0
|
||
'description': '同步脉冲4个数'
|
||
},
|
||
'sync_pulse4_period': {
|
||
'bytes': 4,
|
||
'value': [0x00, 0x00, 0x00, 0x00], # 0
|
||
'description': '同步脉冲4周期'
|
||
},
|
||
'start_stop_control': {
|
||
'bytes': 2,
|
||
'value': [0x00, 0x00],
|
||
'description': '启停控制'
|
||
},
|
||
'crc16': {
|
||
'bytes': 2,
|
||
'value': None, # 根据前面数据计算
|
||
'description': 'Modbus CRC16校验和'
|
||
}
|
||
}
|
||
|
||
# 打包配置
|
||
PACKING_CONFIG = {
|
||
'signals_per_packet': 800, # 每包800个点火信号
|
||
'bytes_per_signal': 8, # 每个信号8字节
|
||
'crc_bytes': 2, # CRC16校验和2字节
|
||
'bytes_per_packet': 6402, # 800*8 + 2 = 6402字节
|
||
'group_size': 8, # 8个信号为一组
|
||
'padding_byte': 0x00 # 填充字节
|
||
}
|
||
|
||
@classmethod
|
||
def get_instruction_length(cls) -> int:
|
||
"""获取指令部分的总长度(字节)"""
|
||
total = 0
|
||
for field_info in cls.INSTRUCTION_FORMAT.values():
|
||
total += field_info['bytes']
|
||
return total
|
||
|
||
@classmethod
|
||
def calculate_ddr_end_addr(cls, total_data_bytes: int) -> List[int]:
|
||
"""
|
||
计算DDR结束地址
|
||
DDR结束地址+8必须是800的整数倍
|
||
|
||
参数:
|
||
total_data_bytes: 总数据字节数(不包括指令部分)
|
||
|
||
返回:
|
||
DDR结束地址的4字节列表
|
||
"""
|
||
# DDR结束地址 = 起始地址 + 总字节数
|
||
end_addr = total_data_bytes
|
||
|
||
# 调整结束地址,使得end_addr + 8是800的整数倍
|
||
remainder = (end_addr + 8) % 800
|
||
if remainder != 0:
|
||
end_addr += (800 - remainder)
|
||
|
||
# 转换为4字节列表(大端序)
|
||
return [(end_addr >> (8 * (3 - i))) & 0xFF for i in range(4)]
|
||
|
||
@classmethod
|
||
def calculate_crc16(cls, data: bytes) -> List[int]:
|
||
"""
|
||
计算Modbus CRC16校验和
|
||
使用标准Modbus CRC16算法
|
||
"""
|
||
crc = 0xFFFF
|
||
|
||
for byte in data:
|
||
crc ^= byte
|
||
for _ in range(8):
|
||
if crc & 0x0001:
|
||
crc = (crc >> 1) ^ 0xA001
|
||
else:
|
||
crc >>= 1
|
||
|
||
# 返回低字节在前
|
||
return [crc & 0xFF, (crc >> 8) & 0xFF]
|
||
|
||
@classmethod
|
||
def generate_instruction_bytes(cls, total_data_bytes: int) -> bytes:
|
||
"""
|
||
生成完整的指令部分字节流
|
||
|
||
参数:
|
||
total_data_bytes: 总数据字节数(不包括指令部分)
|
||
|
||
返回:
|
||
指令部分的字节流
|
||
"""
|
||
instruction_data = bytearray()
|
||
|
||
# 1. 包头
|
||
instruction_data.extend(bytes(cls.INSTRUCTION_FORMAT['header']['value']))
|
||
|
||
# 2. 信号类型
|
||
instruction_data.extend(bytes(cls.INSTRUCTION_FORMAT['signal_type']['value']))
|
||
|
||
# 3. DDR起始地址
|
||
instruction_data.extend(bytes(cls.INSTRUCTION_FORMAT['ddr_start_addr']['value']))
|
||
|
||
# 4. DDR结束地址(根据总数据大小计算)
|
||
ddr_end_addr = cls.calculate_ddr_end_addr(total_data_bytes)
|
||
instruction_data.extend(bytes(ddr_end_addr))
|
||
|
||
# 5. 同步脉冲1个数
|
||
instruction_data.extend(bytes(cls.INSTRUCTION_FORMAT['sync_pulse1_count']['value']))
|
||
|
||
# 6. 同步脉冲1周期
|
||
instruction_data.extend(bytes(cls.INSTRUCTION_FORMAT['sync_pulse1_period']['value']))
|
||
|
||
# 7. 同步脉冲2个数
|
||
instruction_data.extend(bytes(cls.INSTRUCTION_FORMAT['sync_pulse2_count']['value']))
|
||
|
||
# 8. 同步脉冲2周期
|
||
instruction_data.extend(bytes(cls.INSTRUCTION_FORMAT['sync_pulse2_period']['value']))
|
||
|
||
# 9. 同步脉冲3个数
|
||
instruction_data.extend(bytes(cls.INSTRUCTION_FORMAT['sync_pulse3_count']['value']))
|
||
|
||
# 10. 同步脉冲3周期
|
||
instruction_data.extend(bytes(cls.INSTRUCTION_FORMAT['sync_pulse3_period']['value']))
|
||
|
||
# 11. 同步脉冲4个数
|
||
instruction_data.extend(bytes(cls.INSTRUCTION_FORMAT['sync_pulse4_count']['value']))
|
||
|
||
# 12. 同步脉冲4周期
|
||
instruction_data.extend(bytes(cls.INSTRUCTION_FORMAT['sync_pulse4_period']['value']))
|
||
|
||
# 13. 启停控制
|
||
instruction_data.extend(bytes(cls.INSTRUCTION_FORMAT['start_stop_control']['value']))
|
||
|
||
# 14. 计算CRC16(不包括CRC字段本身)
|
||
crc_bytes = cls.calculate_crc16(instruction_data)
|
||
instruction_data.extend(bytes(crc_bytes))
|
||
|
||
return bytes(instruction_data)
|
||
|
||
@classmethod
|
||
def print_instruction_info(cls, total_data_bytes: int):
|
||
"""打印指令信息"""
|
||
print("=" * 60)
|
||
print("指令配置信息")
|
||
print("=" * 60)
|
||
|
||
# 生成指令字节流
|
||
instruction_bytes = cls.generate_instruction_bytes(total_data_bytes)
|
||
|
||
# 解析并打印每个字段
|
||
offset = 0
|
||
for field_name, field_info in cls.INSTRUCTION_FORMAT.items():
|
||
field_bytes = field_info['bytes']
|
||
field_data = instruction_bytes[offset:offset + field_bytes]
|
||
|
||
# 格式化显示
|
||
hex_str = ' '.join(f'0x{b:02X}' for b in field_data)
|
||
int_str = ' '.join(f'{b:3d}' for b in field_data)
|
||
|
||
if field_name == 'ddr_end_addr' and field_info['value'] is None:
|
||
# 计算DDR结束地址
|
||
end_addr_bytes = cls.calculate_ddr_end_addr(total_data_bytes)
|
||
end_addr = sum(end_addr_bytes[i] << (8 * (3 - i)) for i in range(4))
|
||
print(f"{field_info['description']:20s}: {hex_str} ({int_str}) -> 地址: 0x{end_addr:08X} ({end_addr})")
|
||
elif field_name == 'crc16' and field_info['value'] is None:
|
||
crc_value = (field_data[1] << 8) | field_data[0]
|
||
print(f"{field_info['description']:20s}: {hex_str} ({int_str}) -> CRC16: 0x{crc_value:04X}")
|
||
else:
|
||
print(f"{field_info['description']:20s}: {hex_str} ({int_str})")
|
||
|
||
offset += field_bytes
|
||
|
||
# 打印内存信息
|
||
end_addr_bytes = cls.calculate_ddr_end_addr(total_data_bytes)
|
||
end_addr = sum(end_addr_bytes[i] << (8 * (3 - i)) for i in range(4))
|
||
|
||
print("\n内存信息:")
|
||
print(f" 总数据字节: {total_data_bytes} 字节")
|
||
print(f" DDR起始地址: 0x{0:08X}")
|
||
print(f" DDR结束地址: 0x{end_addr:08X} ({end_addr})")
|
||
print(f" 验证 (end_addr + 8) % 800: {(end_addr + 8) % 800} (应为0)")
|
||
print("=" * 60)
|
||
|
||
# ==================== 工具函数 ====================
|
||
# class BitFieldEncoder:
|
||
# """64位整数字段编码器"""
|
||
|
||
# def encode(self, timestamp: int, energies: List[float], event_end: bool = False) -> int:
|
||
# """编码探测器信号为64位整数"""
|
||
# if len(energies) != 3:
|
||
# energies = energies + [0] * (3 - len(energies))
|
||
|
||
# # 1. 事件结束标志
|
||
# value = int(event_end) << 63
|
||
|
||
# # 2. 时间戳 (18 bits) - 检查溢出
|
||
# if timestamp > ((1 << 18) - 1):
|
||
# raise ValueError(f"时间戳溢出: {timestamp} > {((1 << 18) - 1)}")
|
||
# timestamp_int = int(timestamp) & ((1 << 18) - 1)
|
||
# value |= timestamp_int << 45
|
||
|
||
# # 3. 探测器掩码和能量 - 检查能量溢出
|
||
# detector_mask = 0
|
||
# for i in range(3):
|
||
# if energies[i] > 0:
|
||
# # 检查能量溢出
|
||
# if energies[i] > ((1 << 14) - 1):
|
||
# raise ValueError(f"能量溢出: 探测器{i+1}能量{energies[i]} > {((1 << 14) - 1)}")
|
||
|
||
# # 设置探测器掩码位
|
||
# detector_mask |= 1 << (2 - i)
|
||
|
||
# # 编码能量 (14 bits)
|
||
# energy_int = int(round(energies[i])) & ((1 << 14) - 1)
|
||
# if i == 0: # det1
|
||
# value |= energy_int << 28
|
||
# elif i == 1: # det2
|
||
# value |= energy_int << 14
|
||
# else: # det3
|
||
# value |= energy_int
|
||
|
||
# # 设置探测器掩码
|
||
# value |= detector_mask << 42
|
||
|
||
# return value
|
||
|
||
# def decode(self, encoded_value: int) -> dict:
|
||
# """解码64位整数"""
|
||
# result = {}
|
||
# result['event_end'] = bool((encoded_value >> 63) & 1)
|
||
# result['timestamp'] = (encoded_value >> 45) & ((1 << 18) - 1)
|
||
# result['detector_mask'] = (encoded_value >> 42) & 0b111
|
||
|
||
# energies = []
|
||
# mask_bits = [(result['detector_mask'] >> 2) & 1,
|
||
# (result['detector_mask'] >> 1) & 1,
|
||
# result['detector_mask'] & 1]
|
||
|
||
# energy1 = (encoded_value >> 28) & ((1 << 14) - 1)
|
||
# energies.append(energy1 if mask_bits[0] else 0.0)
|
||
|
||
# energy2 = (encoded_value >> 14) & ((1 << 14) - 1)
|
||
# energies.append(energy2 if mask_bits[1] else 0.0)
|
||
|
||
# energy3 = encoded_value & ((1 << 14) - 1)
|
||
# energies.append(energy3 if mask_bits[2] else 0.0)
|
||
|
||
# result['energies'] = energies
|
||
# return result
|
||
|
||
# def to_binary_string(self, encoded_value: int) -> str:
|
||
# return format(encoded_value, '064b')
|
||
|
||
# def to_hex_string(self, encoded_value: int) -> str:
|
||
# return format(encoded_value, '016x')
|
||
|
||
# def set_event_end_flag(self, encoded_value: int) -> int:
|
||
# """设置事件结束标志位"""
|
||
# return encoded_value | (1 << 63)
|
||
|
||
# def clear_event_end_flag(self, encoded_value: int) -> int:
|
||
# """清除事件结束标志位"""
|
||
# return encoded_value & ~(1 << 63)
|
||
|
||
|
||
class BitFieldEncoder:
|
||
"""64位整数字段编码器(修改版:0表示点火,1表示未点火)"""
|
||
|
||
def encode(self, timestamp: int, energies: List[float], event_end: bool = False) -> int:
|
||
"""
|
||
编码探测器信号为64位整数
|
||
|
||
修改:探测器掩码位 - 0表示点火,1表示未点火
|
||
|
||
位域分配:
|
||
bit63: 事件结束标志 (1: 结束, 0: 继续)
|
||
bit62-45: 18位时间戳 (us)
|
||
bit44-42: 3位探测器掩码 (0: 点火, 1: 未点火)
|
||
bit41-28: 14位探测器1能量 (keV)
|
||
bit27-14: 14位探测器2能量 (keV)
|
||
bit13-0: 14位探测器3能量 (keV)
|
||
"""
|
||
if len(energies) != 3:
|
||
energies = energies + [0] * (3 - len(energies))
|
||
|
||
# 1. 事件结束标志
|
||
value = int(event_end) << 63
|
||
|
||
# 2. 时间戳 (18 bits) - 检查溢出
|
||
if timestamp > ((1 << 18) - 1):
|
||
raise ValueError(f"时间戳溢出: {timestamp} > {((1 << 18) - 1)}")
|
||
timestamp_int = int(timestamp) & ((1 << 18) - 1)
|
||
value |= timestamp_int << 45
|
||
|
||
# 3. 探测器掩码和能量
|
||
# 修改:默认所有探测器都未点火(掩码位为1)
|
||
detector_mask = 0b111 # 二进制111表示三个探测器都未点火
|
||
|
||
for i in range(3):
|
||
if energies[i] > 0:
|
||
# 修改:该探测器点火,对应掩码位清零(0表示点火)
|
||
# 掩码位顺序:bit44=det1, bit43=det2, bit42=det3
|
||
detector_mask &= ~(1 << (2 - i)) # 清除对应位
|
||
|
||
# 检查能量溢出
|
||
if energies[i] > ((1 << 14) - 1):
|
||
raise ValueError(f"能量溢出: 探测器{i+1}能量{energies[i]} > {((1 << 14) - 1)}")
|
||
|
||
# 编码能量 (14 bits)
|
||
energy_int = int(round(energies[i])) & ((1 << 14) - 1)
|
||
if i == 0: # det1
|
||
value |= energy_int << 28
|
||
elif i == 1: # det2
|
||
value |= energy_int << 14
|
||
else: # det3
|
||
value |= energy_int
|
||
|
||
# 设置探测器掩码(修改:1表示未点火,0表示点火)
|
||
value |= detector_mask << 42
|
||
|
||
return value
|
||
|
||
def decode(self, encoded_value: int) -> dict:
|
||
"""
|
||
解码64位整数
|
||
|
||
修改:探测器掩码位 - 0表示点火,1表示未点火
|
||
"""
|
||
result = {}
|
||
|
||
# 1. 事件结束标志
|
||
result['event_end'] = bool((encoded_value >> 63) & 1)
|
||
|
||
# 2. 时间戳
|
||
result['timestamp'] = (encoded_value >> 45) & ((1 << 18) - 1)
|
||
|
||
# 3. 探测器掩码
|
||
detector_mask = (encoded_value >> 42) & 0b111
|
||
result['detector_mask'] = detector_mask
|
||
|
||
# 修改:掩码位解释 - 0表示点火,1表示未点火
|
||
# 所以要反转逻辑:如果掩码位为0,表示该探测器点火
|
||
mask_bits = [((detector_mask >> 2) & 1) == 0, # det1: bit44=0表示点火
|
||
((detector_mask >> 1) & 1) == 0, # det2: bit43=0表示点火
|
||
(detector_mask & 1) == 0] # det3: bit42=0表示点火
|
||
|
||
# 4. 各探测器能量
|
||
energies = []
|
||
|
||
# 探测器1能量
|
||
energy1 = (encoded_value >> 28) & ((1 << 14) - 1)
|
||
energies.append(energy1 if mask_bits[0] else 0.0) # mask_bits[0]=True表示点火
|
||
|
||
# 探测器2能量
|
||
energy2 = (encoded_value >> 14) & ((1 << 14) - 1)
|
||
energies.append(energy2 if mask_bits[1] else 0.0) # mask_bits[1]=True表示点火
|
||
|
||
# 探测器3能量
|
||
energy3 = encoded_value & ((1 << 14) - 1)
|
||
energies.append(energy3 if mask_bits[2] else 0.0) # mask_bits[2]=True表示点火
|
||
|
||
result['energies'] = energies
|
||
|
||
# 添加额外的解析信息
|
||
result['detector_fired'] = mask_bits # True表示点火,False表示未点火
|
||
result['detector_mask_binary'] = format(detector_mask, '03b')
|
||
|
||
return result
|
||
|
||
def to_binary_string(self, encoded_value: int) -> str:
|
||
"""转换为64位二进制字符串"""
|
||
return format(encoded_value, '064b')
|
||
|
||
def to_hex_string(self, encoded_value: int) -> str:
|
||
"""转换为16位十六进制字符串"""
|
||
return format(encoded_value, '016x')
|
||
|
||
def set_event_end_flag(self, encoded_value: int) -> int:
|
||
"""设置事件结束标志位"""
|
||
return encoded_value | (1 << 63)
|
||
|
||
def clear_event_end_flag(self, encoded_value: int) -> int:
|
||
"""清除事件结束标志位"""
|
||
return encoded_value & ~(1 << 63)
|
||
|
||
def decode_detailed(self, encoded_value: int) -> dict:
|
||
"""
|
||
详细解码64位整数,包含更多信息
|
||
|
||
返回:
|
||
包含详细解析信息的字典
|
||
"""
|
||
basic = self.decode(encoded_value)
|
||
|
||
# 添加详细解析
|
||
detailed = {
|
||
'hex': f"0x{self.to_hex_string(encoded_value)}",
|
||
'binary': self.to_binary_string(encoded_value),
|
||
'event_end': basic['event_end'],
|
||
'timestamp_us': basic['timestamp'],
|
||
'detector_mask': basic['detector_mask'],
|
||
'detector_mask_binary': basic['detector_mask_binary'],
|
||
'detector_fired': basic['detector_fired'],
|
||
'energies_kev': basic['energies'],
|
||
'bit_fields': {
|
||
'event_end': (63, 63, basic['event_end']),
|
||
'timestamp': (62, 45, basic['timestamp']),
|
||
'detector_mask': (44, 42, basic['detector_mask']),
|
||
'detector1_energy': (41, 28, basic['energies'][0]),
|
||
'detector2_energy': (27, 14, basic['energies'][1]),
|
||
'detector3_energy': (13, 0, basic['energies'][2])
|
||
}
|
||
}
|
||
|
||
# 添加探测器状态说明
|
||
detector_status = []
|
||
for i in range(3):
|
||
if basic['detector_fired'][i]:
|
||
status = f"探测器{i+1}: 点火 ({basic['energies'][i]} keV)"
|
||
else:
|
||
status = f"探测器{i+1}: 未点火"
|
||
detector_status.append(status)
|
||
|
||
detailed['detector_status'] = detector_status
|
||
|
||
return detailed
|
||
|
||
def calculate_crc16(data: bytes) -> bytes:
|
||
"""计算Modbus CRC16校验和"""
|
||
crc = 0xFFFF
|
||
|
||
for byte in data:
|
||
crc ^= byte
|
||
for _ in range(8):
|
||
if crc & 0x0001:
|
||
crc = (crc >> 1) ^ 0xA001
|
||
else:
|
||
crc >>= 1
|
||
|
||
# 返回低字节在前的2字节
|
||
return struct.pack('<H', crc)
|
||
|
||
# ==================== 默认配置(增加打包模式参数) ====================
|
||
DEFAULT_CONFIG_YAML = """
|
||
# 探测器信号模拟器配置 V4.3
|
||
# 增加事件生成模式选择
|
||
|
||
simulation:
|
||
num_events: 8 # 生成的事件数量
|
||
output_file: "detector_output" # 去掉扩展名,会根据格式自动添加
|
||
output_format: "all" # 可选: binary, text, debug, all
|
||
packing_mode: "separate" # 打包模式: separate(单独打包), combined(整体打包)
|
||
event_generation_mode: "pulse" # 事件生成模式: random(随机), fixed(固定), pulse(脉冲)
|
||
|
||
detectors:
|
||
detector1:
|
||
sample_space_size: 5000
|
||
allow_replacement: true
|
||
energy_distribution:
|
||
type: "normal"
|
||
mean: 1000.0
|
||
std: 200.0
|
||
timestamp_distribution:
|
||
type: "exponential"
|
||
scale: 50.0
|
||
|
||
detector2:
|
||
sample_space_size: 4000
|
||
allow_replacement: true
|
||
energy_distribution:
|
||
type: "normal"
|
||
mean: 800.0
|
||
std: 150.0
|
||
timestamp_distribution:
|
||
type: "uniform"
|
||
low: 0.0
|
||
high: 200.0
|
||
|
||
detector3:
|
||
sample_space_size: 6000
|
||
allow_replacement: true
|
||
energy_distribution:
|
||
type: "gamma"
|
||
shape: 2.0
|
||
scale: 500.0
|
||
timestamp_distribution:
|
||
type: "normal"
|
||
mean: 100.0
|
||
std: 30.0
|
||
|
||
sampling:
|
||
min_signals_per_detector: 1 # 每个探测器最少信号数(随机模式使用)
|
||
max_signals_per_detector: 10 # 每个探测器最多信号数(随机模式使用)
|
||
require_signal: true
|
||
|
||
# 固定事件模式配置
|
||
fixed_events:
|
||
num_signals_per_event: 5 # 每个事件的信号数量
|
||
energy_levels: # 各探测器能量水平 (keV)
|
||
detector1: [1000, 2000, 3000, 4000, 5000]
|
||
detector2: [800, 1600, 2400, 3200, 4000]
|
||
detector3: [500, 1000, 1500, 2000, 2500]
|
||
timestamps: [10, 50, 100, 200, 300] # 时间戳序列 (us)
|
||
repeat_pattern: true # 是否重复使用模式
|
||
|
||
# 脉冲信号模式配置
|
||
pulse_signals:
|
||
pulse_count: 100 # 脉冲数量
|
||
pulse_interval: 2 # 脉冲间隔 (us)
|
||
pulse_width: 1 # 脉冲宽度 (信号数量)
|
||
energy_levels: # 各探测器能量水平 (keV)
|
||
detector1: 1500
|
||
detector2: 1200
|
||
detector3: 900
|
||
jitter: 0 # 时间抖动 (us)
|
||
energy_noise: 0 # 能量噪声 (keV)
|
||
"""
|
||
|
||
# ==================== 数据类定义 ====================
|
||
from dataclasses import dataclass
|
||
|
||
@dataclass
|
||
class DistributionConfig:
|
||
"""分布配置数据类"""
|
||
type: str
|
||
parameters: Dict[str, float]
|
||
|
||
@dataclass
|
||
class DetectorConfig:
|
||
"""探测器配置数据类"""
|
||
name: str
|
||
sample_space_size: int
|
||
allow_replacement: bool # 抽样是否可重复
|
||
energy_dist: DistributionConfig
|
||
timestamp_dist: DistributionConfig
|
||
|
||
# ==================== 核心类定义(保持原有逻辑) ====================
|
||
class SampleSpaceGenerator:
|
||
"""样本空间生成器(能量单位:keV)"""
|
||
|
||
@staticmethod
|
||
def generate_from_distribution(dist_config: DistributionConfig, size: int) -> np.ndarray:
|
||
"""根据配置生成样本(能量单位:keV)"""
|
||
dist_type = dist_config.type
|
||
params = dist_config.parameters
|
||
|
||
if dist_type == 'normal':
|
||
return np.random.normal(params.get('mean', 0), params.get('std', 1), size)
|
||
elif dist_type == 'uniform':
|
||
return np.random.uniform(params.get('low', 0), params.get('high', 1), size)
|
||
elif dist_type == 'exponential':
|
||
return np.random.exponential(params.get('scale', 1), size)
|
||
elif dist_type == 'gamma':
|
||
return np.random.gamma(params.get('shape', 1), params.get('scale', 1), size)
|
||
elif dist_type == 'poisson':
|
||
return np.random.poisson(params.get('lam', 1), size)
|
||
elif dist_type == 'lognormal':
|
||
return np.random.lognormal(params.get('mean', 0), params.get('sigma', 1), size)
|
||
else:
|
||
raise ValueError(f"不支持的分布类型: {dist_type}")
|
||
|
||
def generate_sample_space(self, detector_config: DetectorConfig) -> np.ndarray:
|
||
"""生成样本空间(能量单位:keV)"""
|
||
size = detector_config.sample_space_size
|
||
|
||
# 生成能量样本 (keV)
|
||
energies = self.generate_from_distribution(detector_config.energy_dist, size)
|
||
energies = np.abs(energies)
|
||
|
||
# 生成时间戳样本 (us)
|
||
timestamps = self.generate_from_distribution(detector_config.timestamp_dist, size)
|
||
timestamps = np.abs(timestamps)
|
||
|
||
# 组合成二维向量
|
||
sample_space = np.column_stack((energies, timestamps))
|
||
|
||
logger.info(f"为{detector_config.name}生成了{size}个样本 (能量单位: keV)")
|
||
return sample_space
|
||
|
||
class DetectorSampler:
|
||
"""探测器采样器"""
|
||
|
||
def __init__(self, sample_space: np.ndarray, allow_replacement: bool = True):
|
||
self.sample_space = sample_space
|
||
self.size = len(sample_space)
|
||
self.allow_replacement = allow_replacement
|
||
|
||
def sample(self, num_samples: int) -> np.ndarray:
|
||
"""从样本空间抽样"""
|
||
if num_samples <= 0:
|
||
return np.array([])
|
||
|
||
if self.allow_replacement:
|
||
indices = np.random.choice(self.size, num_samples, replace=True)
|
||
else:
|
||
if num_samples > self.size:
|
||
logger.warning(f"请求样本数({num_samples})超过样本空间大小({self.size})")
|
||
num_samples = self.size
|
||
indices = np.random.choice(self.size, num_samples, replace=False)
|
||
|
||
return self.sample_space[indices].copy()
|
||
|
||
class EventSimulatorV4:
|
||
"""事件模拟器 V4(支持三种事件生成模式)"""
|
||
|
||
def __init__(self, config: Dict[str, Any]):
|
||
self.config = config
|
||
self.sample_generator = SampleSpaceGenerator()
|
||
self.encoder = BitFieldEncoder()
|
||
|
||
# 获取事件生成模式
|
||
self.event_generation_mode = config['simulation'].get('event_generation_mode', 'random')
|
||
logger.info(f"使用事件生成模式: {self.event_generation_mode}")
|
||
|
||
# 初始化探测器配置(所有模式都需要)
|
||
self.detectors = self._initialize_detectors()
|
||
|
||
# 对于随机模式,需要样本空间和采样器
|
||
if self.event_generation_mode == 'random':
|
||
self.sample_spaces = self._generate_sample_spaces()
|
||
self.samplers = {}
|
||
for name, sample_space in self.sample_spaces.items():
|
||
allow_replacement = self.detectors[name].allow_replacement
|
||
self.samplers[name] = DetectorSampler(sample_space, allow_replacement)
|
||
else:
|
||
# 固定模式和脉冲模式不需要样本空间
|
||
self.sample_spaces = {}
|
||
self.samplers = {}
|
||
|
||
# 加载固定事件配置
|
||
self.fixed_events_config = config.get('fixed_events', {})
|
||
|
||
# 加载脉冲信号配置
|
||
self.pulse_signals_config = config.get('pulse_signals', {})
|
||
|
||
def _initialize_detectors(self) -> Dict[str, DetectorConfig]:
|
||
"""初始化探测器配置"""
|
||
detectors = {}
|
||
|
||
for det_name, det_config in self.config['detectors'].items():
|
||
energy_dist = DistributionConfig(
|
||
type=det_config['energy_distribution']['type'],
|
||
parameters={k: v for k, v in det_config['energy_distribution'].items() if k != 'type'}
|
||
)
|
||
|
||
timestamp_dist = DistributionConfig(
|
||
type=det_config['timestamp_distribution']['type'],
|
||
parameters={k: v for k, v in det_config['timestamp_distribution'].items() if k != 'type'}
|
||
)
|
||
|
||
detectors[det_name] = DetectorConfig(
|
||
name=det_name,
|
||
sample_space_size=det_config['sample_space_size'],
|
||
allow_replacement=det_config.get('allow_replacement', True),
|
||
energy_dist=energy_dist,
|
||
timestamp_dist=timestamp_dist
|
||
)
|
||
|
||
return detectors
|
||
|
||
def _generate_sample_spaces(self) -> Dict[str, np.ndarray]:
|
||
"""为所有探测器生成样本空间(仅随机模式使用)"""
|
||
sample_spaces = {}
|
||
|
||
for name, detector_config in self.detectors.items():
|
||
sample_spaces[name] = self.sample_generator.generate_sample_space(detector_config)
|
||
|
||
return sample_spaces
|
||
|
||
def generate_event_random(self, event_id: int) -> List[int]:
|
||
"""
|
||
随机模式:生成一个随机事件
|
||
|
||
返回:
|
||
事件信号列表
|
||
"""
|
||
sampling_config = self.config['sampling']
|
||
require_signal = sampling_config.get('require_signal', True)
|
||
|
||
# 为每个探测器随机确定信号数量
|
||
num_signals = {
|
||
name: np.random.randint(
|
||
sampling_config['min_signals_per_detector'],
|
||
sampling_config['max_signals_per_detector'] + 1
|
||
)
|
||
for name in self.detectors.keys()
|
||
}
|
||
|
||
# 从每个探测器抽样
|
||
detector_signals = {}
|
||
for name, sampler in self.samplers.items():
|
||
samples = sampler.sample(num_signals[name])
|
||
|
||
if len(samples) > 0:
|
||
sorted_indices = np.argsort(samples[:, 1])
|
||
detector_signals[name] = samples[sorted_indices]
|
||
else:
|
||
detector_signals[name] = np.array([])
|
||
|
||
# 合并所有探测器信号并编码为64位整数
|
||
encoded_events = self._merge_and_encode_signals(detector_signals, require_signal)
|
||
|
||
# 获取最大时间戳用于事件结束标记
|
||
max_timestamp = 0
|
||
if encoded_events:
|
||
for encoded in encoded_events:
|
||
decoded = self.encoder.decode(encoded)
|
||
if decoded['timestamp'] > max_timestamp:
|
||
max_timestamp = decoded['timestamp']
|
||
|
||
# 设置最后一个信号的事件结束标志位
|
||
if encoded_events:
|
||
encoded_events[-1] = self.encoder.set_event_end_flag(encoded_events[-1])
|
||
else:
|
||
# 如果没有信号,创建一个结束标记信号
|
||
end_signal = self.encoder.encode(max_timestamp, [0, 0, 0], event_end=True)
|
||
encoded_events.append(end_signal)
|
||
|
||
return encoded_events
|
||
|
||
def generate_event_fixed(self, event_id: int) -> List[int]:
|
||
"""
|
||
固定模式:生成一个固定模式的事件
|
||
|
||
返回:
|
||
事件信号列表
|
||
"""
|
||
config = self.fixed_events_config
|
||
|
||
# 获取配置参数
|
||
num_signals = config.get('num_signals_per_event', 5)
|
||
energy_levels = config.get('energy_levels', {})
|
||
timestamps = config.get('timestamps', [])
|
||
repeat_pattern = config.get('repeat_pattern', True)
|
||
|
||
# 准备能量数据
|
||
det1_energies = energy_levels.get('detector1', [1000] * num_signals)
|
||
det2_energies = energy_levels.get('detector2', [800] * num_signals)
|
||
det3_energies = energy_levels.get('detector3', [500] * num_signals)
|
||
|
||
# 准备时间戳
|
||
if repeat_pattern and len(timestamps) > 0:
|
||
# 重复时间戳模式
|
||
timestamps_cycle = timestamps * ((num_signals // len(timestamps)) + 1)
|
||
timestamps_used = timestamps_cycle[:num_signals]
|
||
else:
|
||
# 使用递增时间戳
|
||
if len(timestamps) >= num_signals:
|
||
timestamps_used = timestamps[:num_signals]
|
||
else:
|
||
base_timestamp = 10
|
||
interval = 50
|
||
timestamps_used = [base_timestamp + i * interval for i in range(num_signals)]
|
||
|
||
# 确保数组长度一致
|
||
det1_energies = self._ensure_length(det1_energies, num_signals, 1000)
|
||
det2_energies = self._ensure_length(det2_energies, num_signals, 800)
|
||
det3_energies = self._ensure_length(det3_energies, num_signals, 500)
|
||
|
||
# 编码信号
|
||
encoded_signals = []
|
||
for i in range(num_signals):
|
||
timestamp = timestamps_used[i]
|
||
energies = [det1_energies[i], det2_energies[i], det3_energies[i]]
|
||
|
||
try:
|
||
encoded = self.encoder.encode(timestamp, energies, event_end=False)
|
||
encoded_signals.append(encoded)
|
||
except Exception as e:
|
||
logger.warning(f"固定模式编码信号时出错: {e}")
|
||
# 创建一个默认信号
|
||
default_signal = self.encoder.encode(timestamp, [0, 0, 0], event_end=False)
|
||
encoded_signals.append(default_signal)
|
||
|
||
# 设置最后一个信号的事件结束标志位
|
||
if encoded_signals:
|
||
encoded_signals[-1] = self.encoder.set_event_end_flag(encoded_signals[-1])
|
||
|
||
logger.info(f"固定模式生成事件 {event_id+1}: {num_signals}个信号")
|
||
return encoded_signals
|
||
|
||
def generate_event_pulse(self, event_id: int) -> List[int]:
|
||
"""
|
||
脉冲模式:生成一个脉冲信号事件
|
||
|
||
返回:
|
||
事件信号列表
|
||
"""
|
||
config = self.pulse_signals_config
|
||
|
||
# 获取配置参数
|
||
pulse_count = config.get('pulse_count', 10)
|
||
pulse_interval = config.get('pulse_interval', 100)
|
||
pulse_width = config.get('pulse_width', 5)
|
||
energy_levels = config.get('energy_levels', {})
|
||
jitter = config.get('jitter', 10.0)
|
||
energy_noise = config.get('energy_noise', 100.0)
|
||
|
||
# 获取各探测器基础能量
|
||
det1_energy = energy_levels.get('detector1', 1500)
|
||
det2_energy = energy_levels.get('detector2', 1200)
|
||
det3_energy = energy_levels.get('detector3', 900)
|
||
|
||
# 生成脉冲信号
|
||
encoded_signals = []
|
||
|
||
for pulse_idx in range(pulse_count):
|
||
# 计算脉冲起始时间
|
||
base_time = pulse_idx * pulse_interval
|
||
|
||
# 生成脉冲内的多个信号
|
||
for sub_idx in range(pulse_width):
|
||
# 添加时间抖动
|
||
time_jitter = np.random.uniform(-jitter, jitter) if jitter > 0 else 0
|
||
timestamp = base_time + sub_idx * (pulse_interval / pulse_width) + time_jitter
|
||
timestamp = max(0, timestamp) # 确保时间戳非负
|
||
|
||
# 添加能量噪声
|
||
det1_noisy = det1_energy + np.random.uniform(-energy_noise, energy_noise) if energy_noise > 0 else det1_energy
|
||
det2_noisy = det2_energy + np.random.uniform(-energy_noise, energy_noise) if energy_noise > 0 else det2_energy
|
||
det3_noisy = det3_energy + np.random.uniform(-energy_noise, energy_noise) if energy_noise > 0 else det3_energy
|
||
|
||
# 确保能量非负
|
||
det1_noisy = max(0, det1_noisy)
|
||
det2_noisy = max(0, det2_noisy)
|
||
det3_noisy = max(0, det3_noisy)
|
||
|
||
energies = [det1_noisy, det2_noisy, det3_noisy]
|
||
|
||
try:
|
||
encoded = self.encoder.encode(int(timestamp), energies, event_end=False)
|
||
encoded_signals.append(encoded)
|
||
except Exception as e:
|
||
logger.warning(f"脉冲模式编码信号时出错: {e}")
|
||
# 创建一个默认信号
|
||
default_signal = self.encoder.encode(int(timestamp), [0, 0, 0], event_end=False)
|
||
encoded_signals.append(default_signal)
|
||
|
||
# 设置最后一个信号的事件结束标志位
|
||
if encoded_signals:
|
||
encoded_signals[-1] = self.encoder.set_event_end_flag(encoded_signals[-1])
|
||
|
||
logger.info(f"脉冲模式生成事件 {event_id+1}: {len(encoded_signals)}个信号")
|
||
return encoded_signals
|
||
|
||
def _ensure_length(self, array: List[float], target_length: int, default_value: float) -> List[float]:
|
||
"""确保数组达到目标长度,不足时用默认值填充"""
|
||
if len(array) >= target_length:
|
||
return array[:target_length]
|
||
else:
|
||
return array + [default_value] * (target_length - len(array))
|
||
|
||
def _merge_and_encode_signals(self, detector_signals: Dict[str, np.ndarray], require_signal: bool = True) -> List[int]:
|
||
"""
|
||
合并探测器信号并编码为64位整数
|
||
|
||
参数:
|
||
detector_signals: 各探测器的信号数组
|
||
require_signal: 是否只输出有信号的探测器
|
||
|
||
返回:
|
||
编码后的64位整数列表
|
||
"""
|
||
# 创建一个时间戳到信号的映射
|
||
time_signal_map = {}
|
||
|
||
for det_idx, det_name in enumerate(['detector1', 'detector2', 'detector3']):
|
||
signals = detector_signals.get(det_name, np.array([]))
|
||
|
||
for signal in signals:
|
||
if len(signal) == 2:
|
||
energy_kev = signal[0]
|
||
timestamp_us = signal[1]
|
||
|
||
try:
|
||
# 检查溢出
|
||
if energy_kev > ((1 << 14) - 1):
|
||
logger.warning(f"能量溢出,放弃信号: {energy_kev}keV")
|
||
continue
|
||
|
||
if timestamp_us > ((1 << 18) - 1):
|
||
logger.warning(f"时间戳溢出,放弃信号: {timestamp_us}us")
|
||
continue
|
||
|
||
# 量化能量
|
||
energy_int = int(round(energy_kev)) & ((1 << 14) - 1)
|
||
|
||
# 添加到时间映射
|
||
timestamp_int = int(round(timestamp_us))
|
||
if timestamp_int not in time_signal_map:
|
||
time_signal_map[timestamp_int] = {
|
||
'timestamp': timestamp_int,
|
||
'energies': [0.0, 0.0, 0.0],
|
||
'has_signal': [False, False, False]
|
||
}
|
||
|
||
time_signal_map[timestamp_int]['energies'][det_idx] = energy_int
|
||
time_signal_map[timestamp_int]['has_signal'][det_idx] = True
|
||
|
||
except Exception as e:
|
||
logger.warning(f"编码信号时出错,放弃: {e}")
|
||
continue
|
||
|
||
# 按时间戳排序
|
||
sorted_timestamps = sorted(time_signal_map.keys())
|
||
|
||
# 编码每个时间点的信号
|
||
encoded_events = []
|
||
|
||
for timestamp in sorted_timestamps:
|
||
data = time_signal_map[timestamp]
|
||
energies = data['energies']
|
||
has_signal = data['has_signal']
|
||
|
||
if require_signal and not any(has_signal):
|
||
continue
|
||
|
||
try:
|
||
encoded = self.encoder.encode(timestamp, energies, event_end=False)
|
||
encoded_events.append(encoded)
|
||
except Exception as e:
|
||
logger.warning(f"编码时间点{timestamp}信号时出错,放弃: {e}")
|
||
continue
|
||
|
||
return encoded_events
|
||
|
||
def generate_event(self, event_id: int) -> List[int]:
|
||
"""
|
||
根据当前模式生成一个事件
|
||
|
||
返回:
|
||
事件信号列表
|
||
"""
|
||
if self.event_generation_mode == 'random':
|
||
return self.generate_event_random(event_id)
|
||
elif self.event_generation_mode == 'fixed':
|
||
return self.generate_event_fixed(event_id)
|
||
elif self.event_generation_mode == 'pulse':
|
||
return self.generate_event_pulse(event_id)
|
||
else:
|
||
logger.warning(f"未知的事件生成模式: {self.event_generation_mode},使用随机模式")
|
||
return self.generate_event_random(event_id)
|
||
|
||
def simulate_events(self, num_events: Optional[int] = None) -> List[List[int]]:
|
||
"""
|
||
模拟多个事件,根据配置的模式生成
|
||
|
||
参数:
|
||
num_events: 事件数量(覆盖配置)
|
||
|
||
返回:
|
||
所有事件的列表
|
||
"""
|
||
if num_events is None:
|
||
num_events = self.config['simulation']['num_events']
|
||
|
||
all_events = []
|
||
|
||
print(f"开始模拟事件...")
|
||
print(f" 模式: {self.event_generation_mode}")
|
||
print(f" 数量: {num_events}个事件")
|
||
|
||
for i in range(num_events):
|
||
if self.event_generation_mode == 'random':
|
||
event = self.generate_event_random(i)
|
||
valid_signals = sum(1 for sig in event if sig != 0)
|
||
logger.info(f"已生成随机事件 {i + 1}/{num_events} - 有效信号: {valid_signals}")
|
||
elif self.event_generation_mode == 'fixed':
|
||
event = self.generate_event_fixed(i)
|
||
logger.info(f"已生成固定事件 {i + 1}/{num_events} - 信号数: {len(event)}")
|
||
elif self.event_generation_mode == 'pulse':
|
||
event = self.generate_event_pulse(i)
|
||
logger.info(f"已生成脉冲事件 {i + 1}/{num_events} - 信号数: {len(event)}")
|
||
else:
|
||
event = self.generate_event_random(i)
|
||
valid_signals = sum(1 for sig in event if sig != 0)
|
||
logger.info(f"已生成事件 {i + 1}/{num_events} - 有效信号: {valid_signals}")
|
||
|
||
all_events.append(event)
|
||
|
||
# 汇总统计信息
|
||
total_signals = sum(len(event) for event in all_events)
|
||
total_valid = sum(sum(1 for sig in event if sig != 0) for event in all_events)
|
||
|
||
print(f"\n模拟完成!")
|
||
print(f" 总事件数: {len(all_events)}")
|
||
print(f" 总信号数: {total_signals}")
|
||
print(f" 有效信号数: {total_valid}")
|
||
print(f" 平均每事件信号数: {total_signals/len(all_events):.1f}")
|
||
|
||
return all_events
|
||
|
||
def visualize_event_generation_mode(self):
|
||
"""可视化当前事件生成模式的特性"""
|
||
mode = self.event_generation_mode
|
||
|
||
if mode == 'random':
|
||
print("随机模式特性:")
|
||
print(" - 从样本空间随机抽样")
|
||
print(" - 每个探测器独立抽样")
|
||
print(" - 信号数量随机变化")
|
||
print(" - 能量和时间戳符合配置的分布")
|
||
|
||
# 显示样本空间统计
|
||
if self.sample_spaces:
|
||
print("\n样本空间统计:")
|
||
for name, space in self.sample_spaces.items():
|
||
energies = space[:, 0]
|
||
timestamps = space[:, 1]
|
||
print(f" {name}:")
|
||
print(f" 能量: 平均{energies.mean():.0f}keV, 范围{energies.min():.0f}-{energies.max():.0f}keV")
|
||
print(f" 时间戳: 平均{timestamps.mean():.0f}us, 范围{timestamps.min():.0f}-{timestamps.max():.0f}us")
|
||
|
||
elif mode == 'fixed':
|
||
print("固定模式特性:")
|
||
print(" - 使用预定义的固定模式")
|
||
print(" - 信号数量、能量、时间戳都是固定的")
|
||
print(" - 便于调试和测试")
|
||
|
||
# 显示固定模式配置
|
||
config = self.fixed_events_config
|
||
print(f"\n固定模式配置:")
|
||
print(f" 每事件信号数: {config.get('num_signals_per_event', 5)}")
|
||
|
||
energy_levels = config.get('energy_levels', {})
|
||
for det, energies in energy_levels.items():
|
||
print(f" {det}能量: {energies[:5]}..." if len(energies) > 5 else f" {det}能量: {energies}")
|
||
|
||
timestamps = config.get('timestamps', [])
|
||
print(f" 时间戳序列: {timestamps[:5]}..." if len(timestamps) > 5 else f" 时间戳序列: {timestamps}")
|
||
print(f" 重复模式: {config.get('repeat_pattern', True)}")
|
||
|
||
elif mode == 'pulse':
|
||
print("脉冲模式特性:")
|
||
print(" - 生成周期性脉冲信号")
|
||
print(" - 脉冲数量、间隔、宽度可配置")
|
||
print(" - 可添加时间抖动和能量噪声")
|
||
|
||
# 显示脉冲模式配置
|
||
config = self.pulse_signals_config
|
||
print(f"\n脉冲模式配置:")
|
||
print(f" 脉冲数量: {config.get('pulse_count', 10)}")
|
||
print(f" 脉冲间隔: {config.get('pulse_interval', 100)}us")
|
||
print(f" 脉冲宽度: {config.get('pulse_width', 5)}个信号")
|
||
|
||
energy_levels = config.get('energy_levels', {})
|
||
for det, energy in energy_levels.items():
|
||
print(f" {det}基础能量: {energy}keV")
|
||
|
||
print(f" 时间抖动: ±{config.get('jitter', 10.0)}us")
|
||
print(f" 能量噪声: ±{config.get('energy_noise', 100.0)}keV")
|
||
|
||
else:
|
||
print(f"未知模式: {mode}")
|
||
|
||
def visualize_sample_spaces(self):
|
||
"""可视化样本空间(仅随机模式有效)"""
|
||
if self.event_generation_mode != 'random':
|
||
print(f"当前模式为 {self.event_generation_mode},无需可视化样本空间")
|
||
return
|
||
|
||
if not self.sample_spaces:
|
||
print("样本空间未初始化")
|
||
return
|
||
|
||
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
|
||
axes = axes.flatten()
|
||
|
||
for idx, (det_name, sample_space) in enumerate(self.sample_spaces.items()):
|
||
if idx >= 6:
|
||
break
|
||
|
||
ax = axes[idx]
|
||
energies = sample_space[:, 0]
|
||
timestamps = sample_space[:, 1]
|
||
|
||
ax.scatter(timestamps, energies, alpha=0.5, s=5)
|
||
ax.set_xlabel('时间戳 (us)')
|
||
ax.set_ylabel('能量 (keV)')
|
||
ax.set_title(f'{det_name} 样本空间')
|
||
ax.grid(True, alpha=0.3)
|
||
|
||
plt.tight_layout()
|
||
plt.show()
|
||
|
||
def visualize_event(self, event_signals: List[int], event_id: int = 0):
|
||
"""可视化事件信号"""
|
||
if not event_signals:
|
||
print(f"事件 {event_id+1} 没有信号")
|
||
return
|
||
|
||
# 解码有效信号
|
||
decoded_signals = []
|
||
for encoded in event_signals:
|
||
if encoded != 0:
|
||
try:
|
||
decoded = self.encoder.decode(encoded)
|
||
decoded_signals.append(decoded)
|
||
except:
|
||
continue
|
||
|
||
if not decoded_signals:
|
||
print(f"事件 {event_id+1} 没有有效信号")
|
||
return
|
||
|
||
# 提取数据
|
||
timestamps = [event['timestamp'] for event in decoded_signals]
|
||
energies = np.array([event['energies'] for event in decoded_signals])
|
||
|
||
# 根据模式调整标题
|
||
mode = self.event_generation_mode
|
||
if mode == 'random':
|
||
mode_text = '随机'
|
||
elif mode == 'fixed':
|
||
mode_text = '固定'
|
||
elif mode == 'pulse':
|
||
mode_text = '脉冲'
|
||
else:
|
||
mode_text = ''
|
||
|
||
# 创建可视化
|
||
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
|
||
|
||
# 1. 时间-能量散点图
|
||
ax1 = axes[0, 0]
|
||
colors = ['red', 'green', 'blue']
|
||
labels = ['探测器1', '探测器2', '探测器3']
|
||
|
||
for det_idx in range(3):
|
||
det_energies = energies[:, det_idx]
|
||
valid_mask = det_energies > 0
|
||
if np.any(valid_mask):
|
||
ax1.scatter(np.array(timestamps)[valid_mask], det_energies[valid_mask],
|
||
color=colors[det_idx], label=labels[det_idx], alpha=0.6, s=20)
|
||
|
||
ax1.set_xlabel('时间戳 (us)')
|
||
ax1.set_ylabel('能量 (keV)')
|
||
ax1.set_title(f'{mode_text}模式 - 事件 {event_id+1} - 探测器信号分布')
|
||
ax1.legend()
|
||
ax1.grid(True, alpha=0.3)
|
||
|
||
# 2. 能量分布直方图
|
||
ax2 = axes[0, 1]
|
||
for det_idx in range(3):
|
||
det_energies = energies[:, det_idx]
|
||
valid_energies = det_energies[det_energies > 0]
|
||
if len(valid_energies) > 0:
|
||
ax2.hist(valid_energies, bins=20, alpha=0.5,
|
||
color=colors[det_idx], label=labels[det_idx])
|
||
|
||
ax2.set_xlabel('能量 (keV)')
|
||
ax2.set_ylabel('计数')
|
||
ax2.set_title('能量分布')
|
||
ax2.legend()
|
||
ax2.grid(True, alpha=0.3)
|
||
|
||
# 3. 探测器激活统计
|
||
ax3 = axes[1, 0]
|
||
signal_counts = [(energies[:, i] > 0).sum() for i in range(3)]
|
||
bars = ax3.bar(labels, signal_counts, color=colors, alpha=0.7)
|
||
ax3.set_ylabel('信号数量')
|
||
ax3.set_title('各探测器信号数量')
|
||
|
||
for bar, count in zip(bars, signal_counts):
|
||
height = bar.get_height()
|
||
ax3.text(bar.get_x() + bar.get_width()/2., height,
|
||
f'{count}', ha='center', va='bottom')
|
||
|
||
ax3.grid(True, alpha=0.3, axis='y')
|
||
|
||
# 4. 事件信息
|
||
ax4 = axes[1, 1]
|
||
ax4.axis('off')
|
||
|
||
info_text = f"{mode_text}模式 - 事件 {event_id+1} 信息:\n\n"
|
||
info_text += f"总信号数: {len(event_signals)}\n"
|
||
info_text += f"有效信号数: {len(decoded_signals)}\n"
|
||
|
||
# 检查事件结束标志
|
||
last_signal = event_signals[-1]
|
||
if last_signal != 0:
|
||
last_decoded = self.encoder.decode(last_signal)
|
||
info_text += f"事件结束标志: {'是' if last_decoded['event_end'] else '否'}\n"
|
||
if last_decoded['event_end']:
|
||
info_text += f"结束时间戳: {last_decoded['timestamp']}us\n"
|
||
|
||
if decoded_signals:
|
||
all_timestamps = [d['timestamp'] for d in decoded_signals]
|
||
info_text += f"时间戳范围: {min(all_timestamps)} - {max(all_timestamps)}us\n"
|
||
|
||
# 探测器统计
|
||
for det_idx in range(3):
|
||
det_energies = energies[:, det_idx]
|
||
valid_count = (det_energies > 0).sum()
|
||
if valid_count > 0:
|
||
avg_energy = det_energies[det_energies > 0].mean()
|
||
info_text += f"探测器{det_idx+1}: {valid_count}信号,平均{avg_energy:.0f}keV\n"
|
||
|
||
ax4.text(0.05, 0.95, info_text, transform=ax4.transAxes,
|
||
fontsize=10, verticalalignment='top',
|
||
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
|
||
|
||
plt.tight_layout()
|
||
plt.show()
|
||
|
||
# ==================== 优化的数据写入器 V4 ====================
|
||
class DataWriterV4:
|
||
"""数据写入器 V4(支持两种打包模式)"""
|
||
|
||
def __init__(self, config: Dict[str, Any]):
|
||
self.config = config
|
||
self.encoder = BitFieldEncoder()
|
||
self.instruction_config = InstructionConfigV4()
|
||
self.packing_config = self.instruction_config.PACKING_CONFIG
|
||
|
||
# 获取打包模式
|
||
self.packing_mode = config['simulation'].get('packing_mode', 'separate')
|
||
if self.packing_mode not in ['separate', 'combined']:
|
||
raise ValueError(f"不支持的打包模式: {self.packing_mode}")
|
||
|
||
logger.info(f"使用打包模式: {self.packing_mode}")
|
||
|
||
def _encode_signal_to_bytes(self, signal: int) -> bytes:
|
||
"""
|
||
将64位整数编码为8字节(高字节在前)
|
||
"""
|
||
return struct.pack('>Q', signal) # 大端序
|
||
|
||
def _reorganize_signals_for_output(self, signals: List[int]) -> List[int]:
|
||
"""
|
||
重新组织信号用于输出
|
||
8个信号为一组,每组内从后往前
|
||
"""
|
||
group_size = 8
|
||
num_groups = len(signals) // group_size
|
||
|
||
reorganized = []
|
||
|
||
for group_idx in range(num_groups):
|
||
start_idx = group_idx * group_size
|
||
end_idx = start_idx + group_size
|
||
group = signals[start_idx:end_idx]
|
||
|
||
# 反转组内顺序
|
||
reversed_group = list(reversed(group))
|
||
reorganized.extend(reversed_group)
|
||
|
||
return reorganized
|
||
|
||
def _create_packet_bytes(self, packet_signals: List[int]) -> bytes:
|
||
"""
|
||
创建一个数据包的字节流(6402字节)
|
||
|
||
参数:
|
||
packet_signals: 一个数据包的信号列表(必须为800个)
|
||
|
||
返回:
|
||
数据包的字节流(6400数据字节 + 2字节CRC16)
|
||
"""
|
||
if len(packet_signals) != 800:
|
||
raise ValueError(f"数据包必须包含800个信号,当前为{len(packet_signals)}")
|
||
|
||
# 1. 重新组织信号(8个一组,从后往前)
|
||
reorganized_signals = self._reorganize_signals_for_output(packet_signals)
|
||
|
||
# 2. 编码为字节流
|
||
packet_data = bytearray()
|
||
for signal in reorganized_signals:
|
||
packet_data.extend(self._encode_signal_to_bytes(signal))
|
||
|
||
# 3. 计算CRC16(对整个6400字节数据)
|
||
crc_bytes = calculate_crc16(packet_data)
|
||
|
||
# 4. 添加CRC16
|
||
packet_data.extend(crc_bytes)
|
||
|
||
# 验证大小
|
||
if len(packet_data) != 6402:
|
||
raise ValueError(f"数据包字节流大小错误: {len(packet_data)} != 6402")
|
||
|
||
return bytes(packet_data)
|
||
|
||
def _pack_events_separately(self, all_events: List[List[int]]) -> List[bytes]:
|
||
"""
|
||
单独打包模式:每个事件单独打包成一个数据包
|
||
|
||
参数:
|
||
all_events: 所有事件的信号列表
|
||
|
||
返回:
|
||
数据包字节流列表
|
||
"""
|
||
packets = []
|
||
|
||
for event_idx, event_signals in enumerate(all_events):
|
||
# 每个事件单独处理
|
||
signals_needed = 800
|
||
|
||
# 如果事件信号不足800个,补零
|
||
if len(event_signals) < signals_needed:
|
||
padding_needed = signals_needed - len(event_signals)
|
||
padded_signals = event_signals + [0] * padding_needed
|
||
elif len(event_signals) > signals_needed:
|
||
# 如果超过800个,截断(理论上不应该发生)
|
||
logger.warning(f"事件{event_idx+1}信号数({len(event_signals)})超过800,将截断")
|
||
padded_signals = event_signals[:signals_needed]
|
||
else:
|
||
padded_signals = event_signals
|
||
|
||
# 设置最后一个信号的事件结束标志位
|
||
if padded_signals[-1] != 0:
|
||
padded_signals[-1] = self.encoder.set_event_end_flag(padded_signals[-1])
|
||
else:
|
||
# 如果最后一个信号是0,设置为事件结束标记
|
||
padded_signals[-1] = self.encoder.encode(0, [0, 0, 0], event_end=True)
|
||
|
||
# 创建数据包
|
||
packet_bytes = self._create_packet_bytes(padded_signals)
|
||
packets.append(packet_bytes)
|
||
|
||
logger.info(f"已打包事件 {event_idx + 1}/{len(all_events)} (单独打包模式)")
|
||
|
||
return packets
|
||
|
||
def _pack_events_combined(self, all_events: List[List[int]]) -> List[bytes]:
|
||
"""
|
||
整体打包模式:所有事件信号拼接,每800个信号一包
|
||
|
||
参数:
|
||
all_events: 所有事件的信号列表
|
||
|
||
返回:
|
||
数据包字节流列表
|
||
"""
|
||
# 1. 收集所有信号,并设置每个事件的最后一个信号为结束标志
|
||
all_signals = []
|
||
|
||
for event_idx, event_signals in enumerate(all_events):
|
||
if event_signals:
|
||
# 设置最后一个信号的事件结束标志位
|
||
if event_signals[-1] != 0:
|
||
event_signals[-1] = self.encoder.set_event_end_flag(event_signals[-1])
|
||
else:
|
||
# 如果最后一个信号是0,设置为事件结束标记
|
||
event_signals[-1] = self.encoder.encode(0, [0, 0, 0], event_end=True)
|
||
|
||
all_signals.extend(event_signals)
|
||
else:
|
||
# 如果事件没有信号,添加一个结束标记信号
|
||
end_signal = self.encoder.encode(0, [0, 0, 0], event_end=True)
|
||
all_signals.append(end_signal)
|
||
|
||
logger.info(f"已处理事件 {event_idx + 1}/{len(all_events)},信号数: {len(event_signals)}")
|
||
|
||
# 2. 将信号分组成800个一包
|
||
packets = []
|
||
signals_per_packet = 800
|
||
|
||
total_signals = len(all_signals)
|
||
num_full_packets = total_signals // signals_per_packet
|
||
remaining_signals = total_signals % signals_per_packet
|
||
|
||
logger.info(f"总信号数: {total_signals}, 完整包: {num_full_packets}, 剩余信号: {remaining_signals}")
|
||
|
||
# 3. 处理完整的数据包
|
||
for i in range(num_full_packets):
|
||
start_idx = i * signals_per_packet
|
||
end_idx = start_idx + signals_per_packet
|
||
packet_signals = all_signals[start_idx:end_idx]
|
||
|
||
# 创建数据包
|
||
packet_bytes = self._create_packet_bytes(packet_signals)
|
||
packets.append(packet_bytes)
|
||
|
||
logger.info(f"已打包完整数据包 {i + 1}/{num_full_packets}")
|
||
|
||
# 4. 处理最后一个不完整的数据包(如果有)
|
||
if remaining_signals > 0:
|
||
start_idx = num_full_packets * signals_per_packet
|
||
last_packet_signals = all_signals[start_idx:]
|
||
|
||
# 补零到800个信号
|
||
padding_needed = signals_per_packet - remaining_signals
|
||
padded_signals = last_packet_signals + [0] * padding_needed
|
||
|
||
# 设置最后一个信号的事件结束标志位
|
||
if padded_signals[-1] != 0:
|
||
padded_signals[-1] = self.encoder.set_event_end_flag(padded_signals[-1])
|
||
else:
|
||
padded_signals[-1] = self.encoder.encode(0, [0, 0, 0], event_end=True)
|
||
|
||
# 创建最后一个数据包
|
||
packet_bytes = self._create_packet_bytes(padded_signals)
|
||
packets.append(packet_bytes)
|
||
|
||
logger.info(f"已打包最后一个数据包 (补零{padding_needed}个信号)")
|
||
|
||
return packets
|
||
|
||
def write_events_binary_v4(self, all_events: List[List[int]], output_file: str):
|
||
"""
|
||
将事件写入二进制文件 V4
|
||
"""
|
||
# 根据打包模式创建数据包
|
||
if self.packing_mode == 'separate':
|
||
packets = self._pack_events_separately(all_events)
|
||
print(f"打包模式: 单独打包 - 每个事件一个数据包")
|
||
else: # combined
|
||
packets = self._pack_events_combined(all_events)
|
||
print(f"打包模式: 整体打包 - 所有事件信号拼接后分包")
|
||
|
||
# 计算总数据字节数(不包括指令部分)
|
||
total_data_bytes = sum(len(packet) for packet in packets)
|
||
|
||
# 生成指令部分
|
||
instruction_bytes = self.instruction_config.generate_instruction_bytes(total_data_bytes)
|
||
|
||
# 写入文件
|
||
with open(output_file, 'wb') as f:
|
||
# 写入指令部分
|
||
# f.write(instruction_bytes)
|
||
|
||
# 写入所有数据包
|
||
for packet_idx, packet in enumerate(packets):
|
||
f.write(packet)
|
||
|
||
# 统计信息
|
||
print(f"二进制文件结构:")
|
||
print(f" 指令部分: {len(instruction_bytes)} 字节")
|
||
print(f" 数据包数量: {len(packets)}")
|
||
print(f" 总数据字节: {total_data_bytes} 字节")
|
||
print(f" 每数据包: {self.packing_config['bytes_per_packet']} 字节")
|
||
print(f" 总大小: {len(instruction_bytes) + total_data_bytes} 字节")
|
||
|
||
logger.info(f"已将 {len(packets)} 个数据包写入二进制文件 V4: {output_file}")
|
||
|
||
def write_events_text_v4(self, all_events: List[List[int]], output_file: str):
|
||
"""
|
||
将事件写入文本文件 V4
|
||
每行一个十进制整数(0-255)
|
||
"""
|
||
# 根据打包模式创建数据包
|
||
if self.packing_mode == 'separate':
|
||
packets = self._pack_events_separately(all_events)
|
||
print(f"打包模式: 单独打包 - 每个事件一个数据包")
|
||
else: # combined
|
||
packets = self._pack_events_combined(all_events)
|
||
print(f"打包模式: 整体打包 - 所有事件信号拼接后分包")
|
||
|
||
# 计算总数据字节数(不包括指令部分)
|
||
total_data_bytes = sum(len(packet) for packet in packets)
|
||
|
||
# 生成指令部分
|
||
instruction_bytes = self.instruction_config.generate_instruction_bytes(total_data_bytes)
|
||
|
||
# 将所有字节合并
|
||
all_bytes = instruction_bytes
|
||
for packet in packets:
|
||
all_bytes += packet
|
||
|
||
# 转换为整数列表
|
||
int_list = list(all_bytes)
|
||
|
||
# 写入文件(每行一个十进制整数)
|
||
with open(output_file, 'w', encoding='utf-8') as f:
|
||
for i, byte_value in enumerate(int_list):
|
||
f.write(f"{byte_value}\n")
|
||
|
||
# 打印信息
|
||
print(f"文本文件结构:")
|
||
print(f" 总字节数: {len(int_list)}")
|
||
print(f" 总行数: {len(int_list)}")
|
||
print(f" 指令部分: {len(instruction_bytes)} 字节 (行 1-{len(instruction_bytes)})")
|
||
print(f" 数据部分: {total_data_bytes} 字节 (行 {len(instruction_bytes)+1}-{len(int_list)})")
|
||
print(f" 数据包数量: {len(packets)}")
|
||
print(f" 每数据包: {self.packing_config['bytes_per_packet']} 字节 = {self.packing_config['bytes_per_packet']} 行")
|
||
print(f" 每行一个0-255的十进制整数")
|
||
|
||
logger.info(f"已将 {len(int_list)} 个字节写入文本文件: {output_file}")
|
||
|
||
def write_events_debug_text(self, all_events: List[List[int]], output_file: str):
|
||
"""
|
||
写入调试文本文件(详细事件信息)
|
||
"""
|
||
with open(output_file, 'w', encoding='utf-8') as f:
|
||
f.write(f"# 探测器信号数据 V4.2 - 打包模式: {self.packing_mode}\n")
|
||
f.write("# 每个数据包: 800个点火信号 + 2字节CRC16 = 6402字节\n")
|
||
f.write("# 8个信号为一组,每组内从后往前导出\n")
|
||
f.write("# 格式: 事件ID | 信号序号 | 十六进制编码 | 解析结果\n")
|
||
f.write("#" * 120 + "\n")
|
||
|
||
for event_idx, event_signals in enumerate(all_events):
|
||
f.write(f"\n# 事件 {event_idx + 1} ({len(event_signals)}个信号)\n")
|
||
|
||
valid_count = 0
|
||
|
||
for sig_idx, encoded in enumerate(event_signals):
|
||
hex_str = self.encoder.to_hex_string(encoded)
|
||
|
||
if encoded == 0:
|
||
info = "填充信号"
|
||
else:
|
||
valid_count += 1
|
||
decoded = self.encoder.decode(encoded)
|
||
|
||
info_parts = []
|
||
if decoded['event_end']:
|
||
info_parts.append("事件结束")
|
||
info_parts.append(f"时间戳:{decoded['timestamp']}us")
|
||
|
||
energies = decoded['energies']
|
||
active_dets = []
|
||
for i in range(3):
|
||
if energies[i] > 0:
|
||
active_dets.append(f"Det{i+1}:{energies[i]}keV")
|
||
|
||
if active_dets:
|
||
info_parts.append(f"信号:{', '.join(active_dets)}")
|
||
|
||
info = ", ".join(info_parts)
|
||
|
||
# 标记是否为事件最后一个信号
|
||
marker = " [事件结束]" if sig_idx == len(event_signals) - 1 else ""
|
||
|
||
f.write(f"{event_idx:4d} | {sig_idx+1:4d} | 0x{hex_str} | {info}{marker}\n")
|
||
|
||
f.write(f"# 总结: 有效信号={valid_count}, 总信号={len(event_signals)}\n")
|
||
|
||
logger.info(f"已将 {len(all_events)} 个事件写入调试文本文件: {output_file}")
|
||
|
||
def write_events(self, all_events: List[List[int]], output_file: str = None):
|
||
"""
|
||
根据配置写入所有格式的文件
|
||
"""
|
||
if output_file is None:
|
||
output_file = self.config['simulation']['output_file']
|
||
|
||
output_format = self.config['simulation'].get('output_format', 'all')
|
||
base_name = Path(output_file).stem
|
||
|
||
if output_format == 'binary':
|
||
binary_file = f"{base_name}.bin"
|
||
self.write_events_binary_v4(all_events, binary_file)
|
||
|
||
elif output_format == 'text':
|
||
text_file = f"{base_name}_v4.txt"
|
||
self.write_events_text_v4(all_events, text_file)
|
||
|
||
elif output_format == 'debug':
|
||
debug_file = f"{base_name}_debug.txt"
|
||
self.write_events_debug_text(all_events, debug_file)
|
||
|
||
elif output_format == 'all':
|
||
binary_file = f"{base_name}.bin"
|
||
text_file = f"{base_name}_v4.txt"
|
||
debug_file = f"{base_name}_debug.txt"
|
||
|
||
print("生成所有格式文件:")
|
||
print(f" 1. {binary_file} - 二进制格式(指令+数据)")
|
||
print(f" 2. {text_file} - 文本格式(每行一个十进制整数)")
|
||
print(f" 3. {debug_file} - 调试文本格式(详细解析)")
|
||
print()
|
||
|
||
self.write_events_binary_v4(all_events, binary_file)
|
||
print()
|
||
self.write_events_text_v4(all_events, text_file)
|
||
print()
|
||
self.write_events_debug_text(all_events, debug_file)
|
||
|
||
else:
|
||
raise ValueError(f"不支持的输出格式: {output_format}")
|
||
|
||
def print_file_info(self, all_events: List[List[int]]):
|
||
"""
|
||
打印文件信息
|
||
"""
|
||
num_events = len(all_events)
|
||
|
||
print("=" * 80)
|
||
print(f"文件输出信息 V4.2 - 打包模式: {self.packing_mode}")
|
||
print("=" * 80)
|
||
|
||
# 计算数据包数量
|
||
if self.packing_mode == 'separate':
|
||
num_packets = num_events
|
||
else: # combined
|
||
total_signals = sum(len(event) for event in all_events)
|
||
num_full_packets = total_signals // 800
|
||
remaining_signals = total_signals % 800
|
||
num_packets = num_full_packets + (1 if remaining_signals > 0 else 0)
|
||
|
||
# 计算总数据字节数
|
||
total_data_bytes = num_packets * self.packing_config['bytes_per_packet']
|
||
|
||
# 打印指令信息
|
||
self.instruction_config.print_instruction_info(total_data_bytes)
|
||
|
||
# 打印事件统计
|
||
total_valid_signals = 0
|
||
total_all_signals = 0
|
||
|
||
for event_idx, event_signals in enumerate(all_events):
|
||
valid_in_event = sum(1 for sig in event_signals if sig != 0)
|
||
total_valid_signals += valid_in_event
|
||
total_all_signals += len(event_signals)
|
||
|
||
# 检查最后一个信号的事件结束标志
|
||
if event_signals:
|
||
last_signal = event_signals[-1]
|
||
if last_signal != 0:
|
||
last_decoded = self.encoder.decode(last_signal)
|
||
if not last_decoded['event_end']:
|
||
print(f"警告: 事件{event_idx+1}的最后一个信号没有设置事件结束标志")
|
||
|
||
print("\n事件统计:")
|
||
print(f" 总事件数: {num_events}")
|
||
print(f" 总信号数: {total_all_signals}")
|
||
print(f" 有效信号数: {total_valid_signals}")
|
||
print(f" 平均每事件信号数: {total_all_signals/num_events:.1f}")
|
||
|
||
print(f"\n数据包统计 ({self.packing_mode}模式):")
|
||
print(f" 数据包数量: {num_packets}")
|
||
print(f" 每数据包: {self.packing_config['bytes_per_packet']} 字节")
|
||
print(f" 总数据字节: {total_data_bytes} 字节")
|
||
|
||
if self.packing_mode == 'combined':
|
||
print(f" 完整数据包: {total_all_signals // 800}")
|
||
print(f" 最后一个数据包信号数: {total_all_signals % 800 if total_all_signals % 800 != 0 else 800}")
|
||
print(f" 最后一个数据包补零数: {800 - (total_all_signals % 800) if total_all_signals % 800 != 0 else 0}")
|
||
|
||
print("=" * 80)
|
||
|
||
# ==================== 工具函数 ====================
|
||
def load_config_from_yaml(config_str: str = None, config_file: str = None) -> Dict[str, Any]:
|
||
"""从YAML字符串或文件加载配置"""
|
||
if config_str:
|
||
config = yaml.safe_load(config_str)
|
||
elif config_file:
|
||
with open(config_file, 'r', encoding='utf-8') as f:
|
||
config = yaml.safe_load(f)
|
||
else:
|
||
config = yaml.safe_load(DEFAULT_CONFIG_YAML)
|
||
|
||
return config
|
||
|
||
def validate_config(config: Dict[str, Any]):
|
||
"""验证配置"""
|
||
required_sections = ['simulation', 'detectors', 'sampling']
|
||
|
||
for section in required_sections:
|
||
if section not in config:
|
||
raise ValueError(f"配置文件中缺少必需的部分: {section}")
|
||
|
||
if len(config['detectors']) != 3:
|
||
logger.warning(f"配置了{len(config['detectors'])}个探测器,需要3个")
|
||
|
||
# ==================== Jupyter notebook主函数 ====================
|
||
|
||
def run_simulation_v4(config_str: str = None,
|
||
config_file: str = None,
|
||
num_events: int = None,
|
||
output_file: str = None,
|
||
output_format: str = None,
|
||
packing_mode: str = None,
|
||
event_generation_mode: str = None,
|
||
visualize: bool = True):
|
||
"""
|
||
在Jupyter notebook中运行模拟 V4.3
|
||
|
||
参数:
|
||
config_str: YAML配置字符串
|
||
config_file: 配置文件路径
|
||
num_events: 事件数量(覆盖配置)
|
||
output_file: 输出文件路径(覆盖配置)
|
||
output_format: 输出格式(覆盖配置)
|
||
packing_mode: 打包模式(覆盖配置)
|
||
event_generation_mode: 事件生成模式(覆盖配置)
|
||
visualize: 是否可视化结果
|
||
"""
|
||
|
||
print("=" * 80)
|
||
print("探测器信号模拟器 V4.3")
|
||
print("=" * 80)
|
||
print("新增功能: 事件生成模式选择")
|
||
print(" 1. random(随机模式): 从样本空间随机抽样")
|
||
print(" 2. fixed(固定模式): 生成固定模式的事件")
|
||
print(" 3. pulse(脉冲模式): 生成周期性脉冲信号")
|
||
print("=" * 80)
|
||
|
||
# 1. 加载配置
|
||
config = load_config_from_yaml(config_str, config_file)
|
||
validate_config(config)
|
||
|
||
# 覆盖配置参数
|
||
if num_events:
|
||
config['simulation']['num_events'] = num_events
|
||
if output_file:
|
||
config['simulation']['output_file'] = output_file
|
||
if output_format:
|
||
config['simulation']['output_format'] = output_format
|
||
if packing_mode:
|
||
config['simulation']['packing_mode'] = packing_mode
|
||
if event_generation_mode:
|
||
config['simulation']['event_generation_mode'] = event_generation_mode
|
||
|
||
mode = config['simulation'].get('event_generation_mode', 'random')
|
||
|
||
print(f"配置信息:")
|
||
print(f" 事件数量: {config['simulation']['num_events']}")
|
||
print(f" 输出格式: {config['simulation'].get('output_format', 'all')}")
|
||
print(f" 打包模式: {config['simulation'].get('packing_mode', 'separate')}")
|
||
print(f" 事件生成模式: {mode}")
|
||
print(f" 输出文件: {config['simulation']['output_file']}")
|
||
print("=" * 80)
|
||
|
||
# 2. 创建模拟器
|
||
simulator = EventSimulatorV4(config)
|
||
|
||
# 3. 显示当前模式特性
|
||
simulator.visualize_event_generation_mode()
|
||
|
||
# 4. 可视化样本空间(仅随机模式)
|
||
if visualize and mode == 'random':
|
||
print("\n可视化样本空间...")
|
||
simulator.visualize_sample_spaces()
|
||
|
||
# 5. 模拟事件
|
||
print("\n开始模拟事件...")
|
||
all_events = simulator.simulate_events()
|
||
|
||
# 6. 可视化第一个事件
|
||
if visualize and all_events:
|
||
print("\n可视化第一个事件...")
|
||
simulator.visualize_event(all_events[0], 0)
|
||
|
||
# 7. 创建数据写入器并写入文件
|
||
data_writer = DataWriterV4(config)
|
||
|
||
# 打印文件信息
|
||
data_writer.print_file_info(all_events)
|
||
|
||
# 写入文件
|
||
print("\n生成输出文件...")
|
||
data_writer.write_events(all_events)
|
||
|
||
# 8. 转换为DataFrame(用于分析)
|
||
df = events_to_dataframe_v4(all_events)
|
||
|
||
# 9. 显示统计信息
|
||
print("\n" + "=" * 80)
|
||
print("模拟完成!")
|
||
|
||
# 详细统计
|
||
total_valid = 0
|
||
total_all = 0
|
||
events_with_proper_end = 0
|
||
|
||
for event_idx, event in enumerate(all_events):
|
||
valid_in_event = sum(1 for sig in event if sig != 0)
|
||
total_valid += valid_in_event
|
||
total_all += len(event)
|
||
|
||
# 检查事件结束标志
|
||
if event:
|
||
last_signal = event[-1]
|
||
if last_signal != 0:
|
||
last_decoded = simulator.encoder.decode(last_signal)
|
||
if last_decoded['event_end']:
|
||
events_with_proper_end += 1
|
||
|
||
print(f"总事件数: {len(all_events)}")
|
||
print(f"总信号数: {total_all}")
|
||
print(f"有效信号数: {total_valid}")
|
||
print(f"平均每事件信号数: {total_all/len(all_events):.1f}")
|
||
print(f"有效信号比例: {total_valid/total_all*100:.1f}%")
|
||
print(f"正确设置事件结束标志的事件: {events_with_proper_end}/{len(all_events)}")
|
||
|
||
# 根据打包模式显示额外信息
|
||
packing_mode = config['simulation'].get('packing_mode', 'separate')
|
||
if packing_mode == 'separate':
|
||
print(f"数据包数量: {len(all_events)} (每个事件一个数据包)")
|
||
else:
|
||
total_packets = total_all // 800 + (1 if total_all % 800 != 0 else 0)
|
||
print(f"数据包数量: {total_packets} (所有事件信号拼接)")
|
||
print(f"最后一个数据包补零数: {800 - (total_all % 800) if total_all % 800 != 0 else 0}")
|
||
|
||
print("=" * 80)
|
||
|
||
return df, simulator, all_events, data_writer
|
||
|
||
# ==================== 辅助函数 ====================
|
||
def events_to_dataframe_v4(all_events: List[List[int]]) -> pd.DataFrame:
|
||
"""将事件转换为DataFrame(用于分析)"""
|
||
encoder = BitFieldEncoder()
|
||
all_data = []
|
||
|
||
for event_idx, event_signals in enumerate(all_events):
|
||
for sig_idx, encoded in enumerate(event_signals):
|
||
if encoded == 0:
|
||
row = {
|
||
'event_id': event_idx + 1,
|
||
'signal_index': sig_idx + 1,
|
||
'is_padding': True,
|
||
'encoded_hex': "0x0000000000000000",
|
||
'event_end': False,
|
||
'timestamp_us': 0,
|
||
'detector_mask': 0,
|
||
'detector1_energy_kev': 0,
|
||
'detector2_energy_kev': 0,
|
||
'detector3_energy_kev': 0
|
||
}
|
||
else:
|
||
decoded = encoder.decode(encoded)
|
||
row = {
|
||
'event_id': event_idx + 1,
|
||
'signal_index': sig_idx + 1,
|
||
'is_padding': False,
|
||
'encoded_hex': f"0x{encoder.to_hex_string(encoded)}",
|
||
'event_end': decoded['event_end'],
|
||
'timestamp_us': decoded['timestamp'],
|
||
'detector_mask': decoded['detector_mask'],
|
||
'detector1_energy_kev': decoded['energies'][0],
|
||
'detector2_energy_kev': decoded['energies'][1],
|
||
'detector3_energy_kev': decoded['energies'][2]
|
||
}
|
||
all_data.append(row)
|
||
|
||
df = pd.DataFrame(all_data)
|
||
return df
|
||
|
||
# ==================== 验证函数 ====================
|
||
def verify_output_files_v4(base_filename: str):
|
||
"""
|
||
验证生成的输出文件 V4
|
||
"""
|
||
print(f"\n验证输出文件: {base_filename}.*")
|
||
print("-" * 80)
|
||
|
||
files_to_check = [
|
||
f"{base_filename}.bin",
|
||
f"{base_filename}_v4.txt",
|
||
f"{base_filename}_debug.txt"
|
||
]
|
||
|
||
for file_path in files_to_check:
|
||
if Path(file_path).exists():
|
||
file_size = Path(file_path).stat().st_size
|
||
print(f"✓ {file_path}: {file_size} 字节")
|
||
|
||
if file_path.endswith('_v4.txt'):
|
||
with open(file_path, 'r') as f:
|
||
lines = [f.readline().strip() for _ in range(10)]
|
||
print(f" 前5行值: {', '.join(lines[:5])}")
|
||
|
||
with open(file_path, 'r') as f:
|
||
total_lines = sum(1 for _ in f)
|
||
print(f" 总行数: {total_lines}")
|
||
|
||
else:
|
||
print(f"✗ {file_path}: 文件不存在")
|
||
|
||
print("-" * 80)
|
||
|
||
def compare_packing_modes(config_str: str = None, num_events: int = 3):
|
||
"""
|
||
比较两种打包模式
|
||
"""
|
||
print("比较两种打包模式...")
|
||
print("=" * 80)
|
||
|
||
# 1. 单独打包模式
|
||
print("\n1. 单独打包模式 (separate):")
|
||
print("-" * 40)
|
||
|
||
config1 = load_config_from_yaml(config_str)
|
||
if config_str:
|
||
config1 = yaml.safe_load(config_str)
|
||
else:
|
||
config1 = yaml.safe_load(DEFAULT_CONFIG_YAML)
|
||
|
||
config1['simulation']['packing_mode'] = 'separate'
|
||
config1['simulation']['num_events'] = num_events
|
||
config1['simulation']['output_file'] = 'detector_output_separate'
|
||
|
||
simulator1 = EventSimulatorV4(config1)
|
||
events1 = simulator1.simulate_events()
|
||
|
||
writer1 = DataWriterV4(config1)
|
||
writer1.print_file_info(events1)
|
||
|
||
# 2. 整体打包模式
|
||
print("\n2. 整体打包模式 (combined):")
|
||
print("-" * 40)
|
||
|
||
config2 = load_config_from_yaml(config_str)
|
||
if config_str:
|
||
config2 = yaml.safe_load(config_str)
|
||
else:
|
||
config2 = yaml.safe_load(DEFAULT_CONFIG_YAML)
|
||
|
||
config2['simulation']['packing_mode'] = 'combined'
|
||
config2['simulation']['num_events'] = num_events
|
||
config2['simulation']['output_file'] = 'detector_output_combined'
|
||
|
||
simulator2 = EventSimulatorV4(config2)
|
||
events2 = simulator2.simulate_events()
|
||
|
||
writer2 = DataWriterV4(config2)
|
||
writer2.print_file_info(events2)
|
||
|
||
# 3. 比较结果
|
||
print("\n3. 模式比较:")
|
||
print("-" * 40)
|
||
|
||
# 计算信号总数
|
||
total_signals1 = sum(len(event) for event in events1)
|
||
total_signals2 = sum(len(event) for event in events2)
|
||
|
||
# 计算数据包数量
|
||
packets1 = len(events1) # 单独打包:每个事件一个数据包
|
||
packets2 = total_signals2 // 800 + (1 if total_signals2 % 800 != 0 else 0) # 整体打包
|
||
|
||
print(f"单独打包模式:")
|
||
print(f" 事件数: {len(events1)}")
|
||
print(f" 总信号数: {total_signals1}")
|
||
print(f" 数据包数: {packets1}")
|
||
print(f" 总字节数: {packets1 * 6402} 字节")
|
||
|
||
print(f"\n整体打包模式:")
|
||
print(f" 事件数: {len(events2)}")
|
||
print(f" 总信号数: {total_signals2}")
|
||
print(f" 数据包数: {packets2}")
|
||
print(f" 总字节数: {packets2 * 6402} 字节")
|
||
|
||
print(f"\n比较结果:")
|
||
print(f" 信号数量差异: {abs(total_signals1 - total_signals2)}")
|
||
print(f" 数据包数量差异: {abs(packets1 - packets2)}")
|
||
print(f" 字节数量差异: {abs(packets1 * 6402 - packets2 * 6402)}")
|
||
|
||
# 计算效率(更少的数据包通常更高效)
|
||
if packets1 < packets2:
|
||
print(f" 推荐模式: 单独打包 (数据包更少)")
|
||
elif packets2 < packets1:
|
||
print(f" 推荐模式: 整体打包 (数据包更少)")
|
||
else:
|
||
print(f" 推荐模式: 两者相同")
|
||
|
||
print("=" * 80)
|
||
|
||
# ==================== 在Jupyter中直接运行 ====================
|
||
if __name__ == "__main__" or "__file__" not in globals():
|
||
print("探测器信号模拟器 V4.2 已加载")
|
||
print("\n可用函数:")
|
||
print(" 1. run_simulation_v4() - 运行完整模拟")
|
||
print(" 2. verify_output_files_v4() - 验证输出文件")
|
||
print(" 3. compare_packing_modes() - 比较两种打包模式")
|
||
print("\n示例:")
|
||
print(" # 运行模拟(单独打包模式)")
|
||
# print(f" df, simulator, events, writer = {run_simulation_v4(num_events=2, packing_mode='separate')}")
|
||
print(" ")
|
||
print(" # 运行模拟(整体打包模式)")
|
||
print(f" df, simulator, events, writer = {run_simulation_v4(packing_mode='combined')}")
|
||
print(" ")
|
||
print(" # 比较两种打包模式")
|
||
print(f" {compare_packing_modes(num_events=3)}")
|
||
print(" ")
|
||
print(" # 验证输出文件")
|
||
print(f" {verify_output_files_v4('detector_output')}") |