Files
fastapi-project-template/config/database.py

181 lines
6.2 KiB
Python

# _*_ coding : UTF-8 _*_
# @Time : 2025/01/18 02:00
# @UpdateTime : 2025/01/18 02:00
# @Author : sonder
# @File : database.py
# @Software : PyCharm
# @Comment : 本程序
import asyncio
import logging
import subprocess
import sys
from datetime import datetime
from logging.handlers import RotatingFileHandler
from pathlib import Path
from tortoise import Tortoise
from config.env import DataBaseConfig
from utils.log import logger, log_path_sql
async def init_db():
"""
异步初始化数据库连接。
"""
# 在数据库连接 URL 中添加时区参数(东八区)
db_url = (
f"mysql://{DataBaseConfig.db_username}:{DataBaseConfig.db_password}@"
f"{DataBaseConfig.db_host}:{DataBaseConfig.db_port}/{DataBaseConfig.db_database}"
"?charset=utf8mb4" # 指定时区为东八区,
)
await Tortoise.init(
db_url=db_url,
modules={"models": ["models"]}, # 指向 models 目录,
timezone="Asia/Shanghai",
)
# 根据 db_echo 配置是否打印 SQL 查询日志
if DataBaseConfig.db_echo:
logger.info("SQL 查询日志已启用")
await configure_tortoise_logging(enable_logging=True, log_level=DataBaseConfig.db_log_level)
else:
logger.info("SQL 查询日志已禁用")
# 禁用 SQL 查询日志
logger.remove(log_path_sql)
# 生成数据库表结构
await Tortoise.generate_schemas()
logger.success("数据库连接成功!")
async def close_db():
"""
关闭数据库连接。
"""
await Tortoise.close_connections()
logger.success("数据库连接关闭!")
async def configure_tortoise_logging(enable_logging: bool = True, log_level: int = logging.DEBUG):
"""
异步配置 Tortoise ORM 日志输出。
:param enable_logging: 是否启用日志输出
:param log_level: 日志输出级别,默认为 DEBUG
"""
aiomysql_logger = logging.getLogger("aiomysql")
tortoise_logger = logging.getLogger("tortoise")
# 清除之前的处理器,避免重复添加
if tortoise_logger.hasHandlers():
tortoise_logger.handlers.clear()
if aiomysql_logger.hasHandlers():
aiomysql_logger.handlers.clear()
if enable_logging:
# 设置日志格式
fmt = logging.Formatter(
fmt="%(asctime)s - %(name)s:%(lineno)d - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
# 创建控制台处理器(输出到控制台)
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(log_level)
console_handler.setFormatter(fmt)
# 创建文件处理器(输出到文件)
file_handler = RotatingFileHandler(
filename=log_path_sql,
maxBytes=50 * 1024 * 1024, # 日志文件大小达到 50MB 时轮换
backupCount=5, # 保留 5 个旧日志文件
encoding="utf-8",
)
file_handler.setLevel(log_level)
file_handler.setFormatter(fmt)
# 配置 tortoise 顶级日志记录器
tortoise_logger.setLevel(log_level)
tortoise_logger.addHandler(console_handler) # 添加控制台处理器
tortoise_logger.addHandler(file_handler) # 添加文件处理器
# 配置 aiomysql 日志记录器
aiomysql_logger.setLevel(log_level)
aiomysql_logger.addHandler(console_handler) # 添加控制台处理器
aiomysql_logger.addHandler(file_handler) # 添加文件处理器
# 配置 SQL 查询日志记录器
sql_logger = logging.getLogger("tortoise.db_client")
sql_logger.setLevel(log_level)
class SQLResultLogger(logging.Handler):
async def emit(self, record):
# 只处理 SQL 查询相关的日志
if "SELECT" in record.getMessage() or "INSERT" in record.getMessage() or "UPDATE" in record.getMessage() or "DELETE" in record.getMessage():
# 输出 SQL 查询语句
console_handler.emit(record)
file_handler.emit(record)
# 异步获取并记录查询结果
await self.log_query_result(record)
async def log_query_result(self, record):
"""
执行查询并返回结果。
"""
try:
from tortoise import Tortoise
connection = Tortoise.get_connection("default")
result = await connection.execute_query_dict(record.getMessage())
return result
except Exception as e:
return f"获取查询结果失败: {str(e)}"
# 添加自定义 SQL 查询日志处理器
sql_result_handler = SQLResultLogger()
sql_result_handler.setLevel(log_level)
sql_logger.addHandler(sql_result_handler)
else:
# 如果禁用日志,设置日志级别为 WARNING 以抑制大部分输出
tortoise_logger.setLevel(logging.WARNING)
async def backup_database():
"""
备份数据库
"""
logger.info("开始备份数据库")
# 配置数据库连接信息
backup_dir = Path().cwd() / "sql" # 备份文件存储的目录
# 如果 migrations 目录不存在,则创建
backup_dir.mkdir(parents=True, exist_ok=True)
# 生成备份文件名,格式为 dbYYYYMMDDHHMMSS.sql
timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
backup_filename = f"db{timestamp}.sql"
backup_filepath = backup_dir / backup_filename # 使用 Path 对象组合路径
# 构造 mysqldump 命令
command = [
"mysqldump",
"-u", DataBaseConfig.db_username,
f"-p{DataBaseConfig.db_password}", # 直接传递密码
DataBaseConfig.db_database,
"--result-file=" + str(backup_filepath) # 指定备份文件路径
]
# 使用 asyncio.to_thread 来在线程中执行阻塞操作
await asyncio.to_thread(run_mysqldump, command)
def run_mysqldump(command):
"""在阻塞线程中执行 mysqldump 命令"""
try:
subprocess.run(command, check=True)
logger.info(f"数据库备份已完成,文件保存为 {command[-1]}")
except subprocess.CalledProcessError as e:
logger.error(f"备份失败,错误: {e}")