This commit is contained in:
2025-11-18 16:46:04 +08:00
commit 1bd91df9a1
24 changed files with 1954 additions and 0 deletions

0
app/utils/__init__.py Normal file
View File

143
app/utils/browser_api.py Normal file
View File

@@ -0,0 +1,143 @@
import datetime
import asyncio
import httpx
from loguru import logger
from utils.decorators import handle_exceptions_unified
class BrowserApi:
"""
浏览器接口
"""
def __init__(self):
self.local_url = 'http://127.0.0.1:54345'
self.headers = {'Content-Type': 'application/json'}
# 使用异步 HTTP 客户端,启用连接池和超时设置
self.client = httpx.AsyncClient(
base_url=self.local_url,
headers=self.headers,
timeout=httpx.Timeout(30.0, connect=10.0), # 总超时30秒连接超时10秒
limits=httpx.Limits(max_keepalive_connections=50, max_connections=100), # 连接池配置
)
async def __aenter__(self):
"""异步上下文管理器入口"""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""异步上下文管理器出口,关闭客户端"""
await self.aclose()
async def aclose(self):
"""关闭 HTTP 客户端"""
if self.client:
await self.client.aclose()
# 打开指纹浏览器
@handle_exceptions_unified()
async def open_browser(self, id: str, jc: int = 0):
"""
打开指纹浏览器(异步优化版本)
:param jc: 计次
:param id: 浏览器id
:return:http, pid
"""
if jc > 3:
return None, None
url = '/browser/open'
data = {
'id': id
}
try:
res = await self.client.post(url, json=data)
res.raise_for_status() # 检查 HTTP 状态码
res_data = res.json()
logger.info(f'打开指纹浏览器: {res_data}')
if not res_data.get('success'):
logger.error(f'打开指纹浏览器失败: {res_data}')
return await self.open_browser(id, jc + 1)
data = res_data.get('data')
http = data.get('http')
pid = data.get('pid')
logger.info(f'打开指纹浏览器成功: {http}, {pid}')
return http, pid
except httpx.TimeoutException as e:
logger.error(f'打开指纹浏览器超时: {e}')
if jc < 3:
return await self.open_browser(id, jc + 1)
return None, None
except httpx.RequestError as e:
logger.error(f'打开指纹浏览器请求错误: {e}')
if jc < 3:
return await self.open_browser(id, jc + 1)
return None, None
except Exception as e:
logger.error(f'打开指纹浏览器异常: {e}')
if jc < 3:
return await self.open_browser(id, jc + 1)
return None, None
# 关闭指纹浏览器
@handle_exceptions_unified()
async def close_browser(self, id: str, jc: int = 0):
"""
关闭指纹浏览器(异步优化版本)
:param jc: 计次
:param id: 浏览器id
:return:
"""
if jc > 3:
return None
url = '/browser/close'
data = {
'id': id
}
try:
res = await self.client.post(url, json=data)
res.raise_for_status() # 检查 HTTP 状态码
res_data = res.json()
logger.info(f'关闭指纹浏览器: {res_data}')
if not res_data.get('success'):
msg = res_data.get('msg', '')
# 如果浏览器正在打开中,等待后重试(不是真正的错误)
if '正在打开中' in msg or 'opening' in msg.lower():
if jc < 3:
# 等待 1-3 秒后重试(根据重试次数递增等待时间)
wait_time = (jc + 1) * 1.0 # 第1次重试等1秒第2次等2秒第3次等3秒
logger.info(f'浏览器正在打开中,等待 {wait_time} 秒后重试关闭: browser_id={id}')
await asyncio.sleep(wait_time)
return await self.close_browser(id, jc + 1)
else:
# 超过重试次数,记录警告但不作为错误
logger.warning(f'关闭指纹浏览器失败浏览器正在打开中已重试3次: browser_id={id}')
return None
else:
# 其他错误,记录为错误并重试
logger.error(f'关闭指纹浏览器失败: {res_data}')
if jc < 3:
await asyncio.sleep(0.5) # 短暂等待后重试
return await self.close_browser(id, jc + 1)
return None
logger.info(f'关闭指纹浏览器成功: browser_id={id}')
return True
except httpx.TimeoutException as e:
logger.error(f'关闭指纹浏览器超时: {e}')
if jc < 3:
await asyncio.sleep(1.0)
return await self.close_browser(id, jc + 1)
return None
except httpx.RequestError as e:
logger.error(f'关闭指纹浏览器请求错误: {e}')
if jc < 3:
await asyncio.sleep(1.0)
return await self.close_browser(id, jc + 1)
return None
except Exception as e:
logger.error(f'关闭指纹浏览器异常: {e}')
if jc < 3:
await asyncio.sleep(1.0)
return await self.close_browser(id, jc + 1)
return None
browser_api = BrowserApi()

165
app/utils/decorators.py Normal file
View File

@@ -0,0 +1,165 @@
from functools import wraps
from fastapi import HTTPException
from typing import Callable, Any, Optional
import logging
import asyncio
from tortoise.exceptions import OperationalError
# 获取日志记录器
logger = logging.getLogger(__name__)
def handle_exceptions_unified(
max_retries: int = 0,
retry_delay: float = 1.0,
status_code: int = 500,
custom_message: Optional[str] = None,
is_background_task: bool = False
):
"""
统一的异常处理装饰器
集成了所有异常处理功能:数据库重试、自定义状态码、自定义消息、后台任务处理
Args:
max_retries: 最大重试次数默认0不重试
retry_delay: 重试间隔时间默认1秒
status_code: HTTP状态码默认500
custom_message: 自定义错误消息前缀
is_background_task: 是否为后台任务不抛出HTTPException
使用方法:
# 基础异常处理
@handle_exceptions_unified()
async def basic_function(...):
pass
# 带数据库重试
@handle_exceptions_unified(max_retries=3, retry_delay=1.0)
async def db_function(...):
pass
# 自定义状态码和消息
@handle_exceptions_unified(status_code=400, custom_message="参数错误")
async def validation_function(...):
pass
# 后台任务处理
@handle_exceptions_unified(is_background_task=True)
async def background_function(...):
pass
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(*args, **kwargs) -> Any:
last_exception = None
for attempt in range(max_retries + 1):
try:
return await func(*args, **kwargs)
except HTTPException as e:
# HTTPException 直接抛出,不重试
if is_background_task:
logger.error(f"后台任务 {func.__name__} HTTPException: {str(e)}")
return False
raise
except OperationalError as e:
last_exception = e
error_msg = str(e).lower()
# 检查是否是连接相关的错误
if any(keyword in error_msg for keyword in [
'lost connection', 'connection', 'timeout',
'server has gone away', 'broken pipe'
]):
if attempt < max_retries:
logger.warning(
f"函数 {func.__name__} 数据库连接错误 (尝试 {attempt + 1}/{max_retries + 1}): {str(e)}"
)
# 等待一段时间后重试,使用指数退避
await asyncio.sleep(retry_delay * (2 ** attempt))
continue
else:
logger.error(
f"函数 {func.__name__} 数据库连接错误,已达到最大重试次数: {str(e)}"
)
else:
# 非连接错误,直接处理
logger.error(f"函数 {func.__name__} 发生数据库错误: {str(e)}")
if is_background_task:
return False
error_detail = f"{custom_message}: {str(e)}" if custom_message else f"数据库操作失败: {str(e)}"
raise HTTPException(status_code=status_code, detail=error_detail)
except Exception as e:
last_exception = e
if attempt < max_retries:
logger.warning(
f"函数 {func.__name__} 发生异常 (尝试 {attempt + 1}/{max_retries + 1}): {str(e)}"
)
await asyncio.sleep(retry_delay * (2 ** attempt))
continue
else:
logger.error(f"函数 {func.__name__} 发生异常: {str(e)}", exc_info=True)
if is_background_task:
return False
break
# 所有重试都失败了,处理最后一个异常
if is_background_task:
return False
if isinstance(last_exception, OperationalError):
error_detail = f"{custom_message}: 数据库连接失败: {str(last_exception)}" if custom_message else f"数据库连接失败: {str(last_exception)}"
else:
error_detail = f"{custom_message}: {str(last_exception)}" if custom_message else str(last_exception)
raise HTTPException(status_code=status_code, detail=error_detail)
return wrapper
return decorator
# 向后兼容的别名函数
def handle_exceptions_with_db_retry(max_retries: int = 3, retry_delay: float = 1.0):
"""
带数据库连接重试的异常处理装饰器(向后兼容)
这是 handle_exceptions_unified 的别名,保持向后兼容性
"""
return handle_exceptions_unified(max_retries=max_retries, retry_delay=retry_delay)
def handle_exceptions(func: Callable) -> Callable:
"""
基础异常处理装饰器(向后兼容)
这是 handle_exceptions_unified() 的别名,保持向后兼容性
"""
return handle_exceptions_unified()(func)
def handle_background_task_exceptions(func: Callable) -> Callable:
"""
后台任务异常处理装饰器(向后兼容)
这是 handle_exceptions_unified 的别名,保持向后兼容性
"""
return handle_exceptions_unified(is_background_task=True)(func)
def handle_exceptions_with_custom_message(message: str = "操作失败"):
"""
带自定义错误消息的异常处理装饰器(向后兼容)
这是 handle_exceptions_unified 的别名,保持向后兼容性
"""
return handle_exceptions_unified(custom_message=message)
def handle_exceptions_with_status_code(status_code: int = 500, message: str = None):
"""
带自定义状态码和错误消息的异常处理装饰器(向后兼容)
这是 handle_exceptions_unified 的别名,保持向后兼容性
"""
return handle_exceptions_unified(status_code=status_code, custom_message=message)

