# _*_ coding : UTF-8 _*_ # @Time : 2025/01/19 01:42 # @UpdateTime : 2025/01/19 01:42 # @Author : sonder # @File : login.py # @Software : PyCharm # @Comment : 本程序 import uuid from datetime import timedelta, datetime from typing import Optional, Union from fastapi import Form, Depends, Request from fastapi.encoders import jsonable_encoder from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from jose import jwt from jose.exceptions import JWEInvalidAuth, ExpiredSignatureError, JWEError from config.constant import RedisKeyConfig from config.env import JwtConfig from controller.query import QueryController from exceptions.exception import AuthException from models import LoginLog from schemas.login import LoginParams from utils.log import logger from utils.password import Password oauth2_scheme = OAuth2PasswordBearer(tokenUrl='login') class CustomOAuth2PasswordRequestForm(OAuth2PasswordRequestForm): """ 自定义OAuth2PasswordRequestForm类,增加验证码及会话编号参数 """ def __init__( self, grant_type: str = Form(default=None, regex='password'), username: str = Form(..., description="用户账号"), password: str = Form(..., description="用户密码"), scope: str = Form(default=''), client_id: Optional[str] = Form(default=None), client_secret: Optional[str] = Form(default=None), loginDays: Optional[int] = Form(default=1), code: Optional[str] = Form(default=''), uuid: Optional[str] = Form(default=''), ): super().__init__( grant_type=grant_type, username=username, password=password, scope=scope, client_id=client_id, client_secret=client_secret, ) self.code = code self.uuid = uuid self.loginDays = loginDays class LoginController: """ 登录控制器 """ @classmethod async def login(cls, params: LoginParams): """ 登录 :param params: :return: """ result = await QueryController.get_user_by_username(params.username) if result and await Password.verify_password(params.password, result.password): userInfo = await QueryController.get_user_info(user_id=result.id.__str__()) logger.success(f"用户 {params.username} 登录成功") session_id = uuid.uuid4().__str__() accessToken = await cls.create_token( data={"user": jsonable_encoder(userInfo), "id": result.id.__str__(), "session_id": session_id}, expires_delta=timedelta(minutes=params.loginDays * 24 * 60)) expiresTime = (datetime.now() + timedelta(minutes=params.loginDays * 24 * 60)).timestamp() refreshToken = await cls.create_token( data={"user": jsonable_encoder(userInfo), "id": result.id.__str__(), "session_id": session_id}, expires_delta=timedelta(minutes=(params.loginDays * 24 + 2) * 60)) return {"status": True, "accessToken": accessToken, "refreshToken": refreshToken, "userInfo": userInfo, "expiresTime": expiresTime, "session_id": session_id, "expiresIn": params.loginDays * 24 * 60} logger.error(f"用户 {params.username} 登录失败") return {"status": False} @classmethod async def create_token(cls, data: dict, expires_delta: Union[timedelta, None] = None) -> str: """ 创建token :param data: 存储数据 :param expires_delta: 过期时间 :return: token """ to_copy = data.copy() if expires_delta: expire = datetime.now() + expires_delta else: expire = datetime.now() + timedelta(minutes=JwtConfig.jwt_expire_minutes) to_copy.update({"exp": expire}) return jwt.encode(claims=to_copy, key=JwtConfig.jwt_secret_key, algorithm=JwtConfig.jwt_algorithm) @classmethod async def get_current_user(cls, request: Request = Request, token: str = Depends(oauth2_scheme)): """ 获取当前用户 :param request: :param token: :return: """ try: if token.startswith('Bearer'): token = token.split(' ')[1] payload = jwt.decode(token=token, key=JwtConfig.jwt_secret_key, algorithms=[JwtConfig.jwt_algorithm]) user_id: str = payload.get("id", "") session_id: str = payload.get('session_id', "") if not user_id: logger.warning('用户token不合法') raise AuthException(data='', message='用户token不合法') except (JWEInvalidAuth, ExpiredSignatureError, JWEError): logger.warning('用户token已失效,请重新登录') raise AuthException(data='', message='用户token已失效,请重新登录') userInfo = await request.app.state.redis.get(f'{RedisKeyConfig.USER_INFO.key}:{user_id}') if userInfo: userInfo = eval(userInfo) if not userInfo: userInfo = await QueryController.get_user_info(user_id=user_id) await request.app.state.redis.set(f'{RedisKeyConfig.USER_INFO.key}:{user_id}', str(jsonable_encoder(userInfo)), ex=timedelta(minutes=2)) if not userInfo: logger.warning('用户token不合法') raise AuthException(data='', message='用户token不合法') redis_token = await request.app.state.redis.get(f'{RedisKeyConfig.ACCESS_TOKEN.key}:{session_id}') if not redis_token: logger.warning('用户token已失效,请重新登录') raise AuthException(data='', message='用户token已失效,请重新登录') return userInfo @classmethod async def logout(cls, request: Request = Request, token: str = Depends(oauth2_scheme)) -> bool: """ 登出 """ try: if token.startswith('Bearer'): token = token.split(' ')[1] payload = jwt.decode(token=token, key=JwtConfig.jwt_secret_key, algorithms=[JwtConfig.jwt_algorithm]) session_id: str = payload.get('session_id', "") except (JWEInvalidAuth, ExpiredSignatureError, JWEError): logger.warning('用户token已失效,请重新登录') raise AuthException(data='', message='用户token已失效,请重新登录') redis_token = await request.app.state.redis.get(f'{RedisKeyConfig.ACCESS_TOKEN.key}:{session_id}') if redis_token == token: await request.app.state.redis.delete(f'{RedisKeyConfig.ACCESS_TOKEN.key}:{session_id}') return True return False @classmethod async def get_user_routes(cls, user_id: str, sub_departments: list) -> Union[list, None]: """ 获取用户路由 """ permissions = await QueryController.get_user_permissions(user_id=user_id, sub_departments=sub_departments) for permission in permissions: permission["id"] = str(permission["id"]) permission["parentId"] = str(permission["parentId"]) if permission.get("parentId") else "" permissions = await cls.find_complete_data(permissions) return permissions @classmethod async def find_node_recursive(cls, node_id: str, data: list) -> dict: """ 递归查找节点 :param node_id: 节点ID :param data: 数据 """ result = {} data = list(filter(lambda x: x.get('type') == 0, data)) for item in data: if item["id"] == node_id: children = [] for child_item in data: if child_item["parentId"] == node_id: child_node = await cls.find_node_recursive(child_item["id"], data) if child_node: children.append(child_node) meta = { k: v for k, v in { "title": item["title"], "rank": item["rank"], "icon": item["icon"], "extraIcon": item["extraIcon"], "showParent": item["showParent"], "keepAlive": item["keepAlive"], "frameSrc": item["frameSrc"], "frameLoading": item["frameLoading"], "permissions": [item["auths"]], }.items() if v } if item["showLink"]: meta["showLink"] = True else: meta["showLink"] = False result = { "name": item["name"], "path": item["path"], "meta": meta, "children": children } if item["component"]: result["component"] = item["component"].replace(".vue", "").replace(".ts", "").replace(".tsx", "").replace( ".js", "").replace(".jsx", "").strip() if item["redirect"]: result["redirect"] = item["redirect"] if result["name"] == "": result.pop("name") if result["children"] == []: result.pop("children") else: result["children"] = sorted(result["children"], key=lambda x: x["meta"]["rank"]) break return result @classmethod async def find_complete_data(cls, data: list) -> list: """ 查找完整数据 :param data: 数据 """ complete_data = [] root_ids = [item["id"] for item in data if not item["parentId"]] for root_id in root_ids: complete_data.append(await cls.find_node_recursive(root_id, data)) return complete_data @classmethod async def get_online_user(cls, request: Request, sub_departments: list) -> list: """ 获取在线用户 """ access_token_keys = await request.app.state.redis.keys(f'{RedisKeyConfig.ACCESS_TOKEN.key}*') if not access_token_keys: access_token_keys = [] access_token_values_list = [await request.app.state.redis.get(key) for key in access_token_keys] online_info_list = [] for item in access_token_values_list: payload = jwt.decode(item, JwtConfig.jwt_secret_key, algorithms=[JwtConfig.jwt_algorithm]) session_id = payload.get("session_id") result = await LoginLog.get_or_none(session_id=session_id, user__department__id__in=sub_departments, del_flag=1).values( id="id", user_id="user__id", username="user__username", user_nickname="user__nickname", department_id="user__department__id", department_name="user__department__name", login_ip="login_ip", login_location="login_location", browser="browser", os="os", status="status", login_time="login_time", session_id="session_id", create_time="create_time", update_time="update_time" ) if not result: continue online_info_list.append(result) return online_info_list