273 lines
11 KiB
Python
Raw Normal View History

2025-02-13 02:27:44 +08:00
# _*_ coding : UTF-8 _*_
# @Time : 2025/01/19 01:42
# @UpdateTime : 2025/01/19 01:42
# @Author : sonder
# @File : login.py
# @Software : PyCharm
# @Comment : 本程序
import uuid
from datetime import timedelta, datetime
from typing import Optional, Union
from fastapi import Form, Depends, Request
from fastapi.encoders import jsonable_encoder
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import jwt
from jose.exceptions import JWEInvalidAuth, ExpiredSignatureError, JWEError
from config.constant import RedisKeyConfig
from config.env import JwtConfig
from controller.query import QueryController
from exceptions.exception import AuthException
from models import LoginLog
from schemas.login import LoginParams
from utils.log import logger
from utils.password import Password
oauth2_scheme = OAuth2PasswordBearer(tokenUrl='login')
class CustomOAuth2PasswordRequestForm(OAuth2PasswordRequestForm):
"""
自定义OAuth2PasswordRequestForm类增加验证码及会话编号参数
"""
def __init__(
self,
grant_type: str = Form(default=None, regex='password'),
username: str = Form(..., description="用户账号"),
password: str = Form(..., description="用户密码"),
scope: str = Form(default=''),
client_id: Optional[str] = Form(default=None),
client_secret: Optional[str] = Form(default=None),
loginDays: Optional[int] = Form(default=1),
code: Optional[str] = Form(default=''),
uuid: Optional[str] = Form(default=''),
):
super().__init__(
grant_type=grant_type,
username=username,
password=password,
scope=scope,
client_id=client_id,
client_secret=client_secret,
)
self.code = code
self.uuid = uuid
self.loginDays = loginDays
class LoginController:
"""
登录控制器
"""
@classmethod
async def login(cls, params: LoginParams):
"""
登录
:param params:
:return:
"""
result = await QueryController.get_user_by_username(params.username)
if result and await Password.verify_password(params.password, result.password):
userInfo = await QueryController.get_user_info(user_id=result.id.__str__())
logger.success(f"用户 {params.username} 登录成功")
session_id = uuid.uuid4().__str__()
accessToken = await cls.create_token(
data={"user": jsonable_encoder(userInfo), "id": result.id.__str__(), "session_id": session_id},
expires_delta=timedelta(minutes=params.loginDays * 24 * 60))
expiresTime = (datetime.now() + timedelta(minutes=params.loginDays * 24 * 60)).timestamp()
refreshToken = await cls.create_token(
data={"user": jsonable_encoder(userInfo), "id": result.id.__str__(), "session_id": session_id},
expires_delta=timedelta(minutes=(params.loginDays * 24 + 2) * 60))
return {"status": True, "accessToken": accessToken, "refreshToken": refreshToken,
"userInfo": userInfo,
"expiresTime": expiresTime, "session_id": session_id, "expiresIn": params.loginDays * 24 * 60}
logger.error(f"用户 {params.username} 登录失败")
return {"status": False}
@classmethod
async def create_token(cls, data: dict, expires_delta: Union[timedelta, None] = None) -> str:
"""
创建token
:param data: 存储数据
:param expires_delta: 过期时间
:return: token
"""
to_copy = data.copy()
if expires_delta:
expire = datetime.now() + expires_delta
else:
expire = datetime.now() + timedelta(minutes=JwtConfig.jwt_expire_minutes)
to_copy.update({"exp": expire})
return jwt.encode(claims=to_copy, key=JwtConfig.jwt_secret_key, algorithm=JwtConfig.jwt_algorithm)
@classmethod
async def get_current_user(cls, request: Request = Request, token: str = Depends(oauth2_scheme)):
"""
获取当前用户
:param request:
:param token:
:return:
"""
try:
if token.startswith('Bearer'):
token = token.split(' ')[1]
payload = jwt.decode(token=token, key=JwtConfig.jwt_secret_key, algorithms=[JwtConfig.jwt_algorithm])
user_id: str = payload.get("id", "")
session_id: str = payload.get('session_id', "")
if not user_id:
logger.warning('用户token不合法')
raise AuthException(data='', message='用户token不合法')
except (JWEInvalidAuth, ExpiredSignatureError, JWEError):
logger.warning('用户token已失效请重新登录')
raise AuthException(data='', message='用户token已失效请重新登录')
userInfo = await request.app.state.redis.get(f'{RedisKeyConfig.USER_INFO.key}:{user_id}')
if userInfo:
userInfo = eval(userInfo)
if not userInfo:
userInfo = await QueryController.get_user_info(user_id=user_id)
await request.app.state.redis.set(f'{RedisKeyConfig.USER_INFO.key}:{user_id}',
str(jsonable_encoder(userInfo)),
ex=timedelta(minutes=5))
if not userInfo:
logger.warning('用户token不合法')
raise AuthException(data='', message='用户token不合法')
redis_token = await request.app.state.redis.get(f'{RedisKeyConfig.ACCESS_TOKEN.key}:{session_id}')
if not redis_token:
logger.warning('用户token已失效请重新登录')
raise AuthException(data='', message='用户token已失效请重新登录')
return userInfo
@classmethod
async def logout(cls, request: Request = Request, token: str = Depends(oauth2_scheme)) -> bool:
"""
登出
"""
try:
if token.startswith('Bearer'):
token = token.split(' ')[1]
payload = jwt.decode(token=token, key=JwtConfig.jwt_secret_key, algorithms=[JwtConfig.jwt_algorithm])
session_id: str = payload.get('session_id', "")
except (JWEInvalidAuth, ExpiredSignatureError, JWEError):
logger.warning('用户token已失效请重新登录')
raise AuthException(data='', message='用户token已失效请重新登录')
redis_token = await request.app.state.redis.get(f'{RedisKeyConfig.ACCESS_TOKEN.key}:{session_id}')
if redis_token == token:
await request.app.state.redis.delete(f'{RedisKeyConfig.ACCESS_TOKEN.key}:{session_id}')
return True
return False
@classmethod
async def get_user_routes(cls, user_id: str) -> Union[list, None]:
"""
获取用户路由
"""
permissions = await QueryController.get_user_permissions(user_id=user_id)
for permission in permissions:
permission["id"] = str(permission["id"])
permission["parentId"] = str(permission["parentId"]) if permission.get("parentId") else ""
permissions = await cls.find_complete_data(permissions)
return permissions
@classmethod
async def find_node_recursive(cls, node_id: str, data: list) -> dict:
"""
递归查找节点
:param node_id: 节点ID
:param data: 数据
"""
result = {}
data = list(filter(lambda x: x.get('type') == 0, data))
for item in data:
if item["id"] == node_id:
children = []
for child_item in data:
if child_item["parentId"] == node_id:
child_node = await cls.find_node_recursive(child_item["id"], data)
if child_node:
children.append(child_node)
2025-02-15 23:36:20 +08:00
meta = {
k: v for k, v in {
2025-02-13 02:27:44 +08:00
"title": item["title"],
"rank": item["rank"],
"icon": item["icon"],
2025-02-15 23:36:20 +08:00
"extraIcon": item["extraIcon"],
"showParent": item["showParent"],
"keepAlive": item["keepAlive"],
"frameSrc": item["frameSrc"],
"frameLoading": item["frameLoading"],
2025-02-13 02:27:44 +08:00
"permissions": [item["auths"]],
2025-02-15 23:36:20 +08:00
}.items() if v
}
if item["showLink"]:
meta["showLink"] = True
else:
meta["showLink"] = False
result = {
"name": item["name"],
"path": item["path"],
"meta": meta,
2025-02-13 02:27:44 +08:00
"children": children
}
if item["component"]:
result["component"] = item["component"].replace(".vue", "").replace(".ts", "").replace(".tsx",
"").replace(
".js", "").replace(".jsx", "").strip()
if item["redirect"]:
result["redirect"] = item["redirect"]
if result["name"] == "":
result.pop("name")
if result["children"] == []:
result.pop("children")
2025-02-16 00:22:50 +08:00
else:
result["children"] = sorted(result["children"], key=lambda x: x["meta"]["rank"])
2025-02-13 02:27:44 +08:00
break
return result
@classmethod
async def find_complete_data(cls, data: list) -> list:
"""
查找完整数据
:param data: 数据
"""
complete_data = []
root_ids = [item["id"] for item in data if not item["parentId"]]
for root_id in root_ids:
complete_data.append(await cls.find_node_recursive(root_id, data))
return complete_data
@classmethod
async def get_online_user(cls, request: Request) -> list:
"""
获取在线用户
"""
access_token_keys = await request.app.state.redis.keys(f'{RedisKeyConfig.ACCESS_TOKEN.key}*')
if not access_token_keys:
access_token_keys = []
access_token_values_list = [await request.app.state.redis.get(key) for key in access_token_keys]
online_info_list = []
for item in access_token_values_list:
payload = jwt.decode(item, JwtConfig.jwt_secret_key, algorithms=[JwtConfig.jwt_algorithm])
session_id = payload.get("session_id")
result = await LoginLog.get_or_none(session_id=session_id).values(
id="id",
user_id="user__id",
username="user__username",
user_nickname="user__nickname",
department_id="user__department__id",
department_name="user__department__name",
login_ip="login_ip",
login_location="login_location",
browser="browser",
os="os",
status="status",
login_time="login_time",
session_id="session_id",
create_time="create_time",
update_time="update_time"
)
online_info_list.append(result)
return online_info_list