refactor(private-api): use OAuth to authorize

This commit is contained in:
MingxuanGame
2025-08-12 16:04:19 +00:00
parent 186656d72f
commit b5afbed36c
8 changed files with 37 additions and 68 deletions

View File

@@ -26,8 +26,6 @@ CORS_URLS='[]'
FRONTEND_URL FRONTEND_URL
# 调试模式,生产环境请设置为 false # 调试模式,生产环境请设置为 false
DEBUG=false DEBUG=false
# 私有 API 密钥,用于前后端 API 调用,使用 openssl rand -hex 32 生成
PRIVATE_API_SECRET="your_private_api_secret_here"
# osu! 登录设置 # osu! 登录设置
OSU_CLIENT_ID=5 # lazer client ID OSU_CLIENT_ID=5 # lazer client ID

View File

@@ -71,7 +71,6 @@ docker-compose -f docker-compose-osurx.yml up -d
| `SERVER_URL` | 服务器 URL | `http://localhost:8000` | | `SERVER_URL` | 服务器 URL | `http://localhost:8000` |
| `CORS_URLS` | 额外的 CORS 允许的域名列表 (JSON 格式) | `[]` | | `CORS_URLS` | 额外的 CORS 允许的域名列表 (JSON 格式) | `[]` |
| `FRONTEND_URL` | 前端 URL当访问从游戏打开的 URL 时会重定向到这个 URL为空表示不重定向 | `` | | `FRONTEND_URL` | 前端 URL当访问从游戏打开的 URL 时会重定向到这个 URL为空表示不重定向 | `` |
| `PRIVATE_API_SECRET` | 私有 API 密钥,用于前后端 API 调用 | `your_private_api_secret_here` |
### OAuth 设置 ### OAuth 设置
| 变量名 | 描述 | 默认值 | | 变量名 | 描述 | 默认值 |

View File

@@ -64,7 +64,6 @@ class Settings(BaseSettings):
host: str = "0.0.0.0" host: str = "0.0.0.0"
port: int = 8000 port: int = 8000
debug: bool = False debug: bool = False
private_api_secret: str = "your_private_api_secret_here"
cors_urls: list[HttpUrl] = [] cors_urls: list[HttpUrl] = []
server_url: HttpUrl = HttpUrl("http://localhost:8000") server_url: HttpUrl = HttpUrl("http://localhost:8000")
frontend_url: HttpUrl | None = None frontend_url: HttpUrl | None = None

View File

