refactor(database): migrate to sqlmodel
This commit is contained in:
49
app/auth.py
49
app/auth.py
@@ -4,7 +4,6 @@ from datetime import datetime, timedelta
|
|||||||
import hashlib
|
import hashlib
|
||||||
import secrets
|
import secrets
|
||||||
import string
|
import string
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.database import (
|
from app.database import (
|
||||||
@@ -15,7 +14,7 @@ from app.database import (
|
|||||||
import bcrypt
|
import bcrypt
|
||||||
from jose import JWTError, jwt
|
from jose import JWTError, jwt
|
||||||
from passlib.context import CryptContext
|
from passlib.context import CryptContext
|
||||||
from sqlalchemy.orm import Session
|
from sqlmodel import Session, select
|
||||||
|
|
||||||
# 密码哈希上下文
|
# 密码哈希上下文
|
||||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||||
@@ -71,7 +70,7 @@ def get_password_hash(password: str) -> str:
|
|||||||
return pw_bcrypt.decode()
|
return pw_bcrypt.decode()
|
||||||
|
|
||||||
|
|
||||||
def authenticate_user_legacy(db: Session, name: str, password: str) -> Optional[DBUser]:
|
def authenticate_user_legacy(db: Session, name: str, password: str) -> DBUser | None:
|
||||||
"""
|
"""
|
||||||
验证用户身份 - 使用类似 from_login 的逻辑
|
验证用户身份 - 使用类似 from_login 的逻辑
|
||||||
"""
|
"""
|
||||||
@@ -79,12 +78,13 @@ def authenticate_user_legacy(db: Session, name: str, password: str) -> Optional[
|
|||||||
pw_md5 = hashlib.md5(password.encode()).hexdigest()
|
pw_md5 = hashlib.md5(password.encode()).hexdigest()
|
||||||
|
|
||||||
# 2. 根据用户名查找用户
|
# 2. 根据用户名查找用户
|
||||||
user = db.query(DBUser).filter(DBUser.name == name).first()
|
statement = select(DBUser).where(DBUser.name == name)
|
||||||
|
user = db.exec(statement).first()
|
||||||
if not user:
|
if not user:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 3. 验证密码
|
# 3. 验证密码 - 修复逻辑错误
|
||||||
if not (user.pw_bcrypt is None and user.pw_bcrypt != ""):
|
if user.pw_bcrypt is None or user.pw_bcrypt == "":
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 4. 检查缓存
|
# 4. 检查缓存
|
||||||
@@ -107,12 +107,12 @@ def authenticate_user_legacy(db: Session, name: str, password: str) -> Optional[
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def authenticate_user(db: Session, username: str, password: str) -> Optional[DBUser]:
|
def authenticate_user(db: Session, username: str, password: str) -> DBUser | None:
|
||||||
"""验证用户身份"""
|
"""验证用户身份"""
|
||||||
return authenticate_user_legacy(db, username, password)
|
return authenticate_user_legacy(db, username, password)
|
||||||
|
|
||||||
|
|
||||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
|
||||||
"""创建访问令牌"""
|
"""创建访问令牌"""
|
||||||
to_encode = data.copy()
|
to_encode = data.copy()
|
||||||
if expires_delta:
|
if expires_delta:
|
||||||
@@ -136,7 +136,7 @@ def generate_refresh_token() -> str:
|
|||||||
return "".join(secrets.choice(characters) for _ in range(length))
|
return "".join(secrets.choice(characters) for _ in range(length))
|
||||||
|
|
||||||
|
|
||||||
def verify_token(token: str) -> Optional[dict]:
|
def verify_token(token: str) -> dict | None:
|
||||||
"""验证访问令牌"""
|
"""验证访问令牌"""
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(
|
payload = jwt.decode(
|
||||||
@@ -154,7 +154,10 @@ def store_token(
|
|||||||
expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
|
expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
|
||||||
|
|
||||||
# 删除用户的旧令牌
|
# 删除用户的旧令牌
|
||||||
db.query(OAuthToken).filter(OAuthToken.user_id == user_id).delete()
|
statement = select(OAuthToken).where(OAuthToken.user_id == user_id)
|
||||||
|
old_tokens = db.exec(statement).all()
|
||||||
|
for token in old_tokens:
|
||||||
|
db.delete(token)
|
||||||
|
|
||||||
# 创建新令牌记录
|
# 创建新令牌记录
|
||||||
token_record = OAuthToken(
|
token_record = OAuthToken(
|
||||||
@@ -169,25 +172,19 @@ def store_token(
|
|||||||
return token_record
|
return token_record
|
||||||
|
|
||||||
|
|
||||||
def get_token_by_access_token(db: Session, access_token: str) -> Optional[OAuthToken]:
|
def get_token_by_access_token(db: Session, access_token: str) -> OAuthToken | None:
|
||||||
"""根据访问令牌获取令牌记录"""
|
"""根据访问令牌获取令牌记录"""
|
||||||
return (
|
statement = select(OAuthToken).where(
|
||||||
db.query(OAuthToken)
|
OAuthToken.access_token == access_token,
|
||||||
.filter(
|
OAuthToken.expires_at > datetime.utcnow(),
|
||||||
OAuthToken.access_token == access_token,
|
|
||||||
OAuthToken.expires_at > datetime.utcnow(),
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
)
|
||||||
|
return db.exec(statement).first()
|
||||||
|
|
||||||
|
|
||||||
def get_token_by_refresh_token(db: Session, refresh_token: str) -> Optional[OAuthToken]:
|
def get_token_by_refresh_token(db: Session, refresh_token: str) -> OAuthToken | None:
|
||||||
"""根据刷新令牌获取令牌记录"""
|
"""根据刷新令牌获取令牌记录"""
|
||||||
return (
|
statement = select(OAuthToken).where(
|
||||||
db.query(OAuthToken)
|
OAuthToken.refresh_token == refresh_token,
|
||||||
.filter(
|
OAuthToken.expires_at > datetime.utcnow(),
|
||||||
OAuthToken.refresh_token == refresh_token,
|
|
||||||
OAuthToken.expires_at > datetime.utcnow(),
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
)
|
||||||
|
return db.exec(statement).first()
|
||||||
|
|||||||
873
app/database.py
873
app/database.py
File diff suppressed because it is too large
Load Diff
@@ -1,7 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from sqlalchemy import create_engine
|
from sqlmodel import Session, create_engine
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import redis
|
import redis
|
||||||
@@ -11,7 +10,6 @@ from app.config import settings
|
|||||||
|
|
||||||
# 数据库引擎
|
# 数据库引擎
|
||||||
engine = create_engine(settings.DATABASE_URL)
|
engine = create_engine(settings.DATABASE_URL)
|
||||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
||||||
|
|
||||||
# Redis 连接
|
# Redis 连接
|
||||||
if redis:
|
if redis:
|
||||||
@@ -22,11 +20,8 @@ else:
|
|||||||
|
|
||||||
# 数据库依赖
|
# 数据库依赖
|
||||||
def get_db():
|
def get_db():
|
||||||
db = SessionLocal()
|
with Session(engine) as session:
|
||||||
try:
|
yield session
|
||||||
yield db
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
|
|
||||||
# Redis 依赖
|
# Redis 依赖
|
||||||
|
|||||||
@@ -1,14 +1,16 @@
|
|||||||
from fastapi import Depends, HTTPException
|
from __future__ import annotations
|
||||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from app.auth import get_token_by_access_token
|
from app.auth import get_token_by_access_token
|
||||||
|
|
||||||
from .database import get_db
|
|
||||||
from app.database import (
|
from app.database import (
|
||||||
User as DBUser,
|
User as DBUser,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .database import get_db
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException
|
||||||
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||||
|
from sqlmodel import Session, select
|
||||||
|
|
||||||
security = HTTPBearer()
|
security = HTTPBearer()
|
||||||
|
|
||||||
|
|
||||||
@@ -29,5 +31,5 @@ async def get_current_user_by_token(token: str, db: Session) -> DBUser | None:
|
|||||||
token_record = get_token_by_access_token(db, token)
|
token_record = get_token_by_access_token(db, token)
|
||||||
if not token_record:
|
if not token_record:
|
||||||
return None
|
return None
|
||||||
user = db.query(DBUser).filter(DBUser.id == token_record.user_id).first()
|
user = db.exec(select(DBUser).where(DBUser.id == token_record.user_id)).first()
|
||||||
return user
|
return user
|
||||||
|
|||||||
@@ -3,10 +3,14 @@ from __future__ import annotations
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
from app.database import (
|
||||||
|
LazerUserAchievement,
|
||||||
|
Team as Team,
|
||||||
|
)
|
||||||
|
|
||||||
from .score import GameMode
|
from .score import GameMode
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from app.database import LazerUserAchievement # 添加数据库模型导入
|
|
||||||
|
|
||||||
|
|
||||||
class PlayStyle(str, Enum):
|
class PlayStyle(str, Enum):
|
||||||
@@ -110,13 +114,6 @@ class DailyChallengeStats(BaseModel):
|
|||||||
weekly_streak_current: int = 0
|
weekly_streak_current: int = 0
|
||||||
|
|
||||||
|
|
||||||
class Team(BaseModel):
|
|
||||||
flag_url: str
|
|
||||||
id: int
|
|
||||||
name: str
|
|
||||||
short_name: str
|
|
||||||
|
|
||||||
|
|
||||||
class Page(BaseModel):
|
class Page(BaseModel):
|
||||||
html: str = ""
|
html: str = ""
|
||||||
raw: str = ""
|
raw: str = ""
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from app.dependencies import get_db
|
|||||||
from app.models.oauth import TokenResponse
|
from app.models.oauth import TokenResponse
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Form, HTTPException
|
from fastapi import APIRouter, Depends, Form, HTTPException
|
||||||
from sqlalchemy.orm import Session
|
from sqlmodel import Session
|
||||||
|
|
||||||
router = APIRouter(tags=["osu! OAuth 认证"])
|
router = APIRouter(tags=["osu! OAuth 认证"])
|
||||||
|
|
||||||
@@ -92,7 +92,7 @@ async def oauth_token(
|
|||||||
new_refresh_token = generate_refresh_token()
|
new_refresh_token = generate_refresh_token()
|
||||||
|
|
||||||
# 更新令牌
|
# 更新令牌
|
||||||
user_id = int(getattr(token_record, 'user_id'))
|
user_id = int(getattr(token_record, "user_id"))
|
||||||
store_token(
|
store_token(
|
||||||
db,
|
db,
|
||||||
user_id,
|
user_id,
|
||||||
|
|||||||
282
app/utils.py
282
app/utils.py
@@ -1,8 +1,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
from app.database import User as DBUser
|
from app.database import (
|
||||||
|
LazerUserCounts,
|
||||||
|
LazerUserProfile,
|
||||||
|
LazerUserStatistics,
|
||||||
|
User as DBUser,
|
||||||
|
)
|
||||||
from app.models.user import (
|
from app.models.user import (
|
||||||
Country,
|
Country,
|
||||||
Cover,
|
Cover,
|
||||||
@@ -14,7 +19,6 @@ from app.models.user import (
|
|||||||
RankHighest,
|
RankHighest,
|
||||||
RankHistory,
|
RankHistory,
|
||||||
Statistics,
|
Statistics,
|
||||||
Team,
|
|
||||||
User,
|
User,
|
||||||
UserAchievement,
|
UserAchievement,
|
||||||
)
|
)
|
||||||
@@ -37,25 +41,28 @@ def convert_db_user_to_api_user(
|
|||||||
profile = db_user.lazer_profile
|
profile = db_user.lazer_profile
|
||||||
if not profile:
|
if not profile:
|
||||||
# 如果没有 lazer 资料,使用默认值
|
# 如果没有 lazer 资料,使用默认值
|
||||||
profile = create_default_profile(db_user)
|
profile = LazerUserProfile(
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
|
||||||
# 获取 Lazer 用户计数
|
# 获取 Lazer 用户计数 - 使用正确的 lazer_counts 关系
|
||||||
|
lzrcnt = db_user.lazer_counts
|
||||||
|
|
||||||
lzrcnt = db_user.lazer_statistics
|
|
||||||
if not lzrcnt:
|
if not lzrcnt:
|
||||||
# 如果没有 lazer 计数,使用默认值
|
# 如果没有 lazer 计数,使用默认值
|
||||||
lzrcnt = create_default_counts()
|
lzrcnt = LazerUserCounts(user_id=user_id)
|
||||||
|
|
||||||
# 获取指定模式的统计信息
|
# 获取指定模式的统计信息
|
||||||
user_stats = None
|
user_stats = None
|
||||||
for stat in db_user.lazer_statistics:
|
if db_user.lazer_statistics:
|
||||||
if stat.mode == ruleset:
|
for stat in db_user.lazer_statistics:
|
||||||
user_stats = stat
|
if stat.mode == ruleset:
|
||||||
break
|
user_stats = stat
|
||||||
|
break
|
||||||
|
|
||||||
if not user_stats:
|
if not user_stats:
|
||||||
# 如果没有找到指定模式的统计,创建默认统计
|
# 如果没有找到指定模式的统计,创建默认统计
|
||||||
user_stats = create_default_lazer_statistics(ruleset)
|
user_stats = LazerUserStatistics(user_id=user_id)
|
||||||
|
|
||||||
# 获取国家信息
|
# 获取国家信息
|
||||||
country_code = db_user.country_code if db_user.country_code is not None else "XX"
|
country_code = db_user.country_code if db_user.country_code is not None else "XX"
|
||||||
@@ -66,7 +73,7 @@ def convert_db_user_to_api_user(
|
|||||||
kudosu = Kudosu(available=0, total=0)
|
kudosu = Kudosu(available=0, total=0)
|
||||||
|
|
||||||
# 获取计数信息
|
# 获取计数信息
|
||||||
counts = create_default_counts()
|
counts = LazerUserCounts(user_id=user_id)
|
||||||
|
|
||||||
# 转换统计信息
|
# 转换统计信息
|
||||||
statistics = Statistics(
|
statistics = Statistics(
|
||||||
@@ -199,12 +206,7 @@ def convert_db_user_to_api_user(
|
|||||||
team = None
|
team = None
|
||||||
if db_user.team_membership:
|
if db_user.team_membership:
|
||||||
team_member = db_user.team_membership[0] # 假设用户只属于一个团队
|
team_member = db_user.team_membership[0] # 假设用户只属于一个团队
|
||||||
team = Team(
|
team = team_member.team
|
||||||
flag_url=team_member.team.flag_url or "",
|
|
||||||
id=team_member.team.id,
|
|
||||||
name=team_member.team.name,
|
|
||||||
short_name=team_member.team.short_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建用户对象
|
# 创建用户对象
|
||||||
# 从db_user获取基本字段值
|
# 从db_user获取基本字段值
|
||||||
@@ -229,27 +231,25 @@ def convert_db_user_to_api_user(
|
|||||||
avatar_url = str(db_user.avatar.r2_original_url)
|
avatar_url = str(db_user.avatar.r2_original_url)
|
||||||
|
|
||||||
# 如果还是没有找到,通过查询获取
|
# 如果还是没有找到,通过查询获取
|
||||||
if db_session and avatar_url is None:
|
# if db_session and avatar_url is None:
|
||||||
try:
|
# try:
|
||||||
# 导入UserAvatar模型
|
# # 导入UserAvatar模型
|
||||||
from app.database import UserAvatar
|
|
||||||
|
|
||||||
# 尝试查找用户的头像记录
|
# # 尝试查找用户的头像记录
|
||||||
avatar_record = (
|
# statement = select(UserAvatar).where(
|
||||||
db_session.query(UserAvatar)
|
# UserAvatar.user_id == user_id, UserAvatar.is_active == True
|
||||||
.filter_by(user_id=user_id, is_active=True)
|
# )
|
||||||
.first()
|
# avatar_record = db_session.exec(statement).first()
|
||||||
)
|
# if avatar_record is not None:
|
||||||
if avatar_record is not None:
|
# if avatar_record.r2_game_url is not None:
|
||||||
if avatar_record.r2_game_url is not None:
|
# # 优先使用游戏用的头像URL
|
||||||
# 优先使用游戏用的头像URL
|
# avatar_url = str(avatar_record.r2_game_url)
|
||||||
avatar_url = str(avatar_record.r2_game_url)
|
# elif avatar_record.r2_original_url is not None:
|
||||||
elif avatar_record.r2_original_url is not None:
|
# # 其次使用原始头像URL
|
||||||
# 其次使用原始头像URL
|
# avatar_url = str(avatar_record.r2_original_url)
|
||||||
avatar_url = str(avatar_record.r2_original_url)
|
# except Exception as e:
|
||||||
except Exception as e:
|
# print(f"获取用户头像时出错: {e}")
|
||||||
print(f"获取用户头像时出错: {e}")
|
# print(f"最终头像URL: {avatar_url}")
|
||||||
print(f"最终头像URL: {avatar_url}")
|
|
||||||
# 如果仍然没有找到头像URL,则使用默认URL
|
# 如果仍然没有找到头像URL,则使用默认URL
|
||||||
if avatar_url is None:
|
if avatar_url is None:
|
||||||
avatar_url = "https://a.gu-osu.gmoe.cc/api/users/avatar/1"
|
avatar_url = "https://a.gu-osu.gmoe.cc/api/users/avatar/1"
|
||||||
@@ -265,15 +265,12 @@ def convert_db_user_to_api_user(
|
|||||||
"kudosu",
|
"kudosu",
|
||||||
]
|
]
|
||||||
if profile and profile.profile_order:
|
if profile and profile.profile_order:
|
||||||
profile_order = profile.profile_order
|
profile_order = profile.profile_order.split(",")
|
||||||
|
|
||||||
# 在convert_db_user_to_api_user函数中添加active_tournament_banners处理
|
# 在convert_db_user_to_api_user函数中添加active_tournament_banners处理
|
||||||
active_tournament_banners = []
|
active_tournament_banners = []
|
||||||
if (
|
if db_user.active_banners:
|
||||||
hasattr(db_user, "lazer_tournament_banners")
|
for banner in db_user.active_banners:
|
||||||
and db_user.lazer_tournament_banners
|
|
||||||
):
|
|
||||||
for banner in db_user.lazer_tournament_banners:
|
|
||||||
active_tournament_banners.append(
|
active_tournament_banners.append(
|
||||||
{
|
{
|
||||||
"tournament_id": banner.tournament_id,
|
"tournament_id": banner.tournament_id,
|
||||||
@@ -284,7 +281,7 @@ def convert_db_user_to_api_user(
|
|||||||
|
|
||||||
# 在convert_db_user_to_api_user函数中添加badges处理
|
# 在convert_db_user_to_api_user函数中添加badges处理
|
||||||
badges = []
|
badges = []
|
||||||
if hasattr(db_user, "lazer_badges") and db_user.lazer_badges:
|
if db_user.lazer_badges:
|
||||||
for badge in db_user.lazer_badges:
|
for badge in db_user.lazer_badges:
|
||||||
badges.append(
|
badges.append(
|
||||||
{
|
{
|
||||||
@@ -298,10 +295,7 @@ def convert_db_user_to_api_user(
|
|||||||
|
|
||||||
# 在convert_db_user_to_api_user函数中添加monthly_playcounts处理
|
# 在convert_db_user_to_api_user函数中添加monthly_playcounts处理
|
||||||
monthly_playcounts = []
|
monthly_playcounts = []
|
||||||
if (
|
if db_user.lazer_monthly_playcounts:
|
||||||
hasattr(db_user, "lazer_monthly_playcounts")
|
|
||||||
and db_user.lazer_monthly_playcounts
|
|
||||||
):
|
|
||||||
for playcount in db_user.lazer_monthly_playcounts:
|
for playcount in db_user.lazer_monthly_playcounts:
|
||||||
monthly_playcounts.append(
|
monthly_playcounts.append(
|
||||||
{
|
{
|
||||||
@@ -314,10 +308,7 @@ def convert_db_user_to_api_user(
|
|||||||
|
|
||||||
# 在convert_db_user_to_api_user函数中添加previous_usernames处理
|
# 在convert_db_user_to_api_user函数中添加previous_usernames处理
|
||||||
previous_usernames = []
|
previous_usernames = []
|
||||||
if (
|
if db_user.lazer_previous_usernames:
|
||||||
hasattr(db_user, "lazer_previous_usernames")
|
|
||||||
and db_user.lazer_previous_usernames
|
|
||||||
):
|
|
||||||
for username in db_user.lazer_previous_usernames:
|
for username in db_user.lazer_previous_usernames:
|
||||||
previous_usernames.append(
|
previous_usernames.append(
|
||||||
{
|
{
|
||||||
@@ -348,22 +339,22 @@ def convert_db_user_to_api_user(
|
|||||||
avatar_url=avatar_url,
|
avatar_url=avatar_url,
|
||||||
country_code=str(country_code),
|
country_code=str(country_code),
|
||||||
default_group=profile.default_group if profile else "default",
|
default_group=profile.default_group if profile else "default",
|
||||||
is_active=profile.is_active if profile else True,
|
is_active=profile.is_active,
|
||||||
is_bot=profile.is_bot if profile else False,
|
is_bot=profile.is_bot,
|
||||||
is_deleted=profile.is_deleted if profile else False,
|
is_deleted=profile.is_deleted,
|
||||||
is_online=profile.is_online if profile else True,
|
is_online=profile.is_online,
|
||||||
is_supporter=profile.is_supporter if profile else False,
|
is_supporter=profile.is_supporter,
|
||||||
is_restricted=profile.is_restricted if profile else False,
|
is_restricted=profile.is_restricted,
|
||||||
last_visit=db_user.last_visit if db_user.last_visit else None,
|
last_visit=db_user.last_visit,
|
||||||
pm_friends_only=profile.pm_friends_only if profile else False,
|
pm_friends_only=profile.pm_friends_only,
|
||||||
profile_colour=profile.profile_colour if profile else None,
|
profile_colour=profile.profile_colour,
|
||||||
cover_url=profile.cover_url
|
cover_url=profile.cover_url
|
||||||
if profile and profile.cover_url
|
if profile and profile.cover_url
|
||||||
else "https://assets.ppy.sh/user-profile-covers/default.jpeg",
|
else "https://assets.ppy.sh/user-profile-covers/default.jpeg",
|
||||||
discord=profile.discord if profile else None,
|
discord=profile.discord if profile else None,
|
||||||
has_supported=profile.has_supported if profile else False,
|
has_supported=profile.has_supported if profile else False,
|
||||||
interests=profile.interests if profile else None,
|
interests=profile.interests if profile else None,
|
||||||
join_date=profile.join_date,
|
join_date=profile.join_date if profile.join_date else datetime.now(UTC),
|
||||||
location=profile.location if profile else None,
|
location=profile.location if profile else None,
|
||||||
max_blocks=profile.max_blocks if profile and profile.max_blocks else 100,
|
max_blocks=profile.max_blocks if profile and profile.max_blocks else 100,
|
||||||
max_friends=profile.max_friends if profile and profile.max_friends else 500,
|
max_friends=profile.max_friends if profile and profile.max_friends else 500,
|
||||||
@@ -408,7 +399,7 @@ def convert_db_user_to_api_user(
|
|||||||
daily_challenge_user_stats=None,
|
daily_challenge_user_stats=None,
|
||||||
groups=[],
|
groups=[],
|
||||||
monthly_playcounts=monthly_playcounts,
|
monthly_playcounts=monthly_playcounts,
|
||||||
page=Page(html=profile.page_html, raw=profile.page_raw)
|
page=Page(html=profile.page_html or "", raw=profile.page_raw or "")
|
||||||
if profile.page_html or profile.page_raw
|
if profile.page_html or profile.page_raw
|
||||||
else Page(),
|
else Page(),
|
||||||
previous_usernames=previous_usernames,
|
previous_usernames=previous_usernames,
|
||||||
@@ -439,164 +430,3 @@ def get_country_name(country_code: str) -> str:
|
|||||||
# 可以添加更多国家
|
# 可以添加更多国家
|
||||||
}
|
}
|
||||||
return country_names.get(country_code, "Unknown")
|
return country_names.get(country_code, "Unknown")
|
||||||
|
|
||||||
|
|
||||||
def create_default_profile(db_user: DBUser):
|
|
||||||
"""创建默认的用户资料"""
|
|
||||||
|
|
||||||
# 完善 MockProfile 类定义
|
|
||||||
class MockProfile:
|
|
||||||
def __init__(self):
|
|
||||||
self.is_active = True
|
|
||||||
self.is_bot = False
|
|
||||||
self.is_deleted = False
|
|
||||||
self.is_online = True
|
|
||||||
self.is_supporter = False
|
|
||||||
self.is_restricted = False
|
|
||||||
self.session_verified = False
|
|
||||||
self.has_supported = False
|
|
||||||
self.pm_friends_only = False
|
|
||||||
self.default_group = "default"
|
|
||||||
self.last_visit = None
|
|
||||||
self.join_date = db_user.join_date if db_user else datetime.utcnow()
|
|
||||||
self.profile_colour = None
|
|
||||||
self.profile_hue = None
|
|
||||||
self.avatar_url = None
|
|
||||||
self.cover_url = None
|
|
||||||
self.discord = None
|
|
||||||
self.twitter = None
|
|
||||||
self.website = None
|
|
||||||
self.title = None
|
|
||||||
self.title_url = None
|
|
||||||
self.interests = None
|
|
||||||
self.location = None
|
|
||||||
self.occupation = None
|
|
||||||
self.playmode = "osu"
|
|
||||||
self.support_level = 0
|
|
||||||
self.max_blocks = 100
|
|
||||||
self.max_friends = 500
|
|
||||||
self.post_count = 0
|
|
||||||
# 添加profile_order字段
|
|
||||||
self.profile_order = [
|
|
||||||
"me",
|
|
||||||
"recent_activity",
|
|
||||||
"top_ranks",
|
|
||||||
"medals",
|
|
||||||
"historical",
|
|
||||||
"beatmaps",
|
|
||||||
"kudosu",
|
|
||||||
]
|
|
||||||
self.page_html = ""
|
|
||||||
self.page_raw = ""
|
|
||||||
# 在MockProfile类中添加active_tournament_banners字段
|
|
||||||
self.active_tournament_banners = (
|
|
||||||
MockLazerTournamentBanner.create_default_banners()
|
|
||||||
)
|
|
||||||
self.active_tournament_banners = [] # 默认空列表
|
|
||||||
|
|
||||||
return MockProfile()
|
|
||||||
|
|
||||||
|
|
||||||
def create_default_lazer_statistics(mode: str):
|
|
||||||
"""创建默认的 Lazer 统计信息"""
|
|
||||||
|
|
||||||
class MockLazerStatistics:
|
|
||||||
def __init__(self, mode: str):
|
|
||||||
self.mode = mode
|
|
||||||
self.count_100 = 0
|
|
||||||
self.count_300 = 0
|
|
||||||
self.count_50 = 0
|
|
||||||
self.count_miss = 0
|
|
||||||
self.level_current = 1
|
|
||||||
self.level_progress = 0
|
|
||||||
self.global_rank = None
|
|
||||||
self.global_rank_exp = None
|
|
||||||
self.pp = 0.0
|
|
||||||
self.pp_exp = 0.0
|
|
||||||
self.ranked_score = 0
|
|
||||||
self.hit_accuracy = 0.0
|
|
||||||
self.total_score = 0
|
|
||||||
self.total_hits = 0
|
|
||||||
self.maximum_combo = 0
|
|
||||||
self.play_count = 0
|
|
||||||
self.play_time = 0
|
|
||||||
self.replays_watched_by_others = 0
|
|
||||||
self.is_ranked = False
|
|
||||||
self.grade_ss = 0
|
|
||||||
self.grade_ssh = 0
|
|
||||||
self.grade_s = 0
|
|
||||||
self.grade_sh = 0
|
|
||||||
self.grade_a = 0
|
|
||||||
self.country_rank = None
|
|
||||||
self.rank_highest = None
|
|
||||||
self.rank_highest_updated_at = None
|
|
||||||
|
|
||||||
return MockLazerStatistics(mode)
|
|
||||||
|
|
||||||
|
|
||||||
def create_default_country(country_code: str):
|
|
||||||
"""创建默认的国家信息"""
|
|
||||||
|
|
||||||
class MockCountry:
|
|
||||||
def __init__(self, code: str):
|
|
||||||
self.code = code
|
|
||||||
self.name = get_country_name(code)
|
|
||||||
|
|
||||||
return MockCountry(country_code)
|
|
||||||
|
|
||||||
|
|
||||||
def create_default_kudosu():
|
|
||||||
"""创建默认的 Kudosu 信息"""
|
|
||||||
|
|
||||||
class MockKudosu:
|
|
||||||
def __init__(self):
|
|
||||||
self.available = 0
|
|
||||||
self.total = 0
|
|
||||||
|
|
||||||
return MockKudosu()
|
|
||||||
|
|
||||||
|
|
||||||
def create_default_counts():
|
|
||||||
"""创建默认的计数信息"""
|
|
||||||
|
|
||||||
class MockCounts:
|
|
||||||
def __init__(self):
|
|
||||||
self.recent_scores_count = None
|
|
||||||
self.beatmap_playcounts_count = 0
|
|
||||||
self.scores_first_count = 0
|
|
||||||
self.scores_pinned_count = 0
|
|
||||||
self.comments_count = 0
|
|
||||||
self.favourite_beatmapset_count = 0
|
|
||||||
self.follower_count = 0
|
|
||||||
self.graveyard_beatmapset_count = 0
|
|
||||||
self.guest_beatmapset_count = 0
|
|
||||||
self.loved_beatmapset_count = 0
|
|
||||||
self.mapping_follower_count = 0
|
|
||||||
self.nominated_beatmapset_count = 0
|
|
||||||
self.pending_beatmapset_count = 0
|
|
||||||
self.ranked_beatmapset_count = 0
|
|
||||||
self.ranked_and_approved_beatmapset_count = 0
|
|
||||||
self.unranked_beatmapset_count = 0
|
|
||||||
self.scores_best_count = 0
|
|
||||||
self.scores_first_count = 0
|
|
||||||
self.scores_pinned_count = 0
|
|
||||||
self.scores_recent_count = 0
|
|
||||||
|
|
||||||
return MockCounts()
|
|
||||||
|
|
||||||
|
|
||||||
class MockLazerTournamentBanner:
|
|
||||||
def __init__(self, tournament_id: int, image_url: str, is_active: bool = True):
|
|
||||||
self.tournament_id = tournament_id
|
|
||||||
self.image_url = image_url
|
|
||||||
self.is_active = is_active
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_default_banners():
|
|
||||||
"""创建默认的锦标赛横幅配置"""
|
|
||||||
return [
|
|
||||||
MockLazerTournamentBanner(1, "https://example.com/banner1.jpg", True),
|
|
||||||
MockLazerTournamentBanner(2, "https://example.com/banner2.jpg", False),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -10,33 +10,36 @@ import time
|
|||||||
|
|
||||||
from app.auth import get_password_hash
|
from app.auth import get_password_hash
|
||||||
from app.database import (
|
from app.database import (
|
||||||
Base,
|
|
||||||
DailyChallengeStats,
|
DailyChallengeStats,
|
||||||
|
LazerUserAchievement,
|
||||||
LazerUserStatistics,
|
LazerUserStatistics,
|
||||||
RankHistory,
|
RankHistory,
|
||||||
User,
|
User,
|
||||||
UserAchievement,
|
|
||||||
)
|
)
|
||||||
from app.dependencies.database import engine, get_db
|
from app.dependencies.database import engine, get_db
|
||||||
|
|
||||||
|
from sqlmodel import SQLModel
|
||||||
|
|
||||||
# 创建所有表
|
# 创建所有表
|
||||||
Base.metadata.create_all(bind=engine)
|
SQLModel.metadata.create_all(bind=engine)
|
||||||
|
|
||||||
|
|
||||||
def create_sample_user():
|
def create_sample_user():
|
||||||
"""创建示例用户数据"""
|
"""创建示例用户数据"""
|
||||||
db = next(get_db())
|
with next(get_db()) as db:
|
||||||
|
# 检查用户是否已存在
|
||||||
|
from sqlmodel import select
|
||||||
|
|
||||||
# 检查用户是否已存在
|
statement = select(User).where(User.name == "Googujiang")
|
||||||
existing_user = db.query(User).filter(User.name == "Googujiang").first()
|
existing_user = db.exec(statement).first()
|
||||||
if existing_user:
|
if existing_user:
|
||||||
print("示例用户已存在,跳过创建")
|
print("示例用户已存在,跳过创建")
|
||||||
return existing_user
|
return existing_user
|
||||||
|
|
||||||
# 当前时间戳
|
# 当前时间戳
|
||||||
current_timestamp = int(time.time())
|
current_timestamp = int(time.time())
|
||||||
join_timestamp = int(datetime(2019, 11, 29, 17, 23, 13).timestamp())
|
join_timestamp = int(datetime(2019, 11, 29, 17, 23, 13).timestamp())
|
||||||
last_visit_timestamp = int(datetime(2025, 7, 18, 16, 31, 29).timestamp())
|
last_visit_timestamp = int(datetime(2025, 7, 18, 16, 31, 29).timestamp())
|
||||||
|
|
||||||
# 创建用户
|
# 创建用户
|
||||||
user = User(
|
user = User(
|
||||||
@@ -150,6 +153,10 @@ def create_sample_user():
|
|||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(user)
|
db.refresh(user)
|
||||||
|
|
||||||
|
# 确保用户ID存在
|
||||||
|
if user.id is None:
|
||||||
|
raise ValueError("User ID is None after saving to database")
|
||||||
|
|
||||||
# 创建 osu! 模式统计
|
# 创建 osu! 模式统计
|
||||||
osu_stats = LazerUserStatistics(
|
osu_stats = LazerUserStatistics(
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
@@ -370,28 +377,28 @@ def create_sample_user():
|
|||||||
|
|
||||||
# 创建一些成就
|
# 创建一些成就
|
||||||
achievements = [
|
achievements = [
|
||||||
UserAchievement(
|
LazerUserAchievement(
|
||||||
# user_id=user.id,
|
user_id=user.id,
|
||||||
achievement_id=336,
|
achievement_id=336,
|
||||||
achieved_at=datetime(2025, 6, 21, 19, 6, 32),
|
achieved_at=datetime(2025, 6, 21, 19, 6, 32),
|
||||||
),
|
),
|
||||||
UserAchievement(
|
LazerUserAchievement(
|
||||||
# user_id=user.id,
|
user_id=user.id,
|
||||||
achievement_id=319,
|
achievement_id=319,
|
||||||
achieved_at=datetime(2025, 6, 1, 0, 52, 0),
|
achieved_at=datetime(2025, 6, 1, 0, 52, 0),
|
||||||
),
|
),
|
||||||
UserAchievement(
|
LazerUserAchievement(
|
||||||
# user_id=user.id,
|
user_id=user.id,
|
||||||
achievement_id=222,
|
achievement_id=222,
|
||||||
achieved_at=datetime(2025, 5, 28, 12, 24, 37),
|
achieved_at=datetime(2025, 5, 28, 12, 24, 37),
|
||||||
),
|
),
|
||||||
UserAchievement(
|
LazerUserAchievement(
|
||||||
# user_id=user.id,
|
user_id=user.id,
|
||||||
achievement_id=38,
|
achievement_id=38,
|
||||||
achieved_at=datetime(2024, 7, 5, 15, 43, 23),
|
achieved_at=datetime(2024, 7, 5, 15, 43, 23),
|
||||||
),
|
),
|
||||||
UserAchievement(
|
LazerUserAchievement(
|
||||||
# user_id=user.id,
|
user_id=user.id,
|
||||||
achievement_id=67,
|
achievement_id=67,
|
||||||
achieved_at=datetime(2024, 6, 24, 5, 6, 44),
|
achieved_at=datetime(2024, 6, 24, 5, 6, 44),
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ dependencies = [
|
|||||||
"python-multipart>=0.0.6",
|
"python-multipart>=0.0.6",
|
||||||
"redis>=5.0.1",
|
"redis>=5.0.1",
|
||||||
"sqlalchemy>=2.0.23",
|
"sqlalchemy>=2.0.23",
|
||||||
|
"sqlmodel>=0.0.24",
|
||||||
"uvicorn[standard]>=0.24.0",
|
"uvicorn[standard]>=0.24.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -12,9 +12,11 @@ import sys
|
|||||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
|
||||||
from app.database import User
|
from app.database import User
|
||||||
from app.dependencies import get_db
|
from app.dependencies.database import get_db
|
||||||
from app.utils import convert_db_user_to_api_user
|
from app.utils import convert_db_user_to_api_user
|
||||||
|
|
||||||
|
from sqlmodel import select
|
||||||
|
|
||||||
|
|
||||||
def test_lazer_tables():
|
def test_lazer_tables():
|
||||||
"""测试 lazer 表的基本功能"""
|
"""测试 lazer 表的基本功能"""
|
||||||
@@ -26,7 +28,8 @@ def test_lazer_tables():
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 测试查询用户
|
# 测试查询用户
|
||||||
user = db.query(User).first()
|
statement = select(User)
|
||||||
|
user = db.exec(statement).first()
|
||||||
if not user:
|
if not user:
|
||||||
print("❌ 没有找到用户,请先同步数据")
|
print("❌ 没有找到用户,请先同步数据")
|
||||||
return False
|
return False
|
||||||
@@ -83,7 +86,8 @@ def test_authentication():
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 尝试认证第一个用户
|
# 尝试认证第一个用户
|
||||||
user = db.query(User).first()
|
statement = select(User)
|
||||||
|
user = db.exec(statement).first()
|
||||||
if not user:
|
if not user:
|
||||||
print("❌ 没有用户进行认证测试")
|
print("❌ 没有用户进行认证测试")
|
||||||
return False
|
return False
|
||||||
|
|||||||
15
uv.lock
generated
15
uv.lock
generated
@@ -475,6 +475,7 @@ dependencies = [
|
|||||||
{ name = "python-multipart" },
|
{ name = "python-multipart" },
|
||||||
{ name = "redis" },
|
{ name = "redis" },
|
||||||
{ name = "sqlalchemy" },
|
{ name = "sqlalchemy" },
|
||||||
|
{ name = "sqlmodel" },
|
||||||
{ name = "uvicorn", extra = ["standard"] },
|
{ name = "uvicorn", extra = ["standard"] },
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -499,6 +500,7 @@ requires-dist = [
|
|||||||
{ name = "python-multipart", specifier = ">=0.0.6" },
|
{ name = "python-multipart", specifier = ">=0.0.6" },
|
||||||
{ name = "redis", specifier = ">=5.0.1" },
|
{ name = "redis", specifier = ">=5.0.1" },
|
||||||
{ name = "sqlalchemy", specifier = ">=2.0.23" },
|
{ name = "sqlalchemy", specifier = ">=2.0.23" },
|
||||||
|
{ name = "sqlmodel", specifier = ">=0.0.24" },
|
||||||
{ name = "uvicorn", extras = ["standard"], specifier = ">=0.24.0" },
|
{ name = "uvicorn", extras = ["standard"], specifier = ">=0.24.0" },
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -810,6 +812,19 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/1c/fc/9ba22f01b5cdacc8f5ed0d22304718d2c758fce3fd49a5372b886a86f37c/sqlalchemy-2.0.41-py3-none-any.whl", hash = "sha256:57df5dc6fdb5ed1a88a1ed2195fd31927e705cad62dedd86b46972752a80f576", size = 1911224, upload-time = "2025-05-14T17:39:42.154Z" },
|
{ url = "https://files.pythonhosted.org/packages/1c/fc/9ba22f01b5cdacc8f5ed0d22304718d2c758fce3fd49a5372b886a86f37c/sqlalchemy-2.0.41-py3-none-any.whl", hash = "sha256:57df5dc6fdb5ed1a88a1ed2195fd31927e705cad62dedd86b46972752a80f576", size = 1911224, upload-time = "2025-05-14T17:39:42.154Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "sqlmodel"
|
||||||
|
version = "0.0.24"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "pydantic" },
|
||||||
|
{ name = "sqlalchemy" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/86/4b/c2ad0496f5bdc6073d9b4cef52be9c04f2b37a5773441cc6600b1857648b/sqlmodel-0.0.24.tar.gz", hash = "sha256:cc5c7613c1a5533c9c7867e1aab2fd489a76c9e8a061984da11b4e613c182423", size = 116780, upload-time = "2025-03-07T05:43:32.887Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/16/91/484cd2d05569892b7fef7f5ceab3bc89fb0f8a8c0cde1030d383dbc5449c/sqlmodel-0.0.24-py3-none-any.whl", hash = "sha256:6778852f09370908985b667d6a3ab92910d0d5ec88adcaf23dbc242715ff7193", size = 28622, upload-time = "2025-03-07T05:43:30.37Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "starlette"
|
name = "starlette"
|
||||||
version = "0.47.2"
|
version = "0.47.2"
|
||||||
|
|||||||
Reference in New Issue
Block a user