feat(developer): support custom OAuth 2.0 client

This commit is contained in:
MingxuanGame
2025-08-11 12:33:31 +00:00
parent ee9381d1f0
commit 6e71141146
21 changed files with 380 additions and 82 deletions

View File

@@ -21,9 +21,11 @@ PORT=8000
# 调试模式,生产环境请设置为 false
DEBUG=false
# osu!lazer 登录设置
OSU_CLIENT_ID="5"
OSU_CLIENT_SECRET="FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk"
# osu! 登录设置
OSU_CLIENT_ID=5 # lazer client ID
OSU_CLIENT_SECRET="FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk" # lazer client secret
OSU_WEB_CLIENT_ID=6 # 网页端 client ID
OSU_WEB_CLIENT_SECRET="your_osu_web_client_secret_here" # 网页端 client secret使用 openssl rand -hex 40 生成
# SignalR 服务器设置
SIGNALR_NEGOTIATE_TIMEOUT=30
@@ -32,7 +34,7 @@ SIGNALR_PING_INTERVAL=15
# Fetcher 设置
FETCHER_CLIENT_ID=""
FETCHER_CLIENT_SECRET=""
FETCHER_SCOPES=["public"]
FETCHER_SCOPES=public
FETCHER_CALLBACK_URL="http://localhost:8000/fetcher/callback"
# 日志设置

View File

@@ -73,6 +73,8 @@ docker-compose -f docker-compose-osurx.yml up -d
|--------|------|--------|
| `OSU_CLIENT_ID` | OAuth 客户端 ID | `5` |
| `OSU_CLIENT_SECRET` | OAuth 客户端密钥 | `FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk` |
| `OSU_WEB_CLIENT_ID` | Web OAuth 客户端 ID | `6` |
| `OSU_WEB_CLIENT_SECRET` | Web OAuth 客户端密钥 | `your_osu_web_client_secret_here`
### SignalR 服务器设置
| 变量名 | 描述 | 默认值 |

View File

