ETest-Vue-FastAPI/ruoyi-fastapi-backend/module_admin/dao/warehouse_receipt_dao.py

206 lines
8.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from sqlalchemy import select, func, and_, or_, update, delete
from sqlalchemy.ext.asyncio import AsyncSession
from module_admin.entity.do.warehouse_receipt_do import WarehouseReceipt
from module_admin.entity.vo.warehouse_receipt_vo import WarehouseReceiptPageQueryModel, WarehouseReceiptModel
from typing import List, Optional
from datetime import datetime
class WarehouseReceiptDao:
"""
入库单数据访问层
"""
@classmethod
async def get_receipt_by_id(cls, db: AsyncSession, receipt_id: int):
"""
根据入库单ID获取入库单信息
"""
query = select(WarehouseReceipt).where(
WarehouseReceipt.receipt_id == receipt_id,
WarehouseReceipt.del_flag == '0'
)
result = await db.execute(query)
return result.scalars().first()
@classmethod
async def get_receipt_by_no(cls, db: AsyncSession, receipt_no: str):
"""
根据入库单号获取入库单信息
"""
query = select(WarehouseReceipt).where(
WarehouseReceipt.receipt_no == receipt_no,
WarehouseReceipt.del_flag == '0'
)
result = await db.execute(query)
return result.scalars().first()
@classmethod
async def get_receipt_list(cls, db: AsyncSession, query_object: WarehouseReceiptPageQueryModel, is_page: bool = False):
"""
获取入库单列表
"""
print(f"DEBUG DAO: page_num={query_object.page_num}, page_size={query_object.page_size}, is_page={is_page}")
query = select(WarehouseReceipt).where(WarehouseReceipt.del_flag == '0')
# 条件筛选
if query_object.receipt_no:
query = query.where(WarehouseReceipt.receipt_no.like(f'%{query_object.receipt_no}%'))
if query_object.client_unit:
query = query.where(WarehouseReceipt.client_unit.like(f'%{query_object.client_unit}%'))
if query_object.client_contact:
query = query.where(WarehouseReceipt.client_contact.like(f'%{query_object.client_contact}%'))
if query_object.receipt_date_start and query_object.receipt_date_end:
query = query.where(and_(
WarehouseReceipt.receipt_date >= query_object.receipt_date_start,
WarehouseReceipt.receipt_date <= query_object.receipt_date_end
))
if query_object.purpose:
query = query.where(WarehouseReceipt.purpose == query_object.purpose)
if query_object.status:
query = query.where(WarehouseReceipt.status == query_object.status)
# 排序
query = query.order_by(WarehouseReceipt.create_time.desc())
# 分页
if is_page:
offset = (query_object.page_num - 1) * query_object.page_size
query = query.offset(offset).limit(query_object.page_size)
result = await db.execute(query)
return result.scalars().all()
@classmethod
async def get_receipt_count(cls, db: AsyncSession, query_object: WarehouseReceiptPageQueryModel):
"""
获取入库单总数
"""
query = select(func.count()).select_from(WarehouseReceipt).where(WarehouseReceipt.del_flag == '0')
# 条件筛选(与列表查询保持一致)
if query_object.receipt_no:
query = query.where(WarehouseReceipt.receipt_no.like(f'%{query_object.receipt_no}%'))
if query_object.client_unit:
query = query.where(WarehouseReceipt.client_unit.like(f'%{query_object.client_unit}%'))
if query_object.client_contact:
query = query.where(WarehouseReceipt.client_contact.like(f'%{query_object.client_contact}%'))
if query_object.receipt_date_start and query_object.receipt_date_end:
query = query.where(and_(
WarehouseReceipt.receipt_date >= query_object.receipt_date_start,
WarehouseReceipt.receipt_date <= query_object.receipt_date_end
))
if query_object.purpose:
query = query.where(WarehouseReceipt.purpose == query_object.purpose)
if query_object.status:
query = query.where(WarehouseReceipt.status == query_object.status)
result = await db.execute(query)
return result.scalar()
@classmethod
async def add_receipt(cls, db: AsyncSession, receipt: WarehouseReceipt):
"""
新增入库单
"""
db.add(receipt)
await db.flush()
return receipt
@classmethod
async def edit_receipt(cls, db: AsyncSession, receipt: WarehouseReceiptModel):
"""
编辑入库单
"""
update_data = receipt.model_dump(exclude_unset=True, exclude={'receipt_id', 'samples', 'sample_count'})
stmt = update(WarehouseReceipt).where(WarehouseReceipt.receipt_id == receipt.receipt_id).values(**update_data)
await db.execute(stmt)
await db.flush()
@classmethod
async def delete_receipt(cls, db: AsyncSession, receipt_ids: List[int]):
"""
删除入库单(逻辑删除)
"""
stmt = update(WarehouseReceipt).where(WarehouseReceipt.receipt_id.in_(receipt_ids)).values(del_flag='1')
await db.execute(stmt)
await db.flush()
@classmethod
async def generate_receipt_no(cls, db: AsyncSession, year: int, retry_count: int = 0) -> str:
"""
生成入库单号
格式2025RJ001 (年份+RJ+三位序号)
retry_count: 重试次数,用于生成不同的序号
注意需要同时考虑旧格式2026内检xxx和新格式2026RJxxx
使用数据库行锁防止并发冲突
"""
type_code = 'RJ' # 入库拼音首字母,避免中文
try:
# 使用 FOR UPDATE 锁定查询,防止并发冲突
# 先查询当前年份最大的入库单号(包括新旧两种格式)
from sqlalchemy import text
# 使用原生 SQL 带 FOR UPDATE 锁定
lock_sql = text(f"""
SELECT receipt_no FROM warehouse_receipt
WHERE (receipt_no LIKE '{year}{type_code}%' OR receipt_no LIKE '{year}内检%')
AND del_flag = '0'
ORDER BY receipt_id DESC
LIMIT 1
FOR UPDATE
""")
try:
result = await db.execute(lock_sql)
max_no_row = result.fetchone()
max_no = max_no_row[0] if max_no_row else None
except Exception:
# 如果 FOR UPDATE 失败(如某些数据库不支持),回退到普通查询
query = select(WarehouseReceipt.receipt_no).where(
or_(
WarehouseReceipt.receipt_no.like(f'{year}{type_code}%'),
WarehouseReceipt.receipt_no.like(f'{year}内检%')
),
WarehouseReceipt.del_flag == '0'
).order_by(WarehouseReceipt.receipt_id.desc()).limit(1)
result = await db.execute(query)
max_no = result.scalar()
if max_no:
# 提取序号部分
try:
# 新格式: 2026RJ001 -> 提取 001
if f'{year}{type_code}' in max_no:
seq_str = max_no[len(str(year)) + len(type_code):]
seq = int(seq_str)
# 旧格式: 2026内检001 -> 提取 001
elif f'{year}内检' in max_no:
seq_str = max_no[len(str(year)) + 2:] # "内检" 是2个字符
seq = int(seq_str)
else:
seq = 0
except (ValueError, IndexError):
seq = 0
else:
seq = 0
# 如果有重试,增加序号避免冲突
seq = seq + 1 + retry_count
return f'{year}{type_code}{seq:03d}'
except Exception as e:
# 如果出错,返回带重试序号的默认编号
import time
timestamp = int(time.time()) % 1000 + retry_count
return f'{year}{type_code}{timestamp:03d}'