259 lines
11 KiB
Python
259 lines
11 KiB
Python
# _*_ coding : UTF-8 _*_
|
||
# @Time : 2025/01/19 01:42
|
||
# @UpdateTime : 2025/01/19 01:42
|
||
# @Author : sonder
|
||
# @File : login.py
|
||
# @Software : PyCharm
|
||
# @Comment : 本程序
|
||
import uuid
|
||
from datetime import timedelta, datetime
|
||
from typing import Optional, Union
|
||
|
||
from fastapi import Form, Depends, Request
|
||
from fastapi.encoders import jsonable_encoder
|
||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||
from jose import jwt
|
||
from jose.exceptions import JWEInvalidAuth, ExpiredSignatureError, JWEError
|
||
|
||
from config.constant import RedisKeyConfig
|
||
from config.env import JwtConfig
|
||
from controller.query import QueryController
|
||
from exceptions.exception import AuthException
|
||
from models import LoginLog
|
||
from schemas.login import LoginParams
|
||
from utils.log import logger
|
||
from utils.password import Password
|
||
|
||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl='login')
|
||
|
||
|
||
class CustomOAuth2PasswordRequestForm(OAuth2PasswordRequestForm):
|
||
"""
|
||
自定义OAuth2PasswordRequestForm类,增加验证码及会话编号参数
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
grant_type: str = Form(default=None, regex='password'),
|
||
username: str = Form(..., description="用户账号"),
|
||
password: str = Form(..., description="用户密码"),
|
||
scope: str = Form(default=''),
|
||
client_id: Optional[str] = Form(default=None),
|
||
client_secret: Optional[str] = Form(default=None),
|
||
loginDays: Optional[int] = Form(default=1),
|
||
code: Optional[str] = Form(default=''),
|
||
uuid: Optional[str] = Form(default=''),
|
||
):
|
||
super().__init__(
|
||
grant_type=grant_type,
|
||
username=username,
|
||
password=password,
|
||
scope=scope,
|
||
client_id=client_id,
|
||
client_secret=client_secret,
|
||
)
|
||
self.code = code
|
||
self.uuid = uuid
|
||
self.loginDays = loginDays
|
||
|
||
|
||
class LoginController:
|
||
"""
|
||
登录控制器
|
||
"""
|
||
|
||
@classmethod
|
||
async def login(cls, params: LoginParams):
|
||
"""
|
||
登录
|
||
:param params:
|
||
:return:
|
||
"""
|
||
result = await QueryController.get_user_by_username(params.username)
|
||
if result and await Password.verify_password(params.password, result.password):
|
||
userInfo = await QueryController.get_user_info(user_id=result.id.__str__())
|
||
logger.success(f"用户 {params.username} 登录成功")
|
||
session_id = uuid.uuid4().__str__()
|
||
accessToken = await cls.create_token(
|
||
data={"user": jsonable_encoder(userInfo), "id": result.id.__str__(), "session_id": session_id},
|
||
expires_delta=timedelta(minutes=params.loginDays * 24 * 60))
|
||
expiresTime = (datetime.now() + timedelta(minutes=params.loginDays * 24 * 60)).timestamp()
|
||
refreshToken = await cls.create_token(
|
||
data={"user": jsonable_encoder(userInfo), "id": result.id.__str__(), "session_id": session_id},
|
||
expires_delta=timedelta(minutes=(params.loginDays * 24 + 2) * 60))
|
||
return {"status": True, "accessToken": accessToken, "refreshToken": refreshToken,
|
||
"userInfo": userInfo,
|
||
"expiresTime": expiresTime, "session_id": session_id, "expiresIn": params.loginDays * 24 * 60}
|
||
logger.error(f"用户 {params.username} 登录失败")
|
||
return {"status": False}
|
||
|
||
@classmethod
|
||
async def create_token(cls, data: dict, expires_delta: Union[timedelta, None] = None) -> str:
|
||
"""
|
||
创建token
|
||
:param data: 存储数据
|
||
:param expires_delta: 过期时间
|
||
:return: token
|
||
"""
|
||
to_copy = data.copy()
|
||
if expires_delta:
|
||
expire = datetime.now() + expires_delta
|
||
else:
|
||
expire = datetime.now() + timedelta(minutes=JwtConfig.jwt_expire_minutes)
|
||
to_copy.update({"exp": expire})
|
||
return jwt.encode(claims=to_copy, key=JwtConfig.jwt_secret_key, algorithm=JwtConfig.jwt_algorithm)
|
||
|
||
@classmethod
|
||
async def get_current_user(cls, request: Request = Request, token: str = Depends(oauth2_scheme)):
|
||
"""
|
||
获取当前用户
|
||
:param request:
|
||
:param token:
|
||
:return:
|
||
"""
|
||
try:
|
||
if token.startswith('Bearer'):
|
||
token = token.split(' ')[1]
|
||
payload = jwt.decode(token=token, key=JwtConfig.jwt_secret_key, algorithms=[JwtConfig.jwt_algorithm])
|
||
user_id: str = payload.get("id", "")
|
||
session_id: str = payload.get('session_id', "")
|
||
if not user_id:
|
||
logger.warning('用户token不合法')
|
||
raise AuthException(data='', message='用户token不合法')
|
||
except (JWEInvalidAuth, ExpiredSignatureError, JWEError):
|
||
logger.warning('用户token已失效,请重新登录')
|
||
raise AuthException(data='', message='用户token已失效,请重新登录')
|
||
userInfo = await request.app.state.redis.get(f'{RedisKeyConfig.USER_INFO.key}:{user_id}')
|
||
if userInfo:
|
||
userInfo = eval(userInfo)
|
||
if not userInfo:
|
||
userInfo = await QueryController.get_user_info(user_id=user_id)
|
||
await request.app.state.redis.set(f'{RedisKeyConfig.USER_INFO.key}:{user_id}',
|
||
str(jsonable_encoder(userInfo)),
|
||
ex=timedelta(minutes=5))
|
||
if not userInfo:
|
||
logger.warning('用户token不合法')
|
||
raise AuthException(data='', message='用户token不合法')
|
||
redis_token = await request.app.state.redis.get(f'{RedisKeyConfig.ACCESS_TOKEN.key}:{session_id}')
|
||
if not redis_token:
|
||
logger.warning('用户token已失效,请重新登录')
|
||
raise AuthException(data='', message='用户token已失效,请重新登录')
|
||
return userInfo
|
||
|
||
@classmethod
|
||
async def logout(cls, request: Request = Request, token: str = Depends(oauth2_scheme)) -> bool:
|
||
"""
|
||
登出
|
||
"""
|
||
try:
|
||
if token.startswith('Bearer'):
|
||
token = token.split(' ')[1]
|
||
payload = jwt.decode(token=token, key=JwtConfig.jwt_secret_key, algorithms=[JwtConfig.jwt_algorithm])
|
||
session_id: str = payload.get('session_id', "")
|
||
except (JWEInvalidAuth, ExpiredSignatureError, JWEError):
|
||
logger.warning('用户token已失效,请重新登录')
|
||
raise AuthException(data='', message='用户token已失效,请重新登录')
|
||
redis_token = await request.app.state.redis.get(f'{RedisKeyConfig.ACCESS_TOKEN.key}:{session_id}')
|
||
if redis_token == token:
|
||
await request.app.state.redis.delete(f'{RedisKeyConfig.ACCESS_TOKEN.key}:{session_id}')
|
||
return True
|
||
return False
|
||
|
||
@classmethod
|
||
async def get_user_routes(cls, user_id: str) -> Union[list, None]:
|
||
"""
|
||
获取用户路由
|
||
"""
|
||
permissions = await QueryController.get_user_permissions(user_id=user_id)
|
||
for permission in permissions:
|
||
permission["id"] = str(permission["id"])
|
||
permission["parentId"] = str(permission["parentId"]) if permission.get("parentId") else ""
|
||
permissions = await cls.find_complete_data(permissions)
|
||
return permissions
|
||
|
||
@classmethod
|
||
async def find_node_recursive(cls, node_id: str, data: list) -> dict:
|
||
"""
|
||
递归查找节点
|
||
:param node_id: 节点ID
|
||
:param data: 数据
|
||
"""
|
||
result = {}
|
||
data = list(filter(lambda x: x.get('type') == 0, data))
|
||
for item in data:
|
||
if item["id"] == node_id:
|
||
children = []
|
||
for child_item in data:
|
||
if child_item["parentId"] == node_id:
|
||
child_node = await cls.find_node_recursive(child_item["id"], data)
|
||
if child_node:
|
||
children.append(child_node)
|
||
result = {
|
||
"name": item["name"],
|
||
"path": item["path"],
|
||
"meta": {
|
||
"title": item["title"],
|
||
"rank": item["rank"],
|
||
"icon": item["icon"],
|
||
"permissions": [item["auths"]],
|
||
},
|
||
"children": children
|
||
}
|
||
if item["component"]:
|
||
result["component"] = item["component"].replace(".vue", "").replace(".ts", "").replace(".tsx",
|
||
"").replace(
|
||
".js", "").replace(".jsx", "").strip()
|
||
if item["redirect"]:
|
||
result["redirect"] = item["redirect"]
|
||
if result["name"] == "":
|
||
result.pop("name")
|
||
if result["children"] == []:
|
||
result.pop("children")
|
||
break
|
||
return result
|
||
|
||
@classmethod
|
||
async def find_complete_data(cls, data: list) -> list:
|
||
"""
|
||
查找完整数据
|
||
:param data: 数据
|
||
"""
|
||
complete_data = []
|
||
root_ids = [item["id"] for item in data if not item["parentId"]]
|
||
for root_id in root_ids:
|
||
complete_data.append(await cls.find_node_recursive(root_id, data))
|
||
return complete_data
|
||
|
||
@classmethod
|
||
async def get_online_user(cls, request: Request) -> list:
|
||
"""
|
||
获取在线用户
|
||
"""
|
||
access_token_keys = await request.app.state.redis.keys(f'{RedisKeyConfig.ACCESS_TOKEN.key}*')
|
||
if not access_token_keys:
|
||
access_token_keys = []
|
||
access_token_values_list = [await request.app.state.redis.get(key) for key in access_token_keys]
|
||
online_info_list = []
|
||
for item in access_token_values_list:
|
||
payload = jwt.decode(item, JwtConfig.jwt_secret_key, algorithms=[JwtConfig.jwt_algorithm])
|
||
session_id = payload.get("session_id")
|
||
result = await LoginLog.get_or_none(session_id=session_id).values(
|
||
id="id",
|
||
user_id="user__id",
|
||
username="user__username",
|
||
user_nickname="user__nickname",
|
||
department_id="user__department__id",
|
||
department_name="user__department__name",
|
||
login_ip="login_ip",
|
||
login_location="login_location",
|
||
browser="browser",
|
||
os="os",
|
||
status="status",
|
||
login_time="login_time",
|
||
session_id="session_id",
|
||
create_time="create_time",
|
||
update_time="update_time"
|
||
)
|
||
online_info_list.append(result)
|
||
return online_info_list
|