From 6e7114114646253666385f42502df640ed8ce3bd Mon Sep 17 00:00:00 2001 From: MingxuanGame Date: Mon, 11 Aug 2025 12:33:31 +0000 Subject: [PATCH] feat(developer): support custom OAuth 2.0 client --- .env.example | 10 +- README.md | 2 + app/auth.py | 24 +++- app/config.py | 6 +- app/database/__init__.py | 3 +- app/database/auth.py | 14 ++- app/database/oauth_clients.py | 0 app/dependencies/user.py | 76 ++++++++++-- app/router/auth.py | 112 ++++++++++++++++-- app/router/beatmap.py | 10 +- app/router/beatmapset.py | 10 +- app/router/me.py | 4 +- app/router/relationship.py | 11 +- app/router/room.py | 30 +++-- app/router/score.py | 30 ++--- app/router/user.py | 7 +- app/signalr/router.py | 16 ++- main.py | 15 ++- ...a8669ba11e96_auth_support_custom_client.py | 67 +++++++++++ pyproject.toml | 1 + uv.lock | 14 +++ 21 files changed, 380 insertions(+), 82 deletions(-) delete mode 100644 app/database/oauth_clients.py create mode 100644 migrations/versions/a8669ba11e96_auth_support_custom_client.py diff --git a/.env.example b/.env.example index 82b575b..a713dbb 100644 --- a/.env.example +++ b/.env.example @@ -21,9 +21,11 @@ PORT=8000 # 调试模式,生产环境请设置为 false DEBUG=false -# osu!lazer 登录设置 -OSU_CLIENT_ID="5" -OSU_CLIENT_SECRET="FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk" +# osu! 登录设置 +OSU_CLIENT_ID=5 # lazer client ID +OSU_CLIENT_SECRET="FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk" # lazer client secret +OSU_WEB_CLIENT_ID=6 # 网页端 client ID +OSU_WEB_CLIENT_SECRET="your_osu_web_client_secret_here" # 网页端 client secret,使用 openssl rand -hex 40 生成 # SignalR 服务器设置 SIGNALR_NEGOTIATE_TIMEOUT=30 @@ -32,7 +34,7 @@ SIGNALR_PING_INTERVAL=15 # Fetcher 设置 FETCHER_CLIENT_ID="" FETCHER_CLIENT_SECRET="" -FETCHER_SCOPES=["public"] +FETCHER_SCOPES=public FETCHER_CALLBACK_URL="http://localhost:8000/fetcher/callback" # 日志设置 diff --git a/README.md b/README.md index 01ea6d9..4bf5445 100644 --- a/README.md +++ b/README.md @@ -73,6 +73,8 @@ docker-compose -f docker-compose-osurx.yml up -d |--------|------|--------| | `OSU_CLIENT_ID` | OAuth 客户端 ID | `5` | | `OSU_CLIENT_SECRET` | OAuth 客户端密钥 | `FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk` | +| `OSU_WEB_CLIENT_ID` | Web OAuth 客户端 ID | `6` | +| `OSU_WEB_CLIENT_SECRET` | Web OAuth 客户端密钥 | `your_osu_web_client_secret_here` ### SignalR 服务器设置 | 变量名 | 描述 | 默认值 | diff --git a/app/auth.py b/app/auth.py index ddf5f56..c1a0000 100644 --- a/app/auth.py +++ b/app/auth.py @@ -15,6 +15,7 @@ from app.log import logger import bcrypt from jose import JWTError, jwt from passlib.context import CryptContext +from redis.asyncio import Redis from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -156,6 +157,8 @@ def verify_token(token: str) -> dict | None: async def store_token( db: AsyncSession, user_id: int, + client_id: int, + scopes: list[str], access_token: str, refresh_token: str, expires_in: int, @@ -164,7 +167,9 @@ async def store_token( expires_at = datetime.utcnow() + timedelta(seconds=expires_in) # 删除用户的旧令牌 - statement = select(OAuthToken).where(OAuthToken.user_id == user_id) + statement = select(OAuthToken).where( + OAuthToken.user_id == user_id, OAuthToken.client_id == client_id + ) old_tokens = (await db.exec(statement)).all() for token in old_tokens: await db.delete(token) @@ -179,7 +184,9 @@ async def store_token( # 创建新令牌记录 token_record = OAuthToken( user_id=user_id, + client_id=client_id, access_token=access_token, + scope=",".join(scopes), refresh_token=refresh_token, expires_at=expires_at, ) @@ -209,3 +216,18 @@ async def get_token_by_refresh_token( OAuthToken.expires_at > datetime.utcnow(), ) return (await db.exec(statement)).first() + + +async def get_user_by_authorization_code( + db: AsyncSession, redis: Redis, client_id: int, code: str +) -> tuple[User, list[str]] | None: + user_id = await redis.hget(f"oauth:code:{client_id}:{code}", "user_id") # pyright: ignore[reportGeneralTypeIssues] + scopes = await redis.hget(f"oauth:code:{client_id}:{code}", "scopes") # pyright: ignore[reportGeneralTypeIssues] + if not user_id or not scopes: + return None + + await redis.hdel(f"oauth:code:{client_id}:{code}", "user_id", "scopes") # pyright: ignore[reportGeneralTypeIssues] + + statement = select(User).where(User.id == int(user_id)) + user = (await db.exec(statement)).first() + return (user, scopes.split(",")) if user else None diff --git a/app/config.py b/app/config.py index a5045e3..004971d 100644 --- a/app/config.py +++ b/app/config.py @@ -23,13 +23,15 @@ class Settings(BaseSettings): return f"mysql+aiomysql://{self.mysql_user}:{self.mysql_password}@{self.mysql_host}:{self.mysql_port}/{self.mysql_database}" # JWT 设置 - secret_key: str = Field(default="your-secret-key-here", alias="jwt_secret_key") + secret_key: str = Field(default="your_jwt_secret_here", alias="jwt_secret_key") algorithm: str = "HS256" access_token_expire_minutes: int = 1440 # OAuth 设置 - osu_client_id: str = "5" + osu_client_id: int = 5 osu_client_secret: str = "FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk" + osu_web_client_id: int = 6 + osu_web_client_secret: str = "your_osu_web_client_secret_here" # 服务器设置 host: str = "0.0.0.0" diff --git a/app/database/__init__.py b/app/database/__init__.py index b4167f0..5304b34 100644 --- a/app/database/__init__.py +++ b/app/database/__init__.py @@ -1,5 +1,5 @@ from .achievement import UserAchievement, UserAchievementResp -from .auth import OAuthToken +from .auth import OAuthClient, OAuthToken from .beatmap import ( Beatmap as Beatmap, BeatmapResp as BeatmapResp, @@ -71,6 +71,7 @@ __all__ = [ "MultiplayerEvent", "MultiplayerEventResp", "MultiplayerScores", + "OAuthClient", "OAuthToken", "PPBestScore", "Playlist", diff --git a/app/database/auth.py b/app/database/auth.py index 554dced..cf62afe 100644 --- a/app/database/auth.py +++ b/app/database/auth.py @@ -1,10 +1,11 @@ from datetime import datetime +import secrets from typing import TYPE_CHECKING from app.models.model import UTCBaseModel from sqlalchemy import Column, DateTime -from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel +from sqlmodel import JSON, BigInteger, Field, ForeignKey, Relationship, SQLModel if TYPE_CHECKING: from .lazer_user import User @@ -17,6 +18,7 @@ class OAuthToken(UTCBaseModel, SQLModel, table=True): user_id: int = Field( sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) ) + client_id: int = Field(index=True) access_token: str = Field(max_length=500, unique=True) refresh_token: str = Field(max_length=500, unique=True) token_type: str = Field(default="Bearer", max_length=20) @@ -27,3 +29,13 @@ class OAuthToken(UTCBaseModel, SQLModel, table=True): ) user: "User" = Relationship() + + +class OAuthClient(SQLModel, table=True): + __tablename__ = "oauth_clients" # pyright: ignore[reportAssignmentType] + client_id: int | None = Field(default=None, primary_key=True, index=True) + client_secret: str = Field(default_factory=secrets.token_hex, index=True) + redirect_uris: list[str] = Field(default_factory=list, sa_column=Column(JSON)) + owner_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True) + ) diff --git a/app/database/oauth_clients.py b/app/database/oauth_clients.py deleted file mode 100644 index e69de29..0000000 diff --git a/app/dependencies/user.py b/app/dependencies/user.py index 5537f4f..8ebde87 100644 --- a/app/dependencies/user.py +++ b/app/dependencies/user.py @@ -1,34 +1,84 @@ from __future__ import annotations +from typing import Annotated + from app.auth import get_token_by_access_token +from app.config import settings from app.database import User from .database import get_db from fastapi import Depends, HTTPException -from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from fastapi.security import ( + HTTPBearer, + OAuth2AuthorizationCodeBearer, + OAuth2PasswordBearer, + SecurityScopes, +) from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession security = HTTPBearer() +oauth2_password = OAuth2PasswordBearer( + tokenUrl="oauth/token", + scopes={"*": "Allows access to all scopes."}, +) + +oauth2_code = OAuth2AuthorizationCodeBearer( + authorizationUrl="oauth/authorize", + tokenUrl="oauth/token", + scopes={ + "chat.read": "Allows read chat messages on a user's behalf.", + "chat.write": "Allows sending chat messages on a user's behalf.", + "chat.write_manage": ( + "Allows joining and leaving chat channels on a user's behalf." + ), + "delegate": ( + "Allows acting as the owner of a client; " + "only available for Client Credentials Grant." + ), + "forum.write": "Allows creating and editing forum posts on a user's behalf.", + "friends.read": "Allows reading of the user's friend list.", + "identify": "Allows reading of the public profile of the user (/me).", + "public": "Allows reading of publicly available data on behalf of the user.", + }, +) + + async def get_current_user( - credentials: HTTPAuthorizationCredentials = Depends(security), - db: AsyncSession = Depends(get_db), + security_scopes: SecurityScopes, + db: Annotated[AsyncSession, Depends(get_db)], + token_pw: Annotated[str | None, Depends(oauth2_password)] = None, + token_code: Annotated[str | None, Depends(oauth2_code)] = None, ) -> User: """获取当前认证用户""" - token = credentials.credentials + token = token_pw or token_code + if not token: + raise HTTPException(status_code=401, detail="Not authenticated") - user = await get_current_user_by_token(token, db) + token_record = await get_token_by_access_token(db, token) + if not token_record: + raise HTTPException(status_code=401, detail="Invalid or expired token") + + is_client = token_record.client_id in ( + settings.osu_client_id, + settings.osu_web_client_id, + ) + + if security_scopes.scopes == ["*"]: + # client/web only + if not token_pw or not is_client: + raise HTTPException(status_code=401, detail="Not authenticated") + elif not is_client: + for scope in security_scopes.scopes: + if scope not in token_record.scope.split(","): + raise HTTPException( + status_code=403, detail=f"Insufficient scope: {scope}" + ) + + user = (await db.exec(select(User).where(User.id == token_record.user_id))).first() if not user: raise HTTPException(status_code=401, detail="Invalid or expired token") return user - - -async def get_current_user_by_token(token: str, db: AsyncSession) -> User | None: - token_record = await get_token_by_access_token(db, token) - if not token_record: - return None - user = (await db.exec(select(User).where(User.id == token_record.user_id))).first() - return user diff --git a/app/router/auth.py b/app/router/auth.py index f5015ab..d0c826a 100644 --- a/app/router/auth.py +++ b/app/router/auth.py @@ -2,6 +2,7 @@ from __future__ import annotations from datetime import UTC, datetime, timedelta import re +from typing import Literal from app.auth import ( authenticate_user, @@ -9,12 +10,14 @@ from app.auth import ( generate_refresh_token, get_password_hash, get_token_by_refresh_token, + get_user_by_authorization_code, store_token, ) from app.config import settings -from app.database import DailyChallengeStats, User +from app.database import DailyChallengeStats, OAuthClient, User from app.database.statistics import UserStatistics from app.dependencies import get_db +from app.dependencies.database import get_redis from app.log import logger from app.models.oauth import ( OAuthErrorResponse, @@ -26,6 +29,7 @@ from app.models.score import GameMode from fastapi import APIRouter, Depends, Form from fastapi.responses import JSONResponse +from redis.asyncio import Redis from sqlalchemy import text from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -195,21 +199,36 @@ async def register_user( @router.post("/oauth/token", response_model=TokenResponse) async def oauth_token( - grant_type: str = Form(...), - client_id: str = Form(...), + grant_type: Literal[ + "authorization_code", "refresh_token", "password", "client_credentials" + ] = Form(...), + client_id: int = Form(...), client_secret: str = Form(...), + code: str | None = Form(None), scope: str = Form("*"), username: str | None = Form(None), password: str | None = Form(None), refresh_token: str | None = Form(None), db: AsyncSession = Depends(get_db), + redis: Redis = Depends(get_redis), ): """OAuth 令牌端点""" - # 验证客户端凭据 - if ( - client_id != settings.osu_client_id - or client_secret != settings.osu_client_secret - ): + scopes = scope.split(" ") + + client = ( + await db.exec( + select(OAuthClient).where( + OAuthClient.client_id == client_id, + OAuthClient.client_secret == client_secret, + ) + ) + ).first() + is_game_client = (client_id, client_secret) in [ + (settings.osu_client_id, settings.osu_client_secret), + (settings.osu_web_client_id, settings.osu_web_client_secret), + ] + + if client is None and not is_game_client: return create_oauth_error_response( error="invalid_client", description=( @@ -222,7 +241,6 @@ async def oauth_token( ) if grant_type == "password": - # 密码授权流程 if not username or not password: return create_oauth_error_response( error="invalid_request", @@ -233,6 +251,16 @@ async def oauth_token( ), hint="Username and password required", ) + if scopes != ["*"]: + return create_oauth_error_response( + error="invalid_scope", + description=( + "The requested scope is invalid, unknown, " + "or malformed. The client may not request " + "more than one scope at a time." + ), + hint="Only '*' scope is allowed for password grant type", + ) # 验证用户 user = await authenticate_user(db, username, password) @@ -261,6 +289,8 @@ async def oauth_token( await store_token( db, user.id, + client_id, + scopes, access_token, refresh_token_str, settings.access_token_expire_minutes * 60, @@ -313,6 +343,8 @@ async def oauth_token( await store_token( db, token_record.user_id, + client_id, + scopes, access_token, new_refresh_token, settings.access_token_expire_minutes * 60, @@ -325,7 +357,69 @@ async def oauth_token( refresh_token=new_refresh_token, scope=scope, ) + elif grant_type == "authorization_code": + if client is None: + return create_oauth_error_response( + error="invalid_client", + description=( + "Client authentication failed (e.g., unknown client, " + "no client authentication included, " + "or unsupported authentication method)." + ), + hint="Invalid client credentials", + status_code=401, + ) + if not code: + return create_oauth_error_response( + error="invalid_request", + description=( + "The request is missing a required parameter, " + "includes an invalid parameter value, " + "includes a parameter more than once, or is otherwise malformed." + ), + hint="Authorization code required", + ) + + code_result = await get_user_by_authorization_code(db, redis, client_id, code) + if not code_result: + return create_oauth_error_response( + error="invalid_grant", + description=( + "The provided authorization grant (e.g., authorization code, " + "resource owner credentials) or refresh token is invalid, " + "expired, revoked, does not match the redirection URI used in " + "the authorization request, or was issued to another client." + ), + hint="Invalid authorization code", + ) + user, scopes = code_result + # 生成令牌 + access_token_expires = timedelta(minutes=settings.access_token_expire_minutes) + access_token = create_access_token( + data={"sub": str(user.id)}, expires_delta=access_token_expires + ) + refresh_token_str = generate_refresh_token() + + # 存储令牌 + assert user.id + await store_token( + db, + user.id, + client_id, + scopes, + access_token, + refresh_token_str, + settings.access_token_expire_minutes * 60, + ) + + return TokenResponse( + access_token=access_token, + token_type="Bearer", + expires_in=settings.access_token_expire_minutes * 60, + refresh_token=refresh_token_str, + scope=" ".join(scopes), + ) else: return create_oauth_error_response( error="unsupported_grant_type", diff --git a/app/router/beatmap.py b/app/router/beatmap.py index 591a7ae..123c260 100644 --- a/app/router/beatmap.py +++ b/app/router/beatmap.py @@ -19,7 +19,7 @@ from app.models.score import ( from .api_router import router -from fastapi import Depends, HTTPException, Query +from fastapi import Depends, HTTPException, Query, Security from httpx import HTTPError, HTTPStatusError from pydantic import BaseModel from redis.asyncio import Redis @@ -33,7 +33,7 @@ async def lookup_beatmap( id: int | None = Query(default=None, alias="id"), md5: str | None = Query(default=None, alias="checksum"), filename: str | None = Query(default=None, alias="filename"), - current_user: User = Depends(get_current_user), + current_user: User = Security(get_current_user, scopes=["public"]), db: AsyncSession = Depends(get_db), fetcher: Fetcher = Depends(get_fetcher), ): @@ -56,7 +56,7 @@ async def lookup_beatmap( @router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp) async def get_beatmap( bid: int, - current_user: User = Depends(get_current_user), + current_user: User = Security(get_current_user, scopes=["public"]), db: AsyncSession = Depends(get_db), fetcher: Fetcher = Depends(get_fetcher), ): @@ -75,7 +75,7 @@ class BatchGetResp(BaseModel): @router.get("/beatmaps/", tags=["beatmap"], response_model=BatchGetResp) async def batch_get_beatmaps( b_ids: list[int] = Query(alias="ids[]", default_factory=list), - current_user: User = Depends(get_current_user), + current_user: User = Security(get_current_user, scopes=["public"]), db: AsyncSession = Depends(get_db), fetcher: Fetcher = Depends(get_fetcher), ): @@ -126,7 +126,7 @@ async def batch_get_beatmaps( ) async def get_beatmap_attributes( beatmap: int, - current_user: User = Depends(get_current_user), + current_user: User = Security(get_current_user, scopes=["public"]), mods: list[str] = Query(default_factory=list), ruleset: GameMode | None = Query(default=None), ruleset_id: int | None = Query(default=None), diff --git a/app/router/beatmapset.py b/app/router/beatmapset.py index bebd178..09f1aeb 100644 --- a/app/router/beatmapset.py +++ b/app/router/beatmapset.py @@ -10,7 +10,7 @@ from app.fetcher import Fetcher from .api_router import router -from fastapi import Depends, Form, HTTPException, Query +from fastapi import Depends, Form, HTTPException, Query, Security from fastapi.responses import RedirectResponse from httpx import HTTPError from sqlmodel import select @@ -20,7 +20,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession @router.get("/beatmapsets/lookup", tags=["beatmapset"], response_model=BeatmapsetResp) async def lookup_beatmapset( beatmap_id: int = Query(), - current_user: User = Depends(get_current_user), + current_user: User = Security(get_current_user, scopes=["public"]), db: AsyncSession = Depends(get_db), fetcher: Fetcher = Depends(get_fetcher), ): @@ -34,7 +34,7 @@ async def lookup_beatmapset( @router.get("/beatmapsets/{sid}", tags=["beatmapset"], response_model=BeatmapsetResp) async def get_beatmapset( sid: int, - current_user: User = Depends(get_current_user), + current_user: User = Security(get_current_user, scopes=["public"]), db: AsyncSession = Depends(get_db), fetcher: Fetcher = Depends(get_fetcher), ): @@ -51,7 +51,7 @@ async def get_beatmapset( async def download_beatmapset( beatmapset: int, no_video: bool = Query(True, alias="noVideo"), - current_user: User = Depends(get_current_user), + current_user: User = Security(get_current_user, scopes=["*"]), ): if current_user.country_code == "CN": return RedirectResponse( @@ -68,7 +68,7 @@ async def download_beatmapset( async def favourite_beatmapset( beatmapset: int, action: Literal["favourite", "unfavourite"] = Form(), - current_user: User = Depends(get_current_user), + current_user: User = Security(get_current_user, scopes=["*"]), db: AsyncSession = Depends(get_db), ): existing_favourite = ( diff --git a/app/router/me.py b/app/router/me.py index b6d7d26..28b37ea 100644 --- a/app/router/me.py +++ b/app/router/me.py @@ -8,7 +8,7 @@ from app.models.score import GameMode from .api_router import router -from fastapi import Depends +from fastapi import Depends, Security from sqlmodel.ext.asyncio.session import AsyncSession @@ -16,7 +16,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession @router.get("/me/", response_model=UserResp) async def get_user_info_default( ruleset: GameMode | None = None, - current_user: User = Depends(get_current_user), + current_user: User = Security(get_current_user, scopes=["identify"]), session: AsyncSession = Depends(get_db), ): return await UserResp.from_db( diff --git a/app/router/relationship.py b/app/router/relationship.py index 02292c9..b63b3a8 100644 --- a/app/router/relationship.py +++ b/app/router/relationship.py @@ -1,13 +1,12 @@ from __future__ import annotations -from app.database import User as DBUser -from app.database.relationship import Relationship, RelationshipResp, RelationshipType +from app.database import Relationship, RelationshipResp, RelationshipType, User from app.dependencies.database import get_db from app.dependencies.user import get_current_user from .api_router import router -from fastapi import Depends, HTTPException, Query, Request +from fastapi import Depends, HTTPException, Query, Request, Security from pydantic import BaseModel from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -17,7 +16,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession @router.get("/blocks", tags=["relationship"], response_model=list[RelationshipResp]) async def get_relationship( request: Request, - current_user: DBUser = Depends(get_current_user), + current_user: User = Security(get_current_user, scopes=["friends.read"]), db: AsyncSession = Depends(get_db), ): relationship_type = ( @@ -43,7 +42,7 @@ class AddFriendResp(BaseModel): async def add_relationship( request: Request, target: int = Query(), - current_user: DBUser = Depends(get_current_user), + current_user: User = Security(get_current_user, scopes=["*"]), db: AsyncSession = Depends(get_db), ): relationship_type = ( @@ -106,7 +105,7 @@ async def add_relationship( async def delete_relationship( request: Request, target: int, - current_user: DBUser = Depends(get_current_user), + current_user: User = Security(get_current_user, scopes=["*"]), db: AsyncSession = Depends(get_db), ): relationship_type = ( diff --git a/app/router/room.py b/app/router/room.py index eaee8d0..3f65fcc 100644 --- a/app/router/room.py +++ b/app/router/room.py @@ -20,7 +20,7 @@ from app.signalr.hub import MultiplayerHubs from .api_router import router -from fastapi import Depends, HTTPException, Query +from fastapi import Depends, HTTPException, Query, Security from pydantic import BaseModel, Field from redis.asyncio import Redis from sqlalchemy.sql.elements import ColumnElement @@ -36,7 +36,7 @@ async def get_all_rooms( category: RoomCategory = Query(RoomCategory.NORMAL), status: RoomStatus | None = Query(None), db: AsyncSession = Depends(get_db), - current_user: User = Depends(get_current_user), + current_user: User = Security(get_current_user, scopes=["public"]), ): resp_list: list[RoomResp] = [] where_clauses: list[ColumnElement[bool]] = [col(Room.category) == category] @@ -124,7 +124,7 @@ async def _participate_room( async def create_room( room: APIUploadedRoom, db: AsyncSession = Depends(get_db), - current_user: User = Depends(get_current_user), + current_user: User = Security(get_current_user, scopes=["*"]), ): user_id = current_user.id db_room = await create_playlist_room_from_api(db, room, user_id) @@ -141,7 +141,7 @@ async def get_room( room: int, category: str = Query(default=""), db: AsyncSession = Depends(get_db), - current_user: User = Depends(get_current_user), + current_user: User = Security(get_current_user, scopes=["*"]), redis: Redis = Depends(get_redis), ): # 直接从db获取信息,毕竟都一样 @@ -155,7 +155,11 @@ async def get_room( @router.delete("/rooms/{room}", tags=["room"]) -async def delete_room(room: int, db: AsyncSession = Depends(get_db)): +async def delete_room( + room: int, + db: AsyncSession = Depends(get_db), + current_user: User = Security(get_current_user, scopes=["*"]), +): db_room = (await db.exec(select(Room).where(Room.id == room))).first() if db_room is None: raise HTTPException(404, "Room not found") @@ -166,7 +170,12 @@ async def delete_room(room: int, db: AsyncSession = Depends(get_db)): @router.put("/rooms/{room}/users/{user}", tags=["room"]) -async def add_user_to_room(room: int, user: int, db: AsyncSession = Depends(get_db)): +async def add_user_to_room( + room: int, + user: int, + db: AsyncSession = Depends(get_db), + current_user: User = Security(get_current_user, scopes=["*"]), +): db_room = (await db.exec(select(Room).where(Room.id == room))).first() if db_room is not None: await _participate_room(room, user, db_room, db) @@ -181,7 +190,10 @@ async def add_user_to_room(room: int, user: int, db: AsyncSession = Depends(get_ @router.delete("/rooms/{room}/users/{user}", tags=["room"]) async def remove_user_from_room( - room: int, user: int, db: AsyncSession = Depends(get_db) + room: int, + user: int, + db: AsyncSession = Depends(get_db), + current_user: User = Security(get_current_user, scopes=["*"]), ): db_room = (await db.exec(select(Room).where(Room.id == room))).first() if db_room is not None: @@ -211,7 +223,7 @@ class APILeaderboard(BaseModel): async def get_room_leaderboard( room: int, db: AsyncSession = Depends(get_db), - current_user: User = Depends(get_current_user), + current_user: User = Security(get_current_user, scopes=["public"]), ): db_room = (await db.exec(select(Room).where(Room.id == room))).first() if db_room is None: @@ -253,7 +265,7 @@ class RoomEvents(BaseModel): async def get_room_events( room_id: int, db: AsyncSession = Depends(get_db), - current_user: User = Depends(get_current_user), + current_user: User = Security(get_current_user, scopes=["public"]), limit: int = Query(100, ge=1, le=1000), after: int | None = Query(None, ge=0), before: int | None = Query(None, ge=0), diff --git a/app/router/score.py b/app/router/score.py index d197c8b..a31030a 100644 --- a/app/router/score.py +++ b/app/router/score.py @@ -46,7 +46,7 @@ from app.path import REPLAY_DIR from .api_router import router -from fastapi import Body, Depends, Form, HTTPException, Query +from fastapi import Body, Depends, Form, HTTPException, Query, Security from fastapi.responses import FileResponse from httpx import HTTPError from pydantic import BaseModel @@ -135,7 +135,7 @@ async def get_beatmap_scores( legacy_only: bool = Query(None), # TODO:加入对这个参数的查询 mods: list[str] = Query(default_factory=set, alias="mods[]"), type: LeaderboardType = Query(LeaderboardType.GLOBAL), - current_user: User = Depends(get_current_user), + current_user: User = Security(get_current_user, scopes=["public"]), db: AsyncSession = Depends(get_db), limit: int = Query(50, ge=1, le=200), ): @@ -170,7 +170,7 @@ async def get_user_beatmap_score( legacy_only: bool = Query(None), mode: str = Query(None), mods: str = Query(None), # TODO:添加mods筛选 - current_user: User = Depends(get_current_user), + current_user: User = Security(get_current_user, scopes=["public"]), db: AsyncSession = Depends(get_db), ): if legacy_only: @@ -211,7 +211,7 @@ async def get_user_all_beatmap_scores( user: int, legacy_only: bool = Query(None), ruleset: str = Query(None), - current_user: User = Depends(get_current_user), + current_user: User = Security(get_current_user, scopes=["public"]), db: AsyncSession = Depends(get_db), ): if legacy_only: @@ -241,7 +241,7 @@ async def create_solo_score( version_hash: str = Form(""), beatmap_hash: str = Form(), ruleset_id: int = Form(..., ge=0, le=3), - current_user: User = Depends(get_current_user), + current_user: User = Security(get_current_user, scopes=["*"]), db: AsyncSession = Depends(get_db), ): assert current_user.id @@ -266,7 +266,7 @@ async def submit_solo_score( beatmap: int, token: int, info: SoloScoreSubmissionInfo, - current_user: User = Depends(get_current_user), + current_user: User = Security(get_current_user, scopes=["*"]), db: AsyncSession = Depends(get_db), redis: Redis = Depends(get_redis), fetcher=Depends(get_fetcher), @@ -284,7 +284,7 @@ async def create_playlist_score( beatmap_hash: str = Form(), ruleset_id: int = Form(..., ge=0, le=3), version_hash: str = Form(""), - current_user: User = Depends(get_current_user), + current_user: User = Security(get_current_user, scopes=["*"]), session: AsyncSession = Depends(get_db), ): room = await session.get(Room, room_id) @@ -351,7 +351,7 @@ async def submit_playlist_score( playlist_id: int, token: int, info: SoloScoreSubmissionInfo, - current_user: User = Depends(get_current_user), + current_user: User = Security(get_current_user, scopes=["*"]), session: AsyncSession = Depends(get_db), redis: Redis = Depends(get_redis), fetcher: Fetcher = Depends(get_fetcher), @@ -404,7 +404,7 @@ async def index_playlist_scores( playlist_id: int, limit: int = 50, cursor: int = Query(2000000, alias="cursor[total_score]"), - current_user: User = Depends(get_current_user), + current_user: User = Security(get_current_user, scopes=["public"]), session: AsyncSession = Depends(get_db), ): room = await session.get(Room, room_id) @@ -464,7 +464,7 @@ async def show_playlist_score( room_id: int, playlist_id: int, score_id: int, - current_user: User = Depends(get_current_user), + current_user: User = Security(get_current_user, scopes=["*"]), session: AsyncSession = Depends(get_db), redis: Redis = Depends(get_redis), ): @@ -528,7 +528,7 @@ async def get_user_playlist_score( room_id: int, playlist_id: int, user_id: int, - current_user: User = Depends(get_current_user), + current_user: User = Security(get_current_user, scopes=["*"]), session: AsyncSession = Depends(get_db), ): score_record = None @@ -558,7 +558,7 @@ async def get_user_playlist_score( @router.put("/score-pins/{score}", status_code=204) async def pin_score( score: int, - current_user: User = Depends(get_current_user), + current_user: User = Security(get_current_user, scopes=["*"]), db: AsyncSession = Depends(get_db), ): score_record = ( @@ -594,7 +594,7 @@ async def pin_score( @router.delete("/score-pins/{score}", status_code=204) async def unpin_score( score: int, - current_user: User = Depends(get_current_user), + current_user: User = Security(get_current_user, scopes=["*"]), db: AsyncSession = Depends(get_db), ): score_record = ( @@ -626,7 +626,7 @@ async def reorder_score_pin( score: int, after_score_id: int | None = Body(default=None), before_score_id: int | None = Body(default=None), - current_user: User = Depends(get_current_user), + current_user: User = Security(get_current_user, scopes=["*"]), db: AsyncSession = Depends(get_db), ): score_record = ( @@ -713,7 +713,7 @@ async def reorder_score_pin( @router.get("/scores/{score_id}/download") async def download_score_replay( score_id: int, - current_user: User = Depends(get_current_user), + current_user: User = Security(get_current_user, scopes=["public"]), db: AsyncSession = Depends(get_db), ): score = (await db.exec(select(Score).where(Score.id == score_id))).first() diff --git a/app/router/user.py b/app/router/user.py index e45426d..eb92552 100644 --- a/app/router/user.py +++ b/app/router/user.py @@ -20,7 +20,7 @@ from app.models.user import BeatmapsetType from .api_router import router -from fastapi import Depends, HTTPException, Query +from fastapi import Depends, HTTPException, Query, Security from pydantic import BaseModel from sqlmodel import exists, false, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -36,6 +36,7 @@ class BatchUserResponse(BaseModel): @router.get("/users/lookup/", response_model=BatchUserResponse) async def get_users( user_ids: list[int] = Query(default_factory=list, alias="ids[]"), + current_user: User = Security(get_current_user, scopes=["public"]), include_variant_statistics: bool = Query(default=False), # TODO: future use session: AsyncSession = Depends(get_db), ): @@ -64,6 +65,7 @@ async def get_user_info( user: str, ruleset: GameMode | None = None, session: AsyncSession = Depends(get_db), + current_user: User = Security(get_current_user, scopes=["public"]), ): searched_user = ( await session.exec( @@ -91,7 +93,7 @@ async def get_user_info( async def get_user_beatmapsets( user_id: int, type: BeatmapsetType, - current_user: User = Depends(get_current_user), + current_user: User = Security(get_current_user, scopes=["public"]), session: AsyncSession = Depends(get_db), limit: int = Query(100, ge=1, le=1000), offset: int = Query(0, ge=0), @@ -147,6 +149,7 @@ async def get_user_scores( limit: int = Query(100, ge=1, le=1000), offset: int = Query(0, ge=0), session: AsyncSession = Depends(get_db), + current_user: User = Security(get_current_user, scopes=["public"]), ): db_user = await session.get(User, user) if not db_user: diff --git a/app/signalr/router.py b/app/signalr/router.py index 237a575..053117c 100644 --- a/app/signalr/router.py +++ b/app/signalr/router.py @@ -9,13 +9,13 @@ import uuid from app.database import User as DBUser from app.dependencies import get_current_user from app.dependencies.database import get_db -from app.dependencies.user import get_current_user_by_token from app.models.signalr import NegotiateResponse, Transport from .hub import Hubs from .packet import PROTOCOLS, SEP -from fastapi import APIRouter, Depends, Header, Query, WebSocket +from fastapi import APIRouter, Depends, Header, HTTPException, Query, WebSocket +from fastapi.security import SecurityScopes from sqlmodel.ext.asyncio.session import AsyncSession router = APIRouter() @@ -55,9 +55,15 @@ async def connect( if id not in hub_: await websocket.close(code=1008) return - if (user := await get_current_user_by_token(token, db)) is None or str( - user.id - ) != user_id: + try: + if ( + user := await get_current_user( + SecurityScopes(scopes=["*"]), db, token_pw=token + ) + ) is None or str(user.id) != user_id: + await websocket.close(code=1008) + return + except HTTPException: await websocket.close(code=1008) return await websocket.accept() diff --git a/main.py b/main.py index 5c88a6b..3f3522c 100644 --- a/main.py +++ b/main.py @@ -7,6 +7,7 @@ from app.config import settings from app.dependencies.database import engine, redis_client from app.dependencies.fetcher import get_fetcher from app.dependencies.scheduler import init_scheduler, stop_scheduler +from app.log import logger from app.router import ( api_router, auth_router, @@ -52,9 +53,19 @@ async def health_check(): return {"status": "ok", "timestamp": datetime.utcnow().isoformat()} -if __name__ == "__main__": - from app.log import logger # noqa: F401 +if settings.secret_key == "your_jwt_secret_here": + logger.warning( + "jwt_secret_key is unset. Your server is unsafe. " + "Use this command to generate: openssl rand -hex 32" + ) +if settings.osu_web_client_secret == "your_osu_web_client_secret_here": + logger.warning( + "osu_web_client_secret is unset. Your server is unsafe. " + "Use this command to generate: openssl rand -hex 40" + ) + +if __name__ == "__main__": import uvicorn uvicorn.run( diff --git a/migrations/versions/a8669ba11e96_auth_support_custom_client.py b/migrations/versions/a8669ba11e96_auth_support_custom_client.py new file mode 100644 index 0000000..0a765a2 --- /dev/null +++ b/migrations/versions/a8669ba11e96_auth_support_custom_client.py @@ -0,0 +1,67 @@ +"""auth: support custom client + +Revision ID: a8669ba11e96 +Revises: aa582c13f905 +Create Date: 2025-08-11 11:47:11.004301 + +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa +import sqlmodel + +# revision identifiers, used by Alembic. +revision: str = "a8669ba11e96" +down_revision: str | Sequence[str] | None = "aa582c13f905" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "oauth_clients", + sa.Column("client_id", sa.Integer(), nullable=False), + sa.Column("client_secret", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("redirect_uris", sa.JSON(), nullable=True), + sa.Column("owner_id", sa.BigInteger(), nullable=True), + sa.ForeignKeyConstraint( + ["owner_id"], + ["lazer_users.id"], + ), + sa.PrimaryKeyConstraint("client_id"), + ) + op.create_index( + op.f("ix_oauth_clients_client_id"), "oauth_clients", ["client_id"], unique=False + ) + op.create_index( + op.f("ix_oauth_clients_client_secret"), + "oauth_clients", + ["client_secret"], + unique=False, + ) + op.create_index( + op.f("ix_oauth_clients_owner_id"), "oauth_clients", ["owner_id"], unique=False + ) + op.add_column("oauth_tokens", sa.Column("client_id", sa.Integer(), nullable=False)) + op.create_index( + op.f("ix_oauth_tokens_client_id"), "oauth_tokens", ["client_id"], unique=False + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_oauth_tokens_client_id"), table_name="oauth_tokens") + op.drop_column("oauth_tokens", "client_id") + op.drop_index(op.f("ix_oauth_clients_owner_id"), table_name="oauth_clients") + op.drop_index(op.f("ix_oauth_clients_client_secret"), table_name="oauth_clients") + op.drop_index(op.f("ix_oauth_clients_client_id"), table_name="oauth_clients") + op.drop_table("oauth_clients") + # ### end Alembic commands ### diff --git a/pyproject.toml b/pyproject.toml index a687f17..b717f26 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ dependencies = [ "aiomysql>=0.2.0", "alembic>=1.12.1", "apscheduler>=3.11.0", + "authlib>=1.6.1", "bcrypt>=4.1.2", "cryptography>=41.0.7", "fastapi>=0.104.1", diff --git a/uv.lock b/uv.lock index ffc6105..5609771 100644 --- a/uv.lock +++ b/uv.lock @@ -69,6 +69,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d0/ae/9a053dd9229c0fde6b1f1f33f609ccff1ee79ddda364c756a924c6d8563b/APScheduler-3.11.0-py3-none-any.whl", hash = "sha256:fc134ca32e50f5eadcc4938e3a4545ab19131435e851abb40b34d63d5141c6da", size = 64004, upload-time = "2024-11-24T19:39:24.442Z" }, ] +[[package]] +name = "authlib" +version = "1.6.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cryptography" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/8e/a1/d8d1c6f8bc922c0b87ae0d933a8ed57be1bef6970894ed79c2852a153cd3/authlib-1.6.1.tar.gz", hash = "sha256:4dffdbb1460ba6ec8c17981a4c67af7d8af131231b5a36a88a1e8c80c111cdfd", size = 159988, upload-time = "2025-07-20T07:38:42.834Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f9/58/cc6a08053f822f98f334d38a27687b69c6655fb05cd74a7a5e70a2aeed95/authlib-1.6.1-py2.py3-none-any.whl", hash = "sha256:e9d2031c34c6309373ab845afc24168fe9e93dc52d252631f52642f21f5ed06e", size = 239299, upload-time = "2025-07-20T07:38:39.259Z" }, +] + [[package]] name = "bcrypt" version = "4.3.0" @@ -506,6 +518,7 @@ dependencies = [ { name = "aiomysql" }, { name = "alembic" }, { name = "apscheduler" }, + { name = "authlib" }, { name = "bcrypt" }, { name = "cryptography" }, { name = "fastapi" }, @@ -536,6 +549,7 @@ requires-dist = [ { name = "aiomysql", specifier = ">=0.2.0" }, { name = "alembic", specifier = ">=1.12.1" }, { name = "apscheduler", specifier = ">=3.11.0" }, + { name = "authlib", specifier = ">=1.6.1" }, { name = "bcrypt", specifier = ">=4.1.2" }, { name = "cryptography", specifier = ">=41.0.7" }, { name = "fastapi", specifier = ">=0.104.1" },