diff --git a/app/dependencies/__init__.py b/app/dependencies/__init__.py new file mode 100644 index 0000000..f36fe81 --- /dev/null +++ b/app/dependencies/__init__.py @@ -0,0 +1,2 @@ +from .database import get_db as get_db +from .user import get_current_user as get_current_user \ No newline at end of file diff --git a/app/dependencies.py b/app/dependencies/database.py similarity index 100% rename from app/dependencies.py rename to app/dependencies/database.py diff --git a/app/dependencies/user.py b/app/dependencies/user.py new file mode 100644 index 0000000..0b288ff --- /dev/null +++ b/app/dependencies/user.py @@ -0,0 +1,32 @@ +from fastapi import Depends, HTTPException +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from sqlalchemy.orm import Session + +from app.auth import get_token_by_access_token + +from .database import get_db +from app.database import ( + User as DBUser, +) + +security = HTTPBearer() + + +async def get_current_user( + credentials: HTTPAuthorizationCredentials = Depends(security), + db: Session = Depends(get_db), +) -> DBUser: + """获取当前认证用户""" + token = credentials.credentials + + # 验证令牌 + token_record = get_token_by_access_token(db, token) + if not token_record: + raise HTTPException(status_code=401, detail="Invalid or expired token") + + # 获取用户 + user = db.query(DBUser).filter(DBUser.id == token_record.user_id).first() + if not user: + raise HTTPException(status_code=404, detail="User not found") + + return user diff --git a/app/router/__init__.py b/app/router/__init__.py new file mode 100644 index 0000000..61e4a6e --- /dev/null +++ b/app/router/__init__.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from . import me # pyright: ignore[reportUnusedImport] # noqa: F401 +from .api_router import router as api_router +from .auth import router as auth_router diff --git a/app/router/api_router.py b/app/router/api_router.py new file mode 100644 index 0000000..e6f2f82 --- /dev/null +++ b/app/router/api_router.py @@ -0,0 +1,4 @@ +from fastapi import APIRouter + + +router = APIRouter() diff --git a/app/router/auth.py b/app/router/auth.py new file mode 100644 index 0000000..afe4cf2 --- /dev/null +++ b/app/router/auth.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +from datetime import timedelta + +from app.auth import ( + authenticate_user, + create_access_token, + generate_refresh_token, + get_token_by_refresh_token, + store_token, +) +from app.config import settings +from app.dependencies import get_db +from app.models import TokenResponse + +from fastapi import APIRouter, Depends, Form, HTTPException +from sqlalchemy.orm import Session + +router = APIRouter(tags=["osu! OAuth 认证"]) + + +@router.post("/oauth/token", response_model=TokenResponse) +async def oauth_token( + grant_type: str = Form(...), + client_id: str = Form(...), + client_secret: str = Form(...), + scope: str = Form("*"), + username: str | None = Form(None), + password: str | None = Form(None), + refresh_token: str | None = Form(None), + db: Session = Depends(get_db), +): + """OAuth 令牌端点""" + # 验证客户端凭据 + if ( + client_id != settings.OSU_CLIENT_ID + or client_secret != settings.OSU_CLIENT_SECRET + ): + raise HTTPException(status_code=401, detail="Invalid client credentials") + + if grant_type == "password": + # 密码授权流程 + if not username or not password: + raise HTTPException( + status_code=400, detail="Username and password required" + ) + + # 验证用户 + user = authenticate_user(db, username, password) + if not user: + raise HTTPException(status_code=401, detail="Invalid username or password") + + # 生成令牌 + access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) + access_token = create_access_token( + data={"sub": str(user.id)}, expires_delta=access_token_expires + ) + refresh_token_str = generate_refresh_token() + + # 存储令牌 + store_token( + db, + user.id, + access_token, + refresh_token_str, + settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60, + ) + + return TokenResponse( + access_token=access_token, + token_type="Bearer", + expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60, + refresh_token=refresh_token_str, + scope=scope, + ) + + elif grant_type == "refresh_token": + # 刷新令牌流程 + if not refresh_token: + raise HTTPException(status_code=400, detail="Refresh token required") + + # 验证刷新令牌 + token_record = get_token_by_refresh_token(db, refresh_token) + if not token_record: + raise HTTPException(status_code=401, detail="Invalid refresh token") + + # 生成新的访问令牌 + access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) + access_token = create_access_token( + data={"sub": str(token_record.user_id)}, expires_delta=access_token_expires + ) + new_refresh_token = generate_refresh_token() + + # 更新令牌 + store_token( + db, + token_record.user_id, + access_token, + new_refresh_token, + settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60, + ) + + return TokenResponse( + access_token=access_token, + token_type="Bearer", + expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60, + refresh_token=new_refresh_token, + scope=scope, + ) + + else: + raise HTTPException(status_code=400, detail="Unsupported grant type") diff --git a/app/router/me.py b/app/router/me.py new file mode 100644 index 0000000..ca142d2 --- /dev/null +++ b/app/router/me.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from typing import Literal + +from app.database import ( + User as DBUser, +) +from app.dependencies import get_current_user, get_db +from app.models import ( + User as ApiUser, +) +from app.utils import convert_db_user_to_api_user + +from .api_router import router + +from fastapi import Depends +from sqlalchemy.orm import Session + + +@router.get("/me/{ruleset}", response_model=ApiUser) +async def get_user_info_default( + ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu", + current_user: DBUser = Depends(get_current_user), + db: Session = Depends(get_db), +): + """获取当前用户信息(默认使用osu模式)""" + # 默认使用osu模式 + api_user = convert_db_user_to_api_user(current_user, ruleset, db) + return api_user diff --git a/main.py b/main.py index ce95a4e..0d085b4 100644 --- a/main.py +++ b/main.py @@ -1,182 +1,18 @@ from __future__ import annotations -from datetime import datetime, timedelta -from typing import Optional +from datetime import datetime -from app.auth import ( - authenticate_user, - create_access_token, - generate_refresh_token, - get_token_by_access_token, - get_token_by_refresh_token, - store_token, -) from app.config import settings -from app.database import ( - User as DBUser, -) -from app.dependencies import get_db -from app.models import ( - TokenResponse, - User as ApiUser, -) -from app.utils import convert_db_user_to_api_user +from app.router import api_router, auth_router -from fastapi import Depends, FastAPI, Form, HTTPException -from fastapi.responses import JSONResponse -from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer -from sqlalchemy.orm import Session +from fastapi import FastAPI # 注意: 表结构现在通过 migrations 管理,不再自动创建 # 如需创建表,请运行: python quick_sync.py app = FastAPI(title="osu! API 模拟服务器", version="1.0.0") - -security = HTTPBearer() - - -@app.post("/oauth/token", response_model=TokenResponse) -async def oauth_token( - grant_type: str = Form(...), - client_id: str = Form(...), - client_secret: str = Form(...), - scope: str = Form("*"), - username: Optional[str] = Form(None), - password: Optional[str] = Form(None), - refresh_token: Optional[str] = Form(None), - db: Session = Depends(get_db), -): - """OAuth 令牌端点""" - # 验证客户端凭据 - if ( - client_id != settings.OSU_CLIENT_ID - or client_secret != settings.OSU_CLIENT_SECRET - ): - raise HTTPException(status_code=401, detail="Invalid client credentials") - - if grant_type == "password": - # 密码授权流程 - if not username or not password: - raise HTTPException( - status_code=400, detail="Username and password required" - ) - - # 验证用户 - user = authenticate_user(db, username, password) - if not user: - raise HTTPException(status_code=401, detail="Invalid username or password") - - # 生成令牌 - access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) - access_token = create_access_token( - data={"sub": str(user.id)}, expires_delta=access_token_expires - ) - refresh_token_str = generate_refresh_token() - - # 存储令牌 - store_token( - db, - user.id, - access_token, - refresh_token_str, - settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60, - ) - - return TokenResponse( - access_token=access_token, - token_type="Bearer", - expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60, - refresh_token=refresh_token_str, - scope=scope, - ) - - elif grant_type == "refresh_token": - # 刷新令牌流程 - if not refresh_token: - raise HTTPException(status_code=400, detail="Refresh token required") - - # 验证刷新令牌 - token_record = get_token_by_refresh_token(db, refresh_token) - if not token_record: - raise HTTPException(status_code=401, detail="Invalid refresh token") - - # 生成新的访问令牌 - access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) - access_token = create_access_token( - data={"sub": str(token_record.user_id)}, expires_delta=access_token_expires - ) - new_refresh_token = generate_refresh_token() - - # 更新令牌 - store_token( - db, - token_record.user_id, - access_token, - new_refresh_token, - settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60, - ) - - return TokenResponse( - access_token=access_token, - token_type="Bearer", - expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60, - refresh_token=new_refresh_token, - scope=scope, - ) - - else: - raise HTTPException(status_code=400, detail="Unsupported grant type") - - -async def get_current_user( - credentials: HTTPAuthorizationCredentials = Depends(security), - db: Session = Depends(get_db), -) -> DBUser: - """获取当前认证用户""" - token = credentials.credentials - - # 验证令牌 - token_record = get_token_by_access_token(db, token) - if not token_record: - raise HTTPException(status_code=401, detail="Invalid or expired token") - - # 获取用户 - user = db.query(DBUser).filter(DBUser.id == token_record.user_id).first() - if not user: - raise HTTPException(status_code=404, detail="User not found") - - return user - - -@app.get("/api/v2/me", response_model=ApiUser) -@app.get("/api/v2/me/", response_model=ApiUser) -async def get_user_info_default( - current_user: DBUser = Depends(get_current_user), db: Session = Depends(get_db) -): - """获取当前用户信息(默认使用osu模式)""" - # 默认使用osu模式 - api_user = convert_db_user_to_api_user(current_user, "osu", db) - return api_user - - -@app.get("/api/v2/me/{ruleset}", response_model=ApiUser) -async def get_user_info( - ruleset: str = "osu", - current_user: DBUser = Depends(get_current_user), - db: Session = Depends(get_db), -): - """获取当前用户信息""" - - # 验证游戏模式 - valid_rulesets = ["osu", "taiko", "fruits", "mania"] - if ruleset not in valid_rulesets: - raise HTTPException( - status_code=400, detail=f"Invalid ruleset. Must be one of: {valid_rulesets}" - ) - - # 转换用户数据 - api_user = convert_db_user_to_api_user(current_user, ruleset, db) - return api_user +app.include_router(api_router, prefix="/api/v2") +app.include_router(auth_router) @app.get("/") @@ -191,102 +27,102 @@ async def health_check(): return {"status": "ok", "timestamp": datetime.utcnow().isoformat()} -@app.get("/api/v2/friends") -async def get_friends(): - return JSONResponse( - content=[ - { - "id": 123456, - "username": "BestFriend", - "is_online": True, - "is_supporter": False, - "country": {"code": "US", "name": "United States"}, - } - ] - ) +# @app.get("/api/v2/friends") +# async def get_friends(): +# return JSONResponse( +# content=[ +# { +# "id": 123456, +# "username": "BestFriend", +# "is_online": True, +# "is_supporter": False, +# "country": {"code": "US", "name": "United States"}, +# } +# ] +# ) -@app.get("/api/v2/notifications") -async def get_notifications(): - return JSONResponse(content={"notifications": [], "unread_count": 0}) +# @app.get("/api/v2/notifications") +# async def get_notifications(): +# return JSONResponse(content={"notifications": [], "unread_count": 0}) -@app.post("/api/v2/chat/ack") -async def chat_ack(): - return JSONResponse(content={"status": "ok"}) +# @app.post("/api/v2/chat/ack") +# async def chat_ack(): +# return JSONResponse(content={"status": "ok"}) -@app.get("/api/v2/users/{user_id}/{mode}") -async def get_user_mode(user_id: int, mode: str): - return JSONResponse( - content={ - "id": user_id, - "username": "测试测试测", - "statistics": { - "level": {"current": 97, "progress": 96}, - "pp": 114514, - "global_rank": 666, - "country_rank": 1, - "hit_accuracy": 100, - }, - "country": {"code": "JP", "name": "Japan"}, - } - ) +# @app.get("/api/v2/users/{user_id}/{mode}") +# async def get_user_mode(user_id: int, mode: str): +# return JSONResponse( +# content={ +# "id": user_id, +# "username": "测试测试测", +# "statistics": { +# "level": {"current": 97, "progress": 96}, +# "pp": 114514, +# "global_rank": 666, +# "country_rank": 1, +# "hit_accuracy": 100, +# }, +# "country": {"code": "JP", "name": "Japan"}, +# } +# ) -@app.get("/api/v2/me") -async def get_me(): - return JSONResponse( - content={ - "id": 15651670, - "username": "Googujiang", - "is_online": True, - "country": {"code": "JP", "name": "Japan"}, - "statistics": { - "level": {"current": 97, "progress": 96}, - "pp": 2826.26, - "global_rank": 298026, - "country_rank": 11220, - "hit_accuracy": 95.7168, - }, - } - ) +# @app.get("/api/v2/me") +# async def get_me(): +# return JSONResponse( +# content={ +# "id": 15651670, +# "username": "Googujiang", +# "is_online": True, +# "country": {"code": "JP", "name": "Japan"}, +# "statistics": { +# "level": {"current": 97, "progress": 96}, +# "pp": 2826.26, +# "global_rank": 298026, +# "country_rank": 11220, +# "hit_accuracy": 95.7168, +# }, +# } +# ) -@app.post("/signalr/metadata/negotiate") -async def metadata_negotiate(negotiateVersion: int = 1): - return JSONResponse( - content={ - "connectionId": "abc123", - "availableTransports": [ - {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]} - ], - } - ) +# @app.post("/signalr/metadata/negotiate") +# async def metadata_negotiate(negotiateVersion: int = 1): +# return JSONResponse( +# content={ +# "connectionId": "abc123", +# "availableTransports": [ +# {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]} +# ], +# } +# ) -@app.post("/signalr/spectator/negotiate") -async def spectator_negotiate(negotiateVersion: int = 1): - return JSONResponse( - content={ - "connectionId": "spec456", - "availableTransports": [ - {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]} - ], - } - ) +# @app.post("/signalr/spectator/negotiate") +# async def spectator_negotiate(negotiateVersion: int = 1): +# return JSONResponse( +# content={ +# "connectionId": "spec456", +# "availableTransports": [ +# {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]} +# ], +# } +# ) -@app.post("/signalr/multiplayer/negotiate") -async def multiplayer_negotiate(negotiateVersion: int = 1): - return JSONResponse( - content={ - "connectionId": "multi789", - "availableTransports": [ - {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]} - ], - } - ) +# @app.post("/signalr/multiplayer/negotiate") +# async def multiplayer_negotiate(negotiateVersion: int = 1): +# return JSONResponse( +# content={ +# "connectionId": "multi789", +# "availableTransports": [ +# {"transport": "WebSockets", "transferFormats": ["Text", "Binary"]} +# ], +# } +# ) if __name__ == "__main__":