@@ -15,6 +15,7 @@ from app.log import logger
import bcrypt
from jose import JWTError, jwt
from passlib.context import CryptContext
from redis.asyncio import Redis
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -156,6 +157,8 @@ def verify_token(token: str) -> dict | None:
async def store_token(
db: AsyncSession,
user_id: int,
client_id: int,
scopes: list[str],
access_token: str,
refresh_token: str,
expires_in: int,
@@ -164,7 +167,9 @@ async def store_token(
expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
# 删除用户的旧令牌
statement = select(OAuthToken).where(OAuthToken.user_id == user_id)
statement = select(OAuthToken).where(
OAuthToken.user_id == user_id, OAuthToken.client_id == client_id
)
old_tokens = (await db.exec(statement)).all()
for token in old_tokens:
await db.delete(token)
@@ -179,7 +184,9 @@ async def store_token(
# 创建新令牌记录
token_record = OAuthToken(
user_id=user_id,
client_id=client_id,
access_token=access_token,
scope=",".join(scopes),
refresh_token=refresh_token,
expires_at=expires_at,
)
@@ -209,3 +216,18 @@ async def get_token_by_refresh_token(
OAuthToken.expires_at > datetime.utcnow(),
)
return (await db.exec(statement)).first()
async def get_user_by_authorization_code(
db: AsyncSession, redis: Redis, client_id: int, code: str
) -> tuple[User, list[str]] | None:
user_id = await redis.hget(f"oauth:code:{client_id}:{code}", "user_id") # pyright: ignore[reportGeneralTypeIssues]
scopes = await redis.hget(f"oauth:code:{client_id}:{code}", "scopes") # pyright: ignore[reportGeneralTypeIssues]
if not user_id or not scopes:
return None
await redis.hdel(f"oauth:code:{client_id}:{code}", "user_id", "scopes") # pyright: ignore[reportGeneralTypeIssues]
statement = select(User).where(User.id == int(user_id))
user = (await db.exec(statement)).first()
return (user, scopes.split(",")) if user else None

View File

@@ -23,13 +23,15 @@ class Settings(BaseSettings):
return f"mysql+aiomysql://{self.mysql_user}:{self.mysql_password}@{self.mysql_host}:{self.mysql_port}/{self.mysql_database}"
# JWT 设置
secret_key: str = Field(default="your-secret-key-here", alias="jwt_secret_key")
secret_key: str = Field(default="your_jwt_secret_here", alias="jwt_secret_key")
algorithm: str = "HS256"
access_token_expire_minutes: int = 1440
# OAuth 设置
osu_client_id: str = "5"
osu_client_id: int = 5
osu_client_secret: str = "FGc9GAtyHzeQDshWP5Ah7dega8hJACAJpQtw6OXk"
osu_web_client_id: int = 6
osu_web_client_secret: str = "your_osu_web_client_secret_here"
# 服务器设置
host: str = "0.0.0.0"

View File

@@ -1,5 +1,5 @@
from .achievement import UserAchievement, UserAchievementResp
from .auth import OAuthToken
from .auth import OAuthClient, OAuthToken
from .beatmap import (
Beatmap as Beatmap,
BeatmapResp as BeatmapResp,
@@ -71,6 +71,7 @@ __all__ = [
"MultiplayerEvent",
"MultiplayerEventResp",
"MultiplayerScores",
"OAuthClient",
"OAuthToken",
"PPBestScore",
"Playlist",

View File

@@ -1,10 +1,11 @@
from datetime import datetime
import secrets
from typing import TYPE_CHECKING
from app.models.model import UTCBaseModel
from sqlalchemy import Column, DateTime
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
from sqlmodel import JSON, BigInteger, Field, ForeignKey, Relationship, SQLModel
if TYPE_CHECKING:
from .lazer_user import User
@@ -17,6 +18,7 @@ class OAuthToken(UTCBaseModel, SQLModel, table=True):
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)
client_id: int = Field(index=True)
access_token: str = Field(max_length=500, unique=True)
refresh_token: str = Field(max_length=500, unique=True)
token_type: str = Field(default="Bearer", max_length=20)
@@ -27,3 +29,13 @@ class OAuthToken(UTCBaseModel, SQLModel, table=True):
)
user: "User" = Relationship()
class OAuthClient(SQLModel, table=True):
__tablename__ = "oauth_clients" # pyright: ignore[reportAssignmentType]
client_id: int | None = Field(default=None, primary_key=True, index=True)
client_secret: str = Field(default_factory=secrets.token_hex, index=True)
redirect_uris: list[str] = Field(default_factory=list, sa_column=Column(JSON))
owner_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)
)

View File

@@ -1,34 +1,84 @@
from __future__ import annotations
from typing import Annotated
from app.auth import get_token_by_access_token
from app.config import settings
from app.database import User
from .database import get_db
from fastapi import Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.security import (
HTTPBearer,
OAuth2AuthorizationCodeBearer,
OAuth2PasswordBearer,
SecurityScopes,
)
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
security = HTTPBearer()
oauth2_password = OAuth2PasswordBearer(
tokenUrl="oauth/token",
scopes={"*": "Allows access to all scopes."},
)
oauth2_code = OAuth2AuthorizationCodeBearer(
authorizationUrl="oauth/authorize",
tokenUrl="oauth/token",
scopes={
"chat.read": "Allows read chat messages on a user's behalf.",
"chat.write": "Allows sending chat messages on a user's behalf.",
"chat.write_manage": (
"Allows joining and leaving chat channels on a user's behalf."
),
"delegate": (
"Allows acting as the owner of a client; "
"only available for Client Credentials Grant."
),
"forum.write": "Allows creating and editing forum posts on a user's behalf.",
"friends.read": "Allows reading of the user's friend list.",
"identify": "Allows reading of the public profile of the user (/me).",
"public": "Allows reading of publicly available data on behalf of the user.",
},
)
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: AsyncSession = Depends(get_db),
security_scopes: SecurityScopes,
db: Annotated[AsyncSession, Depends(get_db)],
token_pw: Annotated[str | None, Depends(oauth2_password)] = None,
token_code: Annotated[str | None, Depends(oauth2_code)] = None,
) -> User:
"""获取当前认证用户"""
token = credentials.credentials
token = token_pw or token_code
if not token:
raise HTTPException(status_code=401, detail="Not authenticated")
user = await get_current_user_by_token(token, db)
token_record = await get_token_by_access_token(db, token)
if not token_record:
raise HTTPException(status_code=401, detail="Invalid or expired token")
is_client = token_record.client_id in (
settings.osu_client_id,
settings.osu_web_client_id,
)
if security_scopes.scopes == ["*"]:
# client/web only
if not token_pw or not is_client:
raise HTTPException(status_code=401, detail="Not authenticated")
elif not is_client:
for scope in security_scopes.scopes:
if scope not in token_record.scope.split(","):
raise HTTPException(
status_code=403, detail=f"Insufficient scope: {scope}"
)
user = (await db.exec(select(User).where(User.id == token_record.user_id))).first()
if not user:
raise HTTPException(status_code=401, detail="Invalid or expired token")
return user
async def get_current_user_by_token(token: str, db: AsyncSession) -> User | None:
token_record = await get_token_by_access_token(db, token)
if not token_record:
return None
user = (await db.exec(select(User).where(User.id == token_record.user_id))).first()
return user

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
from datetime import UTC, datetime, timedelta
import re
from typing import Literal
from app.auth import (
authenticate_user,
@@ -9,12 +10,14 @@ from app.auth import (
generate_refresh_token,
get_password_hash,
get_token_by_refresh_token,
get_user_by_authorization_code,
store_token,
)
from app.config import settings
from app.database import DailyChallengeStats, User
from app.database import DailyChallengeStats, OAuthClient, User
from app.database.statistics import UserStatistics
from app.dependencies import get_db
from app.dependencies.database import get_redis
from app.log import logger
from app.models.oauth import (
OAuthErrorResponse,
@@ -26,6 +29,7 @@ from app.models.score import GameMode
from fastapi import APIRouter, Depends, Form
from fastapi.responses import JSONResponse
from redis.asyncio import Redis
from sqlalchemy import text
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -195,21 +199,36 @@ async def register_user(
@router.post("/oauth/token", response_model=TokenResponse)
async def oauth_token(
grant_type: str = Form(...),
client_id: str = Form(...),
grant_type: Literal[
"authorization_code", "refresh_token", "password", "client_credentials"
] = Form(...),
client_id: int = Form(...),
client_secret: str = Form(...),
code: str | None = Form(None),
scope: str = Form("*"),
username: str | None = Form(None),
password: str | None = Form(None),
refresh_token: str | None = Form(None),
db: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
):
"""OAuth 令牌端点"""
# 验证客户端凭据
if (
client_id != settings.osu_client_id
or client_secret != settings.osu_client_secret
):
scopes = scope.split(" ")
client = (
await db.exec(
select(OAuthClient).where(
OAuthClient.client_id == client_id,
OAuthClient.client_secret == client_secret,
)
)
).first()
is_game_client = (client_id, client_secret) in [
(settings.osu_client_id, settings.osu_client_secret),
(settings.osu_web_client_id, settings.osu_web_client_secret),
]
if client is None and not is_game_client:
return create_oauth_error_response(
error="invalid_client",
description=(
@@ -222,7 +241,6 @@ async def oauth_token(
)
if grant_type == "password":
# 密码授权流程
if not username or not password:
return create_oauth_error_response(
error="invalid_request",
@@ -233,6 +251,16 @@ async def oauth_token(
),
hint="Username and password required",
)
if scopes != ["*"]:
return create_oauth_error_response(
error="invalid_scope",
description=(
"The requested scope is invalid, unknown, "
"or malformed. The client may not request "
"more than one scope at a time."
),
hint="Only '*' scope is allowed for password grant type",
)
# 验证用户
user = await authenticate_user(db, username, password)
@@ -261,6 +289,8 @@ async def oauth_token(
await store_token(
db,
user.id,
client_id,
scopes,
access_token,
refresh_token_str,
settings.access_token_expire_minutes * 60,
@@ -313,6 +343,8 @@ async def oauth_token(
await store_token(
db,
token_record.user_id,
client_id,
scopes,
access_token,
new_refresh_token,
settings.access_token_expire_minutes * 60,
@@ -325,7 +357,69 @@ async def oauth_token(
refresh_token=new_refresh_token,
scope=scope,
)
elif grant_type == "authorization_code":
if client is None:
return create_oauth_error_response(
error="invalid_client",
description=(
"Client authentication failed (e.g., unknown client, "
"no client authentication included, "
"or unsupported authentication method)."
),
hint="Invalid client credentials",
status_code=401,
)
if not code:
return create_oauth_error_response(
error="invalid_request",
description=(
"The request is missing a required parameter, "
"includes an invalid parameter value, "
"includes a parameter more than once, or is otherwise malformed."
),
hint="Authorization code required",
)
code_result = await get_user_by_authorization_code(db, redis, client_id, code)
if not code_result:
return create_oauth_error_response(
error="invalid_grant",
description=(
"The provided authorization grant (e.g., authorization code, "
"resource owner credentials) or refresh token is invalid, "
"expired, revoked, does not match the redirection URI used in "
"the authorization request, or was issued to another client."
),
hint="Invalid authorization code",
)
user, scopes = code_result
# 生成令牌
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
access_token = create_access_token(
data={"sub": str(user.id)}, expires_delta=access_token_expires
)
refresh_token_str = generate_refresh_token()
# 存储令牌
assert user.id
await store_token(
db,
user.id,
client_id,
scopes,
access_token,
refresh_token_str,
settings.access_token_expire_minutes * 60,
)
return TokenResponse(
access_token=access_token,
token_type="Bearer",
expires_in=settings.access_token_expire_minutes * 60,
refresh_token=refresh_token_str,
scope=" ".join(scopes),
)
else:
return create_oauth_error_response(
error="unsupported_grant_type",

View File

@@ -19,7 +19,7 @@ from app.models.score import (
from .api_router import router
from fastapi import Depends, HTTPException, Query
from fastapi import Depends, HTTPException, Query, Security
from httpx import HTTPError, HTTPStatusError
from pydantic import BaseModel
from redis.asyncio import Redis
@@ -33,7 +33,7 @@ async def lookup_beatmap(
id: int | None = Query(default=None, alias="id"),
md5: str | None = Query(default=None, alias="checksum"),
filename: str | None = Query(default=None, alias="filename"),
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
@@ -56,7 +56,7 @@ async def lookup_beatmap(
@router.get("/beatmaps/{bid}", tags=["beatmap"], response_model=BeatmapResp)
async def get_beatmap(
bid: int,
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
@@ -75,7 +75,7 @@ class BatchGetResp(BaseModel):
@router.get("/beatmaps/", tags=["beatmap"], response_model=BatchGetResp)
async def batch_get_beatmaps(
b_ids: list[int] = Query(alias="ids[]", default_factory=list),
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
@@ -126,7 +126,7 @@ async def batch_get_beatmaps(
)
async def get_beatmap_attributes(
beatmap: int,
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["public"]),
mods: list[str] = Query(default_factory=list),
ruleset: GameMode | None = Query(default=None),
ruleset_id: int | None = Query(default=None),

View File

@@ -10,7 +10,7 @@ from app.fetcher import Fetcher
from .api_router import router
from fastapi import Depends, Form, HTTPException, Query
from fastapi import Depends, Form, HTTPException, Query, Security
from fastapi.responses import RedirectResponse
from httpx import HTTPError
from sqlmodel import select
@@ -20,7 +20,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
@router.get("/beatmapsets/lookup", tags=["beatmapset"], response_model=BeatmapsetResp)
async def lookup_beatmapset(
beatmap_id: int = Query(),
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
@@ -34,7 +34,7 @@ async def lookup_beatmapset(
@router.get("/beatmapsets/{sid}", tags=["beatmapset"], response_model=BeatmapsetResp)
async def get_beatmapset(
sid: int,
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
fetcher: Fetcher = Depends(get_fetcher),
):
@@ -51,7 +51,7 @@ async def get_beatmapset(
async def download_beatmapset(
beatmapset: int,
no_video: bool = Query(True, alias="noVideo"),
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["*"]),
):
if current_user.country_code == "CN":
return RedirectResponse(
@@ -68,7 +68,7 @@ async def download_beatmapset(
async def favourite_beatmapset(
beatmapset: int,
action: Literal["favourite", "unfavourite"] = Form(),
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["*"]),
db: AsyncSession = Depends(get_db),
):
existing_favourite = (

View File

@@ -8,7 +8,7 @@ from app.models.score import GameMode
from .api_router import router
from fastapi import Depends
from fastapi import Depends, Security
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -16,7 +16,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
@router.get("/me/", response_model=UserResp)
async def get_user_info_default(
ruleset: GameMode | None = None,
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["identify"]),
session: AsyncSession = Depends(get_db),
):
return await UserResp.from_db(

View File

@@ -1,13 +1,12 @@
from __future__ import annotations
from app.database import User as DBUser
from app.database.relationship import Relationship, RelationshipResp, RelationshipType
from app.database import Relationship, RelationshipResp, RelationshipType, User
from app.dependencies.database import get_db
from app.dependencies.user import get_current_user
from .api_router import router
from fastapi import Depends, HTTPException, Query, Request
from fastapi import Depends, HTTPException, Query, Request, Security
from pydantic import BaseModel
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -17,7 +16,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
@router.get("/blocks", tags=["relationship"], response_model=list[RelationshipResp])
async def get_relationship(
request: Request,
current_user: DBUser = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["friends.read"]),
db: AsyncSession = Depends(get_db),
):
relationship_type = (
@@ -43,7 +42,7 @@ class AddFriendResp(BaseModel):
async def add_relationship(
request: Request,
target: int = Query(),
current_user: DBUser = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["*"]),
db: AsyncSession = Depends(get_db),
):
relationship_type = (
@@ -106,7 +105,7 @@ async def add_relationship(
async def delete_relationship(
request: Request,
target: int,
current_user: DBUser = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["*"]),
db: AsyncSession = Depends(get_db),
):
relationship_type = (

View File

@@ -20,7 +20,7 @@ from app.signalr.hub import MultiplayerHubs
from .api_router import router
from fastapi import Depends, HTTPException, Query
from fastapi import Depends, HTTPException, Query, Security
from pydantic import BaseModel, Field
from redis.asyncio import Redis
from sqlalchemy.sql.elements import ColumnElement
@@ -36,7 +36,7 @@ async def get_all_rooms(
category: RoomCategory = Query(RoomCategory.NORMAL),
status: RoomStatus | None = Query(None),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["public"]),
):
resp_list: list[RoomResp] = []
where_clauses: list[ColumnElement[bool]] = [col(Room.category) == category]
@@ -124,7 +124,7 @@ async def _participate_room(
async def create_room(
room: APIUploadedRoom,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["*"]),
):
user_id = current_user.id
db_room = await create_playlist_room_from_api(db, room, user_id)
@@ -141,7 +141,7 @@ async def get_room(
room: int,
category: str = Query(default=""),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["*"]),
redis: Redis = Depends(get_redis),
):
# 直接从db获取信息毕竟都一样
@@ -155,7 +155,11 @@ async def get_room(
@router.delete("/rooms/{room}", tags=["room"])
async def delete_room(room: int, db: AsyncSession = Depends(get_db)):
async def delete_room(
room: int,
db: AsyncSession = Depends(get_db),
current_user: User = Security(get_current_user, scopes=["*"]),
):
db_room = (await db.exec(select(Room).where(Room.id == room))).first()
if db_room is None:
raise HTTPException(404, "Room not found")
@@ -166,7 +170,12 @@ async def delete_room(room: int, db: AsyncSession = Depends(get_db)):
@router.put("/rooms/{room}/users/{user}", tags=["room"])
async def add_user_to_room(room: int, user: int, db: AsyncSession = Depends(get_db)):
async def add_user_to_room(
room: int,
user: int,
db: AsyncSession = Depends(get_db),
current_user: User = Security(get_current_user, scopes=["*"]),
):
db_room = (await db.exec(select(Room).where(Room.id == room))).first()
if db_room is not None:
await _participate_room(room, user, db_room, db)
@@ -181,7 +190,10 @@ async def add_user_to_room(room: int, user: int, db: AsyncSession = Depends(get_
@router.delete("/rooms/{room}/users/{user}", tags=["room"])
async def remove_user_from_room(
room: int, user: int, db: AsyncSession = Depends(get_db)
room: int,
user: int,
db: AsyncSession = Depends(get_db),
current_user: User = Security(get_current_user, scopes=["*"]),
):
db_room = (await db.exec(select(Room).where(Room.id == room))).first()
if db_room is not None:
@@ -211,7 +223,7 @@ class APILeaderboard(BaseModel):
async def get_room_leaderboard(
room: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["public"]),
):
db_room = (await db.exec(select(Room).where(Room.id == room))).first()
if db_room is None:
@@ -253,7 +265,7 @@ class RoomEvents(BaseModel):
async def get_room_events(
room_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["public"]),
limit: int = Query(100, ge=1, le=1000),
after: int | None = Query(None, ge=0),
before: int | None = Query(None, ge=0),

View File

@@ -46,7 +46,7 @@ from app.path import REPLAY_DIR
from .api_router import router
from fastapi import Body, Depends, Form, HTTPException, Query
from fastapi import Body, Depends, Form, HTTPException, Query, Security
from fastapi.responses import FileResponse
from httpx import HTTPError
from pydantic import BaseModel
@@ -135,7 +135,7 @@ async def get_beatmap_scores(
legacy_only: bool = Query(None), # TODO:加入对这个参数的查询
mods: list[str] = Query(default_factory=set, alias="mods[]"),
type: LeaderboardType = Query(LeaderboardType.GLOBAL),
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
limit: int = Query(50, ge=1, le=200),
):
@@ -170,7 +170,7 @@ async def get_user_beatmap_score(
legacy_only: bool = Query(None),
mode: str = Query(None),
mods: str = Query(None), # TODO:添加mods筛选
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
):
if legacy_only:
@@ -211,7 +211,7 @@ async def get_user_all_beatmap_scores(
user: int,
legacy_only: bool = Query(None),
ruleset: str = Query(None),
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
):
if legacy_only:
@@ -241,7 +241,7 @@ async def create_solo_score(
version_hash: str = Form(""),
beatmap_hash: str = Form(),
ruleset_id: int = Form(..., ge=0, le=3),
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["*"]),
db: AsyncSession = Depends(get_db),
):
assert current_user.id
@@ -266,7 +266,7 @@ async def submit_solo_score(
beatmap: int,
token: int,
info: SoloScoreSubmissionInfo,
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["*"]),
db: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
fetcher=Depends(get_fetcher),
@@ -284,7 +284,7 @@ async def create_playlist_score(
beatmap_hash: str = Form(),
ruleset_id: int = Form(..., ge=0, le=3),
version_hash: str = Form(""),
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["*"]),
session: AsyncSession = Depends(get_db),
):
room = await session.get(Room, room_id)
@@ -351,7 +351,7 @@ async def submit_playlist_score(
playlist_id: int,
token: int,
info: SoloScoreSubmissionInfo,
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["*"]),
session: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
fetcher: Fetcher = Depends(get_fetcher),
@@ -404,7 +404,7 @@ async def index_playlist_scores(
playlist_id: int,
limit: int = 50,
cursor: int = Query(2000000, alias="cursor[total_score]"),
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["public"]),
session: AsyncSession = Depends(get_db),
):
room = await session.get(Room, room_id)
@@ -464,7 +464,7 @@ async def show_playlist_score(
room_id: int,
playlist_id: int,
score_id: int,
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["*"]),
session: AsyncSession = Depends(get_db),
redis: Redis = Depends(get_redis),
):
@@ -528,7 +528,7 @@ async def get_user_playlist_score(
room_id: int,
playlist_id: int,
user_id: int,
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["*"]),
session: AsyncSession = Depends(get_db),
):
score_record = None
@@ -558,7 +558,7 @@ async def get_user_playlist_score(
@router.put("/score-pins/{score}", status_code=204)
async def pin_score(
score: int,
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["*"]),
db: AsyncSession = Depends(get_db),
):
score_record = (
@@ -594,7 +594,7 @@ async def pin_score(
@router.delete("/score-pins/{score}", status_code=204)
async def unpin_score(
score: int,
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["*"]),
db: AsyncSession = Depends(get_db),
):
score_record = (
@@ -626,7 +626,7 @@ async def reorder_score_pin(
score: int,
after_score_id: int | None = Body(default=None),
before_score_id: int | None = Body(default=None),
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["*"]),
db: AsyncSession = Depends(get_db),
):
score_record = (
@@ -713,7 +713,7 @@ async def reorder_score_pin(
@router.get("/scores/{score_id}/download")
async def download_score_replay(
score_id: int,
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["public"]),
db: AsyncSession = Depends(get_db),
):
score = (await db.exec(select(Score).where(Score.id == score_id))).first()

View File

@@ -20,7 +20,7 @@ from app.models.user import BeatmapsetType
from .api_router import router
from fastapi import Depends, HTTPException, Query
from fastapi import Depends, HTTPException, Query, Security
from pydantic import BaseModel
from sqlmodel import exists, false, select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -36,6 +36,7 @@ class BatchUserResponse(BaseModel):
@router.get("/users/lookup/", response_model=BatchUserResponse)
async def get_users(
user_ids: list[int] = Query(default_factory=list, alias="ids[]"),
current_user: User = Security(get_current_user, scopes=["public"]),
include_variant_statistics: bool = Query(default=False), # TODO: future use
session: AsyncSession = Depends(get_db),
):
@@ -64,6 +65,7 @@ async def get_user_info(
user: str,
ruleset: GameMode | None = None,
session: AsyncSession = Depends(get_db),
current_user: User = Security(get_current_user, scopes=["public"]),
):
searched_user = (
await session.exec(
@@ -91,7 +93,7 @@ async def get_user_info(
async def get_user_beatmapsets(
user_id: int,
type: BeatmapsetType,
current_user: User = Depends(get_current_user),
current_user: User = Security(get_current_user, scopes=["public"]),
session: AsyncSession = Depends(get_db),
limit: int = Query(100, ge=1, le=1000),
offset: int = Query(0, ge=0),
@@ -147,6 +149,7 @@ async def get_user_scores(
limit: int = Query(100, ge=1, le=1000),
offset: int = Query(0, ge=0),
session: AsyncSession = Depends(get_db),
current_user: User = Security(get_current_user, scopes=["public"]),
):
db_user = await session.get(User, user)
if not db_user:

View File

@@ -9,13 +9,13 @@ import uuid
from app.database import User as DBUser
from app.dependencies import get_current_user
from app.dependencies.database import get_db
from app.dependencies.user import get_current_user_by_token
from app.models.signalr import NegotiateResponse, Transport
from .hub import Hubs
from .packet import PROTOCOLS, SEP
from fastapi import APIRouter, Depends, Header, Query, WebSocket
from fastapi import APIRouter, Depends, Header, HTTPException, Query, WebSocket
from fastapi.security import SecurityScopes
from sqlmodel.ext.asyncio.session import AsyncSession
router = APIRouter()
@@ -55,9 +55,15 @@ async def connect(
if id not in hub_:
await websocket.close(code=1008)
return
if (user := await get_current_user_by_token(token, db)) is None or str(
user.id
) != user_id:
try:
if (
user := await get_current_user(
SecurityScopes(scopes=["*"]), db, token_pw=token
)
) is None or str(user.id) != user_id:
await websocket.close(code=1008)
return
except HTTPException:
await websocket.close(code=1008)
return
await websocket.accept()

15
main.py
View File

@@ -7,6 +7,7 @@ from app.config import settings
from app.dependencies.database import engine, redis_client
from app.dependencies.fetcher import get_fetcher
from app.dependencies.scheduler import init_scheduler, stop_scheduler
from app.log import logger
from app.router import (
api_router,
auth_router,
@@ -52,9 +53,19 @@ async def health_check():
return {"status": "ok", "timestamp": datetime.utcnow().isoformat()}
if __name__ == "__main__":
from app.log import logger # noqa: F401
if settings.secret_key == "your_jwt_secret_here":
logger.warning(
"jwt_secret_key is unset. Your server is unsafe. "
"Use this command to generate: openssl rand -hex 32"
)
if settings.osu_web_client_secret == "your_osu_web_client_secret_here":
logger.warning(
"osu_web_client_secret is unset. Your server is unsafe. "
"Use this command to generate: openssl rand -hex 40"
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(

View File

@@ -0,0 +1,67 @@
"""auth: support custom client
Revision ID: a8669ba11e96
Revises: aa582c13f905
Create Date: 2025-08-11 11:47:11.004301
"""
from __future__ import annotations
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
import sqlmodel
# revision identifiers, used by Alembic.
revision: str = "a8669ba11e96"
down_revision: str | Sequence[str] | None = "aa582c13f905"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"oauth_clients",
sa.Column("client_id", sa.Integer(), nullable=False),
sa.Column("client_secret", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("redirect_uris", sa.JSON(), nullable=True),
sa.Column("owner_id", sa.BigInteger(), nullable=True),
sa.ForeignKeyConstraint(
["owner_id"],
["lazer_users.id"],
),
sa.PrimaryKeyConstraint("client_id"),
)
op.create_index(
op.f("ix_oauth_clients_client_id"), "oauth_clients", ["client_id"], unique=False
)
op.create_index(
op.f("ix_oauth_clients_client_secret"),
"oauth_clients",
["client_secret"],
unique=False,
)
op.create_index(
op.f("ix_oauth_clients_owner_id"), "oauth_clients", ["owner_id"], unique=False
)
op.add_column("oauth_tokens", sa.Column("client_id", sa.Integer(), nullable=False))
op.create_index(
op.f("ix_oauth_tokens_client_id"), "oauth_tokens", ["client_id"], unique=False
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_oauth_tokens_client_id"), table_name="oauth_tokens")
op.drop_column("oauth_tokens", "client_id")
op.drop_index(op.f("ix_oauth_clients_owner_id"), table_name="oauth_clients")
op.drop_index(op.f("ix_oauth_clients_client_secret"), table_name="oauth_clients")
op.drop_index(op.f("ix_oauth_clients_client_id"), table_name="oauth_clients")
op.drop_table("oauth_clients")
# ### end Alembic commands ###

View File

@@ -8,6 +8,7 @@ dependencies = [
"aiomysql>=0.2.0",
"alembic>=1.12.1",
"apscheduler>=3.11.0",
"authlib>=1.6.1",
"bcrypt>=4.1.2",
"cryptography>=41.0.7",
"fastapi>=0.104.1",

14
uv.lock generated
View File

@@ -69,6 +69,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/d0/ae/9a053dd9229c0fde6b1f1f33f609ccff1ee79ddda364c756a924c6d8563b/APScheduler-3.11.0-py3-none-any.whl", hash = "sha256:fc134ca32e50f5eadcc4938e3a4545ab19131435e851abb40b34d63d5141c6da", size = 64004, upload-time = "2024-11-24T19:39:24.442Z" },
]
[[package]]
name = "authlib"
version = "1.6.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "cryptography" },
]
sdist = { url = "https://files.pythonhosted.org/packages/8e/a1/d8d1c6f8bc922c0b87ae0d933a8ed57be1bef6970894ed79c2852a153cd3/authlib-1.6.1.tar.gz", hash = "sha256:4dffdbb1460ba6ec8c17981a4c67af7d8af131231b5a36a88a1e8c80c111cdfd", size = 159988, upload-time = "2025-07-20T07:38:42.834Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/f9/58/cc6a08053f822f98f334d38a27687b69c6655fb05cd74a7a5e70a2aeed95/authlib-1.6.1-py2.py3-none-any.whl", hash = "sha256:e9d2031c34c6309373ab845afc24168fe9e93dc52d252631f52642f21f5ed06e", size = 239299, upload-time = "2025-07-20T07:38:39.259Z" },
]
[[package]]
name = "bcrypt"
version = "4.3.0"
@@ -506,6 +518,7 @@ dependencies = [
{ name = "aiomysql" },
{ name = "alembic" },
{ name = "apscheduler" },
{ name = "authlib" },
{ name = "bcrypt" },
{ name = "cryptography" },
{ name = "fastapi" },
@@ -536,6 +549,7 @@ requires-dist = [
{ name = "aiomysql", specifier = ">=0.2.0" },
{ name = "alembic", specifier = ">=1.12.1" },
{ name = "apscheduler", specifier = ">=3.11.0" },
{ name = "authlib", specifier = ">=1.6.1" },
{ name = "bcrypt", specifier = ">=4.1.2" },
{ name = "cryptography", specifier = ">=41.0.7" },
{ name = "fastapi", specifier = ">=0.104.1" },