diff --git a/api/login.py b/api/login.py index eedf67a..79a6fef 100644 --- a/api/login.py +++ b/api/login.py @@ -45,6 +45,22 @@ async def login( code=params.code, uuid=params.uuid ) + captcha_enabled = ( + True + if await request.app.state.redis.get(f'{RedisKeyConfig.SYSTEM_CONFIG.key}:account.captcha_enabled') + == 'true' + else False + ) + # 判断请求是否来自于api文档,如果是返回指定格式的结果,用于修复api文档认证成功后token显示undefined的bug + request_from_swagger = request.headers.get('referer').endswith('docs') if request.headers.get( + 'referer') else False + request_from_redoc = request.headers.get('referer').endswith('redoc') if request.headers.get( + 'referer') else False + # 验证码校验,如果开启验证码校验,则进行验证码校验,如果关闭则跳过验证码校验. 如果请求来自api文档,则跳过验证码校验 + if captcha_enabled and not request_from_redoc and not request_from_swagger: + result = await Captcha.verify_code(request, code=user.code, session_id=user.uuid) + if not result["status"]: + return Response.error(msg=result["msg"]) result = await LoginController.login(user) if result["status"]: await request.app.state.redis.set( @@ -59,11 +75,6 @@ async def login( ex=timedelta(minutes=5), ) request.app.state.session_id = result["session_id"] - # 判断请求是否来自于api文档,如果是返回指定格式的结果,用于修复api文档认证成功后token显示undefined的bug - request_from_swagger = request.headers.get('referer').endswith('docs') if request.headers.get( - 'referer') else False - request_from_redoc = request.headers.get('referer').endswith('redoc') if request.headers.get( - 'referer') else False if request_from_swagger or request_from_redoc: return {'access_token': result["accessToken"], 'token_type': 'Bearer', "expires_in": result["expiresIn"] * 60} @@ -77,6 +88,14 @@ async def login( @loginAPI.post("/register", response_class=JSONResponse, response_model=LoginResponse, summary="用户注册") async def register(request: Request, params: RegisterUserParams): + register_enabled = ( + True + if await request.app.state.redis.get(f'{RedisKeyConfig.SYSTEM_CONFIG.key}:register_enabled') + == 'true' + else False + ) + if not register_enabled: + return Response.error(msg="注册功能已关闭!") result = await Email.verify_code(request, username=params.username, mail=params.email, code=params.code) if not result["status"]: return Response.error(msg=result["msg"]) @@ -130,6 +149,12 @@ async def get_captcha(request: Request): == 'true' else False ) + register_enabled = ( + True + if await request.app.state.redis.get(f'{RedisKeyConfig.SYSTEM_CONFIG.key}:register_enabled') + == 'true' + else False + ) if captcha_enabled: captcha_type = ( await request.app.state.redis.get(f'{RedisKeyConfig.SYSTEM_CONFIG.key}:account_captcha_type') @@ -149,12 +174,14 @@ async def get_captcha(request: Request): "uuid": session_id, "captcha": captcha, "captcha_enabled": captcha_enabled, + "register_enabled": register_enabled, }) else: return Response.success(data={ "uuid": None, "captcha": None, "captcha_enabled": captcha_enabled, + "register_enabled": register_enabled, }) diff --git a/utils/captcha.py b/utils/captcha.py index 4ef2702..f4ceba9 100644 --- a/utils/captcha.py +++ b/utils/captcha.py @@ -12,6 +12,9 @@ import random import string from PIL import Image, ImageDraw, ImageFont +from fastapi import Request + +from config.constant import RedisKeyConfig class Captcha: @@ -110,3 +113,28 @@ class Captcha: draw.line((x1, y1, x2, y2), fill=(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)), width=1) + + @classmethod + async def verify_code(cls, request: Request, code: str, session_id: str) -> dict: + """ + 验证验证码 + :param request + :param code: 验证码 + :param session_id: 会话ID + """ + redis_code = await request.app.state.redis.get(f"{RedisKeyConfig.CAPTCHA_CODES.key}:{session_id}") + if redis_code is None: + return { + "status": False, + "msg": "验证码已过期" + } + if str(redis_code).lower() == code.lower(): + await request.app.state.redis.delete(f"{RedisKeyConfig.CAPTCHA_CODES.key}:{session_id}") + return { + "status": True, + "msg": "验证码正确" + } + return { + "status": False, + "msg": "验证码错误" + }