47
app/utils/exceptions.py Normal file
View File

@@ -0,0 +1,47 @@
import os
from fastapi import Request, status
from fastapi.exceptions import HTTPException, RequestValidationError
from fastapi.responses import JSONResponse
from .logs import getLogger
logger = getLogger(os.environ.get('APP_NAME'))
def global_http_exception_handler(request: Request, exc):
"""
全局HTTP请求处理异常
:param request: HTTP请求对象
:param exc: 本次发生的异常对象
:return:
"""
# 使用日志记录异常
logger.error(f"发生异常:{exc.detail}")
# 直接返回JSONResponse避免重新抛出异常导致循环
return JSONResponse(
status_code=exc.status_code,
content={
'err_msg': exc.detail,
'status': False
},
headers=getattr(exc, 'headers', None)
)
def global_request_exception_handler(request: Request, exc):
"""
全局请求校验异常处理函数
:param request: HTTP请求对象
:param exc: 本次发生的异常对象
:return:
"""
# 直接返回JSONResponse避免重新抛出异常
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={
'err_msg': exc.errors()[0],
'status': False
}
)

218
app/utils/logs.py Normal file
View File

@@ -0,0 +1,218 @@
import logging
import os
from logging import Logger
from concurrent_log_handler import ConcurrentRotatingFileHandler
from logging.handlers import TimedRotatingFileHandler
import gzip
import shutil
import glob
from datetime import datetime, timedelta
from pathlib import Path
def getLogger(name: str = 'root') -> Logger:
"""
创建一个按2小时滚动、支持多进程安全、自动压缩日志的 Logger
:param name: 日志器名称
:return: 单例 Logger 对象
"""
logger: Logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
if not logger.handlers:
# 控制台输出
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
# 日志目录
log_dir = "logs"
os.makedirs(log_dir, exist_ok=True)
# 日志文件路径
log_file = os.path.join(log_dir, f"{name}.log")
# 文件处理器每2小时滚动一次保留7天共84个文件支持多进程写入
file_handler = TimedRotatingFileHandler(
filename=log_file,
when='H',
interval=2, # 每2小时切一次
backupCount=84, # 保留7天 = 7 * 24 / 2 = 84个文件
encoding='utf-8',
delay=False,
utc=False # 你也可以改成 True 表示按 UTC 时间切
)
# 设置 Formatter - 简化格式,去掉路径信息
formatter = logging.Formatter(
fmt="{name}{levelname} {asctime} {message}",
datefmt="%Y-%m-%d %H:%M:%S",
style="{"
)
console_formatter = logging.Formatter(
fmt="{levelname} {asctime} {message}",
datefmt="%Y-%m-%d %H:%M:%S",
style="{"
)
file_handler.setFormatter(formatter)
console_handler.setFormatter(console_formatter)
logger.addHandler(console_handler)
logger.addHandler(file_handler)
# 添加压缩功能(在第一次创建 logger 时执行一次)
_compress_old_logs(log_dir, name)
return logger
def _compress_old_logs(log_dir: str, name: str):
"""
将旧日志压缩成 .gz 格式
"""
pattern = os.path.join(log_dir, f"{name}.log.*")
for filepath in glob.glob(pattern):
if filepath.endswith('.gz'):
continue
try:
with open(filepath, 'rb') as f_in:
with gzip.open(filepath + '.gz', 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
os.remove(filepath)
except Exception as e:
print(f"日志压缩失败: {filepath}, 原因: {e}")
def compress_old_logs(log_dir: str = None, name: str = "root"):
"""
压缩旧的日志文件(公共接口)
Args:
log_dir: 日志目录,如果不指定则使用默认目录
name: 日志器名称
"""
if log_dir is None:
log_dir = "logs"
_compress_old_logs(log_dir, name)
def log_api_call(logger: Logger, user_id: str = None, endpoint: str = None, method: str = None, params: dict = None, response_status: int = None, client_ip: str = None):
"""
记录API调用信息包含用户ID、接口路径、请求方法、参数、响应状态和来源IP
Args:
logger: 日志器对象
user_id: 用户ID
endpoint: 接口路径
method: 请求方法 (GET, POST, PUT, DELETE等)
params: 请求参数
response_status: 响应状态码
client_ip: 客户端IP地址
"""
try:
# 构建日志信息
log_parts = []
if user_id:
log_parts.append(f"用户={user_id}")
if client_ip:
log_parts.append(f"IP={client_ip}")
if method and endpoint:
log_parts.append(f"{method} {endpoint}")
elif endpoint:
log_parts.append(f"接口={endpoint}")
if params:
# 过滤敏感信息
safe_params = {k: v for k, v in params.items()
if k.lower() not in ['password', 'token', 'secret', 'key']}
if safe_params:
log_parts.append(f"参数={safe_params}")
if response_status:
log_parts.append(f"状态码={response_status}")
if log_parts:
log_message = " ".join(log_parts)
logger.info(log_message)
except Exception as e:
logger.error(f"记录API调用日志失败: {e}")
def delete_old_compressed_logs(log_dir: str = None, days: int = 7):
"""
删除超过指定天数的压缩日志文件
Args:
log_dir: 日志目录,如果不指定则使用默认目录
days: 保留天数默认7天
"""
try:
if log_dir is None:
log_dir = "logs"
log_path = Path(log_dir)
if not log_path.exists():
return
# 计算截止时间
cutoff_time = datetime.now() - timedelta(days=days)
# 获取所有压缩日志文件
gz_files = [f for f in log_path.iterdir()
if f.is_file() and f.name.endswith('.log.gz')]
deleted_count = 0
for gz_file in gz_files:
# 获取文件修改时间
file_mtime = datetime.fromtimestamp(gz_file.stat().st_mtime)
# 如果文件超过保留期限,删除它
if file_mtime < cutoff_time:
gz_file.unlink()
print(f"删除旧压缩日志文件: {gz_file}")
deleted_count += 1
if deleted_count > 0:
print(f"总共删除了 {deleted_count} 个旧压缩日志文件")
except Exception as e:
print(f"删除旧压缩日志文件失败: {e}")
if __name__ == '__main__':
logger = getLogger('WebAPI')
# 基础日志测试
logger.info("系统启动")
logger.debug("调试信息")
logger.warning("警告信息")
logger.error("错误信息")
# API调用日志测试
log_api_call(
logger=logger,
user_id="user123",
endpoint="/api/users/info",
method="GET",
params={"id": 123, "fields": ["name", "email"]},
response_status=200,
client_ip="192.168.1.100"
)
log_api_call(
logger=logger,
user_id="user456",
endpoint="/api/users/login",
method="POST",
params={"username": "test", "password": "hidden"}, # password会被过滤
response_status=401,
client_ip="10.0.0.50"
)
# 单例验证
logger2 = getLogger('WebAPI')
print(f"Logger单例验证: {id(logger) == id(logger2)}")

8
app/utils/out_base.py Normal file
View File

@@ -0,0 +1,8 @@
from pydantic import BaseModel, Field
class CommonOut(BaseModel):
"""操作结果详情模型"""
code: int = Field(200, description='状态码')
message: str = Field('成功', description='提示信息')
count: int = Field(0, description='操作影响的记录数')

96
app/utils/redis_tool.py Normal file
View File

@@ -0,0 +1,96 @@
import redis
from loguru import logger
class RedisClient:
def __init__(self, host: str = 'localhost', port: int = 6379, password: str = None):
self.host = host
self.port = port
self.password = password
self.browser_client = None
self.task_client = None
self.cache_client = None
self.ok_client = None
self.init()
# 初始化
def init(self):
"""
初始化Redis客户端
:return:
"""
if self.browser_client is None:
self.browser_client = redis.Redis(host=self.host, port=self.port, password=self.password, db=0,
decode_responses=True)
if self.task_client is None:
self.task_client = redis.Redis(host=self.host, port=self.port, password=self.password, db=1,
decode_responses=True)
if self.cache_client is None:
self.cache_client = redis.Redis(host=self.host, port=self.port, password=self.password, db=2,
decode_responses=True)
if self.ok_client is None:
self.ok_client = redis.Redis(host=self.host, port=self.port, password=self.password, db=3,
decode_responses=True)
logger.info("Redis连接已初始化")
# 关闭连接
def close(self):
self.browser_client.close()
self.task_client.close()
self.cache_client.close()
self.ok_client.close()
logger.info("Redis连接已关闭")
"""browser_client"""
# 写入浏览器信息
async def set_browser(self, browser_id: str, data: dict):
try:
# 处理None值将其转换为空字符串
processed_data = {}
for key, value in data.items():
if value is None:
processed_data[key] = ""
else:
processed_data[key] = value
self.browser_client.hset(browser_id, mapping=processed_data)
logger.info(f"写入浏览器信息: {browser_id} - {processed_data}")
return True
except Exception as e:
logger.error(f"写入浏览器信息失败: {browser_id} - {e}")
return False
# 获取浏览器信息
async def get_browser(self, browser_id: str = None):
try:
if browser_id is None:
# 获取全部数据
data = self.browser_client.hgetall()
else:
data = self.browser_client.hgetall(browser_id)
logger.info(f"获取浏览器信息: {browser_id} - {data}")
return data
except Exception as e:
logger.error(f"获取浏览器信息失败: {browser_id} - {e}")
async def main():
host = '183.66.27.14'
port = 50086
password = 'redis_AdJsBP'
redis_client = RedisClient(host, port, password)
# await redis_client.set_browser('9eac7f95ca2d47359ace4083a566e119', {'status': 'online', 'current_task_id': None})
await redis_client.get_browser('9eac7f95ca2d47359ace4083a566e119')
# 关闭连接
redis_client.close()
if __name__ == '__main__':
import asyncio
asyncio.run(main())

177
app/utils/session_store.py Normal file
View File

@@ -0,0 +1,177 @@
import os
import json
import threading
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, List
from loguru import logger
class SessionStore:
"""
会话持久化存储(日志文件版 + 内存缓存)
优化方案:
1. 使用日志文件记录(追加模式,性能好,不会因为文件变大而变慢)
2. 在内存中保留最近的会话记录(用于快速查询)
3. 定期清理过期的内存记录保留最近1小时或最多1000条
"""
def __init__(self, file_path: str = 'logs/sessions.log', enable_log: bool = True, max_memory_records: int = 1000):
"""
初始化会话存储。
Args:
file_path (str): 日志文件路径(默认 logs/sessions.log
enable_log (bool): 是否启用日志记录False 则不记录到文件
max_memory_records (int): 内存中保留的最大记录数默认1000
"""
self.file_path = file_path
self.enable_log = enable_log
self.max_memory_records = max_memory_records
self._lock = threading.Lock()
# 内存中的会话记录 {pid: record}
self._memory_cache: Dict[int, Dict[str, Any]] = {}
# 记录创建时间,用于清理过期记录
self._cache_timestamps: Dict[int, datetime] = {}
if enable_log:
os.makedirs(os.path.dirname(file_path), exist_ok=True)
def _write_log(self, action: str, record: Dict[str, Any]) -> None:
"""
写入日志文件(追加模式,性能好)
Args:
action (str): 操作类型CREATE/UPDATE
record (Dict[str, Any]): 会话记录
"""
if not self.enable_log:
return
try:
with self._lock:
log_line = json.dumps({
'action': action,
'timestamp': datetime.now().isoformat(),
'data': record
}, ensure_ascii=False)
with open(self.file_path, 'a', encoding='utf-8') as f:
f.write(log_line + '\n')
except Exception as e:
# 静默处理日志写入错误,避免影响主流程
logger.debug(f"写入会话日志失败: {e}")
def _cleanup_old_cache(self) -> None:
"""
清理过期的内存缓存记录
- 保留最近1小时的记录
- 最多保留 max_memory_records 条记录
"""
now = datetime.now()
expire_time = now - timedelta(hours=1)
# 清理过期记录
expired_pids = [
pid for pid, timestamp in self._cache_timestamps.items()
if timestamp < expire_time
]
for pid in expired_pids:
self._memory_cache.pop(pid, None)
self._cache_timestamps.pop(pid, None)
# 如果记录数仍然超过限制,删除最旧的记录
if len(self._memory_cache) > self.max_memory_records:
# 按时间戳排序,删除最旧的
sorted_pids = sorted(
self._cache_timestamps.items(),
key=lambda x: x[1]
)
# 计算需要删除的数量
to_remove = len(self._memory_cache) - self.max_memory_records
for pid, _ in sorted_pids[:to_remove]:
self._memory_cache.pop(pid, None)
self._cache_timestamps.pop(pid, None)
def create_session(self, record: Dict[str, Any]) -> None:
"""
创建新会话记录。
Args:
record (Dict[str, Any]): 会话信息字典
"""
record = dict(record)
record.setdefault('created_at', datetime.now().isoformat())
pid = record.get('pid')
if pid is not None:
with self._lock:
# 保存到内存缓存
self._memory_cache[pid] = record
self._cache_timestamps[pid] = datetime.now()
# 清理过期记录
self._cleanup_old_cache()
# 写入日志文件(追加模式,性能好)
self._write_log('CREATE', record)
def update_session(self, pid: int, updates: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""
按 PID 更新会话记录。
Args:
pid (int): 进程ID
updates (Dict[str, Any]): 更新字段字典
Returns:
Optional[Dict[str, Any]]: 更新后的会话记录
"""
with self._lock:
# 从内存缓存获取
record = self._memory_cache.get(pid)
if record:
record.update(updates)
record.setdefault('updated_at', datetime.now().isoformat())
self._cache_timestamps[pid] = datetime.now()
else:
# 如果内存中没有,创建一个新记录
record = {'pid': pid}
record.update(updates)
record.setdefault('created_at', datetime.now().isoformat())
record.setdefault('updated_at', datetime.now().isoformat())
self._memory_cache[pid] = record
self._cache_timestamps[pid] = datetime.now()
if record:
# 写入日志文件
self._write_log('UPDATE', record)
return record
def get_session_by_pid(self, pid: int) -> Optional[Dict[str, Any]]:
"""
按 PID 查询会话记录(仅从内存缓存查询,性能好)
Args:
pid (int): 进程ID
Returns:
Optional[Dict[str, Any]]: 会话记录
"""
with self._lock:
return self._memory_cache.get(pid)
def list_sessions(self, status: Optional[int] = None) -> List[Dict[str, Any]]:
"""
列出会话记录,可按状态过滤(仅从内存缓存查询)
Args:
status (Optional[int]): 状态码过滤(如 100 运行中、200 已结束、500 失败)
Returns:
List[Dict[str, Any]]: 会话记录列表
"""
with self._lock:
records = list(self._memory_cache.values())
if status is None:
return records
return [r for r in records if r.get('status') == status]

56
app/utils/time_tool.py Normal file
View File

@@ -0,0 +1,56 @@
from datetime import datetime, timedelta, timezone
from pydantic import BaseModel, field_serializer
CN_TZ = timezone(timedelta(hours=8))
def now_cn() -> datetime:
"""
获取中国时区的当前时间
返回带有中国时区信息的 datetime 对象
"""
return datetime.now(CN_TZ)
def parse_time(val: str | int, is_end: bool = False) -> datetime:
"""
将传入的字符串或时间戳解析为中国时区的 datetime用于数据库查询时间比较。
支持格式:
- "YYYY-MM-DD"
- "YYYY-MM-DD HH:mm:ss"
- 10 位时间戳(秒)
- 13 位时间戳(毫秒)
"""
dt_cn: datetime
if isinstance(val, int) or (isinstance(val, str) and val.isdigit()):
ts = int(val)
# 根据量级判断是秒还是毫秒
if ts >= 10**12:
dt_cn = datetime.fromtimestamp(ts / 1000, CN_TZ)
else:
dt_cn = datetime.fromtimestamp(ts, CN_TZ)
else:
try:
dt_cn = datetime.strptime(val, "%Y-%m-%d").replace(tzinfo=CN_TZ)
if is_end:
dt_cn = dt_cn.replace(hour=23, minute=59, second=59, microsecond=999999)
except ValueError:
try:
dt_cn = datetime.strptime(val, "%Y-%m-%d %H:%M:%S").replace(tzinfo=CN_TZ)
except ValueError:
raise ValueError("时间格式错误,支持 'YYYY-MM-DD''YYYY-MM-DD HH:mm:ss' 或 10/13位时间戳")
# 与 ORM 配置保持一致use_tz=False返回本地时区的“朴素”时间
return dt_cn.replace(tzinfo=None)
# 自动把 datetime 序列化为 13位时间戳的基类
class TimestampModel(BaseModel):
"""自动把 datetime 序列化为 13位时间戳的基类"""
model_config = {"arbitrary_types_allowed": True}
@field_serializer("*", when_used="json", check_fields=False) # "*" 表示作用于所有字段
def serialize_datetime(self, value):
if isinstance(value, datetime):
return int(value.timestamp()*1000) # 转成 13 位 int 时间戳
return value