2025-02-13 02:27:44 +08:00
|
|
|
|
# _*_ coding : UTF-8 _*_
|
|
|
|
|
# @Time : 2025/01/19 01:42
|
|
|
|
|
# @UpdateTime : 2025/01/19 01:42
|
|
|
|
|
# @Author : sonder
|
|
|
|
|
# @File : login.py
|
|
|
|
|
# @Software : PyCharm
|
|
|
|
|
# @Comment : 本程序
|
|
|
|
|
import uuid
|
|
|
|
|
from datetime import timedelta, datetime
|
|
|
|
|
from typing import Optional, Union
|
|
|
|
|
|
|
|
|
|
from fastapi import Form, Depends, Request
|
|
|
|
|
from fastapi.encoders import jsonable_encoder
|
|
|
|
|
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
|
|
|
|
from jose import jwt
|
|
|
|
|
from jose.exceptions import JWEInvalidAuth, ExpiredSignatureError, JWEError
|
|
|
|
|
|
|
|
|
|
from config.constant import RedisKeyConfig
|
|
|
|
|
from config.env import JwtConfig
|
|
|
|
|
from controller.query import QueryController
|
|
|
|
|
from exceptions.exception import AuthException
|
|
|
|
|
from models import LoginLog
|
|
|
|
|
from schemas.login import LoginParams
|
|
|
|
|
from utils.log import logger
|
|
|
|
|
from utils.password import Password
|
|
|
|
|
|
|
|
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl='login')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CustomOAuth2PasswordRequestForm(OAuth2PasswordRequestForm):
|
|
|
|
|
"""
|
|
|
|
|
自定义OAuth2PasswordRequestForm类,增加验证码及会话编号参数
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
grant_type: str = Form(default=None, regex='password'),
|
|
|
|
|
username: str = Form(..., description="用户账号"),
|
|
|
|
|
password: str = Form(..., description="用户密码"),
|
|
|
|
|
scope: str = Form(default=''),
|
|
|
|
|
client_id: Optional[str] = Form(default=None),
|
|
|
|
|
client_secret: Optional[str] = Form(default=None),
|
|
|
|
|
loginDays: Optional[int] = Form(default=1),
|
|
|
|
|
code: Optional[str] = Form(default=''),
|
|
|
|
|
uuid: Optional[str] = Form(default=''),
|
|
|
|
|
):
|
|
|
|
|
super().__init__(
|
|
|
|
|
grant_type=grant_type,
|
|
|
|
|
username=username,
|
|
|
|
|
password=password,
|
|
|
|
|
scope=scope,
|
|
|
|
|
client_id=client_id,
|
|
|
|
|
client_secret=client_secret,
|
|
|
|
|
)
|
|
|
|
|
self.code = code
|
|
|
|
|
self.uuid = uuid
|
|
|
|
|
self.loginDays = loginDays
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoginController:
|
|
|
|
|
"""
|
|
|
|
|
登录控制器
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
async def login(cls, params: LoginParams):
|
|
|
|
|
"""
|
|
|
|
|
登录
|
|
|
|
|
:param params:
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
result = await QueryController.get_user_by_username(params.username)
|
|
|
|
|
if result and await Password.verify_password(params.password, result.password):
|
|
|
|
|
userInfo = await QueryController.get_user_info(user_id=result.id.__str__())
|
|
|
|
|
logger.success(f"用户 {params.username} 登录成功")
|
|
|
|
|
session_id = uuid.uuid4().__str__()
|
|
|
|
|
accessToken = await cls.create_token(
|
|
|
|
|
data={"user": jsonable_encoder(userInfo), "id": result.id.__str__(), "session_id": session_id},
|
|
|
|
|
expires_delta=timedelta(minutes=params.loginDays * 24 * 60))
|
|
|
|
|
expiresTime = (datetime.now() + timedelta(minutes=params.loginDays * 24 * 60)).timestamp()
|
|
|
|
|
refreshToken = await cls.create_token(
|
|
|
|
|
data={"user": jsonable_encoder(userInfo), "id": result.id.__str__(), "session_id": session_id},
|
|
|
|
|
expires_delta=timedelta(minutes=(params.loginDays * 24 + 2) * 60))
|
|
|
|
|
return {"status": True, "accessToken": accessToken, "refreshToken": refreshToken,
|
|
|
|
|
"userInfo": userInfo,
|
|
|
|
|
"expiresTime": expiresTime, "session_id": session_id, "expiresIn": params.loginDays * 24 * 60}
|
|
|
|
|
logger.error(f"用户 {params.username} 登录失败")
|
|
|
|
|
return {"status": False}
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
async def create_token(cls, data: dict, expires_delta: Union[timedelta, None] = None) -> str:
|
|
|
|
|
"""
|
|
|
|
|
创建token
|
|
|
|
|
:param data: 存储数据
|
|
|
|
|
:param expires_delta: 过期时间
|
|
|
|
|
:return: token
|
|
|
|
|
"""
|
|
|
|
|
to_copy = data.copy()
|
|
|
|
|
if expires_delta:
|
|
|
|
|
expire = datetime.now() + expires_delta
|
|
|
|
|
else:
|
|
|
|
|
expire = datetime.now() + timedelta(minutes=JwtConfig.jwt_expire_minutes)
|
|
|
|
|
to_copy.update({"exp": expire})
|
|
|
|
|
return jwt.encode(claims=to_copy, key=JwtConfig.jwt_secret_key, algorithm=JwtConfig.jwt_algorithm)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
async def get_current_user(cls, request: Request = Request, token: str = Depends(oauth2_scheme)):
|
|
|
|
|
"""
|
|
|
|
|
获取当前用户
|
|
|
|
|
:param request:
|
|
|
|
|
:param token:
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
if token.startswith('Bearer'):
|
|
|
|
|
token = token.split(' ')[1]
|
|
|
|
|
payload = jwt.decode(token=token, key=JwtConfig.jwt_secret_key, algorithms=[JwtConfig.jwt_algorithm])
|
|
|
|
|
user_id: str = payload.get("id", "")
|
|
|
|
|
session_id: str = payload.get('session_id', "")
|
|
|
|
|
if not user_id:
|
|
|
|
|
logger.warning('用户token不合法')
|
|
|
|
|
raise AuthException(data='', message='用户token不合法')
|
|
|
|
|
except (JWEInvalidAuth, ExpiredSignatureError, JWEError):
|
|
|
|
|
logger.warning('用户token已失效,请重新登录')
|
|
|
|
|
raise AuthException(data='', message='用户token已失效,请重新登录')
|
|
|
|
|
userInfo = await request.app.state.redis.get(f'{RedisKeyConfig.USER_INFO.key}:{user_id}')
|
|
|
|
|
if userInfo:
|
|
|
|
|
userInfo = eval(userInfo)
|
|
|
|
|
if not userInfo:
|
|
|
|
|
userInfo = await QueryController.get_user_info(user_id=user_id)
|
|
|
|
|
await request.app.state.redis.set(f'{RedisKeyConfig.USER_INFO.key}:{user_id}',
|
|
|
|
|
str(jsonable_encoder(userInfo)),
|
|
|
|
|
ex=timedelta(minutes=5))
|
|
|
|
|
if not userInfo:
|
|
|
|
|
logger.warning('用户token不合法')
|
|
|
|
|
raise AuthException(data='', message='用户token不合法')
|
|
|
|
|
redis_token = await request.app.state.redis.get(f'{RedisKeyConfig.ACCESS_TOKEN.key}:{session_id}')
|
|
|
|
|
if not redis_token:
|
|
|
|
|
logger.warning('用户token已失效,请重新登录')
|
|
|
|
|
raise AuthException(data='', message='用户token已失效,请重新登录')
|
|
|
|
|
return userInfo
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
async def logout(cls, request: Request = Request, token: str = Depends(oauth2_scheme)) -> bool:
|
|
|
|
|
"""
|
|
|
|
|
登出
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
if token.startswith('Bearer'):
|
|
|
|
|
token = token.split(' ')[1]
|
|
|
|
|
payload = jwt.decode(token=token, key=JwtConfig.jwt_secret_key, algorithms=[JwtConfig.jwt_algorithm])
|
|
|
|
|
session_id: str = payload.get('session_id', "")
|
|
|
|
|
except (JWEInvalidAuth, ExpiredSignatureError, JWEError):
|
|
|
|
|
logger.warning('用户token已失效,请重新登录')
|
|
|
|
|
raise AuthException(data='', message='用户token已失效,请重新登录')
|
|
|
|
|
redis_token = await request.app.state.redis.get(f'{RedisKeyConfig.ACCESS_TOKEN.key}:{session_id}')
|
|
|
|
|
if redis_token == token:
|
|
|
|
|
await request.app.state.redis.delete(f'{RedisKeyConfig.ACCESS_TOKEN.key}:{session_id}')
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
async def get_user_routes(cls, user_id: str) -> Union[list, None]:
|
|
|
|
|
"""
|
|
|
|
|
获取用户路由
|
|
|
|
|
"""
|
|
|
|
|
permissions = await QueryController.get_user_permissions(user_id=user_id)
|
|
|
|
|
for permission in permissions:
|
|
|
|
|
permission["id"] = str(permission["id"])
|
|
|
|
|
permission["parentId"] = str(permission["parentId"]) if permission.get("parentId") else ""
|
|
|
|
|
permissions = await cls.find_complete_data(permissions)
|
|
|
|
|
return permissions
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
async def find_node_recursive(cls, node_id: str, data: list) -> dict:
|
|
|
|
|
"""
|
|
|
|
|
递归查找节点
|
|
|
|
|
:param node_id: 节点ID
|
|
|
|
|
:param data: 数据
|
|
|
|
|
"""
|
|
|
|
|
result = {}
|
|
|
|
|
data = list(filter(lambda x: x.get('type') == 0, data))
|
|
|
|
|
for item in data:
|
|
|
|
|
if item["id"] == node_id:
|
|
|
|
|
children = []
|
|
|
|
|
for child_item in data:
|
|
|
|
|
if child_item["parentId"] == node_id:
|
|
|
|
|
child_node = await cls.find_node_recursive(child_item["id"], data)
|
|
|
|
|
if child_node:
|
|
|
|
|
children.append(child_node)
|
2025-02-15 23:36:20 +08:00
|
|
|
|
meta = {
|
|
|
|
|
k: v for k, v in {
|
2025-02-13 02:27:44 +08:00
|
|
|
|
"title": item["title"],
|
|
|
|
|
"rank": item["rank"],
|
|
|
|
|
"icon": item["icon"],
|
2025-02-15 23:36:20 +08:00
|
|
|
|
"extraIcon": item["extraIcon"],
|
|
|
|
|
"showParent": item["showParent"],
|
|
|
|
|
"keepAlive": item["keepAlive"],
|
|
|
|
|
"frameSrc": item["frameSrc"],
|
|
|
|
|
"frameLoading": item["frameLoading"],
|
2025-02-13 02:27:44 +08:00
|
|
|
|
"permissions": [item["auths"]],
|
2025-02-15 23:36:20 +08:00
|
|
|
|
}.items() if v
|
|
|
|
|
}
|
|
|
|
|
if item["showLink"]:
|
|
|
|
|
meta["showLink"] = True
|
|
|
|
|
else:
|
|
|
|
|
meta["showLink"] = False
|
|
|
|
|
result = {
|
|
|
|
|
"name": item["name"],
|
|
|
|
|
"path": item["path"],
|
|
|
|
|
"meta": meta,
|
2025-02-13 02:27:44 +08:00
|
|
|
|
"children": children
|
|
|
|
|
}
|
|
|
|
|
if item["component"]:
|
|
|
|
|
result["component"] = item["component"].replace(".vue", "").replace(".ts", "").replace(".tsx",
|
|
|
|
|
"").replace(
|
|
|
|
|
".js", "").replace(".jsx", "").strip()
|
|
|
|
|
if item["redirect"]:
|
|
|
|
|
result["redirect"] = item["redirect"]
|
|
|
|
|
if result["name"] == "":
|
|
|
|
|
result.pop("name")
|
|
|
|
|
if result["children"] == []:
|
|
|
|
|
result.pop("children")
|
|
|
|
|
break
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
async def find_complete_data(cls, data: list) -> list:
|
|
|
|
|
"""
|
|
|
|
|
查找完整数据
|
|
|
|
|
:param data: 数据
|
|
|
|
|
"""
|
|
|
|
|
complete_data = []
|
|
|
|
|
root_ids = [item["id"] for item in data if not item["parentId"]]
|
|
|
|
|
for root_id in root_ids:
|
|
|
|
|
complete_data.append(await cls.find_node_recursive(root_id, data))
|
|
|
|
|
return complete_data
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
async def get_online_user(cls, request: Request) -> list:
|
|
|
|
|
"""
|
|
|
|
|
获取在线用户
|
|
|
|
|
"""
|
|
|
|
|
access_token_keys = await request.app.state.redis.keys(f'{RedisKeyConfig.ACCESS_TOKEN.key}*')
|
|
|
|
|
if not access_token_keys:
|
|
|
|
|
access_token_keys = []
|
|
|
|
|
access_token_values_list = [await request.app.state.redis.get(key) for key in access_token_keys]
|
|
|
|
|
online_info_list = []
|
|
|
|
|
for item in access_token_values_list:
|
|
|
|
|
payload = jwt.decode(item, JwtConfig.jwt_secret_key, algorithms=[JwtConfig.jwt_algorithm])
|
|
|
|
|
session_id = payload.get("session_id")
|
|
|
|
|
result = await LoginLog.get_or_none(session_id=session_id).values(
|
|
|
|
|
id="id",
|
|
|
|
|
user_id="user__id",
|
|
|
|
|
username="user__username",
|
|
|
|
|
user_nickname="user__nickname",
|
|
|
|
|
department_id="user__department__id",
|
|
|
|
|
department_name="user__department__name",
|
|
|
|
|
login_ip="login_ip",
|
|
|
|
|
login_location="login_location",
|
|
|
|
|
browser="browser",
|
|
|
|
|
os="os",
|
|
|
|
|
status="status",
|
|
|
|
|
login_time="login_time",
|
|
|
|
|
session_id="session_id",
|
|
|
|
|
create_time="create_time",
|
|
|
|
|
update_time="update_time"
|
|
|
|
|
)
|
|
|
|
|
online_info_list.append(result)
|
|
|
|
|
return online_info_list
|