refactor(database): use asyncio
This commit is contained in:
49
app/auth.py
49
app/auth.py
@@ -14,7 +14,8 @@ from app.database import (
|
|||||||
import bcrypt
|
import bcrypt
|
||||||
from jose import JWTError, jwt
|
from jose import JWTError, jwt
|
||||||
from passlib.context import CryptContext
|
from passlib.context import CryptContext
|
||||||
from sqlmodel import Session, select
|
from sqlmodel import select
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
# 密码哈希上下文
|
# 密码哈希上下文
|
||||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||||
@@ -70,7 +71,9 @@ def get_password_hash(password: str) -> str:
|
|||||||
return pw_bcrypt.decode()
|
return pw_bcrypt.decode()
|
||||||
|
|
||||||
|
|
||||||
def authenticate_user_legacy(db: Session, name: str, password: str) -> DBUser | None:
|
async def authenticate_user_legacy(
|
||||||
|
db: AsyncSession, name: str, password: str
|
||||||
|
) -> DBUser | None:
|
||||||
"""
|
"""
|
||||||
验证用户身份 - 使用类似 from_login 的逻辑
|
验证用户身份 - 使用类似 from_login 的逻辑
|
||||||
"""
|
"""
|
||||||
@@ -79,7 +82,7 @@ def authenticate_user_legacy(db: Session, name: str, password: str) -> DBUser |
|
|||||||
|
|
||||||
# 2. 根据用户名查找用户
|
# 2. 根据用户名查找用户
|
||||||
statement = select(DBUser).where(DBUser.name == name)
|
statement = select(DBUser).where(DBUser.name == name)
|
||||||
user = db.exec(statement).first()
|
user = (await db.exec(statement)).first()
|
||||||
if not user:
|
if not user:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -107,9 +110,11 @@ def authenticate_user_legacy(db: Session, name: str, password: str) -> DBUser |
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def authenticate_user(db: Session, username: str, password: str) -> DBUser | None:
|
async def authenticate_user(
|
||||||
|
db: AsyncSession, username: str, password: str
|
||||||
|
) -> DBUser | None:
|
||||||
"""验证用户身份"""
|
"""验证用户身份"""
|
||||||
return authenticate_user_legacy(db, username, password)
|
return await authenticate_user_legacy(db, username, password)
|
||||||
|
|
||||||
|
|
||||||
def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
|
def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
|
||||||
@@ -147,24 +152,28 @@ def verify_token(token: str) -> dict | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def store_token(
|
async def store_token(
|
||||||
db: Session, user_id: int, access_token: str, refresh_token: str, expires_in: int
|
db: AsyncSession,
|
||||||
|
user_id: int,
|
||||||
|
access_token: str,
|
||||||
|
refresh_token: str,
|
||||||
|
expires_in: int,
|
||||||
) -> OAuthToken:
|
) -> OAuthToken:
|
||||||
"""存储令牌到数据库"""
|
"""存储令牌到数据库"""
|
||||||
expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
|
expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
|
||||||
|
|
||||||
# 删除用户的旧令牌
|
# 删除用户的旧令牌
|
||||||
statement = select(OAuthToken).where(OAuthToken.user_id == user_id)
|
statement = select(OAuthToken).where(OAuthToken.user_id == user_id)
|
||||||
old_tokens = db.exec(statement).all()
|
old_tokens = (await db.exec(statement)).all()
|
||||||
for token in old_tokens:
|
for token in old_tokens:
|
||||||
db.delete(token)
|
await db.delete(token)
|
||||||
|
|
||||||
# 检查是否有重复的 access_token
|
# 检查是否有重复的 access_token
|
||||||
duplicate_token = db.exec(
|
duplicate_token = (
|
||||||
select(OAuthToken).where(OAuthToken.access_token == access_token)
|
await db.exec(select(OAuthToken).where(OAuthToken.access_token == access_token))
|
||||||
).first()
|
).first()
|
||||||
if duplicate_token:
|
if duplicate_token:
|
||||||
db.delete(duplicate_token)
|
await db.delete(duplicate_token)
|
||||||
|
|
||||||
# 创建新令牌记录
|
# 创建新令牌记录
|
||||||
token_record = OAuthToken(
|
token_record = OAuthToken(
|
||||||
@@ -174,24 +183,28 @@ def store_token(
|
|||||||
expires_at=expires_at,
|
expires_at=expires_at,
|
||||||
)
|
)
|
||||||
db.add(token_record)
|
db.add(token_record)
|
||||||
db.commit()
|
await db.commit()
|
||||||
db.refresh(token_record)
|
await db.refresh(token_record)
|
||||||
return token_record
|
return token_record
|
||||||
|
|
||||||
|
|
||||||
def get_token_by_access_token(db: Session, access_token: str) -> OAuthToken | None:
|
async def get_token_by_access_token(
|
||||||
|
db: AsyncSession, access_token: str
|
||||||
|
) -> OAuthToken | None:
|
||||||
"""根据访问令牌获取令牌记录"""
|
"""根据访问令牌获取令牌记录"""
|
||||||
statement = select(OAuthToken).where(
|
statement = select(OAuthToken).where(
|
||||||
OAuthToken.access_token == access_token,
|
OAuthToken.access_token == access_token,
|
||||||
OAuthToken.expires_at > datetime.utcnow(),
|
OAuthToken.expires_at > datetime.utcnow(),
|
||||||
)
|
)
|
||||||
return db.exec(statement).first()
|
return (await db.exec(statement)).first()
|
||||||
|
|
||||||
|
|
||||||
def get_token_by_refresh_token(db: Session, refresh_token: str) -> OAuthToken | None:
|
async def get_token_by_refresh_token(
|
||||||
|
db: AsyncSession, refresh_token: str
|
||||||
|
) -> OAuthToken | None:
|
||||||
"""根据刷新令牌获取令牌记录"""
|
"""根据刷新令牌获取令牌记录"""
|
||||||
statement = select(OAuthToken).where(
|
statement = select(OAuthToken).where(
|
||||||
OAuthToken.refresh_token == refresh_token,
|
OAuthToken.refresh_token == refresh_token,
|
||||||
OAuthToken.expires_at > datetime.utcnow(),
|
OAuthToken.expires_at > datetime.utcnow(),
|
||||||
)
|
)
|
||||||
return db.exec(statement).first()
|
return (await db.exec(statement)).first()
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ load_dotenv()
|
|||||||
class Settings:
|
class Settings:
|
||||||
# 数据库设置
|
# 数据库设置
|
||||||
DATABASE_URL: str = os.getenv(
|
DATABASE_URL: str = os.getenv(
|
||||||
"DATABASE_URL", "mysql+pymysql://root:password@localhost:3306/osu_api"
|
"DATABASE_URL", "mysql+aiomysql://root:password@localhost:3306/osu_api"
|
||||||
)
|
)
|
||||||
REDIS_URL: str = os.getenv("REDIS_URL", "redis://localhost:6379/0")
|
REDIS_URL: str = os.getenv("REDIS_URL", "redis://localhost:6379/0")
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ class User(SQLModel, table=True):
|
|||||||
__tablename__ = "users" # pyright: ignore[reportAssignmentType]
|
__tablename__ = "users" # pyright: ignore[reportAssignmentType]
|
||||||
|
|
||||||
# 主键
|
# 主键
|
||||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
id: int = Field(default=None, primary_key=True, index=True, nullable=False)
|
||||||
|
|
||||||
# 基本信息(匹配 migrations 中的结构)
|
# 基本信息(匹配 migrations 中的结构)
|
||||||
name: str = Field(max_length=32, unique=True, index=True) # 用户名
|
name: str = Field(max_length=32, unique=True, index=True) # 用户名
|
||||||
|
|||||||
@@ -1,2 +1,4 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from .database import get_db as get_db
|
from .database import get_db as get_db
|
||||||
from .user import get_current_user as get_current_user
|
from .user import get_current_user as get_current_user
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from sqlmodel import Session, create_engine
|
from sqlalchemy.ext.asyncio import create_async_engine
|
||||||
|
from sqlmodel import SQLModel
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import redis
|
import redis
|
||||||
@@ -9,7 +11,7 @@ except ImportError:
|
|||||||
from app.config import settings
|
from app.config import settings
|
||||||
|
|
||||||
# 数据库引擎
|
# 数据库引擎
|
||||||
engine = create_engine(settings.DATABASE_URL)
|
engine = create_async_engine(settings.DATABASE_URL)
|
||||||
|
|
||||||
# Redis 连接
|
# Redis 连接
|
||||||
if redis:
|
if redis:
|
||||||
@@ -19,11 +21,16 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
# 数据库依赖
|
# 数据库依赖
|
||||||
def get_db():
|
async def get_db():
|
||||||
with Session(engine) as session:
|
async with AsyncSession(engine) as session:
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
async def create_tables():
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(SQLModel.metadata.create_all)
|
||||||
|
|
||||||
|
|
||||||
# Redis 依赖
|
# Redis 依赖
|
||||||
def get_redis():
|
def get_redis():
|
||||||
return redis_client
|
return redis_client
|
||||||
|
|||||||
@@ -9,14 +9,16 @@ from .database import get_db
|
|||||||
|
|
||||||
from fastapi import Depends, HTTPException
|
from fastapi import Depends, HTTPException
|
||||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||||
from sqlmodel import Session, select
|
from sqlalchemy.orm import joinedload, selectinload
|
||||||
|
from sqlmodel import select
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
security = HTTPBearer()
|
security = HTTPBearer()
|
||||||
|
|
||||||
|
|
||||||
async def get_current_user(
|
async def get_current_user(
|
||||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||||
db: Session = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
) -> DBUser:
|
) -> DBUser:
|
||||||
"""获取当前认证用户"""
|
"""获取当前认证用户"""
|
||||||
token = credentials.credentials
|
token = credentials.credentials
|
||||||
@@ -27,9 +29,31 @@ async def get_current_user(
|
|||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
async def get_current_user_by_token(token: str, db: Session) -> DBUser | None:
|
async def get_current_user_by_token(token: str, db: AsyncSession) -> DBUser | None:
|
||||||
token_record = get_token_by_access_token(db, token)
|
token_record = await get_token_by_access_token(db, token)
|
||||||
if not token_record:
|
if not token_record:
|
||||||
return None
|
return None
|
||||||
user = db.exec(select(DBUser).where(DBUser.id == token_record.user_id)).first()
|
user = (
|
||||||
|
await db.exec(
|
||||||
|
select(DBUser)
|
||||||
|
.options(
|
||||||
|
joinedload(DBUser.lazer_profile), # pyright: ignore[reportArgumentType]
|
||||||
|
joinedload(DBUser.lazer_counts), # pyright: ignore[reportArgumentType]
|
||||||
|
joinedload(DBUser.daily_challenge_stats), # pyright: ignore[reportArgumentType]
|
||||||
|
joinedload(DBUser.avatar), # pyright: ignore[reportArgumentType]
|
||||||
|
selectinload(DBUser.lazer_statistics), # pyright: ignore[reportArgumentType]
|
||||||
|
selectinload(DBUser.lazer_achievements), # pyright: ignore[reportArgumentType]
|
||||||
|
selectinload(DBUser.lazer_profile_sections), # pyright: ignore[reportArgumentType]
|
||||||
|
selectinload(DBUser.statistics), # pyright: ignore[reportArgumentType]
|
||||||
|
selectinload(DBUser.team_membership), # pyright: ignore[reportArgumentType]
|
||||||
|
selectinload(DBUser.rank_history), # pyright: ignore[reportArgumentType]
|
||||||
|
selectinload(DBUser.active_banners), # pyright: ignore[reportArgumentType]
|
||||||
|
selectinload(DBUser.lazer_badges), # pyright: ignore[reportArgumentType]
|
||||||
|
selectinload(DBUser.lazer_monthly_playcounts), # pyright: ignore[reportArgumentType]
|
||||||
|
selectinload(DBUser.lazer_previous_usernames), # pyright: ignore[reportArgumentType]
|
||||||
|
selectinload(DBUser.lazer_replays_watched), # pyright: ignore[reportArgumentType]
|
||||||
|
)
|
||||||
|
.where(DBUser.id == token_record.user_id)
|
||||||
|
)
|
||||||
|
).first()
|
||||||
return user
|
return user
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from app.dependencies import get_db
|
|||||||
from app.models.oauth import TokenResponse
|
from app.models.oauth import TokenResponse
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Form, HTTPException
|
from fastapi import APIRouter, Depends, Form, HTTPException
|
||||||
from sqlmodel import Session
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
router = APIRouter(tags=["osu! OAuth 认证"])
|
router = APIRouter(tags=["osu! OAuth 认证"])
|
||||||
|
|
||||||
@@ -28,7 +28,7 @@ async def oauth_token(
|
|||||||
username: str | None = Form(None),
|
username: str | None = Form(None),
|
||||||
password: str | None = Form(None),
|
password: str | None = Form(None),
|
||||||
refresh_token: str | None = Form(None),
|
refresh_token: str | None = Form(None),
|
||||||
db: Session = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""OAuth 令牌端点"""
|
"""OAuth 令牌端点"""
|
||||||
# 验证客户端凭据
|
# 验证客户端凭据
|
||||||
@@ -46,7 +46,7 @@ async def oauth_token(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 验证用户
|
# 验证用户
|
||||||
user = authenticate_user(db, username, password)
|
user = await authenticate_user(db, username, password)
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(status_code=401, detail="Invalid username or password")
|
raise HTTPException(status_code=401, detail="Invalid username or password")
|
||||||
|
|
||||||
@@ -58,9 +58,9 @@ async def oauth_token(
|
|||||||
refresh_token_str = generate_refresh_token()
|
refresh_token_str = generate_refresh_token()
|
||||||
|
|
||||||
# 存储令牌
|
# 存储令牌
|
||||||
store_token(
|
await store_token(
|
||||||
db,
|
db,
|
||||||
getattr(user, "id"),
|
user.id,
|
||||||
access_token,
|
access_token,
|
||||||
refresh_token_str,
|
refresh_token_str,
|
||||||
settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||||
@@ -80,7 +80,7 @@ async def oauth_token(
|
|||||||
raise HTTPException(status_code=400, detail="Refresh token required")
|
raise HTTPException(status_code=400, detail="Refresh token required")
|
||||||
|
|
||||||
# 验证刷新令牌
|
# 验证刷新令牌
|
||||||
token_record = get_token_by_refresh_token(db, refresh_token)
|
token_record =await get_token_by_refresh_token(db, refresh_token)
|
||||||
if not token_record:
|
if not token_record:
|
||||||
raise HTTPException(status_code=401, detail="Invalid refresh token")
|
raise HTTPException(status_code=401, detail="Invalid refresh token")
|
||||||
|
|
||||||
@@ -92,10 +92,9 @@ async def oauth_token(
|
|||||||
new_refresh_token = generate_refresh_token()
|
new_refresh_token = generate_refresh_token()
|
||||||
|
|
||||||
# 更新令牌
|
# 更新令牌
|
||||||
user_id = int(getattr(token_record, "user_id"))
|
await store_token(
|
||||||
store_token(
|
|
||||||
db,
|
db,
|
||||||
user_id,
|
token_record.user_id,
|
||||||
access_token,
|
access_token,
|
||||||
new_refresh_token,
|
new_refresh_token,
|
||||||
settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from app.database import (
|
|||||||
BeatmapResp,
|
BeatmapResp,
|
||||||
User as DBUser,
|
User as DBUser,
|
||||||
)
|
)
|
||||||
|
from app.database.beatmapset import Beatmapset
|
||||||
from app.dependencies.database import get_db
|
from app.dependencies.database import get_db
|
||||||
from app.dependencies.user import get_current_user
|
from app.dependencies.user import get_current_user
|
||||||
|
|
||||||
@@ -12,16 +13,24 @@ from .api_router import router
|
|||||||
|
|
||||||
from fastapi import Depends, HTTPException, Query
|
from fastapi import Depends, HTTPException, Query
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlmodel import Session, col, select
|
from sqlalchemy.orm import joinedload
|
||||||
|
from sqlmodel import col, select
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
|
||||||
@router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp)
|
@router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp)
|
||||||
async def get_beatmap(
|
async def get_beatmap(
|
||||||
bid: int,
|
bid: int,
|
||||||
current_user: DBUser = Depends(get_current_user),
|
current_user: DBUser = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
beatmap = db.exec(select(Beatmap).where(Beatmap.id == bid)).first()
|
beatmap = (
|
||||||
|
await db.exec(
|
||||||
|
select(Beatmap)
|
||||||
|
.options(joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType]
|
||||||
|
.where(Beatmap.id == bid)
|
||||||
|
)
|
||||||
|
).first()
|
||||||
if not beatmap:
|
if not beatmap:
|
||||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||||
return BeatmapResp.from_db(beatmap)
|
return BeatmapResp.from_db(beatmap)
|
||||||
@@ -36,16 +45,30 @@ class BatchGetResp(BaseModel):
|
|||||||
async def batch_get_beatmaps(
|
async def batch_get_beatmaps(
|
||||||
b_ids: list[int] = Query(alias="id", default_factory=list),
|
b_ids: list[int] = Query(alias="id", default_factory=list),
|
||||||
current_user: DBUser = Depends(get_current_user),
|
current_user: DBUser = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
if not b_ids:
|
if not b_ids:
|
||||||
# select 50 beatmaps by last_updated
|
# select 50 beatmaps by last_updated
|
||||||
beatmaps = db.exec(
|
beatmaps = (
|
||||||
select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50)
|
await db.exec(
|
||||||
|
select(Beatmap)
|
||||||
|
.options(
|
||||||
|
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
|
||||||
|
)
|
||||||
|
.order_by(col(Beatmap.last_updated).desc())
|
||||||
|
.limit(50)
|
||||||
|
)
|
||||||
).all()
|
).all()
|
||||||
else:
|
else:
|
||||||
beatmaps = db.exec(
|
beatmaps = (
|
||||||
select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50)
|
await db.exec(
|
||||||
|
select(Beatmap)
|
||||||
|
.options(
|
||||||
|
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
|
||||||
|
)
|
||||||
|
.where(col(Beatmap.id).in_(b_ids))
|
||||||
|
.limit(50)
|
||||||
|
)
|
||||||
).all()
|
).all()
|
||||||
|
|
||||||
return BatchGetResp(beatmaps=[BeatmapResp.from_db(bm) for bm in beatmaps])
|
return BatchGetResp(beatmaps=[BeatmapResp.from_db(bm) for bm in beatmaps])
|
||||||
|
|||||||
@@ -11,16 +11,24 @@ from app.dependencies.user import get_current_user
|
|||||||
from .api_router import router
|
from .api_router import router
|
||||||
|
|
||||||
from fastapi import Depends, HTTPException
|
from fastapi import Depends, HTTPException
|
||||||
from sqlmodel import Session, select
|
from sqlalchemy.orm import selectinload
|
||||||
|
from sqlmodel import select
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
|
||||||
@router.get("/beatmapsets/{sid}", tags=["beatmapset"], response_model=BeatmapsetResp)
|
@router.get("/beatmapsets/{sid}", tags=["beatmapset"], response_model=BeatmapsetResp)
|
||||||
async def get_beatmapset(
|
async def get_beatmapset(
|
||||||
sid: int,
|
sid: int,
|
||||||
current_user: DBUser = Depends(get_current_user),
|
current_user: DBUser = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
beatmapset = db.exec(select(Beatmapset).where(Beatmapset.id == sid)).first()
|
beatmapset = (
|
||||||
|
await db.exec(
|
||||||
|
select(Beatmapset)
|
||||||
|
.options(selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType]
|
||||||
|
.where(Beatmapset.id == sid)
|
||||||
|
)
|
||||||
|
).first()
|
||||||
if not beatmapset:
|
if not beatmapset:
|
||||||
raise HTTPException(status_code=404, detail="Beatmapset not found")
|
raise HTTPException(status_code=404, detail="Beatmapset not found")
|
||||||
return BeatmapsetResp.from_db(beatmapset)
|
return BeatmapsetResp.from_db(beatmapset)
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from typing import Literal
|
|||||||
from app.database import (
|
from app.database import (
|
||||||
User as DBUser,
|
User as DBUser,
|
||||||
)
|
)
|
||||||
from app.dependencies import get_current_user, get_db
|
from app.dependencies import get_current_user
|
||||||
from app.models.user import (
|
from app.models.user import (
|
||||||
User as ApiUser,
|
User as ApiUser,
|
||||||
)
|
)
|
||||||
@@ -14,7 +14,6 @@ from app.utils import convert_db_user_to_api_user
|
|||||||
from .api_router import router
|
from .api_router import router
|
||||||
|
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/me/{ruleset}", response_model=ApiUser)
|
@router.get("/me/{ruleset}", response_model=ApiUser)
|
||||||
@@ -22,9 +21,8 @@ from sqlalchemy.orm import Session
|
|||||||
async def get_user_info_default(
|
async def get_user_info_default(
|
||||||
ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu",
|
ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu",
|
||||||
current_user: DBUser = Depends(get_current_user),
|
current_user: DBUser = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
"""获取当前用户信息(默认使用osu模式)"""
|
"""获取当前用户信息(默认使用osu模式)"""
|
||||||
# 默认使用osu模式
|
# 默认使用osu模式
|
||||||
api_user = convert_db_user_to_api_user(current_user, ruleset, db)
|
api_user = await convert_db_user_to_api_user(current_user, ruleset)
|
||||||
return api_user
|
return api_user
|
||||||
|
|||||||
@@ -1 +1,3 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from .router import router as signalr_router
|
from .router import router as signalr_router
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from logging import info
|
|
||||||
import time
|
import time
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
import uuid
|
import uuid
|
||||||
@@ -9,15 +8,14 @@ import uuid
|
|||||||
from app.database import User as DBUser
|
from app.database import User as DBUser
|
||||||
from app.dependencies import get_current_user
|
from app.dependencies import get_current_user
|
||||||
from app.dependencies.database import get_db
|
from app.dependencies.database import get_db
|
||||||
from app.dependencies.user import get_current_user_by_token, security
|
from app.dependencies.user import get_current_user_by_token
|
||||||
from app.models.signalr import NegotiateResponse, Transport
|
from app.models.signalr import NegotiateResponse, Transport
|
||||||
from app.router.signalr.packet import SEP
|
from app.router.signalr.packet import SEP
|
||||||
|
|
||||||
from .hub import Hubs
|
from .hub import Hubs
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Header, Query, WebSocket
|
from fastapi import APIRouter, Depends, Header, Query, WebSocket
|
||||||
from fastapi.security import HTTPAuthorizationCredentials
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@@ -48,7 +46,7 @@ async def connect(
|
|||||||
websocket: WebSocket,
|
websocket: WebSocket,
|
||||||
id: str,
|
id: str,
|
||||||
authorization: str = Header(...),
|
authorization: str = Header(...),
|
||||||
db: Session = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
token = authorization[7:]
|
token = authorization[7:]
|
||||||
user_id = id.split(":")[0]
|
user_id = id.split(":")[0]
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Any, Callable, ForwardRef, cast
|
from typing import Any, ForwardRef, cast
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/pydantic/pydantic/blob/main/pydantic/v1/typing.py#L56-L66
|
# https://github.com/pydantic/pydantic/blob/main/pydantic/v1/typing.py#L56-L66
|
||||||
|
|||||||
@@ -23,11 +23,9 @@ from app.models.user import (
|
|||||||
UserAchievement,
|
UserAchievement,
|
||||||
)
|
)
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
|
async def convert_db_user_to_api_user(
|
||||||
def convert_db_user_to_api_user(
|
db_user: DBUser, ruleset: str = "osu"
|
||||||
db_user: DBUser, ruleset: str = "osu", db_session: Session | None = None
|
|
||||||
) -> User:
|
) -> User:
|
||||||
"""将数据库用户模型转换为API用户模型(使用 Lazer 表)"""
|
"""将数据库用户模型转换为API用户模型(使用 Lazer 表)"""
|
||||||
|
|
||||||
|
|||||||
@@ -5,419 +5,84 @@ osu! API 模拟服务器的示例数据填充脚本
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from app.auth import get_password_hash
|
from app.auth import get_password_hash
|
||||||
from app.database import (
|
from app.database import (
|
||||||
DailyChallengeStats,
|
|
||||||
LazerUserAchievement,
|
|
||||||
LazerUserStatistics,
|
|
||||||
RankHistory,
|
|
||||||
User,
|
User,
|
||||||
)
|
)
|
||||||
from app.dependencies.database import engine, get_db
|
from app.dependencies.database import create_tables, engine
|
||||||
|
|
||||||
from sqlmodel import SQLModel
|
from sqlmodel import select
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
# 创建所有表
|
|
||||||
SQLModel.metadata.create_all(bind=engine)
|
|
||||||
|
|
||||||
|
|
||||||
def create_sample_user():
|
async def create_sample_user():
|
||||||
"""创建示例用户数据"""
|
"""创建示例用户数据"""
|
||||||
with next(get_db()) as db:
|
async with AsyncSession(engine) as session:
|
||||||
# 检查用户是否已存在
|
async with session.begin():
|
||||||
from sqlmodel import select
|
|
||||||
|
|
||||||
statement = select(User).where(User.name == "Googujiang")
|
# 检查用户是否已存在
|
||||||
existing_user = db.exec(statement).first()
|
statement = select(User).where(User.name == "Googujiang")
|
||||||
if existing_user:
|
result = await session.execute(statement)
|
||||||
print("示例用户已存在,跳过创建")
|
existing_user = result.scalars().first()
|
||||||
return existing_user
|
if existing_user:
|
||||||
|
print("示例用户已存在,跳过创建")
|
||||||
|
return existing_user
|
||||||
|
|
||||||
# 当前时间戳
|
# 当前时间戳
|
||||||
current_timestamp = int(time.time())
|
current_timestamp = int(time.time())
|
||||||
join_timestamp = int(datetime(2019, 11, 29, 17, 23, 13).timestamp())
|
join_timestamp = int(datetime(2019, 11, 29, 17, 23, 13).timestamp())
|
||||||
last_visit_timestamp = int(datetime(2025, 7, 18, 16, 31, 29).timestamp())
|
last_visit_timestamp = int(datetime(2025, 7, 18, 16, 31, 29).timestamp())
|
||||||
|
|
||||||
# 创建用户
|
# 创建用户
|
||||||
user = User(
|
user = User(
|
||||||
name="Googujiang",
|
name="Googujiang",
|
||||||
safe_name="googujiang", # 安全用户名(小写)
|
safe_name="googujiang", # 安全用户名(小写)
|
||||||
email="googujiang@example.com",
|
email="googujiang@example.com",
|
||||||
priv=1, # 默认权限
|
priv=1, # 默认权限
|
||||||
pw_bcrypt=get_password_hash("password123"), # 使用新的哈希方式
|
pw_bcrypt=get_password_hash("password123"), # 使用新的哈希方式
|
||||||
country="JP",
|
country="JP",
|
||||||
silence_end=0,
|
silence_end=0,
|
||||||
donor_end=0,
|
donor_end=0,
|
||||||
creation_time=join_timestamp,
|
creation_time=join_timestamp,
|
||||||
latest_activity=last_visit_timestamp,
|
latest_activity=last_visit_timestamp,
|
||||||
clan_id=0,
|
clan_id=0,
|
||||||
clan_priv=0,
|
clan_priv=0,
|
||||||
preferred_mode=0, # 0 = osu!
|
preferred_mode=0, # 0 = osu!
|
||||||
play_style=0,
|
play_style=0,
|
||||||
custom_badge_name=None,
|
custom_badge_name=None,
|
||||||
custom_badge_icon=None,
|
custom_badge_icon=None,
|
||||||
userpage_content="「世界に忘れられた」",
|
userpage_content="「世界に忘れられた」",
|
||||||
api_key=None,
|
api_key=None,
|
||||||
# # 兼容性字段
|
)
|
||||||
# avatar_url="https://a.ppy.sh/15651670?1732362658.jpeg",
|
|
||||||
# cover_url="https://assets.ppy.sh/user-profile-covers/15651670/0fc7b77adef39765a570e7f535bc383e5a848850d41a8943f8857984330b8bc6.jpeg",
|
|
||||||
# has_supported=True,
|
|
||||||
# interests="「世界に忘れられた」",
|
|
||||||
# location="咕谷国",
|
|
||||||
# website="https://gmoe.cc",
|
|
||||||
# playstyle=["mouse", "keyboard", "tablet"],
|
|
||||||
# profile_order=[
|
|
||||||
# "me",
|
|
||||||
# "recent_activity",
|
|
||||||
# "top_ranks",
|
|
||||||
# "medals",
|
|
||||||
# "historical",
|
|
||||||
# "beatmaps",
|
|
||||||
# "kudosu",
|
|
||||||
# ],
|
|
||||||
# beatmap_playcounts_count=3306,
|
|
||||||
# favourite_beatmapset_count=15,
|
|
||||||
# follower_count=98,
|
|
||||||
# graveyard_beatmapset_count=7,
|
|
||||||
# mapping_follower_count=1,
|
|
||||||
# previous_usernames=["hehejun"],
|
|
||||||
# monthly_playcounts=[
|
|
||||||
# {"start_date": "2019-11-01", "count": 43},
|
|
||||||
# {"start_date": "2020-04-01", "count": 216},
|
|
||||||
# {"start_date": "2020-05-01", "count": 656},
|
|
||||||
# {"start_date": "2020-07-01", "count": 158},
|
|
||||||
# {"start_date": "2020-08-01", "count": 174},
|
|
||||||
# {"start_date": "2020-10-01", "count": 13},
|
|
||||||
# {"start_date": "2020-11-01", "count": 52},
|
|
||||||
# {"start_date": "2020-12-01", "count": 140},
|
|
||||||
# {"start_date": "2021-01-01", "count": 359},
|
|
||||||
# {"start_date": "2021-02-01", "count": 452},
|
|
||||||
# {"start_date": "2021-03-01", "count": 77},
|
|
||||||
# {"start_date": "2021-04-01", "count": 114},
|
|
||||||
# {"start_date": "2021-05-01", "count": 270},
|
|
||||||
# {"start_date": "2021-06-01", "count": 148},
|
|
||||||
# {"start_date": "2021-07-01", "count": 246},
|
|
||||||
# {"start_date": "2021-08-01", "count": 56},
|
|
||||||
# {"start_date": "2021-09-01", "count": 136},
|
|
||||||
# {"start_date": "2021-10-01", "count": 45},
|
|
||||||
# {"start_date": "2021-11-01", "count": 98},
|
|
||||||
# {"start_date": "2021-12-01", "count": 54},
|
|
||||||
# {"start_date": "2022-01-01", "count": 88},
|
|
||||||
# {"start_date": "2022-02-01", "count": 45},
|
|
||||||
# {"start_date": "2022-03-01", "count": 6},
|
|
||||||
# {"start_date": "2022-04-01", "count": 54},
|
|
||||||
# {"start_date": "2022-05-01", "count": 105},
|
|
||||||
# {"start_date": "2022-06-01", "count": 37},
|
|
||||||
# {"start_date": "2022-07-01", "count": 88},
|
|
||||||
# {"start_date": "2022-08-01", "count": 7},
|
|
||||||
# {"start_date": "2022-09-01", "count": 9},
|
|
||||||
# {"start_date": "2022-10-01", "count": 6},
|
|
||||||
# {"start_date": "2022-11-01", "count": 2},
|
|
||||||
# {"start_date": "2022-12-01", "count": 16},
|
|
||||||
# {"start_date": "2023-01-01", "count": 7},
|
|
||||||
# {"start_date": "2023-04-01", "count": 16},
|
|
||||||
# {"start_date": "2023-05-01", "count": 3},
|
|
||||||
# {"start_date": "2023-06-01", "count": 8},
|
|
||||||
# {"start_date": "2023-07-01", "count": 23},
|
|
||||||
# {"start_date": "2023-08-01", "count": 3},
|
|
||||||
# {"start_date": "2023-09-01", "count": 1},
|
|
||||||
# {"start_date": "2023-10-01", "count": 25},
|
|
||||||
# {"start_date": "2023-11-01", "count": 160},
|
|
||||||
# {"start_date": "2023-12-01", "count": 306},
|
|
||||||
# {"start_date": "2024-01-01", "count": 735},
|
|
||||||
# {"start_date": "2024-02-01", "count": 420},
|
|
||||||
# {"start_date": "2024-03-01", "count": 549},
|
|
||||||
# {"start_date": "2024-04-01", "count": 466},
|
|
||||||
# {"start_date": "2024-05-01", "count": 333},
|
|
||||||
# {"start_date": "2024-06-01", "count": 1126},
|
|
||||||
# {"start_date": "2024-07-01", "count": 534},
|
|
||||||
# {"start_date": "2024-08-01", "count": 280},
|
|
||||||
# {"start_date": "2024-09-01", "count": 116},
|
|
||||||
# {"start_date": "2024-10-01", "count": 120},
|
|
||||||
# {"start_date": "2024-11-01", "count": 332},
|
|
||||||
# {"start_date": "2024-12-01", "count": 243},
|
|
||||||
# {"start_date": "2025-01-01", "count": 122},
|
|
||||||
# {"start_date": "2025-02-01", "count": 379},
|
|
||||||
# {"start_date": "2025-03-01", "count": 278},
|
|
||||||
# {"start_date": "2025-04-01", "count": 296},
|
|
||||||
# {"start_date": "2025-05-01", "count": 964},
|
|
||||||
# {"start_date": "2025-06-01", "count": 821},
|
|
||||||
# {"start_date": "2025-07-01", "count": 230},
|
|
||||||
# ],
|
|
||||||
)
|
|
||||||
|
|
||||||
db.add(user)
|
session.add(user)
|
||||||
db.commit()
|
await session.commit()
|
||||||
db.refresh(user)
|
await session.refresh(user)
|
||||||
|
|
||||||
# 确保用户ID存在
|
# 确保用户ID存在
|
||||||
if user.id is None:
|
if user.id is None:
|
||||||
raise ValueError("User ID is None after saving to database")
|
raise ValueError("User ID is None after saving to database")
|
||||||
|
|
||||||
# 创建 osu! 模式统计
|
print(f"成功创建示例用户: {user.name} (ID: {user.id})")
|
||||||
osu_stats = LazerUserStatistics(
|
print(f"安全用户名: {user.safe_name}")
|
||||||
user_id=user.id,
|
print(f"邮箱: {user.email}")
|
||||||
mode="osu",
|
print(f"国家: {user.country}")
|
||||||
count_100=276274,
|
return user
|
||||||
count_300=1932068,
|
|
||||||
count_50=32776,
|
|
||||||
count_miss=111064,
|
|
||||||
level_current=97,
|
|
||||||
level_progress=96,
|
|
||||||
global_rank=298026,
|
|
||||||
country_rank=11221,
|
|
||||||
pp=2826.26,
|
|
||||||
ranked_score=4415081049,
|
|
||||||
hit_accuracy=95.7168,
|
|
||||||
play_count=12711,
|
|
||||||
play_time=836529,
|
|
||||||
total_score=12390140370,
|
|
||||||
total_hits=2241118,
|
|
||||||
maximum_combo=1859,
|
|
||||||
replays_watched_by_others=0,
|
|
||||||
is_ranked=True,
|
|
||||||
grade_ss=14,
|
|
||||||
grade_ssh=3,
|
|
||||||
grade_s=322,
|
|
||||||
grade_sh=11,
|
|
||||||
grade_a=757,
|
|
||||||
rank_highest=295701,
|
|
||||||
rank_highest_updated_at=datetime(2025, 7, 2, 17, 30, 21),
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建 taiko 模式统计
|
|
||||||
taiko_stats = LazerUserStatistics(
|
|
||||||
user_id=user.id,
|
|
||||||
mode="taiko",
|
|
||||||
count_100=160,
|
|
||||||
count_300=154,
|
|
||||||
count_50=0,
|
|
||||||
count_miss=480,
|
|
||||||
level_current=2,
|
|
||||||
level_progress=49,
|
|
||||||
global_rank=None,
|
|
||||||
pp=0,
|
|
||||||
ranked_score=0,
|
|
||||||
hit_accuracy=0,
|
|
||||||
play_count=6,
|
|
||||||
play_time=217,
|
|
||||||
total_score=79301,
|
|
||||||
total_hits=314,
|
|
||||||
maximum_combo=0,
|
|
||||||
replays_watched_by_others=0,
|
|
||||||
is_ranked=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建 fruits 模式统计
|
|
||||||
fruits_stats = LazerUserStatistics(
|
|
||||||
user_id=user.id,
|
|
||||||
mode="fruits",
|
|
||||||
count_100=109,
|
|
||||||
count_300=1613,
|
|
||||||
count_50=1861,
|
|
||||||
count_miss=328,
|
|
||||||
level_current=6,
|
|
||||||
level_progress=14,
|
|
||||||
global_rank=None,
|
|
||||||
pp=0,
|
|
||||||
ranked_score=343854,
|
|
||||||
hit_accuracy=89.4779,
|
|
||||||
play_count=19,
|
|
||||||
play_time=669,
|
|
||||||
total_score=1362651,
|
|
||||||
total_hits=3583,
|
|
||||||
maximum_combo=75,
|
|
||||||
replays_watched_by_others=0,
|
|
||||||
is_ranked=False,
|
|
||||||
grade_a=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建 mania 模式统计
|
|
||||||
mania_stats = LazerUserStatistics(
|
|
||||||
user_id=user.id,
|
|
||||||
mode="mania",
|
|
||||||
count_100=7867,
|
|
||||||
count_300=12104,
|
|
||||||
count_50=991,
|
|
||||||
count_miss=2951,
|
|
||||||
level_current=12,
|
|
||||||
level_progress=89,
|
|
||||||
global_rank=660670,
|
|
||||||
pp=25.3784,
|
|
||||||
ranked_score=3812295,
|
|
||||||
hit_accuracy=77.9316,
|
|
||||||
play_count=85,
|
|
||||||
play_time=4834,
|
|
||||||
total_score=13454470,
|
|
||||||
total_hits=20962,
|
|
||||||
maximum_combo=573,
|
|
||||||
replays_watched_by_others=0,
|
|
||||||
is_ranked=True,
|
|
||||||
grade_a=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
db.add_all([osu_stats, taiko_stats, fruits_stats, mania_stats])
|
|
||||||
|
|
||||||
# 创建每日挑战统计
|
|
||||||
daily_challenge = DailyChallengeStats(
|
|
||||||
user_id=user.id,
|
|
||||||
daily_streak_best=1,
|
|
||||||
daily_streak_current=0,
|
|
||||||
last_update=datetime(2025, 6, 21, 0, 0, 0),
|
|
||||||
last_weekly_streak=datetime(2025, 6, 19, 0, 0, 0),
|
|
||||||
playcount=1,
|
|
||||||
top_10p_placements=0,
|
|
||||||
top_50p_placements=0,
|
|
||||||
weekly_streak_best=1,
|
|
||||||
weekly_streak_current=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
db.add(daily_challenge)
|
|
||||||
|
|
||||||
# 创建排名历史 (最近90天的数据)
|
|
||||||
rank_data = [
|
|
||||||
322806,
|
|
||||||
323092,
|
|
||||||
323341,
|
|
||||||
323616,
|
|
||||||
323853,
|
|
||||||
324106,
|
|
||||||
324378,
|
|
||||||
324676,
|
|
||||||
324958,
|
|
||||||
325254,
|
|
||||||
325492,
|
|
||||||
325780,
|
|
||||||
326075,
|
|
||||||
326356,
|
|
||||||
326586,
|
|
||||||
326845,
|
|
||||||
327067,
|
|
||||||
327286,
|
|
||||||
327526,
|
|
||||||
327778,
|
|
||||||
328039,
|
|
||||||
328347,
|
|
||||||
328631,
|
|
||||||
328858,
|
|
||||||
329323,
|
|
||||||
329557,
|
|
||||||
329809,
|
|
||||||
329911,
|
|
||||||
330188,
|
|
||||||
330425,
|
|
||||||
330650,
|
|
||||||
330881,
|
|
||||||
331068,
|
|
||||||
331325,
|
|
||||||
331575,
|
|
||||||
331816,
|
|
||||||
332061,
|
|
||||||
328959,
|
|
||||||
315648,
|
|
||||||
315881,
|
|
||||||
308784,
|
|
||||||
309023,
|
|
||||||
309252,
|
|
||||||
309433,
|
|
||||||
309537,
|
|
||||||
309364,
|
|
||||||
309548,
|
|
||||||
308957,
|
|
||||||
309182,
|
|
||||||
309426,
|
|
||||||
309607,
|
|
||||||
309831,
|
|
||||||
310054,
|
|
||||||
310269,
|
|
||||||
310485,
|
|
||||||
310714,
|
|
||||||
310956,
|
|
||||||
310924,
|
|
||||||
311125,
|
|
||||||
311203,
|
|
||||||
311422,
|
|
||||||
311640,
|
|
||||||
303091,
|
|
||||||
303309,
|
|
||||||
303500,
|
|
||||||
303691,
|
|
||||||
303758,
|
|
||||||
303750,
|
|
||||||
303957,
|
|
||||||
299867,
|
|
||||||
300088,
|
|
||||||
300273,
|
|
||||||
300457,
|
|
||||||
295799,
|
|
||||||
295976,
|
|
||||||
296153,
|
|
||||||
296350,
|
|
||||||
296566,
|
|
||||||
296756,
|
|
||||||
296933,
|
|
||||||
297141,
|
|
||||||
297314,
|
|
||||||
297480,
|
|
||||||
297114,
|
|
||||||
297296,
|
|
||||||
297480,
|
|
||||||
297645,
|
|
||||||
297815,
|
|
||||||
297993,
|
|
||||||
298026,
|
|
||||||
]
|
|
||||||
|
|
||||||
rank_history = RankHistory(user_id=user.id, mode="osu", rank_data=rank_data)
|
|
||||||
|
|
||||||
db.add(rank_history)
|
|
||||||
|
|
||||||
# 创建一些成就
|
|
||||||
achievements = [
|
|
||||||
LazerUserAchievement(
|
|
||||||
user_id=user.id,
|
|
||||||
achievement_id=336,
|
|
||||||
achieved_at=datetime(2025, 6, 21, 19, 6, 32),
|
|
||||||
),
|
|
||||||
LazerUserAchievement(
|
|
||||||
user_id=user.id,
|
|
||||||
achievement_id=319,
|
|
||||||
achieved_at=datetime(2025, 6, 1, 0, 52, 0),
|
|
||||||
),
|
|
||||||
LazerUserAchievement(
|
|
||||||
user_id=user.id,
|
|
||||||
achievement_id=222,
|
|
||||||
achieved_at=datetime(2025, 5, 28, 12, 24, 37),
|
|
||||||
),
|
|
||||||
LazerUserAchievement(
|
|
||||||
user_id=user.id,
|
|
||||||
achievement_id=38,
|
|
||||||
achieved_at=datetime(2024, 7, 5, 15, 43, 23),
|
|
||||||
),
|
|
||||||
LazerUserAchievement(
|
|
||||||
user_id=user.id,
|
|
||||||
achievement_id=67,
|
|
||||||
achieved_at=datetime(2024, 6, 24, 5, 6, 44),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
db.add_all(achievements)
|
|
||||||
|
|
||||||
db.commit()
|
|
||||||
print(f"成功创建示例用户: {user.name} (ID: {user.id})")
|
|
||||||
print(f"安全用户名: {user.safe_name}")
|
|
||||||
print(f"邮箱: {user.email}")
|
|
||||||
print(f"国家: {user.country}")
|
|
||||||
return user
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
async def main():
|
||||||
print("开始创建示例数据...")
|
print("开始创建示例数据...")
|
||||||
user = create_sample_user()
|
await create_tables()
|
||||||
|
user = await create_sample_user()
|
||||||
print("示例数据创建完成!")
|
print("示例数据创建完成!")
|
||||||
print(f"用户名: {user.name}")
|
print(f"用户名: {user.name}")
|
||||||
print("密码: password123")
|
print("密码: password123")
|
||||||
print("现在您可以使用这些凭据来测试API了。")
|
print("现在您可以使用这些凭据来测试API了。")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
|
|||||||
17
main.py
17
main.py
@@ -1,24 +1,31 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.dependencies.database import engine
|
from app.dependencies.database import create_tables
|
||||||
from app.router import api_router, auth_router, signalr_router
|
from app.router import api_router, auth_router, signalr_router
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from sqlmodel import SQLModel
|
|
||||||
|
|
||||||
# 注意: 表结构现在通过 migrations 管理,不再自动创建
|
# 注意: 表结构现在通过 migrations 管理,不再自动创建
|
||||||
# 如需创建表,请运行: python quick_sync.py
|
# 如需创建表,请运行: python quick_sync.py
|
||||||
|
|
||||||
app = FastAPI(title="osu! API 模拟服务器", version="1.0.0")
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
# on startup
|
||||||
|
await create_tables()
|
||||||
|
# on shutdown
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(title="osu! API 模拟服务器", version="1.0.0", lifespan=lifespan)
|
||||||
app.include_router(api_router, prefix="/api/v2")
|
app.include_router(api_router, prefix="/api/v2")
|
||||||
app.include_router(signalr_router, prefix="/signalr")
|
app.include_router(signalr_router, prefix="/signalr")
|
||||||
app.include_router(auth_router)
|
app.include_router(auth_router)
|
||||||
|
|
||||||
SQLModel.metadata.create_all(bind=engine)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
async def root():
|
async def root():
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ description = "Add your description here"
|
|||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.11"
|
requires-python = ">=3.11"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"aiomysql>=0.2.0",
|
||||||
"alembic>=1.12.1",
|
"alembic>=1.12.1",
|
||||||
"bcrypt>=4.1.2",
|
"bcrypt>=4.1.2",
|
||||||
"cryptography>=41.0.7",
|
"cryptography>=41.0.7",
|
||||||
@@ -12,7 +13,6 @@ dependencies = [
|
|||||||
"msgpack>=1.1.1",
|
"msgpack>=1.1.1",
|
||||||
"passlib[bcrypt]>=1.7.4",
|
"passlib[bcrypt]>=1.7.4",
|
||||||
"pydantic[email]>=2.5.0",
|
"pydantic[email]>=2.5.0",
|
||||||
"pymysql>=1.1.0",
|
|
||||||
"python-dotenv>=1.0.0",
|
"python-dotenv>=1.0.0",
|
||||||
"python-jose[cryptography]>=3.3.0",
|
"python-jose[cryptography]>=3.3.0",
|
||||||
"python-multipart>=0.0.6",
|
"python-multipart>=0.0.6",
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ fastapi==0.104.1
|
|||||||
uvicorn[standard]==0.24.0
|
uvicorn[standard]==0.24.0
|
||||||
sqlalchemy==2.0.23
|
sqlalchemy==2.0.23
|
||||||
alembic==1.12.1
|
alembic==1.12.1
|
||||||
pymysql==1.1.0
|
|
||||||
cryptography==41.0.7
|
cryptography==41.0.7
|
||||||
redis==5.0.1
|
redis==5.0.1
|
||||||
python-jose[cryptography]==3.3.0
|
python-jose[cryptography]==3.3.0
|
||||||
@@ -11,3 +10,4 @@ python-multipart==0.0.6
|
|||||||
pydantic[email]==2.5.0
|
pydantic[email]==2.5.0
|
||||||
python-dotenv==1.0.0
|
python-dotenv==1.0.0
|
||||||
bcrypt==4.1.2
|
bcrypt==4.1.2
|
||||||
|
aiomysql==0.2.0
|
||||||
|
|||||||
140
test_lazer.py
140
test_lazer.py
@@ -12,108 +12,106 @@ import sys
|
|||||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
|
||||||
from app.database import User
|
from app.database import User
|
||||||
from app.dependencies.database import get_db
|
from app.dependencies.database import engine
|
||||||
from app.utils import convert_db_user_to_api_user
|
from app.utils import convert_db_user_to_api_user
|
||||||
|
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
|
||||||
def test_lazer_tables():
|
async def test_lazer_tables():
|
||||||
"""测试 lazer 表的基本功能"""
|
"""测试 lazer 表的基本功能"""
|
||||||
print("测试 Lazer API 表支持...")
|
print("测试 Lazer API 表支持...")
|
||||||
|
|
||||||
# 获取数据库会话
|
async with AsyncSession(engine) as session:
|
||||||
db_gen = get_db()
|
async with session.begin():
|
||||||
db = next(db_gen)
|
try:
|
||||||
|
# 测试查询用户
|
||||||
|
statement = select(User)
|
||||||
|
result = await session.execute(statement)
|
||||||
|
user = result.scalars().first()
|
||||||
|
if not user:
|
||||||
|
print("❌ 没有找到用户,请先同步数据")
|
||||||
|
return False
|
||||||
|
|
||||||
try:
|
print(f"✓ 找到用户: {user.name} (ID: {user.id})")
|
||||||
# 测试查询用户
|
|
||||||
statement = select(User)
|
|
||||||
user = db.exec(statement).first()
|
|
||||||
if not user:
|
|
||||||
print("❌ 没有找到用户,请先同步数据")
|
|
||||||
return False
|
|
||||||
|
|
||||||
print(f"✓ 找到用户: {user.name} (ID: {user.id})")
|
# 测试 lazer 资料
|
||||||
|
if user.lazer_profile:
|
||||||
|
print(
|
||||||
|
f"✓ 用户有 lazer 资料: 支持者={user.lazer_profile.is_supporter}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("⚠ 用户没有 lazer 资料,将使用默认值")
|
||||||
|
|
||||||
# 测试 lazer 资料
|
# 测试 lazer 统计
|
||||||
if user.lazer_profile:
|
osu_stats = None
|
||||||
print(f"✓ 用户有 lazer 资料: 支持者={user.lazer_profile.is_supporter}")
|
for stat in user.lazer_statistics:
|
||||||
else:
|
if stat.mode == "osu":
|
||||||
print("⚠ 用户没有 lazer 资料,将使用默认值")
|
osu_stats = stat
|
||||||
|
break
|
||||||
|
|
||||||
# 测试 lazer 统计
|
if osu_stats:
|
||||||
osu_stats = None
|
print(
|
||||||
for stat in user.lazer_statistics:
|
f"✓ 用户有 osu! 统计: PP={osu_stats.pp}, "
|
||||||
if stat.mode == "osu":
|
f"游戏次数={osu_stats.play_count}"
|
||||||
osu_stats = stat
|
)
|
||||||
break
|
else:
|
||||||
|
print("⚠ 用户没有 osu! 统计,将使用默认值")
|
||||||
|
|
||||||
if osu_stats:
|
# 测试转换为 API 格式
|
||||||
print(
|
api_user = convert_db_user_to_api_user(user, "osu")
|
||||||
f"✓ 用户有 osu! 统计: PP={osu_stats.pp}, "
|
print("✓ 成功转换为 API 用户格式")
|
||||||
f"游戏次数={osu_stats.play_count}"
|
print(f" - 用户名: {api_user.username}")
|
||||||
)
|
print(f" - 国家: {api_user.country_code}")
|
||||||
else:
|
print(f" - PP: {api_user.statistics.pp}")
|
||||||
print("⚠ 用户没有 osu! 统计,将使用默认值")
|
print(f" - 是否支持者: {api_user.is_supporter}")
|
||||||
|
|
||||||
# 测试转换为 API 格式
|
return True
|
||||||
api_user = convert_db_user_to_api_user(user, "osu", db)
|
|
||||||
print("✓ 成功转换为 API 用户格式")
|
|
||||||
print(f" - 用户名: {api_user.username}")
|
|
||||||
print(f" - 国家: {api_user.country_code}")
|
|
||||||
print(f" - PP: {api_user.statistics.pp}")
|
|
||||||
print(f" - 是否支持者: {api_user.is_supporter}")
|
|
||||||
|
|
||||||
return True
|
except Exception as e:
|
||||||
|
print(f"❌ 测试失败: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
except Exception as e:
|
traceback.print_exc()
|
||||||
print(f"❌ 测试失败: {e}")
|
return False
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
return False
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
|
|
||||||
def test_authentication():
|
async def test_authentication():
|
||||||
"""测试认证功能"""
|
"""测试认证功能"""
|
||||||
print("\n测试认证功能...")
|
print("\n测试认证功能...")
|
||||||
|
|
||||||
db_gen = get_db()
|
async with AsyncSession(engine) as session:
|
||||||
db = next(db_gen)
|
async with session.begin():
|
||||||
|
try:
|
||||||
|
# 尝试认证第一个用户
|
||||||
|
statement = select(User)
|
||||||
|
result = await session.execute(statement)
|
||||||
|
user = result.scalars().first()
|
||||||
|
if not user:
|
||||||
|
print("❌ 没有用户进行认证测试")
|
||||||
|
return False
|
||||||
|
|
||||||
try:
|
print(f"✓ 测试用户: {user.name}")
|
||||||
# 尝试认证第一个用户
|
print("⚠ 注意: 实际密码认证需要正确的密码")
|
||||||
statement = select(User)
|
|
||||||
user = db.exec(statement).first()
|
|
||||||
if not user:
|
|
||||||
print("❌ 没有用户进行认证测试")
|
|
||||||
return False
|
|
||||||
|
|
||||||
print(f"✓ 测试用户: {user.name}")
|
return True
|
||||||
print("⚠ 注意: 实际密码认证需要正确的密码")
|
|
||||||
|
|
||||||
return True
|
except Exception as e:
|
||||||
|
print(f"❌ 认证测试失败: {e}")
|
||||||
except Exception as e:
|
return False
|
||||||
print(f"❌ 认证测试失败: {e}")
|
|
||||||
return False
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
async def main():
|
||||||
"""主测试函数"""
|
"""主测试函数"""
|
||||||
print("Lazer API 系统测试")
|
print("Lazer API 系统测试")
|
||||||
print("=" * 40)
|
print("=" * 40)
|
||||||
|
|
||||||
# 测试表连接
|
# 测试表连接
|
||||||
success1 = test_lazer_tables()
|
success1 = await test_lazer_tables()
|
||||||
|
|
||||||
# 测试认证
|
# 测试认证
|
||||||
success2 = test_authentication()
|
success2 = await test_authentication()
|
||||||
|
|
||||||
print("\n" + "=" * 40)
|
print("\n" + "=" * 40)
|
||||||
if success1 and success2:
|
if success1 and success2:
|
||||||
@@ -130,4 +128,6 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
|
|||||||
16
uv.lock
generated
16
uv.lock
generated
@@ -2,6 +2,18 @@ version = 1
|
|||||||
revision = 2
|
revision = 2
|
||||||
requires-python = ">=3.11"
|
requires-python = ">=3.11"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "aiomysql"
|
||||||
|
version = "0.2.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "pymysql" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/67/76/2c5b55e4406a1957ffdfd933a94c2517455291c97d2b81cec6813754791a/aiomysql-0.2.0.tar.gz", hash = "sha256:558b9c26d580d08b8c5fd1be23c5231ce3aeff2dadad989540fee740253deb67", size = 114706, upload-time = "2023-06-11T19:57:53.608Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/42/87/c982ee8b333c85b8ae16306387d703a1fcdfc81a2f3f15a24820ab1a512d/aiomysql-0.2.0-py3-none-any.whl", hash = "sha256:b7c26da0daf23a5ec5e0b133c03d20657276e4eae9b73e040b72787f6f6ade0a", size = 44215, upload-time = "2023-06-11T19:57:51.09Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "alembic"
|
name = "alembic"
|
||||||
version = "1.16.4"
|
version = "1.16.4"
|
||||||
@@ -462,6 +474,7 @@ name = "osu-lazer-api"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
source = { virtual = "." }
|
source = { virtual = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
{ name = "aiomysql" },
|
||||||
{ name = "alembic" },
|
{ name = "alembic" },
|
||||||
{ name = "bcrypt" },
|
{ name = "bcrypt" },
|
||||||
{ name = "cryptography" },
|
{ name = "cryptography" },
|
||||||
@@ -469,7 +482,6 @@ dependencies = [
|
|||||||
{ name = "msgpack" },
|
{ name = "msgpack" },
|
||||||
{ name = "passlib", extra = ["bcrypt"] },
|
{ name = "passlib", extra = ["bcrypt"] },
|
||||||
{ name = "pydantic", extra = ["email"] },
|
{ name = "pydantic", extra = ["email"] },
|
||||||
{ name = "pymysql" },
|
|
||||||
{ name = "python-dotenv" },
|
{ name = "python-dotenv" },
|
||||||
{ name = "python-jose", extra = ["cryptography"] },
|
{ name = "python-jose", extra = ["cryptography"] },
|
||||||
{ name = "python-multipart" },
|
{ name = "python-multipart" },
|
||||||
@@ -487,6 +499,7 @@ dev = [
|
|||||||
|
|
||||||
[package.metadata]
|
[package.metadata]
|
||||||
requires-dist = [
|
requires-dist = [
|
||||||
|
{ name = "aiomysql", specifier = ">=0.2.0" },
|
||||||
{ name = "alembic", specifier = ">=1.12.1" },
|
{ name = "alembic", specifier = ">=1.12.1" },
|
||||||
{ name = "bcrypt", specifier = ">=4.1.2" },
|
{ name = "bcrypt", specifier = ">=4.1.2" },
|
||||||
{ name = "cryptography", specifier = ">=41.0.7" },
|
{ name = "cryptography", specifier = ">=41.0.7" },
|
||||||
@@ -494,7 +507,6 @@ requires-dist = [
|
|||||||
{ name = "msgpack", specifier = ">=1.1.1" },
|
{ name = "msgpack", specifier = ">=1.1.1" },
|
||||||
{ name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.4" },
|
{ name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.4" },
|
||||||
{ name = "pydantic", extras = ["email"], specifier = ">=2.5.0" },
|
{ name = "pydantic", extras = ["email"], specifier = ">=2.5.0" },
|
||||||
{ name = "pymysql", specifier = ">=1.1.0" },
|
|
||||||
{ name = "python-dotenv", specifier = ">=1.0.0" },
|
{ name = "python-dotenv", specifier = ">=1.0.0" },
|
||||||
{ name = "python-jose", extras = ["cryptography"], specifier = ">=3.3.0" },
|
{ name = "python-jose", extras = ["cryptography"], specifier = ">=3.3.0" },
|
||||||
{ name = "python-multipart", specifier = ">=0.0.6" },
|
{ name = "python-multipart", specifier = ">=0.0.6" },
|
||||||
|
|||||||
Reference in New Issue
Block a user