refactor(database): migrate to sqlmodel

This commit is contained in:
MingxuanGame
2025-07-24 20:49:07 +08:00
parent 1655bb9f53
commit c43ca883a5
11 changed files with 582 additions and 743 deletions

View File

@@ -4,7 +4,6 @@ from datetime import datetime, timedelta
import hashlib import hashlib
import secrets import secrets
import string import string
from typing import Optional
from app.config import settings from app.config import settings
from app.database import ( from app.database import (
@@ -15,7 +14,7 @@ from app.database import (
import bcrypt import bcrypt
from jose import JWTError, jwt from jose import JWTError, jwt
from passlib.context import CryptContext from passlib.context import CryptContext
from sqlalchemy.orm import Session from sqlmodel import Session, select
# 密码哈希上下文 # 密码哈希上下文
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
@@ -71,7 +70,7 @@ def get_password_hash(password: str) -> str:
return pw_bcrypt.decode() return pw_bcrypt.decode()
def authenticate_user_legacy(db: Session, name: str, password: str) -> Optional[DBUser]: def authenticate_user_legacy(db: Session, name: str, password: str) -> DBUser | None:
""" """
验证用户身份 - 使用类似 from_login 的逻辑 验证用户身份 - 使用类似 from_login 的逻辑
""" """
@@ -79,12 +78,13 @@ def authenticate_user_legacy(db: Session, name: str, password: str) -> Optional[
pw_md5 = hashlib.md5(password.encode()).hexdigest() pw_md5 = hashlib.md5(password.encode()).hexdigest()
# 2. 根据用户名查找用户 # 2. 根据用户名查找用户
user = db.query(DBUser).filter(DBUser.name == name).first() statement = select(DBUser).where(DBUser.name == name)
user = db.exec(statement).first()
if not user: if not user:
return None return None
# 3. 验证密码 # 3. 验证密码 - 修复逻辑错误
if not (user.pw_bcrypt is None and user.pw_bcrypt != ""): if user.pw_bcrypt is None or user.pw_bcrypt == "":
return None return None
# 4. 检查缓存 # 4. 检查缓存
@@ -107,12 +107,12 @@ def authenticate_user_legacy(db: Session, name: str, password: str) -> Optional[
return None return None
def authenticate_user(db: Session, username: str, password: str) -> Optional[DBUser]: def authenticate_user(db: Session, username: str, password: str) -> DBUser | None:
"""验证用户身份""" """验证用户身份"""
return authenticate_user_legacy(db, username, password) return authenticate_user_legacy(db, username, password)
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
"""创建访问令牌""" """创建访问令牌"""
to_encode = data.copy() to_encode = data.copy()
if expires_delta: if expires_delta:
@@ -136,7 +136,7 @@ def generate_refresh_token() -> str:
return "".join(secrets.choice(characters) for _ in range(length)) return "".join(secrets.choice(characters) for _ in range(length))
def verify_token(token: str) -> Optional[dict]: def verify_token(token: str) -> dict | None:
"""验证访问令牌""" """验证访问令牌"""
try: try:
payload = jwt.decode( payload = jwt.decode(
@@ -154,7 +154,10 @@ def store_token(
expires_at = datetime.utcnow() + timedelta(seconds=expires_in) expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
# 删除用户的旧令牌 # 删除用户的旧令牌
db.query(OAuthToken).filter(OAuthToken.user_id == user_id).delete() statement = select(OAuthToken).where(OAuthToken.user_id == user_id)
old_tokens = db.exec(statement).all()
for token in old_tokens:
db.delete(token)
# 创建新令牌记录 # 创建新令牌记录
token_record = OAuthToken( token_record = OAuthToken(
@@ -169,25 +172,19 @@ def store_token(
return token_record return token_record
def get_token_by_access_token(db: Session, access_token: str) -> Optional[OAuthToken]: def get_token_by_access_token(db: Session, access_token: str) -> OAuthToken | None:
"""根据访问令牌获取令牌记录""" """根据访问令牌获取令牌记录"""
return ( statement = select(OAuthToken).where(
db.query(OAuthToken) OAuthToken.access_token == access_token,
.filter( OAuthToken.expires_at > datetime.utcnow(),
OAuthToken.access_token == access_token,
OAuthToken.expires_at > datetime.utcnow(),
)
.first()
) )
return db.exec(statement).first()
def get_token_by_refresh_token(db: Session, refresh_token: str) -> Optional[OAuthToken]: def get_token_by_refresh_token(db: Session, refresh_token: str) -> OAuthToken | None:
"""根据刷新令牌获取令牌记录""" """根据刷新令牌获取令牌记录"""
return ( statement = select(OAuthToken).where(
db.query(OAuthToken) OAuthToken.refresh_token == refresh_token,
.filter( OAuthToken.expires_at > datetime.utcnow(),
OAuthToken.refresh_token == refresh_token,
OAuthToken.expires_at > datetime.utcnow(),
)
.first()
) )
return db.exec(statement).first()

File diff suppressed because it is too large Load Diff

View File

@@ -1,7 +1,6 @@
from __future__ import annotations from __future__ import annotations
from sqlalchemy import create_engine from sqlmodel import Session, create_engine
from sqlalchemy.orm import sessionmaker
try: try:
import redis import redis
@@ -11,7 +10,6 @@ from app.config import settings
# 数据库引擎 # 数据库引擎
engine = create_engine(settings.DATABASE_URL) engine = create_engine(settings.DATABASE_URL)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# Redis 连接 # Redis 连接
if redis: if redis:
@@ -22,11 +20,8 @@ else:
# 数据库依赖 # 数据库依赖
def get_db(): def get_db():
db = SessionLocal() with Session(engine) as session:
try: yield session
yield db
finally:
db.close()
# Redis 依赖 # Redis 依赖

View File

@@ -1,14 +1,16 @@
from fastapi import Depends, HTTPException from __future__ import annotations
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlalchemy.orm import Session
from app.auth import get_token_by_access_token from app.auth import get_token_by_access_token
from .database import get_db
from app.database import ( from app.database import (
User as DBUser, User as DBUser,
) )
from .database import get_db
from fastapi import Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlmodel import Session, select
security = HTTPBearer() security = HTTPBearer()
@@ -29,5 +31,5 @@ async def get_current_user_by_token(token: str, db: Session) -> DBUser | None:
token_record = get_token_by_access_token(db, token) token_record = get_token_by_access_token(db, token)
if not token_record: if not token_record:
return None return None
user = db.query(DBUser).filter(DBUser.id == token_record.user_id).first() user = db.exec(select(DBUser).where(DBUser.id == token_record.user_id)).first()
return user return user

View File

@@ -3,10 +3,14 @@ from __future__ import annotations
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from app.database import (
LazerUserAchievement,
Team as Team,
)
from .score import GameMode from .score import GameMode
from pydantic import BaseModel from pydantic import BaseModel
from app.database import LazerUserAchievement # 添加数据库模型导入
class PlayStyle(str, Enum): class PlayStyle(str, Enum):
@@ -110,13 +114,6 @@ class DailyChallengeStats(BaseModel):
weekly_streak_current: int = 0 weekly_streak_current: int = 0
class Team(BaseModel):
flag_url: str
id: int
name: str
short_name: str
class Page(BaseModel): class Page(BaseModel):
html: str = "" html: str = ""
raw: str = "" raw: str = ""

View File

@@ -14,7 +14,7 @@ from app.dependencies import get_db
from app.models.oauth import TokenResponse from app.models.oauth import TokenResponse
from fastapi import APIRouter, Depends, Form, HTTPException from fastapi import APIRouter, Depends, Form, HTTPException
from sqlalchemy.orm import Session from sqlmodel import Session
router = APIRouter(tags=["osu! OAuth 认证"]) router = APIRouter(tags=["osu! OAuth 认证"])
@@ -92,7 +92,7 @@ async def oauth_token(
new_refresh_token = generate_refresh_token() new_refresh_token = generate_refresh_token()
# 更新令牌 # 更新令牌
user_id = int(getattr(token_record, 'user_id')) user_id = int(getattr(token_record, "user_id"))
store_token( store_token(
db, db,
user_id, user_id,

View File

@@ -1,8 +1,13 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import UTC, datetime
from app.database import User as DBUser from app.database import (
LazerUserCounts,
LazerUserProfile,
LazerUserStatistics,
User as DBUser,
)
from app.models.user import ( from app.models.user import (
Country, Country,
Cover, Cover,
@@ -14,7 +19,6 @@ from app.models.user import (
RankHighest, RankHighest,
RankHistory, RankHistory,
Statistics, Statistics,
Team,
User, User,
UserAchievement, UserAchievement,
) )
@@ -37,25 +41,28 @@ def convert_db_user_to_api_user(
profile = db_user.lazer_profile profile = db_user.lazer_profile
if not profile: if not profile:
# 如果没有 lazer 资料,使用默认值 # 如果没有 lazer 资料,使用默认值
profile = create_default_profile(db_user) profile = LazerUserProfile(
user_id=user_id,
)
# 获取 Lazer 用户计数 # 获取 Lazer 用户计数 - 使用正确的 lazer_counts 关系
lzrcnt = db_user.lazer_counts
lzrcnt = db_user.lazer_statistics
if not lzrcnt: if not lzrcnt:
# 如果没有 lazer 计数,使用默认值 # 如果没有 lazer 计数,使用默认值
lzrcnt = create_default_counts() lzrcnt = LazerUserCounts(user_id=user_id)
# 获取指定模式的统计信息 # 获取指定模式的统计信息
user_stats = None user_stats = None
for stat in db_user.lazer_statistics: if db_user.lazer_statistics:
if stat.mode == ruleset: for stat in db_user.lazer_statistics:
user_stats = stat if stat.mode == ruleset:
break user_stats = stat
break
if not user_stats: if not user_stats:
# 如果没有找到指定模式的统计,创建默认统计 # 如果没有找到指定模式的统计,创建默认统计
user_stats = create_default_lazer_statistics(ruleset) user_stats = LazerUserStatistics(user_id=user_id)
# 获取国家信息 # 获取国家信息
country_code = db_user.country_code if db_user.country_code is not None else "XX" country_code = db_user.country_code if db_user.country_code is not None else "XX"
@@ -66,7 +73,7 @@ def convert_db_user_to_api_user(
kudosu = Kudosu(available=0, total=0) kudosu = Kudosu(available=0, total=0)
# 获取计数信息 # 获取计数信息
counts = create_default_counts() counts = LazerUserCounts(user_id=user_id)
# 转换统计信息 # 转换统计信息
statistics = Statistics( statistics = Statistics(
@@ -199,12 +206,7 @@ def convert_db_user_to_api_user(
team = None team = None
if db_user.team_membership: if db_user.team_membership:
team_member = db_user.team_membership[0] # 假设用户只属于一个团队 team_member = db_user.team_membership[0] # 假设用户只属于一个团队
team = Team( team = team_member.team
flag_url=team_member.team.flag_url or "",
id=team_member.team.id,
name=team_member.team.name,
short_name=team_member.team.short_name,
)
# 创建用户对象 # 创建用户对象
# 从db_user获取基本字段值 # 从db_user获取基本字段值
@@ -229,27 +231,25 @@ def convert_db_user_to_api_user(
avatar_url = str(db_user.avatar.r2_original_url) avatar_url = str(db_user.avatar.r2_original_url)
# 如果还是没有找到,通过查询获取 # 如果还是没有找到,通过查询获取
if db_session and avatar_url is None: # if db_session and avatar_url is None:
try: # try:
# 导入UserAvatar模型 # # 导入UserAvatar模型
from app.database import UserAvatar
# 尝试查找用户的头像记录 # # 尝试查找用户的头像记录
avatar_record = ( # statement = select(UserAvatar).where(
db_session.query(UserAvatar) # UserAvatar.user_id == user_id, UserAvatar.is_active == True
.filter_by(user_id=user_id, is_active=True) # )
.first() # avatar_record = db_session.exec(statement).first()
) # if avatar_record is not None:
if avatar_record is not None: # if avatar_record.r2_game_url is not None:
if avatar_record.r2_game_url is not None: # # 优先使用游戏用的头像URL
# 优先使用游戏用的头像URL # avatar_url = str(avatar_record.r2_game_url)
avatar_url = str(avatar_record.r2_game_url) # elif avatar_record.r2_original_url is not None:
elif avatar_record.r2_original_url is not None: # # 其次使用原始头像URL
# 其次使用原始头像URL # avatar_url = str(avatar_record.r2_original_url)
avatar_url = str(avatar_record.r2_original_url) # except Exception as e:
except Exception as e: # print(f"获取用户头像时出错: {e}")
print(f"获取用户头像时出错: {e}") # print(f"最终头像URL: {avatar_url}")
print(f"最终头像URL: {avatar_url}")
# 如果仍然没有找到头像URL则使用默认URL # 如果仍然没有找到头像URL则使用默认URL
if avatar_url is None: if avatar_url is None:
avatar_url = "https://a.gu-osu.gmoe.cc/api/users/avatar/1" avatar_url = "https://a.gu-osu.gmoe.cc/api/users/avatar/1"
@@ -265,15 +265,12 @@ def convert_db_user_to_api_user(
"kudosu", "kudosu",
] ]
if profile and profile.profile_order: if profile and profile.profile_order:
profile_order = profile.profile_order profile_order = profile.profile_order.split(",")
# 在convert_db_user_to_api_user函数中添加active_tournament_banners处理 # 在convert_db_user_to_api_user函数中添加active_tournament_banners处理
active_tournament_banners = [] active_tournament_banners = []
if ( if db_user.active_banners:
hasattr(db_user, "lazer_tournament_banners") for banner in db_user.active_banners:
and db_user.lazer_tournament_banners
):
for banner in db_user.lazer_tournament_banners:
active_tournament_banners.append( active_tournament_banners.append(
{ {
"tournament_id": banner.tournament_id, "tournament_id": banner.tournament_id,
@@ -284,7 +281,7 @@ def convert_db_user_to_api_user(
# 在convert_db_user_to_api_user函数中添加badges处理 # 在convert_db_user_to_api_user函数中添加badges处理
badges = [] badges = []
if hasattr(db_user, "lazer_badges") and db_user.lazer_badges: if db_user.lazer_badges:
for badge in db_user.lazer_badges: for badge in db_user.lazer_badges:
badges.append( badges.append(
{ {
@@ -298,10 +295,7 @@ def convert_db_user_to_api_user(
# 在convert_db_user_to_api_user函数中添加monthly_playcounts处理 # 在convert_db_user_to_api_user函数中添加monthly_playcounts处理
monthly_playcounts = [] monthly_playcounts = []
if ( if db_user.lazer_monthly_playcounts:
hasattr(db_user, "lazer_monthly_playcounts")
and db_user.lazer_monthly_playcounts
):
for playcount in db_user.lazer_monthly_playcounts: for playcount in db_user.lazer_monthly_playcounts:
monthly_playcounts.append( monthly_playcounts.append(
{ {
@@ -314,10 +308,7 @@ def convert_db_user_to_api_user(
# 在convert_db_user_to_api_user函数中添加previous_usernames处理 # 在convert_db_user_to_api_user函数中添加previous_usernames处理
previous_usernames = [] previous_usernames = []
if ( if db_user.lazer_previous_usernames:
hasattr(db_user, "lazer_previous_usernames")
and db_user.lazer_previous_usernames
):
for username in db_user.lazer_previous_usernames: for username in db_user.lazer_previous_usernames:
previous_usernames.append( previous_usernames.append(
{ {
@@ -348,22 +339,22 @@ def convert_db_user_to_api_user(
avatar_url=avatar_url, avatar_url=avatar_url,
country_code=str(country_code), country_code=str(country_code),
default_group=profile.default_group if profile else "default", default_group=profile.default_group if profile else "default",
is_active=profile.is_active if profile else True, is_active=profile.is_active,
is_bot=profile.is_bot if profile else False, is_bot=profile.is_bot,
is_deleted=profile.is_deleted if profile else False, is_deleted=profile.is_deleted,
is_online=profile.is_online if profile else True, is_online=profile.is_online,
is_supporter=profile.is_supporter if profile else False, is_supporter=profile.is_supporter,
is_restricted=profile.is_restricted if profile else False, is_restricted=profile.is_restricted,
last_visit=db_user.last_visit if db_user.last_visit else None, last_visit=db_user.last_visit,
pm_friends_only=profile.pm_friends_only if profile else False, pm_friends_only=profile.pm_friends_only,
profile_colour=profile.profile_colour if profile else None, profile_colour=profile.profile_colour,
cover_url=profile.cover_url cover_url=profile.cover_url
if profile and profile.cover_url if profile and profile.cover_url
else "https://assets.ppy.sh/user-profile-covers/default.jpeg", else "https://assets.ppy.sh/user-profile-covers/default.jpeg",
discord=profile.discord if profile else None, discord=profile.discord if profile else None,
has_supported=profile.has_supported if profile else False, has_supported=profile.has_supported if profile else False,
interests=profile.interests if profile else None, interests=profile.interests if profile else None,
join_date=profile.join_date, join_date=profile.join_date if profile.join_date else datetime.now(UTC),
location=profile.location if profile else None, location=profile.location if profile else None,
max_blocks=profile.max_blocks if profile and profile.max_blocks else 100, max_blocks=profile.max_blocks if profile and profile.max_blocks else 100,
max_friends=profile.max_friends if profile and profile.max_friends else 500, max_friends=profile.max_friends if profile and profile.max_friends else 500,
@@ -408,7 +399,7 @@ def convert_db_user_to_api_user(
daily_challenge_user_stats=None, daily_challenge_user_stats=None,
groups=[], groups=[],
monthly_playcounts=monthly_playcounts, monthly_playcounts=monthly_playcounts,
page=Page(html=profile.page_html, raw=profile.page_raw) page=Page(html=profile.page_html or "", raw=profile.page_raw or "")
if profile.page_html or profile.page_raw if profile.page_html or profile.page_raw
else Page(), else Page(),
previous_usernames=previous_usernames, previous_usernames=previous_usernames,
@@ -439,164 +430,3 @@ def get_country_name(country_code: str) -> str:
# 可以添加更多国家 # 可以添加更多国家
} }
return country_names.get(country_code, "Unknown") return country_names.get(country_code, "Unknown")
def create_default_profile(db_user: DBUser):
"""创建默认的用户资料"""
# 完善 MockProfile 类定义
class MockProfile:
def __init__(self):
self.is_active = True
self.is_bot = False
self.is_deleted = False
self.is_online = True
self.is_supporter = False
self.is_restricted = False
self.session_verified = False
self.has_supported = False
self.pm_friends_only = False
self.default_group = "default"
self.last_visit = None
self.join_date = db_user.join_date if db_user else datetime.utcnow()
self.profile_colour = None
self.profile_hue = None
self.avatar_url = None
self.cover_url = None
self.discord = None
self.twitter = None
self.website = None
self.title = None
self.title_url = None
self.interests = None
self.location = None
self.occupation = None
self.playmode = "osu"
self.support_level = 0
self.max_blocks = 100
self.max_friends = 500
self.post_count = 0
# 添加profile_order字段
self.profile_order = [
"me",
"recent_activity",
"top_ranks",
"medals",
"historical",
"beatmaps",
"kudosu",
]
self.page_html = ""
self.page_raw = ""
# 在MockProfile类中添加active_tournament_banners字段
self.active_tournament_banners = (
MockLazerTournamentBanner.create_default_banners()
)
self.active_tournament_banners = [] # 默认空列表
return MockProfile()
def create_default_lazer_statistics(mode: str):
"""创建默认的 Lazer 统计信息"""
class MockLazerStatistics:
def __init__(self, mode: str):
self.mode = mode
self.count_100 = 0
self.count_300 = 0
self.count_50 = 0
self.count_miss = 0
self.level_current = 1
self.level_progress = 0
self.global_rank = None
self.global_rank_exp = None
self.pp = 0.0
self.pp_exp = 0.0
self.ranked_score = 0
self.hit_accuracy = 0.0
self.total_score = 0
self.total_hits = 0
self.maximum_combo = 0
self.play_count = 0
self.play_time = 0
self.replays_watched_by_others = 0
self.is_ranked = False
self.grade_ss = 0
self.grade_ssh = 0
self.grade_s = 0
self.grade_sh = 0
self.grade_a = 0
self.country_rank = None
self.rank_highest = None
self.rank_highest_updated_at = None
return MockLazerStatistics(mode)
def create_default_country(country_code: str):
"""创建默认的国家信息"""
class MockCountry:
def __init__(self, code: str):
self.code = code
self.name = get_country_name(code)
return MockCountry(country_code)
def create_default_kudosu():
"""创建默认的 Kudosu 信息"""
class MockKudosu:
def __init__(self):
self.available = 0
self.total = 0
return MockKudosu()
def create_default_counts():
"""创建默认的计数信息"""
class MockCounts:
def __init__(self):
self.recent_scores_count = None
self.beatmap_playcounts_count = 0
self.scores_first_count = 0
self.scores_pinned_count = 0
self.comments_count = 0
self.favourite_beatmapset_count = 0
self.follower_count = 0
self.graveyard_beatmapset_count = 0
self.guest_beatmapset_count = 0
self.loved_beatmapset_count = 0
self.mapping_follower_count = 0
self.nominated_beatmapset_count = 0
self.pending_beatmapset_count = 0
self.ranked_beatmapset_count = 0
self.ranked_and_approved_beatmapset_count = 0
self.unranked_beatmapset_count = 0
self.scores_best_count = 0
self.scores_first_count = 0
self.scores_pinned_count = 0
self.scores_recent_count = 0
return MockCounts()
class MockLazerTournamentBanner:
def __init__(self, tournament_id: int, image_url: str, is_active: bool = True):
self.tournament_id = tournament_id
self.image_url = image_url
self.is_active = is_active
@staticmethod
def create_default_banners():
"""创建默认的锦标赛横幅配置"""
return [
MockLazerTournamentBanner(1, "https://example.com/banner1.jpg", True),
MockLazerTournamentBanner(2, "https://example.com/banner2.jpg", False),
]

View File

@@ -10,33 +10,36 @@ import time
from app.auth import get_password_hash from app.auth import get_password_hash
from app.database import ( from app.database import (
Base,
DailyChallengeStats, DailyChallengeStats,
LazerUserAchievement,
LazerUserStatistics, LazerUserStatistics,
RankHistory, RankHistory,
User, User,
UserAchievement,
) )
from app.dependencies.database import engine, get_db from app.dependencies.database import engine, get_db
from sqlmodel import SQLModel
# 创建所有表 # 创建所有表
Base.metadata.create_all(bind=engine) SQLModel.metadata.create_all(bind=engine)
def create_sample_user(): def create_sample_user():
"""创建示例用户数据""" """创建示例用户数据"""
db = next(get_db()) with next(get_db()) as db:
# 检查用户是否已存在
from sqlmodel import select
# 检查用户是否已存在 statement = select(User).where(User.name == "Googujiang")
existing_user = db.query(User).filter(User.name == "Googujiang").first() existing_user = db.exec(statement).first()
if existing_user: if existing_user:
print("示例用户已存在,跳过创建") print("示例用户已存在,跳过创建")
return existing_user return existing_user
# 当前时间戳 # 当前时间戳
current_timestamp = int(time.time()) current_timestamp = int(time.time())
join_timestamp = int(datetime(2019, 11, 29, 17, 23, 13).timestamp()) join_timestamp = int(datetime(2019, 11, 29, 17, 23, 13).timestamp())
last_visit_timestamp = int(datetime(2025, 7, 18, 16, 31, 29).timestamp()) last_visit_timestamp = int(datetime(2025, 7, 18, 16, 31, 29).timestamp())
# 创建用户 # 创建用户
user = User( user = User(
@@ -150,6 +153,10 @@ def create_sample_user():
db.commit() db.commit()
db.refresh(user) db.refresh(user)
# 确保用户ID存在
if user.id is None:
raise ValueError("User ID is None after saving to database")
# 创建 osu! 模式统计 # 创建 osu! 模式统计
osu_stats = LazerUserStatistics( osu_stats = LazerUserStatistics(
user_id=user.id, user_id=user.id,
@@ -370,28 +377,28 @@ def create_sample_user():
# 创建一些成就 # 创建一些成就
achievements = [ achievements = [
UserAchievement( LazerUserAchievement(
# user_id=user.id, user_id=user.id,
achievement_id=336, achievement_id=336,
achieved_at=datetime(2025, 6, 21, 19, 6, 32), achieved_at=datetime(2025, 6, 21, 19, 6, 32),
), ),
UserAchievement( LazerUserAchievement(
# user_id=user.id, user_id=user.id,
achievement_id=319, achievement_id=319,
achieved_at=datetime(2025, 6, 1, 0, 52, 0), achieved_at=datetime(2025, 6, 1, 0, 52, 0),
), ),
UserAchievement( LazerUserAchievement(
# user_id=user.id, user_id=user.id,
achievement_id=222, achievement_id=222,
achieved_at=datetime(2025, 5, 28, 12, 24, 37), achieved_at=datetime(2025, 5, 28, 12, 24, 37),
), ),
UserAchievement( LazerUserAchievement(
# user_id=user.id, user_id=user.id,
achievement_id=38, achievement_id=38,
achieved_at=datetime(2024, 7, 5, 15, 43, 23), achieved_at=datetime(2024, 7, 5, 15, 43, 23),
), ),
UserAchievement( LazerUserAchievement(
# user_id=user.id, user_id=user.id,
achievement_id=67, achievement_id=67,
achieved_at=datetime(2024, 6, 24, 5, 6, 44), achieved_at=datetime(2024, 6, 24, 5, 6, 44),
), ),

View File

@@ -18,6 +18,7 @@ dependencies = [
"python-multipart>=0.0.6", "python-multipart>=0.0.6",
"redis>=5.0.1", "redis>=5.0.1",
"sqlalchemy>=2.0.23", "sqlalchemy>=2.0.23",
"sqlmodel>=0.0.24",
"uvicorn[standard]>=0.24.0", "uvicorn[standard]>=0.24.0",
] ]

View File

@@ -12,9 +12,11 @@ import sys
sys.path.append(os.path.dirname(os.path.dirname(__file__))) sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from app.database import User from app.database import User
from app.dependencies import get_db from app.dependencies.database import get_db
from app.utils import convert_db_user_to_api_user from app.utils import convert_db_user_to_api_user
from sqlmodel import select
def test_lazer_tables(): def test_lazer_tables():
"""测试 lazer 表的基本功能""" """测试 lazer 表的基本功能"""
@@ -26,7 +28,8 @@ def test_lazer_tables():
try: try:
# 测试查询用户 # 测试查询用户
user = db.query(User).first() statement = select(User)
user = db.exec(statement).first()
if not user: if not user:
print("❌ 没有找到用户,请先同步数据") print("❌ 没有找到用户,请先同步数据")
return False return False
@@ -83,7 +86,8 @@ def test_authentication():
try: try:
# 尝试认证第一个用户 # 尝试认证第一个用户
user = db.query(User).first() statement = select(User)
user = db.exec(statement).first()
if not user: if not user:
print("❌ 没有用户进行认证测试") print("❌ 没有用户进行认证测试")
return False return False

15
uv.lock generated
View File

@@ -475,6 +475,7 @@ dependencies = [
{ name = "python-multipart" }, { name = "python-multipart" },
{ name = "redis" }, { name = "redis" },
{ name = "sqlalchemy" }, { name = "sqlalchemy" },
{ name = "sqlmodel" },
{ name = "uvicorn", extra = ["standard"] }, { name = "uvicorn", extra = ["standard"] },
] ]
@@ -499,6 +500,7 @@ requires-dist = [
{ name = "python-multipart", specifier = ">=0.0.6" }, { name = "python-multipart", specifier = ">=0.0.6" },
{ name = "redis", specifier = ">=5.0.1" }, { name = "redis", specifier = ">=5.0.1" },
{ name = "sqlalchemy", specifier = ">=2.0.23" }, { name = "sqlalchemy", specifier = ">=2.0.23" },
{ name = "sqlmodel", specifier = ">=0.0.24" },
{ name = "uvicorn", extras = ["standard"], specifier = ">=0.24.0" }, { name = "uvicorn", extras = ["standard"], specifier = ">=0.24.0" },
] ]
@@ -810,6 +812,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/1c/fc/9ba22f01b5cdacc8f5ed0d22304718d2c758fce3fd49a5372b886a86f37c/sqlalchemy-2.0.41-py3-none-any.whl", hash = "sha256:57df5dc6fdb5ed1a88a1ed2195fd31927e705cad62dedd86b46972752a80f576", size = 1911224, upload-time = "2025-05-14T17:39:42.154Z" }, { url = "https://files.pythonhosted.org/packages/1c/fc/9ba22f01b5cdacc8f5ed0d22304718d2c758fce3fd49a5372b886a86f37c/sqlalchemy-2.0.41-py3-none-any.whl", hash = "sha256:57df5dc6fdb5ed1a88a1ed2195fd31927e705cad62dedd86b46972752a80f576", size = 1911224, upload-time = "2025-05-14T17:39:42.154Z" },
] ]
[[package]]
name = "sqlmodel"
version = "0.0.24"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pydantic" },
{ name = "sqlalchemy" },
]
sdist = { url = "https://files.pythonhosted.org/packages/86/4b/c2ad0496f5bdc6073d9b4cef52be9c04f2b37a5773441cc6600b1857648b/sqlmodel-0.0.24.tar.gz", hash = "sha256:cc5c7613c1a5533c9c7867e1aab2fd489a76c9e8a061984da11b4e613c182423", size = 116780, upload-time = "2025-03-07T05:43:32.887Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/16/91/484cd2d05569892b7fef7f5ceab3bc89fb0f8a8c0cde1030d383dbc5449c/sqlmodel-0.0.24-py3-none-any.whl", hash = "sha256:6778852f09370908985b667d6a3ab92910d0d5ec88adcaf23dbc242715ff7193", size = 28622, upload-time = "2025-03-07T05:43:30.37Z" },
]
[[package]] [[package]]
name = "starlette" name = "starlette"
version = "0.47.2" version = "0.47.2"