@@ -1,17 +1,17 @@
from __future__ import annotations from __future__ import annotations
import base64
import hashlib import hashlib
from io import BytesIO from io import BytesIO
from app.database.lazer_user import User from app.database.lazer_user import User
from app.dependencies.database import get_db from app.dependencies.database import get_db
from app.dependencies.storage import get_storage_service from app.dependencies.storage import get_storage_service
from app.dependencies.user import get_current_user
from app.storage.base import StorageService from app.storage.base import StorageService
from .router import router from .router import router
from fastapi import Body, Depends, HTTPException from fastapi import Depends, File, HTTPException, Security
from PIL import Image from PIL import Image
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -21,8 +21,8 @@ from sqlmodel.ext.asyncio.session import AsyncSession
name="上传头像", name="上传头像",
) )
async def upload_avatar( async def upload_avatar(
file: str = Body(..., description="Base64 编码的图片数据"), content: bytes = File(...),
user_id: int = Body(..., description="用户 ID"), current_user: User = Security(get_current_user, scopes=["*"]),
storage: StorageService = Depends(get_storage_service), storage: StorageService = Depends(get_storage_service),
session: AsyncSession = Depends(get_db), session: AsyncSession = Depends(get_db),
): ):
@@ -38,11 +38,6 @@ async def upload_avatar(
返回: 返回:
- 头像 URL 和文件哈希值 - 头像 URL 和文件哈希值
""" """
content = base64.b64decode(file)
user = await session.get(User, user_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
# check file # check file
if len(content) > 5 * 1024 * 1024: # 5MB limit if len(content) > 5 * 1024 * 1024: # 5MB limit
@@ -58,11 +53,11 @@ async def upload_avatar(
) )
filehash = hashlib.sha256(content).hexdigest() filehash = hashlib.sha256(content).hexdigest()
storage_path = f"avatars/{user_id}_{filehash}.png" storage_path = f"avatars/{current_user.id}_{filehash}.png"
if not await storage.is_exists(storage_path): if not await storage.is_exists(storage_path):
await storage.write_file(storage_path, content) await storage.write_file(storage_path, content)
url = await storage.get_file_url(storage_path) url = await storage.get_file_url(storage_path)
user.avatar_url = url current_user.avatar_url = url
await session.commit() await session.commit()
return { return {

View File

@@ -3,11 +3,13 @@ from __future__ import annotations
import secrets import secrets
from app.database.auth import OAuthClient, OAuthToken from app.database.auth import OAuthClient, OAuthToken
from app.database.lazer_user import User
from app.dependencies.database import get_db, get_redis from app.dependencies.database import get_db, get_redis
from app.dependencies.user import get_current_user
from .router import router from .router import router
from fastapi import Body, Depends, HTTPException from fastapi import Body, Depends, HTTPException, Security
from redis.asyncio import Redis from redis.asyncio import Redis
from sqlmodel import select, text from sqlmodel import select, text
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -22,7 +24,7 @@ async def create_oauth_app(
name: str = Body(..., max_length=100, description="应用程序名称"), name: str = Body(..., max_length=100, description="应用程序名称"),
description: str = Body("", description="应用程序描述"), description: str = Body("", description="应用程序描述"),
redirect_uris: list[str] = Body(..., description="允许的重定向 URI 列表"), redirect_uris: list[str] = Body(..., description="允许的重定向 URI 列表"),
owner_id: int = Body(..., description="应用程序所有者的用户 ID"), current_user: User = Security(get_current_user, scopes=["*"]),
session: AsyncSession = Depends(get_db), session: AsyncSession = Depends(get_db),
): ):
result = await session.execute( # pyright: ignore[reportDeprecated] result = await session.execute( # pyright: ignore[reportDeprecated]
@@ -40,7 +42,7 @@ async def create_oauth_app(
name=name, name=name,
description=description, description=description,
redirect_uris=redirect_uris, redirect_uris=redirect_uris,
owner_id=owner_id, owner_id=current_user.id,
) )
session.add(oauth_client) session.add(oauth_client)
await session.commit() await session.commit()
@@ -60,6 +62,7 @@ async def create_oauth_app(
async def get_oauth_app( async def get_oauth_app(
client_id: int, client_id: int,
session: AsyncSession = Depends(get_db), session: AsyncSession = Depends(get_db),
current_user: User = Security(get_current_user, scopes=["*"]),
): ):
oauth_app = await session.get(OAuthClient, client_id) oauth_app = await session.get(OAuthClient, client_id)
if not oauth_app: if not oauth_app:
@@ -73,16 +76,16 @@ async def get_oauth_app(
@router.get( @router.get(
"/oauth-apps/user/{owner_id}", "/oauth-apps",
name="获取用户的 OAuth 应用列表", name="获取用户的 OAuth 应用列表",
description="获取指定用户创建的所有 OAuth 应用程序", description="获取当前用户创建的所有 OAuth 应用程序",
) )
async def get_user_oauth_apps( async def get_user_oauth_apps(
owner_id: int,
session: AsyncSession = Depends(get_db), session: AsyncSession = Depends(get_db),
current_user: User = Security(get_current_user, scopes=["*"]),
): ):
oauth_apps = await session.exec( oauth_apps = await session.exec(
select(OAuthClient).where(OAuthClient.owner_id == owner_id) select(OAuthClient).where(OAuthClient.owner_id == current_user.id)
) )
return [ return [
{ {
@@ -104,10 +107,15 @@ async def get_user_oauth_apps(
async def delete_oauth_app( async def delete_oauth_app(
client_id: int, client_id: int,
session: AsyncSession = Depends(get_db), session: AsyncSession = Depends(get_db),
current_user: User = Security(get_current_user, scopes=["*"]),
): ):
oauth_client = await session.get(OAuthClient, client_id) oauth_client = await session.get(OAuthClient, client_id)
if not oauth_client: if not oauth_client:
raise HTTPException(status_code=404, detail="OAuth app not found") raise HTTPException(status_code=404, detail="OAuth app not found")
if oauth_client.owner_id != current_user.id:
raise HTTPException(
status_code=403, detail="Forbidden: Not the owner of this app"
)
tokens = await session.exec( tokens = await session.exec(
select(OAuthToken).where(OAuthToken.client_id == client_id) select(OAuthToken).where(OAuthToken.client_id == client_id)
@@ -130,10 +138,15 @@ async def update_oauth_app(
description: str = Body("", description="应用程序新描述"), description: str = Body("", description="应用程序新描述"),
redirect_uris: list[str] = Body(..., description="新的重定向 URI 列表"), redirect_uris: list[str] = Body(..., description="新的重定向 URI 列表"),
session: AsyncSession = Depends(get_db), session: AsyncSession = Depends(get_db),
current_user: User = Security(get_current_user, scopes=["*"]),
): ):
oauth_client = await session.get(OAuthClient, client_id) oauth_client = await session.get(OAuthClient, client_id)
if not oauth_client: if not oauth_client:
raise HTTPException(status_code=404, detail="OAuth app not found") raise HTTPException(status_code=404, detail="OAuth app not found")
if oauth_client.owner_id != current_user.id:
raise HTTPException(
status_code=403, detail="Forbidden: Not the owner of this app"
)
oauth_client.name = name oauth_client.name = name
oauth_client.description = description oauth_client.description = description
@@ -157,10 +170,15 @@ async def update_oauth_app(
async def refresh_secret( async def refresh_secret(
client_id: int, client_id: int,
session: AsyncSession = Depends(get_db), session: AsyncSession = Depends(get_db),
current_user: User = Security(get_current_user, scopes=["*"]),
): ):
oauth_client = await session.get(OAuthClient, client_id) oauth_client = await session.get(OAuthClient, client_id)
if not oauth_client: if not oauth_client:
raise HTTPException(status_code=404, detail="OAuth app not found") raise HTTPException(status_code=404, detail="OAuth app not found")
if oauth_client.owner_id != current_user.id:
raise HTTPException(
status_code=403, detail="Forbidden: Not the owner of this app"
)
oauth_client.client_secret = secrets.token_hex() oauth_client.client_secret = secrets.token_hex()
tokens = await session.exec( tokens = await session.exec(
@@ -186,7 +204,7 @@ async def refresh_secret(
) )
async def generate_oauth_code( async def generate_oauth_code(
client_id: int, client_id: int,
user_id: int = Body(..., description="授权用户的 ID"), current_user: User = Security(get_current_user, scopes=["*"]),
redirect_uri: str = Body(..., description="授权后重定向的 URI"), redirect_uri: str = Body(..., description="授权后重定向的 URI"),
scopes: list[str] = Body(..., description="请求的权限范围列表"), scopes: list[str] = Body(..., description="请求的权限范围列表"),
session: AsyncSession = Depends(get_db), session: AsyncSession = Depends(get_db),
@@ -204,7 +222,7 @@ async def generate_oauth_code(
code = secrets.token_urlsafe(80) code = secrets.token_urlsafe(80)
await redis.hset( # pyright: ignore[reportGeneralTypeIssues] await redis.hset( # pyright: ignore[reportGeneralTypeIssues]
f"oauth:code:{client_id}:{code}", f"oauth:code:{client_id}:{code}",
mapping={"user_id": user_id, "scopes": ",".join(scopes)}, mapping={"user_id": current_user.id, "scopes": ",".join(scopes)},
) )
await redis.expire(f"oauth:code:{client_id}:{code}", 300) await redis.expire(f"oauth:code:{client_id}:{code}", 300)

View File

@@ -1,40 +1,11 @@
from __future__ import annotations from __future__ import annotations
import hashlib
import hmac
import time
from app.config import settings from app.config import settings
from fastapi import APIRouter, Depends, Header, HTTPException, Request from fastapi import APIRouter
async def verify_signature(
request: Request,
ts: int = Header(..., alias="X-Timestamp"),
nonce: str = Header(..., alias="X-Nonce"),
signature: str = Header(..., alias="X-Signature"),
):
path = request.url.path
data = await request.body()
body = data.decode("utf-8")
py_ts = ts // 1000
if abs(time.time() - py_ts) > 30:
raise HTTPException(status_code=403, detail="Invalid timestamp")
payload = f"{path}|{body}|{ts}|{nonce}"
expected_sig = hmac.new(
settings.private_api_secret.encode(), payload.encode(), hashlib.sha256
).hexdigest()
if not hmac.compare_digest(expected_sig, signature):
raise HTTPException(status_code=403, detail="Invalid signature")
router = APIRouter( router = APIRouter(
prefix="/api/private", prefix="/api/private",
dependencies=[Depends(verify_signature)],
include_in_schema=settings.debug, include_in_schema=settings.debug,
tags=["私有 API"], tags=["私有 API"],
) )

View File

@@ -2,10 +2,11 @@ from __future__ import annotations
from app.database.lazer_user import User from app.database.lazer_user import User
from app.dependencies.database import get_db from app.dependencies.database import get_db
from app.dependencies.user import get_current_user
from .router import router from .router import router
from fastapi import Body, Depends, HTTPException from fastapi import Body, Depends, HTTPException, Security
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -15,10 +16,9 @@ from sqlmodel.ext.asyncio.session import AsyncSession
name="修改用户名", name="修改用户名",
) )
async def user_rename( async def user_rename(
user_id: int = Body(..., description="要修改名称的用户 ID"),
new_name: str = Body(..., description="新的用户名"), new_name: str = Body(..., description="新的用户名"),
session: AsyncSession = Depends(get_db), session: AsyncSession = Depends(get_db),
# currentUser: User = Depends(get_current_user) current_user: User = Security(get_current_user, scopes=["*"]),
): ):
"""修改用户名 """修改用户名
@@ -31,9 +31,6 @@ async def user_rename(
返回: 返回:
- 成功: None - 成功: None
""" """
current_user = (await session.exec(select(User).where(User.id == user_id))).first()
if current_user is None:
raise HTTPException(404, "User not found")
samename_user = ( samename_user = (
await session.exec(select(User).where(User.username == new_name)) await session.exec(select(User).where(User.username == new_name))
).first() ).first()

View File

@@ -45,9 +45,6 @@ desc = (
"osu! API 模拟服务器,支持 osu! API v2 和 osu!lazer 的绝大部分功能。\n\n" "osu! API 模拟服务器,支持 osu! API v2 和 osu!lazer 的绝大部分功能。\n\n"
"官方文档:[osu!web 文档](https://osu.ppy.sh/docs/index.html)" "官方文档:[osu!web 文档](https://osu.ppy.sh/docs/index.html)"
) )
if settings.debug:
desc += "\n\n私有 API 签名机制:[GitHub](https://github.com/GooGuTeam/osu_lazer_api/wiki/%E7%A7%81%E6%9C%89-API-%E7%AD%BE%E5%90%8D%E9%AA%8C%E8%AF%81%E6%9C%BA%E5%88%B6)"
app = FastAPI( app = FastAPI(
title="osu! API 模拟服务器", title="osu! API 模拟服务器",
version="1.0.0", version="1.0.0",
@@ -110,11 +107,6 @@ if settings.osu_web_client_secret == "your_osu_web_client_secret_here":
"osu_web_client_secret is unset. Your server is unsafe. " "osu_web_client_secret is unset. Your server is unsafe. "
"Use this command to generate: openssl rand -hex 40" "Use this command to generate: openssl rand -hex 40"
) )
if settings.private_api_secret == "your_private_api_secret_here":
logger.warning(
"private_api_secret is unset. Your server is unsafe. "
"Use this command to generate: openssl rand -hex 32"
)
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn