Files
g0v0-server/main.py
2025-07-23 18:03:30 +08:00

298 lines
8.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
from datetime import datetime, timedelta
from typing import Optional
from app.auth import (
authenticate_user,
create_access_token,
generate_refresh_token,
get_token_by_access_token,
get_token_by_refresh_token,
store_token,
)
from app.config import settings
from app.database import (
User as DBUser,
)
from app.dependencies import get_db
from app.models import (
TokenResponse,
User as ApiUser,
)
from app.utils import convert_db_user_to_api_user
from fastapi import Depends, FastAPI, Form, HTTPException
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlalchemy.orm import Session
# 注意: 表结构现在通过 migrations 管理,不再自动创建
# 如需创建表,请运行: python quick_sync.py
app = FastAPI(title="osu! API 模拟服务器", version="1.0.0")
security = HTTPBearer()
@app.post("/oauth/token", response_model=TokenResponse)
async def oauth_token(
grant_type: str = Form(...),
client_id: str = Form(...),
client_secret: str = Form(...),
scope: str = Form("*"),
username: Optional[str] = Form(None),
password: Optional[str] = Form(None),
refresh_token: Optional[str] = Form(None),
db: Session = Depends(get_db),
):
"""OAuth 令牌端点"""
# 验证客户端凭据
if (
client_id != settings.OSU_CLIENT_ID
or client_secret != settings.OSU_CLIENT_SECRET
):
raise HTTPException(status_code=401, detail="Invalid client credentials")
if grant_type == "password":
# 密码授权流程
if not username or not password:
raise HTTPException(
status_code=400, detail="Username and password required"
)
# 验证用户
user = authenticate_user(db, username, password)
if not user:
raise HTTPException(status_code=401, detail="Invalid username or password")
# 生成令牌
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": str(user.id)}, expires_delta=access_token_expires
)
refresh_token_str = generate_refresh_token()
# 存储令牌
store_token(
db,
user.id,
access_token,
refresh_token_str,
settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
)
return TokenResponse(
access_token=access_token,
token_type="Bearer",
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
refresh_token=refresh_token_str,
scope=scope,
)
elif grant_type == "refresh_token":
# 刷新令牌流程
if not refresh_token:
raise HTTPException(status_code=400, detail="Refresh token required")
# 验证刷新令牌
token_record = get_token_by_refresh_token(db, refresh_token)
if not token_record:
raise HTTPException(status_code=401, detail="Invalid refresh token")
# 生成新的访问令牌
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": str(token_record.user_id)}, expires_delta=access_token_expires
)
new_refresh_token = generate_refresh_token()
# 更新令牌
store_token(
db,
token_record.user_id,
access_token,
new_refresh_token,
settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
)
return TokenResponse(
access_token=access_token,
token_type="Bearer",
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
refresh_token=new_refresh_token,
scope=scope,
)
else:
raise HTTPException(status_code=400, detail="Unsupported grant type")
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: Session = Depends(get_db),
) -> DBUser:
"""获取当前认证用户"""
token = credentials.credentials
# 验证令牌
token_record = get_token_by_access_token(db, token)
if not token_record:
raise HTTPException(status_code=401, detail="Invalid or expired token")
# 获取用户
user = db.query(DBUser).filter(DBUser.id == token_record.user_id).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
return user
@app.get("/api/v2/me", response_model=ApiUser)
@app.get("/api/v2/me/", response_model=ApiUser)
async def get_user_info_default(
current_user: DBUser = Depends(get_current_user), db: Session = Depends(get_db)
):
"""获取当前用户信息默认使用osu模式"""
# 默认使用osu模式
api_user = convert_db_user_to_api_user(current_user, "osu", db)
return api_user
@app.get("/api/v2/me/{ruleset}", response_model=ApiUser)
async def get_user_info(
ruleset: str = "osu",
current_user: DBUser = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取当前用户信息"""
# 验证游戏模式
valid_rulesets = ["osu", "taiko", "fruits", "mania"]
if ruleset not in valid_rulesets:
raise HTTPException(
status_code=400, detail=f"Invalid ruleset. Must be one of: {valid_rulesets}"
)
# 转换用户数据
api_user = convert_db_user_to_api_user(current_user, ruleset, db)
return api_user
@app.get("/")
async def root():
"""根端点"""
return {"message": "osu! API 模拟服务器正在运行"}
@app.get("/health")
async def health_check():
"""健康检查端点"""
return {"status": "ok", "timestamp": datetime.utcnow().isoformat()}
@app.get("/api/v2/friends")
async def get_friends():
return JSONResponse(
content=[
{
"id": 123456,
"username": "BestFriend",
"is_online": True,
"is_supporter": False,
"country": {"code": "US", "name": "United States"},
}
]
)
@app.get("/api/v2/notifications")
async def get_notifications():
return JSONResponse(content={"notifications": [], "unread_count": 0})
@app.post("/api/v2/chat/ack")
async def chat_ack():
return JSONResponse(content={"status": "ok"})
@app.get("/api/v2/users/{user_id}/{mode}")
async def get_user_mode(user_id: int, mode: str):
return JSONResponse(
content={
"id": user_id,
"username": "测试测试测",
"statistics": {
"level": {"current": 97, "progress": 96},
"pp": 114514,
"global_rank": 666,
"country_rank": 1,
"hit_accuracy": 100,
},
"country": {"code": "JP", "name": "Japan"},
}
)
@app.get("/api/v2/me")
async def get_me():
return JSONResponse(
content={
"id": 15651670,
"username": "Googujiang",
"is_online": True,
"country": {"code": "JP", "name": "Japan"},
"statistics": {
"level": {"current": 97, "progress": 96},
"pp": 2826.26,
"global_rank": 298026,
"country_rank": 11220,
"hit_accuracy": 95.7168,
},
}
)
@app.post("/signalr/metadata/negotiate")
async def metadata_negotiate(negotiateVersion: int = 1):
return JSONResponse(
content={
"connectionId": "abc123",
"availableTransports": [
{"transport": "WebSockets", "transferFormats": ["Text", "Binary"]}
],
}
)
@app.post("/signalr/spectator/negotiate")
async def spectator_negotiate(negotiateVersion: int = 1):
return JSONResponse(
content={
"connectionId": "spec456",
"availableTransports": [
{"transport": "WebSockets", "transferFormats": ["Text", "Binary"]}
],
}
)
@app.post("/signalr/multiplayer/negotiate")
async def multiplayer_negotiate(negotiateVersion: int = 1):
return JSONResponse(
content={
"connectionId": "multi789",
"availableTransports": [
{"transport": "WebSockets", "transferFormats": ["Text", "Binary"]}
],
}
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"main:app", host=settings.HOST, port=settings.PORT, reload=settings.DEBUG
)