Merge branch 'main' into score-database-model
This commit is contained in:
16
.pre-commit-config.yaml
Normal file
16
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
default_install_hook_types: [pre-commit, prepare-commit-msg]
|
||||||
|
ci:
|
||||||
|
autofix_commit_msg: "chore(deps): auto fix by pre-commit hooks"
|
||||||
|
autofix_prs: true
|
||||||
|
autoupdate_branch: master
|
||||||
|
autoupdate_schedule: monthly
|
||||||
|
autoupdate_commit_msg: "chore(deps): auto update by pre-commit hooks"
|
||||||
|
repos:
|
||||||
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
|
rev: v0.12.2
|
||||||
|
hooks:
|
||||||
|
- id: ruff-check
|
||||||
|
args: [--fix]
|
||||||
|
stages: [pre-commit]
|
||||||
|
- id: ruff-format
|
||||||
|
stages: [pre-commit]
|
||||||
49
app/auth.py
49
app/auth.py
@@ -14,7 +14,8 @@ from app.database import (
|
|||||||
import bcrypt
|
import bcrypt
|
||||||
from jose import JWTError, jwt
|
from jose import JWTError, jwt
|
||||||
from passlib.context import CryptContext
|
from passlib.context import CryptContext
|
||||||
from sqlmodel import Session, select
|
from sqlmodel import select
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
# 密码哈希上下文
|
# 密码哈希上下文
|
||||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||||
@@ -70,7 +71,9 @@ def get_password_hash(password: str) -> str:
|
|||||||
return pw_bcrypt.decode()
|
return pw_bcrypt.decode()
|
||||||
|
|
||||||
|
|
||||||
def authenticate_user_legacy(db: Session, name: str, password: str) -> DBUser | None:
|
async def authenticate_user_legacy(
|
||||||
|
db: AsyncSession, name: str, password: str
|
||||||
|
) -> DBUser | None:
|
||||||
"""
|
"""
|
||||||
验证用户身份 - 使用类似 from_login 的逻辑
|
验证用户身份 - 使用类似 from_login 的逻辑
|
||||||
"""
|
"""
|
||||||
@@ -79,7 +82,7 @@ def authenticate_user_legacy(db: Session, name: str, password: str) -> DBUser |
|
|||||||
|
|
||||||
# 2. 根据用户名查找用户
|
# 2. 根据用户名查找用户
|
||||||
statement = select(DBUser).where(DBUser.name == name)
|
statement = select(DBUser).where(DBUser.name == name)
|
||||||
user = db.exec(statement).first()
|
user = (await db.exec(statement)).first()
|
||||||
if not user:
|
if not user:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -107,9 +110,11 @@ def authenticate_user_legacy(db: Session, name: str, password: str) -> DBUser |
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def authenticate_user(db: Session, username: str, password: str) -> DBUser | None:
|
async def authenticate_user(
|
||||||
|
db: AsyncSession, username: str, password: str
|
||||||
|
) -> DBUser | None:
|
||||||
"""验证用户身份"""
|
"""验证用户身份"""
|
||||||
return authenticate_user_legacy(db, username, password)
|
return await authenticate_user_legacy(db, username, password)
|
||||||
|
|
||||||
|
|
||||||
def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
|
def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
|
||||||
@@ -147,24 +152,28 @@ def verify_token(token: str) -> dict | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def store_token(
|
async def store_token(
|
||||||
db: Session, user_id: int, access_token: str, refresh_token: str, expires_in: int
|
db: AsyncSession,
|
||||||
|
user_id: int,
|
||||||
|
access_token: str,
|
||||||
|
refresh_token: str,
|
||||||
|
expires_in: int,
|
||||||
) -> OAuthToken:
|
) -> OAuthToken:
|
||||||
"""存储令牌到数据库"""
|
"""存储令牌到数据库"""
|
||||||
expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
|
expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
|
||||||
|
|
||||||
# 删除用户的旧令牌
|
# 删除用户的旧令牌
|
||||||
statement = select(OAuthToken).where(OAuthToken.user_id == user_id)
|
statement = select(OAuthToken).where(OAuthToken.user_id == user_id)
|
||||||
old_tokens = db.exec(statement).all()
|
old_tokens = (await db.exec(statement)).all()
|
||||||
for token in old_tokens:
|
for token in old_tokens:
|
||||||
db.delete(token)
|
await db.delete(token)
|
||||||
|
|
||||||
# 检查是否有重复的 access_token
|
# 检查是否有重复的 access_token
|
||||||
duplicate_token = db.exec(
|
duplicate_token = (
|
||||||
select(OAuthToken).where(OAuthToken.access_token == access_token)
|
await db.exec(select(OAuthToken).where(OAuthToken.access_token == access_token))
|
||||||
).first()
|
).first()
|
||||||
if duplicate_token:
|
if duplicate_token:
|
||||||
db.delete(duplicate_token)
|
await db.delete(duplicate_token)
|
||||||
|
|
||||||
# 创建新令牌记录
|
# 创建新令牌记录
|
||||||
token_record = OAuthToken(
|
token_record = OAuthToken(
|
||||||
@@ -174,24 +183,28 @@ def store_token(
|
|||||||
expires_at=expires_at,
|
expires_at=expires_at,
|
||||||
)
|
)
|
||||||
db.add(token_record)
|
db.add(token_record)
|
||||||
db.commit()
|
await db.commit()
|
||||||
db.refresh(token_record)
|
await db.refresh(token_record)
|
||||||
return token_record
|
return token_record
|
||||||
|
|
||||||
|
|
||||||
def get_token_by_access_token(db: Session, access_token: str) -> OAuthToken | None:
|
async def get_token_by_access_token(
|
||||||
|
db: AsyncSession, access_token: str
|
||||||
|
) -> OAuthToken | None:
|
||||||
"""根据访问令牌获取令牌记录"""
|
"""根据访问令牌获取令牌记录"""
|
||||||
statement = select(OAuthToken).where(
|
statement = select(OAuthToken).where(
|
||||||
OAuthToken.access_token == access_token,
|
OAuthToken.access_token == access_token,
|
||||||
OAuthToken.expires_at > datetime.utcnow(),
|
OAuthToken.expires_at > datetime.utcnow(),
|
||||||
)
|
)
|
||||||
return db.exec(statement).first()
|
return (await db.exec(statement)).first()
|
||||||
|
|
||||||
|
|
||||||
def get_token_by_refresh_token(db: Session, refresh_token: str) -> OAuthToken | None:
|
async def get_token_by_refresh_token(
|
||||||
|
db: AsyncSession, refresh_token: str
|
||||||
|
) -> OAuthToken | None:
|
||||||
"""根据刷新令牌获取令牌记录"""
|
"""根据刷新令牌获取令牌记录"""
|
||||||
statement = select(OAuthToken).where(
|
statement = select(OAuthToken).where(
|
||||||
OAuthToken.refresh_token == refresh_token,
|
OAuthToken.refresh_token == refresh_token,
|
||||||
OAuthToken.expires_at > datetime.utcnow(),
|
OAuthToken.expires_at > datetime.utcnow(),
|
||||||
)
|
)
|
||||||
return db.exec(statement).first()
|
return (await db.exec(statement)).first()
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ load_dotenv()
|
|||||||
class Settings:
|
class Settings:
|
||||||
# 数据库设置
|
# 数据库设置
|
||||||
DATABASE_URL: str = os.getenv(
|
DATABASE_URL: str = os.getenv(
|
||||||
"DATABASE_URL", "mysql+pymysql://root:password@localhost:3306/osu_api"
|
"DATABASE_URL", "mysql+aiomysql://root:password@localhost:3306/osu_api"
|
||||||
)
|
)
|
||||||
REDIS_URL: str = os.getenv("REDIS_URL", "redis://localhost:6379/0")
|
REDIS_URL: str = os.getenv("REDIS_URL", "redis://localhost:6379/0")
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ from sqlmodel import Field, Relationship, SQLModel
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .user import User
|
from .user import User
|
||||||
|
|
||||||
|
|
||||||
class Team(SQLModel, table=True):
|
class Team(SQLModel, table=True):
|
||||||
__tablename__ = "teams" # pyright: ignore[reportAssignmentType]
|
__tablename__ = "teams" # pyright: ignore[reportAssignmentType]
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ class User(SQLModel, table=True):
|
|||||||
__tablename__ = "users" # pyright: ignore[reportAssignmentType]
|
__tablename__ = "users" # pyright: ignore[reportAssignmentType]
|
||||||
|
|
||||||
# 主键
|
# 主键
|
||||||
id: int | None = Field(default=None, primary_key=True, index=True)
|
id: int = Field(default=None, primary_key=True, index=True, nullable=False)
|
||||||
|
|
||||||
# 基本信息(匹配 migrations 中的结构)
|
# 基本信息(匹配 migrations 中的结构)
|
||||||
name: str = Field(max_length=32, unique=True, index=True) # 用户名
|
name: str = Field(max_length=32, unique=True, index=True) # 用户名
|
||||||
|
|||||||
@@ -1,2 +1,4 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from .database import get_db as get_db
|
from .database import get_db as get_db
|
||||||
from .user import get_current_user as get_current_user
|
from .user import get_current_user as get_current_user
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from sqlmodel import Session, create_engine
|
from sqlalchemy.ext.asyncio import create_async_engine
|
||||||
|
from sqlmodel import SQLModel
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import redis
|
import redis
|
||||||
@@ -9,7 +11,7 @@ except ImportError:
|
|||||||
from app.config import settings
|
from app.config import settings
|
||||||
|
|
||||||
# 数据库引擎
|
# 数据库引擎
|
||||||
engine = create_engine(settings.DATABASE_URL)
|
engine = create_async_engine(settings.DATABASE_URL)
|
||||||
|
|
||||||
# Redis 连接
|
# Redis 连接
|
||||||
if redis:
|
if redis:
|
||||||
@@ -19,11 +21,16 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
# 数据库依赖
|
# 数据库依赖
|
||||||
def get_db():
|
async def get_db():
|
||||||
with Session(engine) as session:
|
async with AsyncSession(engine) as session:
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
async def create_tables():
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(SQLModel.metadata.create_all)
|
||||||
|
|
||||||
|
|
||||||
# Redis 依赖
|
# Redis 依赖
|
||||||
def get_redis():
|
def get_redis():
|
||||||
return redis_client
|
return redis_client
|
||||||
|
|||||||
@@ -9,14 +9,16 @@ from .database import get_db
|
|||||||
|
|
||||||
from fastapi import Depends, HTTPException
|
from fastapi import Depends, HTTPException
|
||||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||||
from sqlmodel import Session, select
|
from sqlalchemy.orm import joinedload, selectinload
|
||||||
|
from sqlmodel import select
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
security = HTTPBearer()
|
security = HTTPBearer()
|
||||||
|
|
||||||
|
|
||||||
async def get_current_user(
|
async def get_current_user(
|
||||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||||
db: Session = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
) -> DBUser:
|
) -> DBUser:
|
||||||
"""获取当前认证用户"""
|
"""获取当前认证用户"""
|
||||||
token = credentials.credentials
|
token = credentials.credentials
|
||||||
@@ -27,9 +29,31 @@ async def get_current_user(
|
|||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
async def get_current_user_by_token(token: str, db: Session) -> DBUser | None:
|
async def get_current_user_by_token(token: str, db: AsyncSession) -> DBUser | None:
|
||||||
token_record = get_token_by_access_token(db, token)
|
token_record = await get_token_by_access_token(db, token)
|
||||||
if not token_record:
|
if not token_record:
|
||||||
return None
|
return None
|
||||||
user = db.exec(select(DBUser).where(DBUser.id == token_record.user_id)).first()
|
user = (
|
||||||
|
await db.exec(
|
||||||
|
select(DBUser)
|
||||||
|
.options(
|
||||||
|
joinedload(DBUser.lazer_profile), # pyright: ignore[reportArgumentType]
|
||||||
|
joinedload(DBUser.lazer_counts), # pyright: ignore[reportArgumentType]
|
||||||
|
joinedload(DBUser.daily_challenge_stats), # pyright: ignore[reportArgumentType]
|
||||||
|
joinedload(DBUser.avatar), # pyright: ignore[reportArgumentType]
|
||||||
|
selectinload(DBUser.lazer_statistics), # pyright: ignore[reportArgumentType]
|
||||||
|
selectinload(DBUser.lazer_achievements), # pyright: ignore[reportArgumentType]
|
||||||
|
selectinload(DBUser.lazer_profile_sections), # pyright: ignore[reportArgumentType]
|
||||||
|
selectinload(DBUser.statistics), # pyright: ignore[reportArgumentType]
|
||||||
|
selectinload(DBUser.team_membership), # pyright: ignore[reportArgumentType]
|
||||||
|
selectinload(DBUser.rank_history), # pyright: ignore[reportArgumentType]
|
||||||
|
selectinload(DBUser.active_banners), # pyright: ignore[reportArgumentType]
|
||||||
|
selectinload(DBUser.lazer_badges), # pyright: ignore[reportArgumentType]
|
||||||
|
selectinload(DBUser.lazer_monthly_playcounts), # pyright: ignore[reportArgumentType]
|
||||||
|
selectinload(DBUser.lazer_previous_usernames), # pyright: ignore[reportArgumentType]
|
||||||
|
selectinload(DBUser.lazer_replays_watched), # pyright: ignore[reportArgumentType]
|
||||||
|
)
|
||||||
|
.where(DBUser.id == token_record.user_id)
|
||||||
|
)
|
||||||
|
).first()
|
||||||
return user
|
return user
|
||||||
|
|||||||
@@ -4,14 +4,14 @@ import datetime
|
|||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import msgpack
|
|
||||||
from pydantic import Field, field_validator
|
|
||||||
|
|
||||||
from .signalr import MessagePackArrayModel
|
|
||||||
from .score import (
|
from .score import (
|
||||||
APIMod as APIModBase,
|
APIMod as APIModBase,
|
||||||
HitResult,
|
HitResult,
|
||||||
)
|
)
|
||||||
|
from .signalr import MessagePackArrayModel
|
||||||
|
|
||||||
|
import msgpack
|
||||||
|
from pydantic import Field, field_validator
|
||||||
|
|
||||||
|
|
||||||
class APIMod(APIModBase, MessagePackArrayModel): ...
|
class APIMod(APIModBase, MessagePackArrayModel): ...
|
||||||
|
|||||||
@@ -87,7 +87,7 @@ class UserAchievement(BaseModel):
|
|||||||
return LazerUserAchievement(
|
return LazerUserAchievement(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
achievement_id=self.achievement_id,
|
achievement_id=self.achievement_id,
|
||||||
achieved_at=self.achieved_at
|
achieved_at=self.achieved_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from app.dependencies import get_db
|
|||||||
from app.models.oauth import TokenResponse
|
from app.models.oauth import TokenResponse
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Form, HTTPException
|
from fastapi import APIRouter, Depends, Form, HTTPException
|
||||||
from sqlmodel import Session
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
router = APIRouter(tags=["osu! OAuth 认证"])
|
router = APIRouter(tags=["osu! OAuth 认证"])
|
||||||
|
|
||||||
@@ -28,7 +28,7 @@ async def oauth_token(
|
|||||||
username: str | None = Form(None),
|
username: str | None = Form(None),
|
||||||
password: str | None = Form(None),
|
password: str | None = Form(None),
|
||||||
refresh_token: str | None = Form(None),
|
refresh_token: str | None = Form(None),
|
||||||
db: Session = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""OAuth 令牌端点"""
|
"""OAuth 令牌端点"""
|
||||||
# 验证客户端凭据
|
# 验证客户端凭据
|
||||||
@@ -46,7 +46,7 @@ async def oauth_token(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 验证用户
|
# 验证用户
|
||||||
user = authenticate_user(db, username, password)
|
user = await authenticate_user(db, username, password)
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(status_code=401, detail="Invalid username or password")
|
raise HTTPException(status_code=401, detail="Invalid username or password")
|
||||||
|
|
||||||
@@ -58,9 +58,9 @@ async def oauth_token(
|
|||||||
refresh_token_str = generate_refresh_token()
|
refresh_token_str = generate_refresh_token()
|
||||||
|
|
||||||
# 存储令牌
|
# 存储令牌
|
||||||
store_token(
|
await store_token(
|
||||||
db,
|
db,
|
||||||
getattr(user, "id"),
|
user.id,
|
||||||
access_token,
|
access_token,
|
||||||
refresh_token_str,
|
refresh_token_str,
|
||||||
settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||||
@@ -80,7 +80,7 @@ async def oauth_token(
|
|||||||
raise HTTPException(status_code=400, detail="Refresh token required")
|
raise HTTPException(status_code=400, detail="Refresh token required")
|
||||||
|
|
||||||
# 验证刷新令牌
|
# 验证刷新令牌
|
||||||
token_record = get_token_by_refresh_token(db, refresh_token)
|
token_record = await get_token_by_refresh_token(db, refresh_token)
|
||||||
if not token_record:
|
if not token_record:
|
||||||
raise HTTPException(status_code=401, detail="Invalid refresh token")
|
raise HTTPException(status_code=401, detail="Invalid refresh token")
|
||||||
|
|
||||||
@@ -92,10 +92,9 @@ async def oauth_token(
|
|||||||
new_refresh_token = generate_refresh_token()
|
new_refresh_token = generate_refresh_token()
|
||||||
|
|
||||||
# 更新令牌
|
# 更新令牌
|
||||||
user_id = int(getattr(token_record, "user_id"))
|
await store_token(
|
||||||
store_token(
|
|
||||||
db,
|
db,
|
||||||
user_id,
|
token_record.user_id,
|
||||||
access_token,
|
access_token,
|
||||||
new_refresh_token,
|
new_refresh_token,
|
||||||
settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from app.database import (
|
|||||||
BeatmapResp,
|
BeatmapResp,
|
||||||
User as DBUser,
|
User as DBUser,
|
||||||
)
|
)
|
||||||
|
from app.database.beatmapset import Beatmapset
|
||||||
from app.dependencies.database import get_db
|
from app.dependencies.database import get_db
|
||||||
from app.dependencies.user import get_current_user
|
from app.dependencies.user import get_current_user
|
||||||
|
|
||||||
@@ -12,16 +13,24 @@ from .api_router import router
|
|||||||
|
|
||||||
from fastapi import Depends, HTTPException, Query
|
from fastapi import Depends, HTTPException, Query
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlmodel import Session, col, select
|
from sqlalchemy.orm import joinedload
|
||||||
|
from sqlmodel import col, select
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
|
||||||
@router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp)
|
@router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp)
|
||||||
async def get_beatmap(
|
async def get_beatmap(
|
||||||
bid: int,
|
bid: int,
|
||||||
current_user: DBUser = Depends(get_current_user),
|
current_user: DBUser = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
beatmap = db.exec(select(Beatmap).where(Beatmap.id == bid)).first()
|
beatmap = (
|
||||||
|
await db.exec(
|
||||||
|
select(Beatmap)
|
||||||
|
.options(joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType]
|
||||||
|
.where(Beatmap.id == bid)
|
||||||
|
)
|
||||||
|
).first()
|
||||||
if not beatmap:
|
if not beatmap:
|
||||||
raise HTTPException(status_code=404, detail="Beatmap not found")
|
raise HTTPException(status_code=404, detail="Beatmap not found")
|
||||||
return BeatmapResp.from_db(beatmap)
|
return BeatmapResp.from_db(beatmap)
|
||||||
@@ -36,16 +45,30 @@ class BatchGetResp(BaseModel):
|
|||||||
async def batch_get_beatmaps(
|
async def batch_get_beatmaps(
|
||||||
b_ids: list[int] = Query(alias="id", default_factory=list),
|
b_ids: list[int] = Query(alias="id", default_factory=list),
|
||||||
current_user: DBUser = Depends(get_current_user),
|
current_user: DBUser = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
if not b_ids:
|
if not b_ids:
|
||||||
# select 50 beatmaps by last_updated
|
# select 50 beatmaps by last_updated
|
||||||
beatmaps = db.exec(
|
beatmaps = (
|
||||||
select(Beatmap).order_by(col(Beatmap.last_updated).desc()).limit(50)
|
await db.exec(
|
||||||
|
select(Beatmap)
|
||||||
|
.options(
|
||||||
|
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
|
||||||
|
)
|
||||||
|
.order_by(col(Beatmap.last_updated).desc())
|
||||||
|
.limit(50)
|
||||||
|
)
|
||||||
).all()
|
).all()
|
||||||
else:
|
else:
|
||||||
beatmaps = db.exec(
|
beatmaps = (
|
||||||
select(Beatmap).where(col(Beatmap.id).in_(b_ids)).limit(50)
|
await db.exec(
|
||||||
|
select(Beatmap)
|
||||||
|
.options(
|
||||||
|
joinedload(Beatmap.beatmapset).selectinload(Beatmapset.beatmaps) # pyright: ignore[reportArgumentType]
|
||||||
|
)
|
||||||
|
.where(col(Beatmap.id).in_(b_ids))
|
||||||
|
.limit(50)
|
||||||
|
)
|
||||||
).all()
|
).all()
|
||||||
|
|
||||||
return BatchGetResp(beatmaps=[BeatmapResp.from_db(bm) for bm in beatmaps])
|
return BatchGetResp(beatmaps=[BeatmapResp.from_db(bm) for bm in beatmaps])
|
||||||
|
|||||||
@@ -11,16 +11,24 @@ from app.dependencies.user import get_current_user
|
|||||||
from .api_router import router
|
from .api_router import router
|
||||||
|
|
||||||
from fastapi import Depends, HTTPException
|
from fastapi import Depends, HTTPException
|
||||||
from sqlmodel import Session, select
|
from sqlalchemy.orm import selectinload
|
||||||
|
from sqlmodel import select
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
|
||||||
@router.get("/beatmapsets/{sid}", tags=["beatmapset"], response_model=BeatmapsetResp)
|
@router.get("/beatmapsets/{sid}", tags=["beatmapset"], response_model=BeatmapsetResp)
|
||||||
async def get_beatmapset(
|
async def get_beatmapset(
|
||||||
sid: int,
|
sid: int,
|
||||||
current_user: DBUser = Depends(get_current_user),
|
current_user: DBUser = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
beatmapset = db.exec(select(Beatmapset).where(Beatmapset.id == sid)).first()
|
beatmapset = (
|
||||||
|
await db.exec(
|
||||||
|
select(Beatmapset)
|
||||||
|
.options(selectinload(Beatmapset.beatmaps)) # pyright: ignore[reportArgumentType]
|
||||||
|
.where(Beatmapset.id == sid)
|
||||||
|
)
|
||||||
|
).first()
|
||||||
if not beatmapset:
|
if not beatmapset:
|
||||||
raise HTTPException(status_code=404, detail="Beatmapset not found")
|
raise HTTPException(status_code=404, detail="Beatmapset not found")
|
||||||
return BeatmapsetResp.from_db(beatmapset)
|
return BeatmapsetResp.from_db(beatmapset)
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from typing import Literal
|
|||||||
from app.database import (
|
from app.database import (
|
||||||
User as DBUser,
|
User as DBUser,
|
||||||
)
|
)
|
||||||
from app.dependencies import get_current_user, get_db
|
from app.dependencies import get_current_user
|
||||||
from app.models.user import (
|
from app.models.user import (
|
||||||
User as ApiUser,
|
User as ApiUser,
|
||||||
)
|
)
|
||||||
@@ -14,7 +14,6 @@ from app.utils import convert_db_user_to_api_user
|
|||||||
from .api_router import router
|
from .api_router import router
|
||||||
|
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/me/{ruleset}", response_model=ApiUser)
|
@router.get("/me/{ruleset}", response_model=ApiUser)
|
||||||
@@ -22,9 +21,8 @@ from sqlalchemy.orm import Session
|
|||||||
async def get_user_info_default(
|
async def get_user_info_default(
|
||||||
ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu",
|
ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu",
|
||||||
current_user: DBUser = Depends(get_current_user),
|
current_user: DBUser = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
):
|
||||||
"""获取当前用户信息(默认使用osu模式)"""
|
"""获取当前用户信息(默认使用osu模式)"""
|
||||||
# 默认使用osu模式
|
# 默认使用osu模式
|
||||||
api_user = convert_db_user_to_api_user(current_user, ruleset, db)
|
api_user = await convert_db_user_to_api_user(current_user, ruleset)
|
||||||
return api_user
|
return api_user
|
||||||
|
|||||||
@@ -1 +1,3 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from .router import router as signalr_router
|
from .router import router as signalr_router
|
||||||
|
|||||||
@@ -1,6 +1,10 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
class SignalRException(Exception):
|
class SignalRException(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class InvokeException(SignalRException):
|
class InvokeException(SignalRException):
|
||||||
def __init__(self, message: str) -> None:
|
def __init__(self, message: str) -> None:
|
||||||
self.message = message
|
self.message = message
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from .hub import Hub
|
from .hub import Hub
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from .hub import Hub
|
from .hub import Hub
|
||||||
|
|
||||||
|
|
||||||
class MultiplayerHub(Hub): ...
|
class MultiplayerHub(Hub): ...
|
||||||
|
|||||||
@@ -8,10 +8,8 @@ from .hub import Client, Hub
|
|||||||
class SpectatorHub(Hub):
|
class SpectatorHub(Hub):
|
||||||
async def BeginPlaySession(
|
async def BeginPlaySession(
|
||||||
self, client: Client, score_token: int, state: SpectatorState
|
self, client: Client, score_token: int, state: SpectatorState
|
||||||
) -> None:
|
) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
async def SendFrameData(
|
async def SendFrameData(
|
||||||
self, client: Client, frame_data: FrameDataBundle
|
self, client: Client, frame_data: FrameDataBundle
|
||||||
) -> None:
|
) -> None: ...
|
||||||
...
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ from enum import IntEnum
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import msgpack
|
import msgpack
|
||||||
from pydantic import BaseModel, model_validator
|
|
||||||
|
|
||||||
SEP = b"\x1e"
|
SEP = b"\x1e"
|
||||||
|
|
||||||
@@ -18,6 +17,7 @@ class PacketType(IntEnum):
|
|||||||
PING = 6
|
PING = 6
|
||||||
CLOSE = 7
|
CLOSE = 7
|
||||||
|
|
||||||
|
|
||||||
class ResultKind(IntEnum):
|
class ResultKind(IntEnum):
|
||||||
ERROR = 1
|
ERROR = 1
|
||||||
VOID = 2
|
VOID = 2
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from logging import info
|
|
||||||
import time
|
import time
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
import uuid
|
import uuid
|
||||||
@@ -9,15 +8,14 @@ import uuid
|
|||||||
from app.database import User as DBUser
|
from app.database import User as DBUser
|
||||||
from app.dependencies import get_current_user
|
from app.dependencies import get_current_user
|
||||||
from app.dependencies.database import get_db
|
from app.dependencies.database import get_db
|
||||||
from app.dependencies.user import get_current_user_by_token, security
|
from app.dependencies.user import get_current_user_by_token
|
||||||
from app.models.signalr import NegotiateResponse, Transport
|
from app.models.signalr import NegotiateResponse, Transport
|
||||||
from app.router.signalr.packet import SEP
|
from app.router.signalr.packet import SEP
|
||||||
|
|
||||||
from .hub import Hubs
|
from .hub import Hubs
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Header, Query, WebSocket
|
from fastapi import APIRouter, Depends, Header, Query, WebSocket
|
||||||
from fastapi.security import HTTPAuthorizationCredentials
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@@ -48,7 +46,7 @@ async def connect(
|
|||||||
websocket: WebSocket,
|
websocket: WebSocket,
|
||||||
id: str,
|
id: str,
|
||||||
authorization: str = Header(...),
|
authorization: str = Header(...),
|
||||||
db: Session = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
token = authorization[7:]
|
token = authorization[7:]
|
||||||
user_id = id.split(":")[0]
|
user_id = id.split(":")[0]
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Any, Callable, ForwardRef, cast
|
from typing import Any, ForwardRef, cast
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/pydantic/pydantic/blob/main/pydantic/v1/typing.py#L56-L66
|
# https://github.com/pydantic/pydantic/blob/main/pydantic/v1/typing.py#L56-L66
|
||||||
|
|||||||
39
app/utils.py
39
app/utils.py
@@ -11,7 +11,6 @@ from app.database import (
|
|||||||
from app.models.user import (
|
from app.models.user import (
|
||||||
Country,
|
Country,
|
||||||
Cover,
|
Cover,
|
||||||
DailyChallengeStats,
|
|
||||||
GradeCounts,
|
GradeCounts,
|
||||||
Kudosu,
|
Kudosu,
|
||||||
Level,
|
Level,
|
||||||
@@ -23,12 +22,8 @@ from app.models.user import (
|
|||||||
UserAchievement,
|
UserAchievement,
|
||||||
)
|
)
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
|
async def convert_db_user_to_api_user(db_user: DBUser, ruleset: str = "osu") -> User:
|
||||||
def convert_db_user_to_api_user(
|
|
||||||
db_user: DBUser, ruleset: str = "osu", db_session: Session | None = None
|
|
||||||
) -> User:
|
|
||||||
"""将数据库用户模型转换为API用户模型(使用 Lazer 表)"""
|
"""将数据库用户模型转换为API用户模型(使用 Lazer 表)"""
|
||||||
|
|
||||||
# 从db_user获取基本字段值
|
# 从db_user获取基本字段值
|
||||||
@@ -73,7 +68,7 @@ def convert_db_user_to_api_user(
|
|||||||
kudosu = Kudosu(available=0, total=0)
|
kudosu = Kudosu(available=0, total=0)
|
||||||
|
|
||||||
# 获取计数信息
|
# 获取计数信息
|
||||||
counts = LazerUserCounts(user_id=user_id)
|
# counts = LazerUserCounts(user_id=user_id)
|
||||||
|
|
||||||
# 转换统计信息
|
# 转换统计信息
|
||||||
statistics = Statistics(
|
statistics = Statistics(
|
||||||
@@ -178,21 +173,21 @@ def convert_db_user_to_api_user(
|
|||||||
rank_history = RankHistory(mode=ruleset, data=rank_history_data)
|
rank_history = RankHistory(mode=ruleset, data=rank_history_data)
|
||||||
|
|
||||||
# 转换每日挑战统计
|
# 转换每日挑战统计
|
||||||
daily_challenge_stats = None
|
# daily_challenge_stats = None
|
||||||
if db_user.daily_challenge_stats:
|
# if db_user.daily_challenge_stats:
|
||||||
dcs = db_user.daily_challenge_stats
|
# dcs = db_user.daily_challenge_stats
|
||||||
daily_challenge_stats = DailyChallengeStats(
|
# daily_challenge_stats = DailyChallengeStats(
|
||||||
daily_streak_best=dcs.daily_streak_best,
|
# daily_streak_best=dcs.daily_streak_best,
|
||||||
daily_streak_current=dcs.daily_streak_current,
|
# daily_streak_current=dcs.daily_streak_current,
|
||||||
last_update=dcs.last_update,
|
# last_update=dcs.last_update,
|
||||||
last_weekly_streak=dcs.last_weekly_streak,
|
# last_weekly_streak=dcs.last_weekly_streak,
|
||||||
playcount=dcs.playcount,
|
# playcount=dcs.playcount,
|
||||||
top_10p_placements=dcs.top_10p_placements,
|
# top_10p_placements=dcs.top_10p_placements,
|
||||||
top_50p_placements=dcs.top_50p_placements,
|
# top_50p_placements=dcs.top_50p_placements,
|
||||||
user_id=dcs.user_id,
|
# user_id=dcs.user_id,
|
||||||
weekly_streak_best=dcs.weekly_streak_best,
|
# weekly_streak_best=dcs.weekly_streak_best,
|
||||||
weekly_streak_current=dcs.weekly_streak_current,
|
# weekly_streak_current=dcs.weekly_streak_current,
|
||||||
)
|
# )
|
||||||
|
|
||||||
# 转换最高排名
|
# 转换最高排名
|
||||||
rank_highest = None
|
rank_highest = None
|
||||||
|
|||||||
@@ -5,419 +5,82 @@ osu! API 模拟服务器的示例数据填充脚本
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import time
|
|
||||||
|
|
||||||
from app.auth import get_password_hash
|
from app.auth import get_password_hash
|
||||||
from app.database import (
|
from app.database import (
|
||||||
DailyChallengeStats,
|
|
||||||
LazerUserAchievement,
|
|
||||||
LazerUserStatistics,
|
|
||||||
RankHistory,
|
|
||||||
User,
|
User,
|
||||||
)
|
)
|
||||||
from app.dependencies.database import engine, get_db
|
from app.dependencies.database import create_tables, engine
|
||||||
|
|
||||||
from sqlmodel import SQLModel
|
from sqlmodel import select
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
# 创建所有表
|
|
||||||
SQLModel.metadata.create_all(bind=engine)
|
|
||||||
|
|
||||||
|
|
||||||
def create_sample_user():
|
async def create_sample_user():
|
||||||
"""创建示例用户数据"""
|
"""创建示例用户数据"""
|
||||||
with next(get_db()) as db:
|
async with AsyncSession(engine) as session:
|
||||||
# 检查用户是否已存在
|
async with session.begin():
|
||||||
from sqlmodel import select
|
# 检查用户是否已存在
|
||||||
|
statement = select(User).where(User.name == "Googujiang")
|
||||||
|
result = await session.execute(statement)
|
||||||
|
existing_user = result.scalars().first()
|
||||||
|
if existing_user:
|
||||||
|
print("示例用户已存在,跳过创建")
|
||||||
|
return existing_user
|
||||||
|
|
||||||
statement = select(User).where(User.name == "Googujiang")
|
# 当前时间戳
|
||||||
existing_user = db.exec(statement).first()
|
# current_timestamp = int(time.time())
|
||||||
if existing_user:
|
join_timestamp = int(datetime(2019, 11, 29, 17, 23, 13).timestamp())
|
||||||
print("示例用户已存在,跳过创建")
|
last_visit_timestamp = int(datetime(2025, 7, 18, 16, 31, 29).timestamp())
|
||||||
return existing_user
|
|
||||||
|
|
||||||
# 当前时间戳
|
# 创建用户
|
||||||
current_timestamp = int(time.time())
|
user = User(
|
||||||
join_timestamp = int(datetime(2019, 11, 29, 17, 23, 13).timestamp())
|
name="Googujiang",
|
||||||
last_visit_timestamp = int(datetime(2025, 7, 18, 16, 31, 29).timestamp())
|
safe_name="googujiang", # 安全用户名(小写)
|
||||||
|
email="googujiang@example.com",
|
||||||
|
priv=1, # 默认权限
|
||||||
|
pw_bcrypt=get_password_hash("password123"), # 使用新的哈希方式
|
||||||
|
country="JP",
|
||||||
|
silence_end=0,
|
||||||
|
donor_end=0,
|
||||||
|
creation_time=join_timestamp,
|
||||||
|
latest_activity=last_visit_timestamp,
|
||||||
|
clan_id=0,
|
||||||
|
clan_priv=0,
|
||||||
|
preferred_mode=0, # 0 = osu!
|
||||||
|
play_style=0,
|
||||||
|
custom_badge_name=None,
|
||||||
|
custom_badge_icon=None,
|
||||||
|
userpage_content="「世界に忘れられた」",
|
||||||
|
api_key=None,
|
||||||
|
)
|
||||||
|
|
||||||
# 创建用户
|
session.add(user)
|
||||||
user = User(
|
await session.commit()
|
||||||
name="Googujiang",
|
await session.refresh(user)
|
||||||
safe_name="googujiang", # 安全用户名(小写)
|
|
||||||
email="googujiang@example.com",
|
|
||||||
priv=1, # 默认权限
|
|
||||||
pw_bcrypt=get_password_hash("password123"), # 使用新的哈希方式
|
|
||||||
country="JP",
|
|
||||||
silence_end=0,
|
|
||||||
donor_end=0,
|
|
||||||
creation_time=join_timestamp,
|
|
||||||
latest_activity=last_visit_timestamp,
|
|
||||||
clan_id=0,
|
|
||||||
clan_priv=0,
|
|
||||||
preferred_mode=0, # 0 = osu!
|
|
||||||
play_style=0,
|
|
||||||
custom_badge_name=None,
|
|
||||||
custom_badge_icon=None,
|
|
||||||
userpage_content="「世界に忘れられた」",
|
|
||||||
api_key=None,
|
|
||||||
# # 兼容性字段
|
|
||||||
# avatar_url="https://a.ppy.sh/15651670?1732362658.jpeg",
|
|
||||||
# cover_url="https://assets.ppy.sh/user-profile-covers/15651670/0fc7b77adef39765a570e7f535bc383e5a848850d41a8943f8857984330b8bc6.jpeg",
|
|
||||||
# has_supported=True,
|
|
||||||
# interests="「世界に忘れられた」",
|
|
||||||
# location="咕谷国",
|
|
||||||
# website="https://gmoe.cc",
|
|
||||||
# playstyle=["mouse", "keyboard", "tablet"],
|
|
||||||
# profile_order=[
|
|
||||||
# "me",
|
|
||||||
# "recent_activity",
|
|
||||||
# "top_ranks",
|
|
||||||
# "medals",
|
|
||||||
# "historical",
|
|
||||||
# "beatmaps",
|
|
||||||
# "kudosu",
|
|
||||||
# ],
|
|
||||||
# beatmap_playcounts_count=3306,
|
|
||||||
# favourite_beatmapset_count=15,
|
|
||||||
# follower_count=98,
|
|
||||||
# graveyard_beatmapset_count=7,
|
|
||||||
# mapping_follower_count=1,
|
|
||||||
# previous_usernames=["hehejun"],
|
|
||||||
# monthly_playcounts=[
|
|
||||||
# {"start_date": "2019-11-01", "count": 43},
|
|
||||||
# {"start_date": "2020-04-01", "count": 216},
|
|
||||||
# {"start_date": "2020-05-01", "count": 656},
|
|
||||||
# {"start_date": "2020-07-01", "count": 158},
|
|
||||||
# {"start_date": "2020-08-01", "count": 174},
|
|
||||||
# {"start_date": "2020-10-01", "count": 13},
|
|
||||||
# {"start_date": "2020-11-01", "count": 52},
|
|
||||||
# {"start_date": "2020-12-01", "count": 140},
|
|
||||||
# {"start_date": "2021-01-01", "count": 359},
|
|
||||||
# {"start_date": "2021-02-01", "count": 452},
|
|
||||||
# {"start_date": "2021-03-01", "count": 77},
|
|
||||||
# {"start_date": "2021-04-01", "count": 114},
|
|
||||||
# {"start_date": "2021-05-01", "count": 270},
|
|
||||||
# {"start_date": "2021-06-01", "count": 148},
|
|
||||||
# {"start_date": "2021-07-01", "count": 246},
|
|
||||||
# {"start_date": "2021-08-01", "count": 56},
|
|
||||||
# {"start_date": "2021-09-01", "count": 136},
|
|
||||||
# {"start_date": "2021-10-01", "count": 45},
|
|
||||||
# {"start_date": "2021-11-01", "count": 98},
|
|
||||||
# {"start_date": "2021-12-01", "count": 54},
|
|
||||||
# {"start_date": "2022-01-01", "count": 88},
|
|
||||||
# {"start_date": "2022-02-01", "count": 45},
|
|
||||||
# {"start_date": "2022-03-01", "count": 6},
|
|
||||||
# {"start_date": "2022-04-01", "count": 54},
|
|
||||||
# {"start_date": "2022-05-01", "count": 105},
|
|
||||||
# {"start_date": "2022-06-01", "count": 37},
|
|
||||||
# {"start_date": "2022-07-01", "count": 88},
|
|
||||||
# {"start_date": "2022-08-01", "count": 7},
|
|
||||||
# {"start_date": "2022-09-01", "count": 9},
|
|
||||||
# {"start_date": "2022-10-01", "count": 6},
|
|
||||||
# {"start_date": "2022-11-01", "count": 2},
|
|
||||||
# {"start_date": "2022-12-01", "count": 16},
|
|
||||||
# {"start_date": "2023-01-01", "count": 7},
|
|
||||||
# {"start_date": "2023-04-01", "count": 16},
|
|
||||||
# {"start_date": "2023-05-01", "count": 3},
|
|
||||||
# {"start_date": "2023-06-01", "count": 8},
|
|
||||||
# {"start_date": "2023-07-01", "count": 23},
|
|
||||||
# {"start_date": "2023-08-01", "count": 3},
|
|
||||||
# {"start_date": "2023-09-01", "count": 1},
|
|
||||||
# {"start_date": "2023-10-01", "count": 25},
|
|
||||||
# {"start_date": "2023-11-01", "count": 160},
|
|
||||||
# {"start_date": "2023-12-01", "count": 306},
|
|
||||||
# {"start_date": "2024-01-01", "count": 735},
|
|
||||||
# {"start_date": "2024-02-01", "count": 420},
|
|
||||||
# {"start_date": "2024-03-01", "count": 549},
|
|
||||||
# {"start_date": "2024-04-01", "count": 466},
|
|
||||||
# {"start_date": "2024-05-01", "count": 333},
|
|
||||||
# {"start_date": "2024-06-01", "count": 1126},
|
|
||||||
# {"start_date": "2024-07-01", "count": 534},
|
|
||||||
# {"start_date": "2024-08-01", "count": 280},
|
|
||||||
# {"start_date": "2024-09-01", "count": 116},
|
|
||||||
# {"start_date": "2024-10-01", "count": 120},
|
|
||||||
# {"start_date": "2024-11-01", "count": 332},
|
|
||||||
# {"start_date": "2024-12-01", "count": 243},
|
|
||||||
# {"start_date": "2025-01-01", "count": 122},
|
|
||||||
# {"start_date": "2025-02-01", "count": 379},
|
|
||||||
# {"start_date": "2025-03-01", "count": 278},
|
|
||||||
# {"start_date": "2025-04-01", "count": 296},
|
|
||||||
# {"start_date": "2025-05-01", "count": 964},
|
|
||||||
# {"start_date": "2025-06-01", "count": 821},
|
|
||||||
# {"start_date": "2025-07-01", "count": 230},
|
|
||||||
# ],
|
|
||||||
)
|
|
||||||
|
|
||||||
db.add(user)
|
# 确保用户ID存在
|
||||||
db.commit()
|
if user.id is None:
|
||||||
db.refresh(user)
|
raise ValueError("User ID is None after saving to database")
|
||||||
|
|
||||||
# 确保用户ID存在
|
print(f"成功创建示例用户: {user.name} (ID: {user.id})")
|
||||||
if user.id is None:
|
print(f"安全用户名: {user.safe_name}")
|
||||||
raise ValueError("User ID is None after saving to database")
|
print(f"邮箱: {user.email}")
|
||||||
|
print(f"国家: {user.country}")
|
||||||
# 创建 osu! 模式统计
|
return user
|
||||||
osu_stats = LazerUserStatistics(
|
|
||||||
user_id=user.id,
|
|
||||||
mode="osu",
|
|
||||||
count_100=276274,
|
|
||||||
count_300=1932068,
|
|
||||||
count_50=32776,
|
|
||||||
count_miss=111064,
|
|
||||||
level_current=97,
|
|
||||||
level_progress=96,
|
|
||||||
global_rank=298026,
|
|
||||||
country_rank=11221,
|
|
||||||
pp=2826.26,
|
|
||||||
ranked_score=4415081049,
|
|
||||||
hit_accuracy=95.7168,
|
|
||||||
play_count=12711,
|
|
||||||
play_time=836529,
|
|
||||||
total_score=12390140370,
|
|
||||||
total_hits=2241118,
|
|
||||||
maximum_combo=1859,
|
|
||||||
replays_watched_by_others=0,
|
|
||||||
is_ranked=True,
|
|
||||||
grade_ss=14,
|
|
||||||
grade_ssh=3,
|
|
||||||
grade_s=322,
|
|
||||||
grade_sh=11,
|
|
||||||
grade_a=757,
|
|
||||||
rank_highest=295701,
|
|
||||||
rank_highest_updated_at=datetime(2025, 7, 2, 17, 30, 21),
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建 taiko 模式统计
|
|
||||||
taiko_stats = LazerUserStatistics(
|
|
||||||
user_id=user.id,
|
|
||||||
mode="taiko",
|
|
||||||
count_100=160,
|
|
||||||
count_300=154,
|
|
||||||
count_50=0,
|
|
||||||
count_miss=480,
|
|
||||||
level_current=2,
|
|
||||||
level_progress=49,
|
|
||||||
global_rank=None,
|
|
||||||
pp=0,
|
|
||||||
ranked_score=0,
|
|
||||||
hit_accuracy=0,
|
|
||||||
play_count=6,
|
|
||||||
play_time=217,
|
|
||||||
total_score=79301,
|
|
||||||
total_hits=314,
|
|
||||||
maximum_combo=0,
|
|
||||||
replays_watched_by_others=0,
|
|
||||||
is_ranked=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建 fruits 模式统计
|
|
||||||
fruits_stats = LazerUserStatistics(
|
|
||||||
user_id=user.id,
|
|
||||||
mode="fruits",
|
|
||||||
count_100=109,
|
|
||||||
count_300=1613,
|
|
||||||
count_50=1861,
|
|
||||||
count_miss=328,
|
|
||||||
level_current=6,
|
|
||||||
level_progress=14,
|
|
||||||
global_rank=None,
|
|
||||||
pp=0,
|
|
||||||
ranked_score=343854,
|
|
||||||
hit_accuracy=89.4779,
|
|
||||||
play_count=19,
|
|
||||||
play_time=669,
|
|
||||||
total_score=1362651,
|
|
||||||
total_hits=3583,
|
|
||||||
maximum_combo=75,
|
|
||||||
replays_watched_by_others=0,
|
|
||||||
is_ranked=False,
|
|
||||||
grade_a=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建 mania 模式统计
|
|
||||||
mania_stats = LazerUserStatistics(
|
|
||||||
user_id=user.id,
|
|
||||||
mode="mania",
|
|
||||||
count_100=7867,
|
|
||||||
count_300=12104,
|
|
||||||
count_50=991,
|
|
||||||
count_miss=2951,
|
|
||||||
level_current=12,
|
|
||||||
level_progress=89,
|
|
||||||
global_rank=660670,
|
|
||||||
pp=25.3784,
|
|
||||||
ranked_score=3812295,
|
|
||||||
hit_accuracy=77.9316,
|
|
||||||
play_count=85,
|
|
||||||
play_time=4834,
|
|
||||||
total_score=13454470,
|
|
||||||
total_hits=20962,
|
|
||||||
maximum_combo=573,
|
|
||||||
replays_watched_by_others=0,
|
|
||||||
is_ranked=True,
|
|
||||||
grade_a=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
db.add_all([osu_stats, taiko_stats, fruits_stats, mania_stats])
|
|
||||||
|
|
||||||
# 创建每日挑战统计
|
|
||||||
daily_challenge = DailyChallengeStats(
|
|
||||||
user_id=user.id,
|
|
||||||
daily_streak_best=1,
|
|
||||||
daily_streak_current=0,
|
|
||||||
last_update=datetime(2025, 6, 21, 0, 0, 0),
|
|
||||||
last_weekly_streak=datetime(2025, 6, 19, 0, 0, 0),
|
|
||||||
playcount=1,
|
|
||||||
top_10p_placements=0,
|
|
||||||
top_50p_placements=0,
|
|
||||||
weekly_streak_best=1,
|
|
||||||
weekly_streak_current=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
db.add(daily_challenge)
|
|
||||||
|
|
||||||
# 创建排名历史 (最近90天的数据)
|
|
||||||
rank_data = [
|
|
||||||
322806,
|
|
||||||
323092,
|
|
||||||
323341,
|
|
||||||
323616,
|
|
||||||
323853,
|
|
||||||
324106,
|
|
||||||
324378,
|
|
||||||
324676,
|
|
||||||
324958,
|
|
||||||
325254,
|
|
||||||
325492,
|
|
||||||
325780,
|
|
||||||
326075,
|
|
||||||
326356,
|
|
||||||
326586,
|
|
||||||
326845,
|
|
||||||
327067,
|
|
||||||
327286,
|
|
||||||
327526,
|
|
||||||
327778,
|
|
||||||
328039,
|
|
||||||
328347,
|
|
||||||
328631,
|
|
||||||
328858,
|
|
||||||
329323,
|
|
||||||
329557,
|
|
||||||
329809,
|
|
||||||
329911,
|
|
||||||
330188,
|
|
||||||
330425,
|
|
||||||
330650,
|
|
||||||
330881,
|
|
||||||
331068,
|
|
||||||
331325,
|
|
||||||
331575,
|
|
||||||
331816,
|
|
||||||
332061,
|
|
||||||
328959,
|
|
||||||
315648,
|
|
||||||
315881,
|
|
||||||
308784,
|
|
||||||
309023,
|
|
||||||
309252,
|
|
||||||
309433,
|
|
||||||
309537,
|
|
||||||
309364,
|
|
||||||
309548,
|
|
||||||
308957,
|
|
||||||
309182,
|
|
||||||
309426,
|
|
||||||
309607,
|
|
||||||
309831,
|
|
||||||
310054,
|
|
||||||
310269,
|
|
||||||
310485,
|
|
||||||
310714,
|
|
||||||
310956,
|
|
||||||
310924,
|
|
||||||
311125,
|
|
||||||
311203,
|
|
||||||
311422,
|
|
||||||
311640,
|
|
||||||
303091,
|
|
||||||
303309,
|
|
||||||
303500,
|
|
||||||
303691,
|
|
||||||
303758,
|
|
||||||
303750,
|
|
||||||
303957,
|
|
||||||
299867,
|
|
||||||
300088,
|
|
||||||
300273,
|
|
||||||
300457,
|
|
||||||
295799,
|
|
||||||
295976,
|
|
||||||
296153,
|
|
||||||
296350,
|
|
||||||
296566,
|
|
||||||
296756,
|
|
||||||
296933,
|
|
||||||
297141,
|
|
||||||
297314,
|
|
||||||
297480,
|
|
||||||
297114,
|
|
||||||
297296,
|
|
||||||
297480,
|
|
||||||
297645,
|
|
||||||
297815,
|
|
||||||
297993,
|
|
||||||
298026,
|
|
||||||
]
|
|
||||||
|
|
||||||
rank_history = RankHistory(user_id=user.id, mode="osu", rank_data=rank_data)
|
|
||||||
|
|
||||||
db.add(rank_history)
|
|
||||||
|
|
||||||
# 创建一些成就
|
|
||||||
achievements = [
|
|
||||||
LazerUserAchievement(
|
|
||||||
user_id=user.id,
|
|
||||||
achievement_id=336,
|
|
||||||
achieved_at=datetime(2025, 6, 21, 19, 6, 32),
|
|
||||||
),
|
|
||||||
LazerUserAchievement(
|
|
||||||
user_id=user.id,
|
|
||||||
achievement_id=319,
|
|
||||||
achieved_at=datetime(2025, 6, 1, 0, 52, 0),
|
|
||||||
),
|
|
||||||
LazerUserAchievement(
|
|
||||||
user_id=user.id,
|
|
||||||
achievement_id=222,
|
|
||||||
achieved_at=datetime(2025, 5, 28, 12, 24, 37),
|
|
||||||
),
|
|
||||||
LazerUserAchievement(
|
|
||||||
user_id=user.id,
|
|
||||||
achievement_id=38,
|
|
||||||
achieved_at=datetime(2024, 7, 5, 15, 43, 23),
|
|
||||||
),
|
|
||||||
LazerUserAchievement(
|
|
||||||
user_id=user.id,
|
|
||||||
achievement_id=67,
|
|
||||||
achieved_at=datetime(2024, 6, 24, 5, 6, 44),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
db.add_all(achievements)
|
|
||||||
|
|
||||||
db.commit()
|
|
||||||
print(f"成功创建示例用户: {user.name} (ID: {user.id})")
|
|
||||||
print(f"安全用户名: {user.safe_name}")
|
|
||||||
print(f"邮箱: {user.email}")
|
|
||||||
print(f"国家: {user.country}")
|
|
||||||
return user
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
async def main():
|
||||||
print("开始创建示例数据...")
|
print("开始创建示例数据...")
|
||||||
user = create_sample_user()
|
await create_tables()
|
||||||
|
user = await create_sample_user()
|
||||||
print("示例数据创建完成!")
|
print("示例数据创建完成!")
|
||||||
print(f"用户名: {user.name}")
|
print(f"用户名: {user.name}")
|
||||||
print("密码: password123")
|
print("密码: password123")
|
||||||
print("现在您可以使用这些凭据来测试API了。")
|
print("现在您可以使用这些凭据来测试API了。")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
|
|||||||
17
main.py
17
main.py
@@ -1,24 +1,31 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.dependencies.database import engine
|
from app.dependencies.database import create_tables
|
||||||
from app.router import api_router, auth_router, signalr_router
|
from app.router import api_router, auth_router, signalr_router
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from sqlmodel import SQLModel
|
|
||||||
|
|
||||||
# 注意: 表结构现在通过 migrations 管理,不再自动创建
|
# 注意: 表结构现在通过 migrations 管理,不再自动创建
|
||||||
# 如需创建表,请运行: python quick_sync.py
|
# 如需创建表,请运行: python quick_sync.py
|
||||||
|
|
||||||
app = FastAPI(title="osu! API 模拟服务器", version="1.0.0")
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
# on startup
|
||||||
|
await create_tables()
|
||||||
|
# on shutdown
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(title="osu! API 模拟服务器", version="1.0.0", lifespan=lifespan)
|
||||||
app.include_router(api_router, prefix="/api/v2")
|
app.include_router(api_router, prefix="/api/v2")
|
||||||
app.include_router(signalr_router, prefix="/signalr")
|
app.include_router(signalr_router, prefix="/signalr")
|
||||||
app.include_router(auth_router)
|
app.include_router(auth_router)
|
||||||
|
|
||||||
SQLModel.metadata.create_all(bind=engine)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
async def root():
|
async def root():
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ description = "Add your description here"
|
|||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.11"
|
requires-python = ">=3.11"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"aiomysql>=0.2.0",
|
||||||
"alembic>=1.12.1",
|
"alembic>=1.12.1",
|
||||||
"bcrypt>=4.1.2",
|
"bcrypt>=4.1.2",
|
||||||
"cryptography>=41.0.7",
|
"cryptography>=41.0.7",
|
||||||
@@ -12,7 +13,6 @@ dependencies = [
|
|||||||
"msgpack>=1.1.1",
|
"msgpack>=1.1.1",
|
||||||
"passlib[bcrypt]>=1.7.4",
|
"passlib[bcrypt]>=1.7.4",
|
||||||
"pydantic[email]>=2.5.0",
|
"pydantic[email]>=2.5.0",
|
||||||
"pymysql>=1.1.0",
|
|
||||||
"python-dotenv>=1.0.0",
|
"python-dotenv>=1.0.0",
|
||||||
"python-jose[cryptography]>=3.3.0",
|
"python-jose[cryptography]>=3.3.0",
|
||||||
"python-multipart>=0.0.6",
|
"python-multipart>=0.0.6",
|
||||||
@@ -85,5 +85,6 @@ reportIncompatibleVariableOverride = false
|
|||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
dev = [
|
dev = [
|
||||||
"msgpack-types>=0.5.0",
|
"msgpack-types>=0.5.0",
|
||||||
|
"pre-commit>=4.2.0",
|
||||||
"ruff>=0.12.4",
|
"ruff>=0.12.4",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ python-multipart==0.0.6
|
|||||||
pydantic[email]~=2.11.7
|
pydantic[email]~=2.11.7
|
||||||
python-dotenv~=1.1.1
|
python-dotenv~=1.1.1
|
||||||
bcrypt~=4.3.0
|
bcrypt~=4.3.0
|
||||||
|
|
||||||
msgpack~=1.1.1
|
msgpack~=1.1.1
|
||||||
sqlmodel~=0.0.24
|
sqlmodel~=0.0.24
|
||||||
starlette~=0.47.2
|
starlette~=0.47.2
|
||||||
|
aiomysql==0.2.0
|
||||||
|
|||||||
140
test_lazer.py
140
test_lazer.py
@@ -12,108 +12,106 @@ import sys
|
|||||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
|
||||||
from app.database import User
|
from app.database import User
|
||||||
from app.dependencies.database import get_db
|
from app.dependencies.database import engine
|
||||||
from app.utils import convert_db_user_to_api_user
|
from app.utils import convert_db_user_to_api_user
|
||||||
|
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
|
||||||
def test_lazer_tables():
|
async def test_lazer_tables():
|
||||||
"""测试 lazer 表的基本功能"""
|
"""测试 lazer 表的基本功能"""
|
||||||
print("测试 Lazer API 表支持...")
|
print("测试 Lazer API 表支持...")
|
||||||
|
|
||||||
# 获取数据库会话
|
async with AsyncSession(engine) as session:
|
||||||
db_gen = get_db()
|
async with session.begin():
|
||||||
db = next(db_gen)
|
try:
|
||||||
|
# 测试查询用户
|
||||||
|
statement = select(User)
|
||||||
|
result = await session.execute(statement)
|
||||||
|
user = result.scalars().first()
|
||||||
|
if not user:
|
||||||
|
print("❌ 没有找到用户,请先同步数据")
|
||||||
|
return False
|
||||||
|
|
||||||
try:
|
print(f"✓ 找到用户: {user.name} (ID: {user.id})")
|
||||||
# 测试查询用户
|
|
||||||
statement = select(User)
|
|
||||||
user = db.exec(statement).first()
|
|
||||||
if not user:
|
|
||||||
print("❌ 没有找到用户,请先同步数据")
|
|
||||||
return False
|
|
||||||
|
|
||||||
print(f"✓ 找到用户: {user.name} (ID: {user.id})")
|
# 测试 lazer 资料
|
||||||
|
if user.lazer_profile:
|
||||||
|
print(
|
||||||
|
f"✓ 用户有 lazer 资料: 支持者={user.lazer_profile.is_supporter}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("⚠ 用户没有 lazer 资料,将使用默认值")
|
||||||
|
|
||||||
# 测试 lazer 资料
|
# 测试 lazer 统计
|
||||||
if user.lazer_profile:
|
osu_stats = None
|
||||||
print(f"✓ 用户有 lazer 资料: 支持者={user.lazer_profile.is_supporter}")
|
for stat in user.lazer_statistics:
|
||||||
else:
|
if stat.mode == "osu":
|
||||||
print("⚠ 用户没有 lazer 资料,将使用默认值")
|
osu_stats = stat
|
||||||
|
break
|
||||||
|
|
||||||
# 测试 lazer 统计
|
if osu_stats:
|
||||||
osu_stats = None
|
print(
|
||||||
for stat in user.lazer_statistics:
|
f"✓ 用户有 osu! 统计: PP={osu_stats.pp}, "
|
||||||
if stat.mode == "osu":
|
f"游戏次数={osu_stats.play_count}"
|
||||||
osu_stats = stat
|
)
|
||||||
break
|
else:
|
||||||
|
print("⚠ 用户没有 osu! 统计,将使用默认值")
|
||||||
|
|
||||||
if osu_stats:
|
# 测试转换为 API 格式
|
||||||
print(
|
api_user = convert_db_user_to_api_user(user, "osu")
|
||||||
f"✓ 用户有 osu! 统计: PP={osu_stats.pp}, "
|
print("✓ 成功转换为 API 用户格式")
|
||||||
f"游戏次数={osu_stats.play_count}"
|
print(f" - 用户名: {api_user.username}")
|
||||||
)
|
print(f" - 国家: {api_user.country_code}")
|
||||||
else:
|
print(f" - PP: {api_user.statistics.pp}")
|
||||||
print("⚠ 用户没有 osu! 统计,将使用默认值")
|
print(f" - 是否支持者: {api_user.is_supporter}")
|
||||||
|
|
||||||
# 测试转换为 API 格式
|
return True
|
||||||
api_user = convert_db_user_to_api_user(user, "osu", db)
|
|
||||||
print("✓ 成功转换为 API 用户格式")
|
|
||||||
print(f" - 用户名: {api_user.username}")
|
|
||||||
print(f" - 国家: {api_user.country_code}")
|
|
||||||
print(f" - PP: {api_user.statistics.pp}")
|
|
||||||
print(f" - 是否支持者: {api_user.is_supporter}")
|
|
||||||
|
|
||||||
return True
|
except Exception as e:
|
||||||
|
print(f"❌ 测试失败: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
except Exception as e:
|
traceback.print_exc()
|
||||||
print(f"❌ 测试失败: {e}")
|
return False
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
return False
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
|
|
||||||
def test_authentication():
|
async def test_authentication():
|
||||||
"""测试认证功能"""
|
"""测试认证功能"""
|
||||||
print("\n测试认证功能...")
|
print("\n测试认证功能...")
|
||||||
|
|
||||||
db_gen = get_db()
|
async with AsyncSession(engine) as session:
|
||||||
db = next(db_gen)
|
async with session.begin():
|
||||||
|
try:
|
||||||
|
# 尝试认证第一个用户
|
||||||
|
statement = select(User)
|
||||||
|
result = await session.execute(statement)
|
||||||
|
user = result.scalars().first()
|
||||||
|
if not user:
|
||||||
|
print("❌ 没有用户进行认证测试")
|
||||||
|
return False
|
||||||
|
|
||||||
try:
|
print(f"✓ 测试用户: {user.name}")
|
||||||
# 尝试认证第一个用户
|
print("⚠ 注意: 实际密码认证需要正确的密码")
|
||||||
statement = select(User)
|
|
||||||
user = db.exec(statement).first()
|
|
||||||
if not user:
|
|
||||||
print("❌ 没有用户进行认证测试")
|
|
||||||
return False
|
|
||||||
|
|
||||||
print(f"✓ 测试用户: {user.name}")
|
return True
|
||||||
print("⚠ 注意: 实际密码认证需要正确的密码")
|
|
||||||
|
|
||||||
return True
|
except Exception as e:
|
||||||
|
print(f"❌ 认证测试失败: {e}")
|
||||||
except Exception as e:
|
return False
|
||||||
print(f"❌ 认证测试失败: {e}")
|
|
||||||
return False
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
async def main():
|
||||||
"""主测试函数"""
|
"""主测试函数"""
|
||||||
print("Lazer API 系统测试")
|
print("Lazer API 系统测试")
|
||||||
print("=" * 40)
|
print("=" * 40)
|
||||||
|
|
||||||
# 测试表连接
|
# 测试表连接
|
||||||
success1 = test_lazer_tables()
|
success1 = await test_lazer_tables()
|
||||||
|
|
||||||
# 测试认证
|
# 测试认证
|
||||||
success2 = test_authentication()
|
success2 = await test_authentication()
|
||||||
|
|
||||||
print("\n" + "=" * 40)
|
print("\n" + "=" * 40)
|
||||||
if success1 and success2:
|
if success1 and success2:
|
||||||
@@ -130,4 +128,6 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
|
|||||||
Reference in New Issue
Block a user