0.0.1
This commit is contained in:
0
app/utils/__init__.py
Normal file
0
app/utils/__init__.py
Normal file
143
app/utils/browser_api.py
Normal file
143
app/utils/browser_api.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import datetime
|
||||
import asyncio
|
||||
import httpx
|
||||
from loguru import logger
|
||||
from utils.decorators import handle_exceptions_unified
|
||||
|
||||
|
||||
class BrowserApi:
|
||||
"""
|
||||
浏览器接口
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.local_url = 'http://127.0.0.1:54345'
|
||||
self.headers = {'Content-Type': 'application/json'}
|
||||
# 使用异步 HTTP 客户端,启用连接池和超时设置
|
||||
self.client = httpx.AsyncClient(
|
||||
base_url=self.local_url,
|
||||
headers=self.headers,
|
||||
timeout=httpx.Timeout(30.0, connect=10.0), # 总超时30秒,连接超时10秒
|
||||
limits=httpx.Limits(max_keepalive_connections=50, max_connections=100), # 连接池配置
|
||||
)
|
||||
|
||||
async def __aenter__(self):
|
||||
"""异步上下文管理器入口"""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""异步上下文管理器出口,关闭客户端"""
|
||||
await self.aclose()
|
||||
|
||||
async def aclose(self):
|
||||
"""关闭 HTTP 客户端"""
|
||||
if self.client:
|
||||
await self.client.aclose()
|
||||
|
||||
# 打开指纹浏览器
|
||||
@handle_exceptions_unified()
|
||||
async def open_browser(self, id: str, jc: int = 0):
|
||||
"""
|
||||
打开指纹浏览器(异步优化版本)
|
||||
:param jc: 计次
|
||||
:param id: 浏览器id
|
||||
:return:http, pid
|
||||
"""
|
||||
if jc > 3:
|
||||
return None, None
|
||||
url = '/browser/open'
|
||||
data = {
|
||||
'id': id
|
||||
}
|
||||
try:
|
||||
res = await self.client.post(url, json=data)
|
||||
res.raise_for_status() # 检查 HTTP 状态码
|
||||
res_data = res.json()
|
||||
logger.info(f'打开指纹浏览器: {res_data}')
|
||||
if not res_data.get('success'):
|
||||
logger.error(f'打开指纹浏览器失败: {res_data}')
|
||||
return await self.open_browser(id, jc + 1)
|
||||
data = res_data.get('data')
|
||||
http = data.get('http')
|
||||
pid = data.get('pid')
|
||||
logger.info(f'打开指纹浏览器成功: {http}, {pid}')
|
||||
return http, pid
|
||||
except httpx.TimeoutException as e:
|
||||
logger.error(f'打开指纹浏览器超时: {e}')
|
||||
if jc < 3:
|
||||
return await self.open_browser(id, jc + 1)
|
||||
return None, None
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f'打开指纹浏览器请求错误: {e}')
|
||||
if jc < 3:
|
||||
return await self.open_browser(id, jc + 1)
|
||||
return None, None
|
||||
except Exception as e:
|
||||
logger.error(f'打开指纹浏览器异常: {e}')
|
||||
if jc < 3:
|
||||
return await self.open_browser(id, jc + 1)
|
||||
return None, None
|
||||
|
||||
# 关闭指纹浏览器
|
||||
@handle_exceptions_unified()
|
||||
async def close_browser(self, id: str, jc: int = 0):
|
||||
"""
|
||||
关闭指纹浏览器(异步优化版本)
|
||||
:param jc: 计次
|
||||
:param id: 浏览器id
|
||||
:return:
|
||||
"""
|
||||
if jc > 3:
|
||||
return None
|
||||
url = '/browser/close'
|
||||
data = {
|
||||
'id': id
|
||||
}
|
||||
try:
|
||||
res = await self.client.post(url, json=data)
|
||||
res.raise_for_status() # 检查 HTTP 状态码
|
||||
res_data = res.json()
|
||||
logger.info(f'关闭指纹浏览器: {res_data}')
|
||||
if not res_data.get('success'):
|
||||
msg = res_data.get('msg', '')
|
||||
# 如果浏览器正在打开中,等待后重试(不是真正的错误)
|
||||
if '正在打开中' in msg or 'opening' in msg.lower():
|
||||
if jc < 3:
|
||||
# 等待 1-3 秒后重试(根据重试次数递增等待时间)
|
||||
wait_time = (jc + 1) * 1.0 # 第1次重试等1秒,第2次等2秒,第3次等3秒
|
||||
logger.info(f'浏览器正在打开中,等待 {wait_time} 秒后重试关闭: browser_id={id}')
|
||||
await asyncio.sleep(wait_time)
|
||||
return await self.close_browser(id, jc + 1)
|
||||
else:
|
||||
# 超过重试次数,记录警告但不作为错误
|
||||
logger.warning(f'关闭指纹浏览器失败(浏览器正在打开中,已重试3次): browser_id={id}')
|
||||
return None
|
||||
else:
|
||||
# 其他错误,记录为错误并重试
|
||||
logger.error(f'关闭指纹浏览器失败: {res_data}')
|
||||
if jc < 3:
|
||||
await asyncio.sleep(0.5) # 短暂等待后重试
|
||||
return await self.close_browser(id, jc + 1)
|
||||
return None
|
||||
logger.info(f'关闭指纹浏览器成功: browser_id={id}')
|
||||
return True
|
||||
except httpx.TimeoutException as e:
|
||||
logger.error(f'关闭指纹浏览器超时: {e}')
|
||||
if jc < 3:
|
||||
await asyncio.sleep(1.0)
|
||||
return await self.close_browser(id, jc + 1)
|
||||
return None
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f'关闭指纹浏览器请求错误: {e}')
|
||||
if jc < 3:
|
||||
await asyncio.sleep(1.0)
|
||||
return await self.close_browser(id, jc + 1)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f'关闭指纹浏览器异常: {e}')
|
||||
if jc < 3:
|
||||
await asyncio.sleep(1.0)
|
||||
return await self.close_browser(id, jc + 1)
|
||||
return None
|
||||
|
||||
browser_api = BrowserApi()
|
||||
165
app/utils/decorators.py
Normal file
165
app/utils/decorators.py
Normal file
@@ -0,0 +1,165 @@
|
||||
from functools import wraps
|
||||
from fastapi import HTTPException
|
||||
from typing import Callable, Any, Optional
|
||||
import logging
|
||||
import asyncio
|
||||
from tortoise.exceptions import OperationalError
|
||||
|
||||
# 获取日志记录器
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def handle_exceptions_unified(
|
||||
max_retries: int = 0,
|
||||
retry_delay: float = 1.0,
|
||||
status_code: int = 500,
|
||||
custom_message: Optional[str] = None,
|
||||
is_background_task: bool = False
|
||||
):
|
||||
"""
|
||||
统一的异常处理装饰器
|
||||
|
||||
集成了所有异常处理功能:数据库重试、自定义状态码、自定义消息、后台任务处理
|
||||
|
||||
Args:
|
||||
max_retries: 最大重试次数,默认0(不重试)
|
||||
retry_delay: 重试间隔时间(秒),默认1秒
|
||||
status_code: HTTP状态码,默认500
|
||||
custom_message: 自定义错误消息前缀
|
||||
is_background_task: 是否为后台任务(不抛出HTTPException)
|
||||
|
||||
使用方法:
|
||||
# 基础异常处理
|
||||
@handle_exceptions_unified()
|
||||
async def basic_function(...):
|
||||
pass
|
||||
|
||||
# 带数据库重试
|
||||
@handle_exceptions_unified(max_retries=3, retry_delay=1.0)
|
||||
async def db_function(...):
|
||||
pass
|
||||
|
||||
# 自定义状态码和消息
|
||||
@handle_exceptions_unified(status_code=400, custom_message="参数错误")
|
||||
async def validation_function(...):
|
||||
pass
|
||||
|
||||
# 后台任务处理
|
||||
@handle_exceptions_unified(is_background_task=True)
|
||||
async def background_function(...):
|
||||
pass
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs) -> Any:
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except HTTPException as e:
|
||||
# HTTPException 直接抛出,不重试
|
||||
if is_background_task:
|
||||
logger.error(f"后台任务 {func.__name__} HTTPException: {str(e)}")
|
||||
return False
|
||||
raise
|
||||
except OperationalError as e:
|
||||
last_exception = e
|
||||
error_msg = str(e).lower()
|
||||
|
||||
# 检查是否是连接相关的错误
|
||||
if any(keyword in error_msg for keyword in [
|
||||
'lost connection', 'connection', 'timeout',
|
||||
'server has gone away', 'broken pipe'
|
||||
]):
|
||||
if attempt < max_retries:
|
||||
logger.warning(
|
||||
f"函数 {func.__name__} 数据库连接错误 (尝试 {attempt + 1}/{max_retries + 1}): {str(e)}"
|
||||
)
|
||||
# 等待一段时间后重试,使用指数退避
|
||||
await asyncio.sleep(retry_delay * (2 ** attempt))
|
||||
continue
|
||||
else:
|
||||
logger.error(
|
||||
f"函数 {func.__name__} 数据库连接错误,已达到最大重试次数: {str(e)}"
|
||||
)
|
||||
else:
|
||||
# 非连接错误,直接处理
|
||||
logger.error(f"函数 {func.__name__} 发生数据库错误: {str(e)}")
|
||||
if is_background_task:
|
||||
return False
|
||||
error_detail = f"{custom_message}: {str(e)}" if custom_message else f"数据库操作失败: {str(e)}"
|
||||
raise HTTPException(status_code=status_code, detail=error_detail)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
if attempt < max_retries:
|
||||
logger.warning(
|
||||
f"函数 {func.__name__} 发生异常 (尝试 {attempt + 1}/{max_retries + 1}): {str(e)}"
|
||||
)
|
||||
await asyncio.sleep(retry_delay * (2 ** attempt))
|
||||
continue
|
||||
else:
|
||||
logger.error(f"函数 {func.__name__} 发生异常: {str(e)}", exc_info=True)
|
||||
if is_background_task:
|
||||
return False
|
||||
break
|
||||
|
||||
# 所有重试都失败了,处理最后一个异常
|
||||
if is_background_task:
|
||||
return False
|
||||
|
||||
if isinstance(last_exception, OperationalError):
|
||||
error_detail = f"{custom_message}: 数据库连接失败: {str(last_exception)}" if custom_message else f"数据库连接失败: {str(last_exception)}"
|
||||
else:
|
||||
error_detail = f"{custom_message}: {str(last_exception)}" if custom_message else str(last_exception)
|
||||
|
||||
raise HTTPException(status_code=status_code, detail=error_detail)
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
# 向后兼容的别名函数
|
||||
def handle_exceptions_with_db_retry(max_retries: int = 3, retry_delay: float = 1.0):
|
||||
"""
|
||||
带数据库连接重试的异常处理装饰器(向后兼容)
|
||||
|
||||
这是 handle_exceptions_unified 的别名,保持向后兼容性
|
||||
"""
|
||||
return handle_exceptions_unified(max_retries=max_retries, retry_delay=retry_delay)
|
||||
|
||||
|
||||
def handle_exceptions(func: Callable) -> Callable:
|
||||
"""
|
||||
基础异常处理装饰器(向后兼容)
|
||||
|
||||
这是 handle_exceptions_unified() 的别名,保持向后兼容性
|
||||
"""
|
||||
return handle_exceptions_unified()(func)
|
||||
|
||||
|
||||
def handle_background_task_exceptions(func: Callable) -> Callable:
|
||||
"""
|
||||
后台任务异常处理装饰器(向后兼容)
|
||||
|
||||
这是 handle_exceptions_unified 的别名,保持向后兼容性
|
||||
"""
|
||||
return handle_exceptions_unified(is_background_task=True)(func)
|
||||
|
||||
|
||||
def handle_exceptions_with_custom_message(message: str = "操作失败"):
|
||||
"""
|
||||
带自定义错误消息的异常处理装饰器(向后兼容)
|
||||
|
||||
这是 handle_exceptions_unified 的别名,保持向后兼容性
|
||||
"""
|
||||
return handle_exceptions_unified(custom_message=message)
|
||||
|
||||
|
||||
def handle_exceptions_with_status_code(status_code: int = 500, message: str = None):
|
||||
"""
|
||||
带自定义状态码和错误消息的异常处理装饰器(向后兼容)
|
||||
|
||||
这是 handle_exceptions_unified 的别名,保持向后兼容性
|
||||
"""
|
||||
return handle_exceptions_unified(status_code=status_code, custom_message=message)
|
||||
47
app/utils/exceptions.py
Normal file
47
app/utils/exceptions.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import os
|
||||
from fastapi import Request, status
|
||||
from fastapi.exceptions import HTTPException, RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from .logs import getLogger
|
||||
|
||||
logger = getLogger(os.environ.get('APP_NAME'))
|
||||
|
||||
|
||||
def global_http_exception_handler(request: Request, exc):
|
||||
"""
|
||||
全局HTTP请求处理异常
|
||||
:param request: HTTP请求对象
|
||||
:param exc: 本次发生的异常对象
|
||||
:return:
|
||||
"""
|
||||
|
||||
# 使用日志记录异常
|
||||
logger.error(f"发生异常:{exc.detail}")
|
||||
|
||||
# 直接返回JSONResponse,避免重新抛出异常导致循环
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={
|
||||
'err_msg': exc.detail,
|
||||
'status': False
|
||||
},
|
||||
headers=getattr(exc, 'headers', None)
|
||||
)
|
||||
|
||||
|
||||
def global_request_exception_handler(request: Request, exc):
|
||||
"""
|
||||
全局请求校验异常处理函数
|
||||
:param request: HTTP请求对象
|
||||
:param exc: 本次发生的异常对象
|
||||
:return:
|
||||
"""
|
||||
|
||||
# 直接返回JSONResponse,避免重新抛出异常
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={
|
||||
'err_msg': exc.errors()[0],
|
||||
'status': False
|
||||
}
|
||||
)
|
||||
218
app/utils/logs.py
Normal file
218
app/utils/logs.py
Normal file
@@ -0,0 +1,218 @@
|
||||
import logging
|
||||
import os
|
||||
from logging import Logger
|
||||
from concurrent_log_handler import ConcurrentRotatingFileHandler
|
||||
from logging.handlers import TimedRotatingFileHandler
|
||||
import gzip
|
||||
import shutil
|
||||
import glob
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def getLogger(name: str = 'root') -> Logger:
|
||||
"""
|
||||
创建一个按2小时滚动、支持多进程安全、自动压缩日志的 Logger
|
||||
:param name: 日志器名称
|
||||
:return: 单例 Logger 对象
|
||||
"""
|
||||
logger: Logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
if not logger.handlers:
|
||||
# 控制台输出
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.DEBUG)
|
||||
|
||||
# 日志目录
|
||||
log_dir = "logs"
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
# 日志文件路径
|
||||
log_file = os.path.join(log_dir, f"{name}.log")
|
||||
|
||||
# 文件处理器:每2小时滚动一次,保留7天,共84个文件,支持多进程写入
|
||||
file_handler = TimedRotatingFileHandler(
|
||||
filename=log_file,
|
||||
when='H',
|
||||
interval=2, # 每2小时切一次
|
||||
backupCount=84, # 保留7天 = 7 * 24 / 2 = 84个文件
|
||||
encoding='utf-8',
|
||||
delay=False,
|
||||
utc=False # 你也可以改成 True 表示按 UTC 时间切
|
||||
)
|
||||
|
||||
# 设置 Formatter - 简化格式,去掉路径信息
|
||||
formatter = logging.Formatter(
|
||||
fmt="【{name}】{levelname} {asctime} {message}",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
style="{"
|
||||
)
|
||||
console_formatter = logging.Formatter(
|
||||
fmt="{levelname} {asctime} {message}",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
style="{"
|
||||
)
|
||||
|
||||
file_handler.setFormatter(formatter)
|
||||
console_handler.setFormatter(console_formatter)
|
||||
|
||||
logger.addHandler(console_handler)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
# 添加压缩功能(在第一次创建 logger 时执行一次)
|
||||
_compress_old_logs(log_dir, name)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def _compress_old_logs(log_dir: str, name: str):
|
||||
"""
|
||||
将旧日志压缩成 .gz 格式
|
||||
"""
|
||||
pattern = os.path.join(log_dir, f"{name}.log.*")
|
||||
for filepath in glob.glob(pattern):
|
||||
if filepath.endswith('.gz'):
|
||||
continue
|
||||
try:
|
||||
with open(filepath, 'rb') as f_in:
|
||||
with gzip.open(filepath + '.gz', 'wb') as f_out:
|
||||
shutil.copyfileobj(f_in, f_out)
|
||||
os.remove(filepath)
|
||||
except Exception as e:
|
||||
print(f"日志压缩失败: {filepath}, 原因: {e}")
|
||||
|
||||
|
||||
def compress_old_logs(log_dir: str = None, name: str = "root"):
|
||||
"""
|
||||
压缩旧的日志文件(公共接口)
|
||||
|
||||
Args:
|
||||
log_dir: 日志目录,如果不指定则使用默认目录
|
||||
name: 日志器名称
|
||||
"""
|
||||
if log_dir is None:
|
||||
log_dir = "logs"
|
||||
|
||||
_compress_old_logs(log_dir, name)
|
||||
|
||||
|
||||
def log_api_call(logger: Logger, user_id: str = None, endpoint: str = None, method: str = None, params: dict = None, response_status: int = None, client_ip: str = None):
|
||||
"""
|
||||
记录API调用信息,包含用户ID、接口路径、请求方法、参数、响应状态和来源IP
|
||||
|
||||
Args:
|
||||
logger: 日志器对象
|
||||
user_id: 用户ID
|
||||
endpoint: 接口路径
|
||||
method: 请求方法 (GET, POST, PUT, DELETE等)
|
||||
params: 请求参数
|
||||
response_status: 响应状态码
|
||||
client_ip: 客户端IP地址
|
||||
"""
|
||||
try:
|
||||
# 构建日志信息
|
||||
log_parts = []
|
||||
|
||||
if user_id:
|
||||
log_parts.append(f"用户={user_id}")
|
||||
|
||||
if client_ip:
|
||||
log_parts.append(f"IP={client_ip}")
|
||||
|
||||
if method and endpoint:
|
||||
log_parts.append(f"{method} {endpoint}")
|
||||
elif endpoint:
|
||||
log_parts.append(f"接口={endpoint}")
|
||||
|
||||
if params:
|
||||
# 过滤敏感信息
|
||||
safe_params = {k: v for k, v in params.items()
|
||||
if k.lower() not in ['password', 'token', 'secret', 'key']}
|
||||
if safe_params:
|
||||
log_parts.append(f"参数={safe_params}")
|
||||
|
||||
if response_status:
|
||||
log_parts.append(f"状态码={response_status}")
|
||||
|
||||
if log_parts:
|
||||
log_message = " ".join(log_parts)
|
||||
logger.info(log_message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"记录API调用日志失败: {e}")
|
||||
|
||||
|
||||
def delete_old_compressed_logs(log_dir: str = None, days: int = 7):
|
||||
"""
|
||||
删除超过指定天数的压缩日志文件
|
||||
|
||||
Args:
|
||||
log_dir: 日志目录,如果不指定则使用默认目录
|
||||
days: 保留天数,默认7天
|
||||
"""
|
||||
try:
|
||||
if log_dir is None:
|
||||
log_dir = "logs"
|
||||
|
||||
log_path = Path(log_dir)
|
||||
if not log_path.exists():
|
||||
return
|
||||
|
||||
# 计算截止时间
|
||||
cutoff_time = datetime.now() - timedelta(days=days)
|
||||
|
||||
# 获取所有压缩日志文件
|
||||
gz_files = [f for f in log_path.iterdir()
|
||||
if f.is_file() and f.name.endswith('.log.gz')]
|
||||
|
||||
deleted_count = 0
|
||||
for gz_file in gz_files:
|
||||
# 获取文件修改时间
|
||||
file_mtime = datetime.fromtimestamp(gz_file.stat().st_mtime)
|
||||
|
||||
# 如果文件超过保留期限,删除它
|
||||
if file_mtime < cutoff_time:
|
||||
gz_file.unlink()
|
||||
print(f"删除旧压缩日志文件: {gz_file}")
|
||||
deleted_count += 1
|
||||
|
||||
if deleted_count > 0:
|
||||
print(f"总共删除了 {deleted_count} 个旧压缩日志文件")
|
||||
|
||||
except Exception as e:
|
||||
print(f"删除旧压缩日志文件失败: {e}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
logger = getLogger('WebAPI')
|
||||
|
||||
# 基础日志测试
|
||||
logger.info("系统启动")
|
||||
logger.debug("调试信息")
|
||||
logger.warning("警告信息")
|
||||
logger.error("错误信息")
|
||||
|
||||
# API调用日志测试
|
||||
log_api_call(
|
||||
logger=logger,
|
||||
user_id="user123",
|
||||
endpoint="/api/users/info",
|
||||
method="GET",
|
||||
params={"id": 123, "fields": ["name", "email"]},
|
||||
response_status=200,
|
||||
client_ip="192.168.1.100"
|
||||
)
|
||||
|
||||
log_api_call(
|
||||
logger=logger,
|
||||
user_id="user456",
|
||||
endpoint="/api/users/login",
|
||||
method="POST",
|
||||
params={"username": "test", "password": "hidden"}, # password会被过滤
|
||||
response_status=401,
|
||||
client_ip="10.0.0.50"
|
||||
)
|
||||
|
||||
# 单例验证
|
||||
logger2 = getLogger('WebAPI')
|
||||
print(f"Logger单例验证: {id(logger) == id(logger2)}")
|
||||
8
app/utils/out_base.py
Normal file
8
app/utils/out_base.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CommonOut(BaseModel):
|
||||
"""操作结果详情模型"""
|
||||
code: int = Field(200, description='状态码')
|
||||
message: str = Field('成功', description='提示信息')
|
||||
count: int = Field(0, description='操作影响的记录数')
|
||||
96
app/utils/redis_tool.py
Normal file
96
app/utils/redis_tool.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import redis
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class RedisClient:
|
||||
def __init__(self, host: str = 'localhost', port: int = 6379, password: str = None):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.password = password
|
||||
self.browser_client = None
|
||||
self.task_client = None
|
||||
self.cache_client = None
|
||||
self.ok_client = None
|
||||
self.init()
|
||||
|
||||
# 初始化
|
||||
def init(self):
|
||||
"""
|
||||
初始化Redis客户端
|
||||
:return:
|
||||
"""
|
||||
if self.browser_client is None:
|
||||
self.browser_client = redis.Redis(host=self.host, port=self.port, password=self.password, db=0,
|
||||
decode_responses=True)
|
||||
|
||||
if self.task_client is None:
|
||||
self.task_client = redis.Redis(host=self.host, port=self.port, password=self.password, db=1,
|
||||
decode_responses=True)
|
||||
|
||||
if self.cache_client is None:
|
||||
self.cache_client = redis.Redis(host=self.host, port=self.port, password=self.password, db=2,
|
||||
decode_responses=True)
|
||||
|
||||
if self.ok_client is None:
|
||||
self.ok_client = redis.Redis(host=self.host, port=self.port, password=self.password, db=3,
|
||||
decode_responses=True)
|
||||
|
||||
logger.info("Redis连接已初始化")
|
||||
|
||||
# 关闭连接
|
||||
def close(self):
|
||||
self.browser_client.close()
|
||||
self.task_client.close()
|
||||
self.cache_client.close()
|
||||
self.ok_client.close()
|
||||
logger.info("Redis连接已关闭")
|
||||
|
||||
"""browser_client"""
|
||||
|
||||
# 写入浏览器信息
|
||||
async def set_browser(self, browser_id: str, data: dict):
|
||||
try:
|
||||
# 处理None值,将其转换为空字符串
|
||||
processed_data = {}
|
||||
for key, value in data.items():
|
||||
if value is None:
|
||||
processed_data[key] = ""
|
||||
else:
|
||||
processed_data[key] = value
|
||||
|
||||
self.browser_client.hset(browser_id, mapping=processed_data)
|
||||
logger.info(f"写入浏览器信息: {browser_id} - {processed_data}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"写入浏览器信息失败: {browser_id} - {e}")
|
||||
return False
|
||||
|
||||
# 获取浏览器信息
|
||||
async def get_browser(self, browser_id: str = None):
|
||||
try:
|
||||
if browser_id is None:
|
||||
# 获取全部数据
|
||||
data = self.browser_client.hgetall()
|
||||
else:
|
||||
data = self.browser_client.hgetall(browser_id)
|
||||
logger.info(f"获取浏览器信息: {browser_id} - {data}")
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.error(f"获取浏览器信息失败: {browser_id} - {e}")
|
||||
|
||||
|
||||
async def main():
|
||||
host = '183.66.27.14'
|
||||
port = 50086
|
||||
password = 'redis_AdJsBP'
|
||||
redis_client = RedisClient(host, port, password)
|
||||
# await redis_client.set_browser('9eac7f95ca2d47359ace4083a566e119', {'status': 'online', 'current_task_id': None})
|
||||
await redis_client.get_browser('9eac7f95ca2d47359ace4083a566e119')
|
||||
# 关闭连接
|
||||
redis_client.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
177
app/utils/session_store.py
Normal file
177
app/utils/session_store.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import os
|
||||
import json
|
||||
import threading
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any, List
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class SessionStore:
|
||||
"""
|
||||
会话持久化存储(日志文件版 + 内存缓存)
|
||||
|
||||
优化方案:
|
||||
1. 使用日志文件记录(追加模式,性能好,不会因为文件变大而变慢)
|
||||
2. 在内存中保留最近的会话记录(用于快速查询)
|
||||
3. 定期清理过期的内存记录(保留最近1小时或最多1000条)
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str = 'logs/sessions.log', enable_log: bool = True, max_memory_records: int = 1000):
|
||||
"""
|
||||
初始化会话存储。
|
||||
|
||||
Args:
|
||||
file_path (str): 日志文件路径(默认 logs/sessions.log)
|
||||
enable_log (bool): 是否启用日志记录,False 则不记录到文件
|
||||
max_memory_records (int): 内存中保留的最大记录数,默认1000
|
||||
"""
|
||||
self.file_path = file_path
|
||||
self.enable_log = enable_log
|
||||
self.max_memory_records = max_memory_records
|
||||
self._lock = threading.Lock()
|
||||
# 内存中的会话记录 {pid: record}
|
||||
self._memory_cache: Dict[int, Dict[str, Any]] = {}
|
||||
# 记录创建时间,用于清理过期记录
|
||||
self._cache_timestamps: Dict[int, datetime] = {}
|
||||
|
||||
if enable_log:
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
|
||||
def _write_log(self, action: str, record: Dict[str, Any]) -> None:
|
||||
"""
|
||||
写入日志文件(追加模式,性能好)
|
||||
|
||||
Args:
|
||||
action (str): 操作类型(CREATE/UPDATE)
|
||||
record (Dict[str, Any]): 会话记录
|
||||
"""
|
||||
if not self.enable_log:
|
||||
return
|
||||
|
||||
try:
|
||||
with self._lock:
|
||||
log_line = json.dumps({
|
||||
'action': action,
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'data': record
|
||||
}, ensure_ascii=False)
|
||||
with open(self.file_path, 'a', encoding='utf-8') as f:
|
||||
f.write(log_line + '\n')
|
||||
except Exception as e:
|
||||
# 静默处理日志写入错误,避免影响主流程
|
||||
logger.debug(f"写入会话日志失败: {e}")
|
||||
|
||||
def _cleanup_old_cache(self) -> None:
|
||||
"""
|
||||
清理过期的内存缓存记录
|
||||
- 保留最近1小时的记录
|
||||
- 最多保留 max_memory_records 条记录
|
||||
"""
|
||||
now = datetime.now()
|
||||
expire_time = now - timedelta(hours=1)
|
||||
|
||||
# 清理过期记录
|
||||
expired_pids = [
|
||||
pid for pid, timestamp in self._cache_timestamps.items()
|
||||
if timestamp < expire_time
|
||||
]
|
||||
for pid in expired_pids:
|
||||
self._memory_cache.pop(pid, None)
|
||||
self._cache_timestamps.pop(pid, None)
|
||||
|
||||
# 如果记录数仍然超过限制,删除最旧的记录
|
||||
if len(self._memory_cache) > self.max_memory_records:
|
||||
# 按时间戳排序,删除最旧的
|
||||
sorted_pids = sorted(
|
||||
self._cache_timestamps.items(),
|
||||
key=lambda x: x[1]
|
||||
)
|
||||
# 计算需要删除的数量
|
||||
to_remove = len(self._memory_cache) - self.max_memory_records
|
||||
for pid, _ in sorted_pids[:to_remove]:
|
||||
self._memory_cache.pop(pid, None)
|
||||
self._cache_timestamps.pop(pid, None)
|
||||
|
||||
def create_session(self, record: Dict[str, Any]) -> None:
|
||||
"""
|
||||
创建新会话记录。
|
||||
|
||||
Args:
|
||||
record (Dict[str, Any]): 会话信息字典
|
||||
"""
|
||||
record = dict(record)
|
||||
record.setdefault('created_at', datetime.now().isoformat())
|
||||
pid = record.get('pid')
|
||||
|
||||
if pid is not None:
|
||||
with self._lock:
|
||||
# 保存到内存缓存
|
||||
self._memory_cache[pid] = record
|
||||
self._cache_timestamps[pid] = datetime.now()
|
||||
# 清理过期记录
|
||||
self._cleanup_old_cache()
|
||||
|
||||
# 写入日志文件(追加模式,性能好)
|
||||
self._write_log('CREATE', record)
|
||||
|
||||
def update_session(self, pid: int, updates: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
按 PID 更新会话记录。
|
||||
|
||||
Args:
|
||||
pid (int): 进程ID
|
||||
updates (Dict[str, Any]): 更新字段字典
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 更新后的会话记录
|
||||
"""
|
||||
with self._lock:
|
||||
# 从内存缓存获取
|
||||
record = self._memory_cache.get(pid)
|
||||
if record:
|
||||
record.update(updates)
|
||||
record.setdefault('updated_at', datetime.now().isoformat())
|
||||
self._cache_timestamps[pid] = datetime.now()
|
||||
else:
|
||||
# 如果内存中没有,创建一个新记录
|
||||
record = {'pid': pid}
|
||||
record.update(updates)
|
||||
record.setdefault('created_at', datetime.now().isoformat())
|
||||
record.setdefault('updated_at', datetime.now().isoformat())
|
||||
self._memory_cache[pid] = record
|
||||
self._cache_timestamps[pid] = datetime.now()
|
||||
|
||||
if record:
|
||||
# 写入日志文件
|
||||
self._write_log('UPDATE', record)
|
||||
|
||||
return record
|
||||
|
||||
def get_session_by_pid(self, pid: int) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
按 PID 查询会话记录(仅从内存缓存查询,性能好)
|
||||
|
||||
Args:
|
||||
pid (int): 进程ID
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 会话记录
|
||||
"""
|
||||
with self._lock:
|
||||
return self._memory_cache.get(pid)
|
||||
|
||||
def list_sessions(self, status: Optional[int] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
列出会话记录,可按状态过滤(仅从内存缓存查询)
|
||||
|
||||
Args:
|
||||
status (Optional[int]): 状态码过滤(如 100 运行中、200 已结束、500 失败)
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 会话记录列表
|
||||
"""
|
||||
with self._lock:
|
||||
records = list(self._memory_cache.values())
|
||||
if status is None:
|
||||
return records
|
||||
return [r for r in records if r.get('status') == status]
|
||||
56
app/utils/time_tool.py
Normal file
56
app/utils/time_tool.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pydantic import BaseModel, field_serializer
|
||||
CN_TZ = timezone(timedelta(hours=8))
|
||||
|
||||
|
||||
def now_cn() -> datetime:
|
||||
"""
|
||||
获取中国时区的当前时间
|
||||
返回带有中国时区信息的 datetime 对象
|
||||
"""
|
||||
return datetime.now(CN_TZ)
|
||||
|
||||
def parse_time(val: str | int, is_end: bool = False) -> datetime:
|
||||
"""
|
||||
将传入的字符串或时间戳解析为中国时区的 datetime,用于数据库查询时间比较。
|
||||
支持格式:
|
||||
- "YYYY-MM-DD"
|
||||
- "YYYY-MM-DD HH:mm:ss"
|
||||
- 10 位时间戳(秒)
|
||||
- 13 位时间戳(毫秒)
|
||||
"""
|
||||
dt_cn: datetime
|
||||
|
||||
if isinstance(val, int) or (isinstance(val, str) and val.isdigit()):
|
||||
ts = int(val)
|
||||
# 根据量级判断是秒还是毫秒
|
||||
if ts >= 10**12:
|
||||
dt_cn = datetime.fromtimestamp(ts / 1000, CN_TZ)
|
||||
else:
|
||||
dt_cn = datetime.fromtimestamp(ts, CN_TZ)
|
||||
else:
|
||||
try:
|
||||
dt_cn = datetime.strptime(val, "%Y-%m-%d").replace(tzinfo=CN_TZ)
|
||||
if is_end:
|
||||
dt_cn = dt_cn.replace(hour=23, minute=59, second=59, microsecond=999999)
|
||||
except ValueError:
|
||||
try:
|
||||
dt_cn = datetime.strptime(val, "%Y-%m-%d %H:%M:%S").replace(tzinfo=CN_TZ)
|
||||
except ValueError:
|
||||
raise ValueError("时间格式错误,支持 'YYYY-MM-DD' 或 'YYYY-MM-DD HH:mm:ss' 或 10/13位时间戳")
|
||||
|
||||
# 与 ORM 配置保持一致(use_tz=False),返回本地时区的“朴素”时间
|
||||
return dt_cn.replace(tzinfo=None)
|
||||
|
||||
|
||||
# 自动把 datetime 序列化为 13位时间戳的基类
|
||||
class TimestampModel(BaseModel):
|
||||
"""自动把 datetime 序列化为 13位时间戳的基类"""
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
@field_serializer("*", when_used="json", check_fields=False) # "*" 表示作用于所有字段
|
||||
def serialize_datetime(self, value):
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp()*1000) # 转成 13 位 int 时间戳
|
||||
return value
|
||||
Reference in New Issue
Block a user