139 lines
5.0 KiB
Python
139 lines
5.0 KiB
Python
# _*_ coding : UTF-8 _*_
|
|
# @Time : 2025/01/18 02:00
|
|
# @UpdateTime : 2025/01/18 02:00
|
|
# @Author : sonder
|
|
# @File : database.py
|
|
# @Software : PyCharm
|
|
# @Comment : 本程序
|
|
import logging
|
|
import sys
|
|
from logging.handlers import RotatingFileHandler
|
|
|
|
from tortoise import Tortoise
|
|
|
|
from config.env import DataBaseConfig
|
|
from utils.log import logger, log_path_sql
|
|
|
|
|
|
async def init_db():
|
|
"""
|
|
异步初始化数据库连接。
|
|
"""
|
|
# 在数据库连接 URL 中添加时区参数(东八区)
|
|
db_url = (
|
|
f"mysql://{DataBaseConfig.db_username}:{DataBaseConfig.db_password}@"
|
|
f"{DataBaseConfig.db_host}:{DataBaseConfig.db_port}/{DataBaseConfig.db_database}"
|
|
"?charset=utf8mb4" # 指定时区为东八区,
|
|
)
|
|
|
|
await Tortoise.init(
|
|
db_url=db_url,
|
|
modules={"models": ["models"]}, # 指向 models 目录,
|
|
timezone="Asia/Shanghai",
|
|
)
|
|
|
|
# 根据 db_echo 配置是否打印 SQL 查询日志
|
|
if DataBaseConfig.db_echo:
|
|
logger.info("SQL 查询日志已启用")
|
|
await configure_tortoise_logging(enable_logging=True, log_level=DataBaseConfig.db_log_level)
|
|
else:
|
|
logger.info("SQL 查询日志已禁用")
|
|
# 禁用 SQL 查询日志
|
|
logger.remove(log_path_sql)
|
|
|
|
# 生成数据库表结构
|
|
await Tortoise.generate_schemas()
|
|
logger.success("数据库连接成功!")
|
|
|
|
|
|
async def close_db():
|
|
"""
|
|
关闭数据库连接。
|
|
"""
|
|
await Tortoise.close_connections()
|
|
logger.success("数据库连接关闭!")
|
|
|
|
|
|
async def configure_tortoise_logging(enable_logging: bool = True, log_level: int = logging.DEBUG):
|
|
"""
|
|
异步配置 Tortoise ORM 日志输出。
|
|
|
|
:param enable_logging: 是否启用日志输出
|
|
:param log_level: 日志输出级别,默认为 DEBUG
|
|
"""
|
|
aiomysql_logger = logging.getLogger("aiomysql")
|
|
tortoise_logger = logging.getLogger("tortoise")
|
|
|
|
# 清除之前的处理器,避免重复添加
|
|
if tortoise_logger.hasHandlers():
|
|
tortoise_logger.handlers.clear()
|
|
|
|
if aiomysql_logger.hasHandlers():
|
|
aiomysql_logger.handlers.clear()
|
|
|
|
# if enable_logging:
|
|
# # 设置日志格式
|
|
# fmt = logging.Formatter(
|
|
# fmt="%(asctime)s - %(name)s:%(lineno)d - %(levelname)s - %(message)s",
|
|
# datefmt="%Y-%m-%d %H:%M:%S",
|
|
# )
|
|
#
|
|
# # 创建控制台处理器(输出到控制台)
|
|
# console_handler = logging.StreamHandler(sys.stdout)
|
|
# console_handler.setLevel(log_level)
|
|
# console_handler.setFormatter(fmt)
|
|
#
|
|
# # 创建文件处理器(输出到文件)
|
|
# file_handler = RotatingFileHandler(
|
|
# filename=log_path_sql,
|
|
# maxBytes=50 * 1024 * 1024, # 日志文件大小达到 50MB 时轮换
|
|
# backupCount=5, # 保留 5 个旧日志文件
|
|
# encoding="utf-8",
|
|
# )
|
|
# file_handler.setLevel(log_level)
|
|
# file_handler.setFormatter(fmt)
|
|
#
|
|
# # 配置 tortoise 顶级日志记录器
|
|
# tortoise_logger.setLevel(log_level)
|
|
# tortoise_logger.addHandler(console_handler) # 添加控制台处理器
|
|
# tortoise_logger.addHandler(file_handler) # 添加文件处理器
|
|
#
|
|
# # 配置 aiomysql 日志记录器
|
|
# aiomysql_logger.setLevel(log_level)
|
|
# aiomysql_logger.addHandler(console_handler) # 添加控制台处理器
|
|
# aiomysql_logger.addHandler(file_handler) # 添加文件处理器
|
|
# # 配置 SQL 查询日志记录器
|
|
# sql_logger = logging.getLogger("tortoise.db_client")
|
|
# sql_logger.setLevel(log_level)
|
|
#
|
|
# class SQLResultLogger(logging.Handler):
|
|
# async def emit(self, record):
|
|
# # 只处理 SQL 查询相关的日志
|
|
# if "SELECT" in record.getMessage() or "INSERT" in record.getMessage() or "UPDATE" in record.getMessage() or "DELETE" in record.getMessage():
|
|
# # 输出 SQL 查询语句
|
|
# console_handler.emit(record)
|
|
# file_handler.emit(record)
|
|
#
|
|
# # 异步获取并记录查询结果
|
|
# await self.log_query_result(record)
|
|
#
|
|
# async def log_query_result(self, record):
|
|
# """
|
|
# 执行查询并返回结果。
|
|
# """
|
|
# try:
|
|
# from tortoise import Tortoise
|
|
# connection = Tortoise.get_connection("default")
|
|
# result = await connection.execute_query_dict(record.getMessage())
|
|
# return result
|
|
# except Exception as e:
|
|
# return f"获取查询结果失败: {str(e)}"
|
|
#
|
|
# # 添加自定义 SQL 查询日志处理器
|
|
# sql_result_handler = SQLResultLogger()
|
|
# sql_result_handler.setLevel(log_level)
|
|
# sql_logger.addHandler(sql_result_handler)
|
|
# else:
|
|
# # 如果禁用日志,设置日志级别为 WARNING 以抑制大部分输出
|
|
# tortoise_logger.setLevel(logging.WARNING)
|