feat(user): implement user restrictions

## APIs Restricted for Restricted Users

A restricted user is blocked from performing the following actions, and will typically receive a `403 Forbidden` error:

*   **Chat & Notifications:**
    *   Sending any chat messages (public or private).
    *   Joining or leaving chat channels.
    *   Creating new PM channels.
*   **User Profile & Content:**
    *   Uploading a new avatar.
    *   Uploading a new profile cover image.
    *   Changing their username.
    *   Updating their userpage content.
*   **Scores & Gameplay:**
    *   Submitting scores in multiplayer rooms.
    *   Deleting their own scores (to prevent hiding evidence of cheating).
*   **Beatmaps:**
    *   Rating beatmaps.
    *   Taging beatmaps.
*   **Relationship:**
    *   Adding friends or blocking users.
    *   Removing friends or unblocking users.
*   **Teams:**
    *   Creating, updating, or deleting a team.
    *   Requesting to join a team.
    *   Handling join requests for a team they manage.
    *   Kicking a member from a team they manage.
*   **Multiplayer:**
    *   Creating or deleting multiplayer rooms.
    *   Joining or leaving multiplayer rooms.

## What is Invisible to Normal Users

*   **Leaderboards:**
    *   Beatmap leaderboards.
    *   Multiplayer (playlist) room leaderboards.
*   **User Search/Lists:**
    *   Restricted users will not appear in the results of the `/api/v2/users` endpoint.
    *   They will not appear in the list of a team's members.
*   **Relationship:**
    *   They will not appear in a user's friend list (`/friends`).
*   **Profile & History:**
    *   Attempting to view a restricted user's profile, events, kudosu history, or score history will result in a `404 Not Found` error, effectively making their profile invisible (unless the user viewing the profile is the restricted user themselves).
*   **Chat:**
    *   Normal users cannot start a new PM with a restricted user (they will get a `404 Not Found` error).
*   **Ranking:**
    *   Restricted users are excluded from any rankings.

### How to Restrict a User

Insert into `user_account_history` with `type=restriction`.

```sql
-- length is in seconds
INSERT INTO user_account_history (`description`, `length`, `permanent`, `timestamp`, `type`, `user_id`) VALUE ('some description', 86400, 0, '2025-10-05 01:00:00', 'RESTRICTION', 1);
```

---

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
MingxuanGame
2025-10-06 11:10:25 +08:00
committed by GitHub
parent d19f82df80
commit febc1d761f
25 changed files with 354 additions and 222 deletions

View File

@@ -99,6 +99,7 @@ USER_CACHE_CONCURRENT_LIMIT=10
# Anti-cheat Settings # Anti-cheat Settings
SUSPICIOUS_SCORE_CHECK=true SUSPICIOUS_SCORE_CHECK=true
BANNED_NAME='["mrekk", "vaxei", "btmc", "cookiezi", "peppy", "saragi", "chocomint"]' BANNED_NAME='["mrekk", "vaxei", "btmc", "cookiezi", "peppy", "saragi", "chocomint"]'
ALLOW_DELETE_SCORES=false
# Beatmap Syncing Settings # Beatmap Syncing Settings
# POST `/api/private/beatmapsets/{beatmapset_id}/sync?immediate=true` to sync a beatmapset immediately # POST `/api/private/beatmapsets/{beatmapset_id}/sync?immediate=true` to sync a beatmapset immediately

View File

@@ -601,6 +601,11 @@ STORAGE_SETTINGS='{
), ),
"反作弊设置", "反作弊设置",
] ]
allow_delete_scores: Annotated[
bool,
Field(default=False, description="允许用户删除自己的成绩"),
"反作弊设置",
]
# 存储设置 # 存储设置
storage_service: Annotated[ storage_service: Annotated[

View File

@@ -501,6 +501,7 @@ async def _score_where(
wheres: list[ColumnElement[bool] | TextClause] = [ wheres: list[ColumnElement[bool] | TextClause] = [
col(TotalScoreBestScore.beatmap_id) == beatmap, col(TotalScoreBestScore.beatmap_id) == beatmap,
col(TotalScoreBestScore.gamemode) == mode, col(TotalScoreBestScore.gamemode) == mode,
~User.is_restricted_query(col(TotalScoreBestScore.user_id)),
] ]
if type == LeaderboardType.FRIENDS: if type == LeaderboardType.FRIENDS:

View File

@@ -77,7 +77,7 @@ class UserStatistics(AsyncAttrs, UserStatisticsBase, table=True):
level_current: float = Field(default=1) level_current: float = Field(default=1)
user: "User" = Relationship(back_populates="statistics") # type: ignore[valid-type] user: "User" = Relationship(back_populates="statistics")
class UserStatisticsResp(UserStatisticsBase): class UserStatisticsResp(UserStatisticsBase):

View File

@@ -1,6 +1,6 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
import json import json
from typing import TYPE_CHECKING, Literal, NotRequired, TypedDict from typing import TYPE_CHECKING, Literal, NotRequired, TypedDict, overload
from app.config import settings from app.config import settings
from app.database.auth import TotpKeys from app.database.auth import TotpKeys
@@ -18,10 +18,11 @@ from .events import Event
from .rank_history import RankHistory, RankHistoryResp, RankTop from .rank_history import RankHistory, RankHistoryResp, RankTop
from .statistics import UserStatistics, UserStatisticsResp from .statistics import UserStatistics, UserStatisticsResp
from .team import Team, TeamMember from .team import Team, TeamMember
from .user_account_history import UserAccountHistory, UserAccountHistoryResp from .user_account_history import UserAccountHistory, UserAccountHistoryResp, UserAccountHistoryType
from pydantic import field_validator from pydantic import field_validator
from sqlalchemy.ext.asyncio import AsyncAttrs from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import Mapped
from sqlmodel import ( from sqlmodel import (
JSON, JSON,
BigInteger, BigInteger,
@@ -31,8 +32,10 @@ from sqlmodel import (
Relationship, Relationship,
SQLModel, SQLModel,
col, col,
exists,
func, func,
select, select,
text,
) )
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -88,7 +91,6 @@ class UserBase(UTCBaseModel, SQLModel):
badges: list[Badge] = Field(default_factory=list, sa_column=Column(JSON)) badges: list[Badge] = Field(default_factory=list, sa_column=Column(JSON))
# optional # optional
is_restricted: bool = False
# blocks # blocks
cover: UserProfileCover = Field( cover: UserProfileCover = Field(
default=UserProfileCover(url=""), default=UserProfileCover(url=""),
@@ -155,8 +157,8 @@ class User(AsyncAttrs, UserBase, table=True):
default=None, default=None,
sa_column=Column(BigInteger, primary_key=True, autoincrement=True, index=True), sa_column=Column(BigInteger, primary_key=True, autoincrement=True, index=True),
) )
account_history: list[UserAccountHistory] = Relationship() account_history: list[UserAccountHistory] = Relationship(back_populates="user")
statistics: list[UserStatistics] = Relationship() statistics: list[UserStatistics] = Relationship(back_populates="user")
achievement: list[UserAchievement] = Relationship(back_populates="user") achievement: list[UserAchievement] = Relationship(back_populates="user")
team_membership: TeamMember | None = Relationship(back_populates="user") team_membership: TeamMember | None = Relationship(back_populates="user")
daily_challenge_stats: DailyChallengeStats | None = Relationship(back_populates="user") daily_challenge_stats: DailyChallengeStats | None = Relationship(back_populates="user")
@@ -206,8 +208,43 @@ class User(AsyncAttrs, UserBase, table=True):
return False, "Target user has blocked you." return False, "Target user has blocked you."
if self.pm_friends_only and (not relationship or relationship.type != RelationshipType.FOLLOW): if self.pm_friends_only and (not relationship or relationship.type != RelationshipType.FOLLOW):
return False, "Target user has disabled non-friend communications" return False, "Target user has disabled non-friend communications"
if await self.is_restricted(session):
return False, "Target user is restricted"
return True, "" return True, ""
@classmethod
@overload
def is_restricted_query(cls, user_id: int): ...
@classmethod
@overload
def is_restricted_query(cls, user_id: Mapped[int]): ...
@classmethod
def is_restricted_query(cls, user_id: int | Mapped[int]):
return exists().where(
(col(UserAccountHistory.user_id) == user_id)
& (col(UserAccountHistory.type) == UserAccountHistoryType.RESTRICTION)
& (
(col(UserAccountHistory.permanent).is_(True))
| (
(
func.timestampadd(
text("SECOND"),
col(UserAccountHistory.length),
col(UserAccountHistory.timestamp),
)
> func.now()
)
& (func.now() > col(UserAccountHistory.timestamp))
)
),
)
async def is_restricted(self, session: AsyncSession) -> bool:
active_restrictions = (await session.exec(select(self.is_restricted_query(self.id)))).first()
return active_restrictions or False
class UserResp(UserBase): class UserResp(UserBase):
id: int | None = None id: int | None = None
@@ -246,6 +283,7 @@ class UserResp(UserBase):
daily_challenge_user_stats: DailyChallengeStatsResp | None = None daily_challenge_user_stats: DailyChallengeStatsResp | None = None
default_group: str = "" default_group: str = ""
is_deleted: bool = False # TODO is_deleted: bool = False # TODO
is_restricted: bool = False
# TODO: monthly_playcounts, unread_pm_count rank_history, user_preferences # TODO: monthly_playcounts, unread_pm_count rank_history, user_preferences
@@ -370,6 +408,8 @@ class UserResp(UserBase):
if rank_top if rank_top
else None else None
) )
if "is_restricted" in include:
u.is_restricted = await obj.is_restricted(session)
u.favourite_beatmapset_count = ( u.favourite_beatmapset_count = (
await session.exec( await session.exec(
@@ -468,6 +508,7 @@ ALL_INCLUDED = [
"monthly_playcounts", "monthly_playcounts",
"replays_watched_counts", "replays_watched_counts",
"rank_history", "rank_history",
"is_restricted",
"session_verified", "session_verified",
] ]

View File

@@ -1,10 +1,14 @@
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING
from app.models.model import UTCBaseModel from app.models.model import UTCBaseModel
from app.utils import utcnow from app.utils import utcnow
from sqlmodel import BigInteger, Column, Field, ForeignKey, Integer, SQLModel from sqlmodel import BigInteger, Column, Field, ForeignKey, Integer, Relationship, SQLModel
if TYPE_CHECKING:
from .user import User
class UserAccountHistoryType(str, Enum): class UserAccountHistoryType(str, Enum):
@@ -35,6 +39,8 @@ class UserAccountHistory(UserAccountHistoryBase, table=True):
) )
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True)) user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("lazer_users.id"), index=True))
user: "User" = Relationship(back_populates="account_history")
class UserAccountHistoryResp(UserAccountHistoryBase): class UserAccountHistoryResp(UserAccountHistoryBase):
id: int | None = None id: int | None = None

View File

@@ -30,6 +30,7 @@ oauth2_password = OAuth2PasswordBearer(
scopes={"*": "允许访问全部 API。"}, scopes={"*": "允许访问全部 API。"},
description="osu!lazer 或网页客户端密码登录认证,具有全部权限", description="osu!lazer 或网页客户端密码登录认证,具有全部权限",
scheme_name="Password Grant", scheme_name="Password Grant",
auto_error=False,
) )
oauth2_code = OAuth2AuthorizationCodeBearer( oauth2_code = OAuth2AuthorizationCodeBearer(
@@ -48,6 +49,7 @@ oauth2_code = OAuth2AuthorizationCodeBearer(
}, },
description="osu! OAuth 认证 (授权码认证)", description="osu! OAuth 认证 (授权码认证)",
scheme_name="Authorization Code Grant", scheme_name="Authorization Code Grant",
auto_error=False,
) )
oauth2_client_credentials = OAuth2ClientCredentialsBearer( oauth2_client_credentials = OAuth2ClientCredentialsBearer(
@@ -58,6 +60,7 @@ oauth2_client_credentials = OAuth2ClientCredentialsBearer(
}, },
description="osu! OAuth 认证 (客户端凭证流)", description="osu! OAuth 认证 (客户端凭证流)",
scheme_name="Client Credentials Grant", scheme_name="Client Credentials Grant",
auto_error=False,
) )
v1_api_key = APIKeyQuery(name="k", scheme_name="V1 API Key", description="v1 API 密钥") v1_api_key = APIKeyQuery(name="k", scheme_name="V1 API Key", description="v1 API 密钥")
@@ -78,8 +81,11 @@ async def v1_authorize(
async def get_client_user_and_token( async def get_client_user_and_token(
db: Database, db: Database,
token: Annotated[str, Depends(oauth2_password)], token: Annotated[str | None, Depends(oauth2_password)],
) -> tuple[User, OAuthToken]: ) -> tuple[User, OAuthToken]:
if token is None:
raise HTTPException(status_code=401, detail="Not authenticated")
token_record = await get_token_by_access_token(db, token) token_record = await get_token_by_access_token(db, token)
if not token_record: if not token_record:
raise HTTPException(status_code=401, detail="Invalid or expired token") raise HTTPException(status_code=401, detail="Invalid or expired token")
@@ -129,18 +135,11 @@ async def get_client_user(
return user return user
async def get_current_user_and_token( async def _validate_token(
db: Database, db: Database,
token: str,
security_scopes: SecurityScopes, security_scopes: SecurityScopes,
token_pw: Annotated[str | None, Depends(oauth2_password)] = None,
token_code: Annotated[str | None, Depends(oauth2_code)] = None,
token_client_credentials: Annotated[str | None, Depends(oauth2_client_credentials)] = None,
) -> UserAndToken: ) -> UserAndToken:
"""获取当前认证用户"""
token = token_pw or token_code or token_client_credentials
if not token:
raise HTTPException(status_code=401, detail="Not authenticated")
token_record = await get_token_by_access_token(db, token) token_record = await get_token_by_access_token(db, token)
if not token_record: if not token_record:
raise HTTPException(status_code=401, detail="Invalid or expired token") raise HTTPException(status_code=401, detail="Invalid or expired token")
@@ -161,10 +160,39 @@ async def get_current_user_and_token(
return user, token_record return user, token_record
async def get_current_user_and_token(
db: Database,
security_scopes: SecurityScopes,
token_pw: Annotated[str | None, Depends(oauth2_password)] = None,
token_code: Annotated[str | None, Depends(oauth2_code)] = None,
token_client_credentials: Annotated[str | None, Depends(oauth2_client_credentials)] = None,
) -> UserAndToken:
"""获取当前认证用户"""
token = token_pw or token_code or token_client_credentials
if not token:
raise HTTPException(status_code=401, detail="Not authenticated")
return await _validate_token(db, token, security_scopes)
async def get_current_user( async def get_current_user(
user_and_token: UserAndToken = Depends(get_current_user_and_token), user_and_token: UserAndToken = Depends(get_current_user_and_token),
) -> User: ) -> User:
return user_and_token[0] return user_and_token[0]
async def get_optional_user(
db: Database,
security_scopes: SecurityScopes,
token_pw: Annotated[str | None, Depends(oauth2_password)] = None,
token_code: Annotated[str | None, Depends(oauth2_code)] = None,
token_client_credentials: Annotated[str | None, Depends(oauth2_client_credentials)] = None,
) -> User | None:
token = token_pw or token_code or token_client_credentials
if not token:
return None
return (await _validate_token(db, token, security_scopes))[0]
ClientUser = Annotated[User, Security(get_client_user, scopes=["*"])] ClientUser = Annotated[User, Security(get_client_user, scopes=["*"])]

View File

@@ -90,6 +90,9 @@ async def join_channel(
user: Annotated[str, Path(..., description="用户 ID")], user: Annotated[str, Path(..., description="用户 ID")],
current_user: Annotated[User, Security(get_current_user, scopes=["chat.write_manage"])], current_user: Annotated[User, Security(get_current_user, scopes=["chat.write_manage"])],
): ):
if await current_user.is_restricted(session):
raise HTTPException(status_code=403, detail="You are restricted from sending messages")
# 使用明确的查询避免延迟加载 # 使用明确的查询避免延迟加载
if channel.isdigit(): if channel.isdigit():
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first() db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
@@ -114,6 +117,9 @@ async def leave_channel(
user: Annotated[str, Path(..., description="用户 ID")], user: Annotated[str, Path(..., description="用户 ID")],
current_user: Annotated[User, Security(get_current_user, scopes=["chat.write_manage"])], current_user: Annotated[User, Security(get_current_user, scopes=["chat.write_manage"])],
): ):
if await current_user.is_restricted(session):
raise HTTPException(status_code=403, detail="You are restricted from sending messages")
# 使用明确的查询避免延迟加载 # 使用明确的查询避免延迟加载
if channel.isdigit(): if channel.isdigit():
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first() db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
@@ -198,7 +204,7 @@ async def get_channel(
if int(id_) == current_user.id: if int(id_) == current_user.id:
continue continue
target_user = await session.get(User, int(id_)) target_user = await session.get(User, int(id_))
if target_user is None: if target_user is None or await target_user.is_restricted(session):
raise HTTPException(status_code=404, detail="Target user not found") raise HTTPException(status_code=404, detail="Target user not found")
users.extend([target_user, current_user]) users.extend([target_user, current_user])
break break
@@ -249,9 +255,12 @@ async def create_channel(
current_user: Annotated[User, Security(get_current_user, scopes=["chat.write_manage"])], current_user: Annotated[User, Security(get_current_user, scopes=["chat.write_manage"])],
redis: Redis, redis: Redis,
): ):
if await current_user.is_restricted(session):
raise HTTPException(status_code=403, detail="You are restricted from sending messages")
if req.type == "PM": if req.type == "PM":
target = await session.get(User, req.target_id) target = await session.get(User, req.target_id)
if not target: if not target or await target.is_restricted(session):
raise HTTPException(status_code=404, detail="Target user not found") raise HTTPException(status_code=404, detail="Target user not found")
is_can_pm, block = await target.is_user_can_pm(current_user, session) is_can_pm, block = await target.is_user_can_pm(current_user, session)
if not is_can_pm: if not is_can_pm:

View File

@@ -11,7 +11,7 @@ from app.database.chat import (
UserSilenceResp, UserSilenceResp,
) )
from app.database.user import User from app.database.user import User
from app.dependencies.database import Database, Redis from app.dependencies.database import Database, Redis, redis_message_client
from app.dependencies.param import BodyOrForm from app.dependencies.param import BodyOrForm
from app.dependencies.user import get_current_user from app.dependencies.user import get_current_user
from app.log import log from app.log import log
@@ -79,6 +79,9 @@ async def send_message(
req: Annotated[MessageReq, Depends(BodyOrForm(MessageReq))], req: Annotated[MessageReq, Depends(BodyOrForm(MessageReq))],
current_user: Annotated[User, Security(get_current_user, scopes=["chat.write"])], current_user: Annotated[User, Security(get_current_user, scopes=["chat.write"])],
): ):
if await current_user.is_restricted(session):
raise HTTPException(status_code=403, detail="You are restricted from sending messages")
# 使用明确的查询来获取 channel避免延迟加载 # 使用明确的查询来获取 channel避免延迟加载
if channel.isdigit(): if channel.isdigit():
db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first() db_channel = (await session.exec(select(ChatChannel).where(ChatChannel.channel_id == int(channel)))).first()
@@ -97,9 +100,7 @@ async def send_message(
# 对于多人游戏房间在发送消息前进行Redis键检查 # 对于多人游戏房间在发送消息前进行Redis键检查
if channel_type == ChannelType.MULTIPLAYER: if channel_type == ChannelType.MULTIPLAYER:
try: try:
from app.dependencies.database import get_redis redis = redis_message_client
redis = get_redis()
key = f"channel:{channel_id}:messages" key = f"channel:{channel_id}:messages"
key_type = await redis.type(key) key_type = await redis.type(key)
if key_type not in ["none", "zset"]: if key_type not in ["none", "zset"]:
@@ -265,9 +266,12 @@ async def create_new_pm(
current_user: Annotated[User, Security(get_current_user, scopes=["chat.write"])], current_user: Annotated[User, Security(get_current_user, scopes=["chat.write"])],
redis: Redis, redis: Redis,
): ):
if await current_user.is_restricted(session):
raise HTTPException(status_code=403, detail="You are restricted from sending messages")
user_id = current_user.id user_id = current_user.id
target = await session.get(User, req.target_id) target = await session.get(User, req.target_id)
if target is None: if target is None or await target.is_restricted(session):
raise HTTPException(status_code=404, detail="Target user not found") raise HTTPException(status_code=404, detail="Target user not found")
is_can_pm, block = await target.is_user_can_pm(current_user, session) is_can_pm, block = await target.is_user_can_pm(current_user, session)
if not is_can_pm: if not is_can_pm:

View File

@@ -8,7 +8,7 @@ from app.utils import check_image
from .router import router from .router import router
from fastapi import File from fastapi import File, HTTPException
@router.post("/avatar/upload", name="上传头像", tags=["用户", "g0v0 API"]) @router.post("/avatar/upload", name="上传头像", tags=["用户", "g0v0 API"])
@@ -30,6 +30,8 @@ async def upload_avatar(
返回: 返回:
- 头像 URL 和文件哈希值 - 头像 URL 和文件哈希值
""" """
if await current_user.is_restricted(session):
raise HTTPException(status_code=403, detail="Your account is restricted and cannot perform this action.")
# check file # check file
format_ = check_image(content, 5 * 1024 * 1024, 256, 256) format_ = check_image(content, 5 * 1024 * 1024, 256, 256)

View File

@@ -38,6 +38,9 @@ async def can_rate_beatmapset(
返回: 返回:
- bool: 用户是否可以评价谱面集 - bool: 用户是否可以评价谱面集
""" """
if await current_user.is_restricted(session):
return False
user_id = current_user.id user_id = current_user.id
prev_ratings = (await session.exec(select(BeatmapRating).where(BeatmapRating.user_id == user_id))).first() prev_ratings = (await session.exec(select(BeatmapRating).where(BeatmapRating.user_id == user_id))).first()
if prev_ratings is not None: if prev_ratings is not None:
@@ -73,6 +76,9 @@ async def rate_beatmaps(
返回: 返回:
- 成功: None - 成功: None
""" """
if await current_user.is_restricted(session):
raise HTTPException(status_code=403, detail="Your account is restricted and cannot perform this action.")
user_id = current_user.id user_id = current_user.id
current_beatmapset = (await session.exec(select(exists()).where(Beatmapset.id == beatmapset_id))).first() current_beatmapset = (await session.exec(select(exists()).where(Beatmapset.id == beatmapset_id))).first()
if not current_beatmapset: if not current_beatmapset:

View File

@@ -9,7 +9,7 @@ from app.utils import check_image
from .router import router from .router import router
from fastapi import File from fastapi import File, HTTPException
@router.post("/cover/upload", name="上传头图", tags=["用户", "g0v0 API"]) @router.post("/cover/upload", name="上传头图", tags=["用户", "g0v0 API"])
@@ -31,6 +31,8 @@ async def upload_cover(
返回: 返回:
- 头图 URL 和文件哈希值 - 头图 URL 和文件哈希值
""" """
if await current_user.is_restricted(session):
raise HTTPException(status_code=403, detail="Your account is restricted and cannot perform this action.")
# check file # check file
format_ = check_image(content, 10 * 1024 * 1024, 3000, 2000) format_ = check_image(content, 10 * 1024 * 1024, 3000, 2000)

View File

@@ -1,3 +1,4 @@
from app.config import settings
from app.database.score import Score from app.database.score import Score
from app.dependencies.database import Database, Redis from app.dependencies.database import Database, Redis
from app.dependencies.storage import StorageService from app.dependencies.storage import StorageService
@@ -8,37 +9,42 @@ from .router import router
from fastapi import BackgroundTasks, HTTPException from fastapi import BackgroundTasks, HTTPException
if settings.allow_delete_scores:
@router.delete( @router.delete(
"/score/{score_id}", "/score/{score_id}",
name="删除指定ID的成绩", name="删除指定ID的成绩",
tags=["成绩", "g0v0 API"], tags=["成绩", "g0v0 API"],
status_code=204, status_code=204,
) )
async def delete_score( async def delete_score(
session: Database, session: Database,
background_task: BackgroundTasks, background_task: BackgroundTasks,
score_id: int, score_id: int,
redis: Redis, redis: Redis,
current_user: ClientUser, current_user: ClientUser,
storage_service: StorageService, storage_service: StorageService,
): ):
"""删除成绩 """删除成绩
删除成绩同时删除对应的统计信息、排行榜分数、pp、回放文件 删除成绩同时删除对应的统计信息、排行榜分数、pp、回放文件
参数: 参数:
- score_id: 成绩ID - score_id: 成绩ID
错误情况: 错误情况:
- 404: 找不到指定成绩 - 404: 找不到指定成绩
""" """
score = await session.get(Score, score_id) if await current_user.is_restricted(session):
if not score or score.user_id != current_user.id: # avoid deleting the evidence of cheating
raise HTTPException(status_code=404, detail="找不到指定成绩") raise HTTPException(status_code=403, detail="Your account is restricted and cannot perform this action.")
gamemode = score.gamemode score = await session.get(Score, score_id)
user_id = score.user_id if not score or score.user_id != current_user.id:
await score.delete(session, storage_service) raise HTTPException(status_code=404, detail="找不到指定成绩")
await session.commit()
background_task.add_task(refresh_user_cache_background, redis, user_id, gamemode) gamemode = score.gamemode
user_id = score.user_id
await score.delete(session, storage_service)
await session.commit()
background_task.add_task(refresh_user_cache_background, redis, user_id, gamemode)

View File

@@ -19,7 +19,7 @@ from .router import router
from fastapi import File, Form, HTTPException, Path, Request from fastapi import File, Form, HTTPException, Path, Request
from pydantic import BaseModel from pydantic import BaseModel
from sqlmodel import exists, select from sqlmodel import col, exists, select
@router.post("/team", name="创建战队", response_model=Team, tags=["战队", "g0v0 API"]) @router.post("/team", name="创建战队", response_model=Team, tags=["战队", "g0v0 API"])
@@ -38,6 +38,9 @@ async def create_team(
flag 限制 240x120, 2MB; cover 限制 3000x2000, 10MB flag 限制 240x120, 2MB; cover 限制 3000x2000, 10MB
支持的图片格式: PNG、JPEG、GIF 支持的图片格式: PNG、JPEG、GIF
""" """
if await current_user.is_restricted(session):
raise HTTPException(status_code=403, detail="Your account is restricted and cannot perform this action.")
user_id = current_user.id user_id = current_user.id
if (await current_user.awaitable_attrs.team_membership) is not None: if (await current_user.awaitable_attrs.team_membership) is not None:
raise HTTPException(status_code=403, detail="You are already in a team") raise HTTPException(status_code=403, detail="You are already in a team")
@@ -98,6 +101,9 @@ async def update_team(
flag 限制 240x120, 2MB; cover 限制 3000x2000, 10MB flag 限制 240x120, 2MB; cover 限制 3000x2000, 10MB
支持的图片格式: PNG、JPEG、GIF 支持的图片格式: PNG、JPEG、GIF
""" """
if await current_user.is_restricted(session):
raise HTTPException(status_code=403, detail="Your account is restricted and cannot perform this action.")
team = await session.get(Team, team_id) team = await session.get(Team, team_id)
user_id = current_user.id user_id = current_user.id
if not team: if not team:
@@ -162,6 +168,9 @@ async def delete_team(
current_user: ClientUser, current_user: ClientUser,
redis: Redis, redis: Redis,
): ):
if await current_user.is_restricted(session):
raise HTTPException(status_code=403, detail="Your account is restricted and cannot perform this action.")
team = await session.get(Team, team_id) team = await session.get(Team, team_id)
if not team: if not team:
raise HTTPException(status_code=404, detail="Team not found") raise HTTPException(status_code=404, detail="Team not found")
@@ -190,7 +199,14 @@ async def get_team(
session: Database, session: Database,
team_id: Annotated[int, Path(..., description="战队 ID")], team_id: Annotated[int, Path(..., description="战队 ID")],
): ):
members = (await session.exec(select(TeamMember).where(TeamMember.team_id == team_id))).all() members = (
await session.exec(
select(TeamMember).where(
TeamMember.team_id == team_id,
~User.is_restricted_query(col(TeamMember.user_id)),
)
)
).all()
return TeamQueryResp( return TeamQueryResp(
team=members[0].team, team=members[0].team,
members=[await UserResp.from_db(m.user, session, include=BASE_INCLUDES) for m in members], members=[await UserResp.from_db(m.user, session, include=BASE_INCLUDES) for m in members],
@@ -203,6 +219,9 @@ async def request_join_team(
team_id: Annotated[int, Path(..., description="战队 ID")], team_id: Annotated[int, Path(..., description="战队 ID")],
current_user: ClientUser, current_user: ClientUser,
): ):
if await current_user.is_restricted(session):
raise HTTPException(status_code=403, detail="Your account is restricted and cannot perform this action.")
team = await session.get(Team, team_id) team = await session.get(Team, team_id)
if not team: if not team:
raise HTTPException(status_code=404, detail="Team not found") raise HTTPException(status_code=404, detail="Team not found")
@@ -233,6 +252,9 @@ async def handle_request(
current_user: ClientUser, current_user: ClientUser,
redis: Redis, redis: Redis,
): ):
if await current_user.is_restricted(session):
raise HTTPException(status_code=403, detail="Your account is restricted and cannot perform this action.")
team = await session.get(Team, team_id) team = await session.get(Team, team_id)
if not team: if not team:
raise HTTPException(status_code=404, detail="Team not found") raise HTTPException(status_code=404, detail="Team not found")
@@ -274,6 +296,9 @@ async def kick_member(
current_user: ClientUser, current_user: ClientUser,
redis: Redis, redis: Redis,
): ):
if await current_user.is_restricted(session):
raise HTTPException(status_code=403, detail="Your account is restricted and cannot perform this action.")
team = await session.get(Team, team_id) team = await session.get(Team, team_id)
if not team: if not team:
raise HTTPException(status_code=404, detail="Team not found") raise HTTPException(status_code=404, detail="Team not found")

View File

@@ -40,9 +40,13 @@ async def user_rename(
返回: 返回:
- 成功: None - 成功: None
""" """
if await current_user.is_restricted(session):
# https://github.com/ppy/osu-web/blob/cae2fdf03cfb8c30c8e332cfb142e03188ceffef/app/Libraries/ChangeUsername.php#L48-L49
raise HTTPException(403, "Your account is restricted and cannot perform this action.")
samename_user = (await session.exec(select(exists()).where(User.username == new_name))).first() samename_user = (await session.exec(select(exists()).where(User.username == new_name))).first()
if samename_user: if samename_user:
raise HTTPException(409, "Username Exisits") raise HTTPException(409, "Username Exists")
errors = validate_username(new_name) errors = validate_username(new_name)
if errors: if errors:
raise HTTPException(403, "\n".join(errors)) raise HTTPException(403, "\n".join(errors))
@@ -80,6 +84,8 @@ async def update_userpage(
current_user: ClientUser, current_user: ClientUser,
): ):
"""更新用户页面内容""" """更新用户页面内容"""
if await current_user.is_restricted(session):
raise HTTPException(403, "Your account is restricted and cannot perform this action.")
try: try:
# 处理BBCode内容 # 处理BBCode内容

View File

@@ -3,6 +3,7 @@ from typing import Annotated, Literal
from app.database.best_scores import BestScore from app.database.best_scores import BestScore
from app.database.score import Score, get_leaderboard from app.database.score import Score, get_leaderboard
from app.database.user import User
from app.dependencies.database import Database from app.dependencies.database import Database
from app.models.mods import int_to_mods, mod_to_save, mods_to_int from app.models.mods import int_to_mods, mod_to_save, mods_to_int
from app.models.score import GameMode, LeaderboardType from app.models.score import GameMode, LeaderboardType
@@ -80,6 +81,7 @@ async def get_user_best(
Score.user_id == user if type == "id" or user.isdigit() else col(Score.user).has(username=user), Score.user_id == user if type == "id" or user.isdigit() else col(Score.user).has(username=user),
Score.gamemode == GameMode.from_int_extra(ruleset_id), Score.gamemode == GameMode.from_int_extra(ruleset_id),
exists().where(col(BestScore.score_id) == Score.id), exists().where(col(BestScore.score_id) == Score.id),
~User.is_restricted_query(col(Score.user_id)),
) )
.order_by(col(Score.pp).desc()) .order_by(col(Score.pp).desc())
.options(joinedload(Score.beatmap)) .options(joinedload(Score.beatmap))
@@ -112,6 +114,7 @@ async def get_user_recent(
Score.user_id == user if type == "id" or user.isdigit() else col(Score.user).has(username=user), Score.user_id == user if type == "id" or user.isdigit() else col(Score.user).has(username=user),
Score.gamemode == GameMode.from_int_extra(ruleset_id), Score.gamemode == GameMode.from_int_extra(ruleset_id),
Score.ended_at > utcnow() - timedelta(hours=24), Score.ended_at > utcnow() - timedelta(hours=24),
~User.is_restricted_query(col(Score.user_id)),
) )
.order_by(col(Score.pp).desc()) .order_by(col(Score.pp).desc())
.options(joinedload(Score.beatmap)) .options(joinedload(Score.beatmap))
@@ -147,6 +150,7 @@ async def get_scores(
Score.gamemode == GameMode.from_int_extra(ruleset_id), Score.gamemode == GameMode.from_int_extra(ruleset_id),
Score.beatmap_id == beatmap_id, Score.beatmap_id == beatmap_id,
Score.user_id == user if type == "id" or user.isdigit() else col(Score.user).has(username=user), Score.user_id == user if type == "id" or user.isdigit() else col(Score.user).has(username=user),
~User.is_restricted_query(col(Score.user_id)),
) )
.options(joinedload(Score.beatmap)) .options(joinedload(Score.beatmap))
.order_by(col(Score.classic_total_score).desc()) .order_by(col(Score.classic_total_score).desc())

View File

@@ -6,18 +6,12 @@ from app.database.user import User
from app.dependencies.database import Database, get_redis from app.dependencies.database import Database, get_redis
from app.log import logger from app.log import logger
from app.models.score import GameMode from app.models.score import GameMode
from app.models.v1_user import (
PlayerEventItem,
PlayerInfo,
PlayerModeStats,
PlayerStatsHistory,
)
from app.service.user_cache_service import get_user_cache_service from app.service.user_cache_service import get_user_cache_service
from .router import AllStrModel, router from .router import AllStrModel, router
from fastapi import BackgroundTasks, HTTPException, Query from fastapi import BackgroundTasks, HTTPException, Query
from sqlmodel import select from sqlmodel import col, select
class V1User(AllStrModel): class V1User(AllStrModel):
@@ -53,10 +47,6 @@ class V1User(AllStrModel):
@classmethod @classmethod
async def from_db(cls, session: Database, db_user: User, ruleset: GameMode | None = None) -> "V1User": async def from_db(cls, session: Database, db_user: User, ruleset: GameMode | None = None) -> "V1User":
# 确保 user_id 不为 None
if db_user.id is None:
raise ValueError("User ID cannot be None")
ruleset = ruleset or db_user.playmode ruleset = ruleset or db_user.playmode
current_statistics: UserStatistics | None = None current_statistics: UserStatistics | None = None
for i in await db_user.awaitable_attrs.statistics: for i in await db_user.awaitable_attrs.statistics:
@@ -134,6 +124,7 @@ async def get_user(
await session.exec( await session.exec(
select(User).where( select(User).where(
User.id == user if is_id_query else User.username == user, User.id == user if is_id_query else User.username == user,
~User.is_restricted_query(col(User.id)),
) )
) )
).first() ).first()
@@ -168,7 +159,11 @@ async def _get_pp_history_for_mode(session: Database, user_id: int, mode: GameMo
# 获取最近 30 天的排名历史(由于没有 PP 历史,我们使用当前的 PP 填充) # 获取最近 30 天的排名历史(由于没有 PP 历史,我们使用当前的 PP 填充)
stats = ( stats = (
await session.exec( await session.exec(
select(UserStatistics).where(UserStatistics.user_id == user_id, UserStatistics.mode == mode) select(UserStatistics).where(
UserStatistics.user_id == user_id,
UserStatistics.mode == mode,
~User.is_restricted_query(col(UserStatistics.user_id)),
)
) )
).first() ).first()
@@ -178,128 +173,3 @@ async def _get_pp_history_for_mode(session: Database, user_id: int, mode: GameMo
except Exception as e: except Exception as e:
logger.error(f"Error getting PP history for user {user_id}, mode {mode}: {e}") logger.error(f"Error getting PP history for user {user_id}, mode {mode}: {e}")
return [0.0] * days return [0.0] * days
async def _create_player_mode_stats(
session: Database, user: User, mode: GameMode, user_statistics: list[UserStatistics]
) -> PlayerModeStats:
"""创建单个模式的玩家统计数据"""
# 查找对应模式的统计数据
stats = None
for stat in user_statistics:
if stat.mode == mode:
stats = stat
break
if not stats:
# 如果没有统计数据,创建默认数据
pp_history = [0.0] * 30
return PlayerModeStats(
id=user.id,
mode=int(mode),
tscore=0,
rscore=0,
pp=0.0,
plays=0,
playtime=0,
acc=0.0,
max_combo=0,
total_hits=0,
replay_views=0,
xh_count=0,
x_count=0,
sh_count=0,
s_count=0,
a_count=0,
level=1,
level_progress=0,
rank=0,
country_rank=0,
history=PlayerStatsHistory(pp=pp_history),
)
# 获取排名信息
try:
from app.database.statistics import get_rank
global_rank = await get_rank(session, stats) or 0
country_rank = await get_rank(session, stats, user.country_code) or 0
except Exception as e:
logger.error(f"Error getting rank for user {user.id}: {e}")
global_rank = 0
country_rank = 0
# 获取 PP 历史
pp_history = await _get_pp_history_for_mode(session, user.id, mode)
# 计算等级进度
level_current = int(stats.level_current)
level_progress = int((stats.level_current - level_current) * 100)
return PlayerModeStats(
id=user.id,
mode=int(mode),
tscore=stats.total_score,
rscore=stats.ranked_score,
pp=stats.pp,
plays=stats.play_count,
playtime=stats.play_time,
acc=stats.hit_accuracy,
max_combo=stats.maximum_combo,
total_hits=stats.total_hits,
replay_views=stats.replays_watched_by_others,
xh_count=stats.grade_ssh,
x_count=stats.grade_ss,
sh_count=stats.grade_sh,
s_count=stats.grade_s,
a_count=stats.grade_a,
level=level_current,
level_progress=level_progress,
rank=global_rank,
country_rank=country_rank,
history=PlayerStatsHistory(pp=pp_history),
)
async def _get_player_events(session: Database, user_id: int, event_days: int = 1) -> list[PlayerEventItem]:
"""获取玩家事件"""
try:
# 这里暂时返回空列表,因为事件系统需要更多的实现
# TODO: 实现真正的事件查询
return []
except Exception as e:
logger.error(f"Error getting events for user {user_id}: {e}")
return []
async def _create_player_info(user: User) -> PlayerInfo:
"""创建玩家基本信息"""
return PlayerInfo(
id=user.id,
name=user.username,
safe_name=user.username.lower(), # 使用 username 转小写作为 safe_name
priv=user.priv,
country=user.country_code,
silence_end=int(user.silence_end_at.timestamp()) if user.silence_end_at else 0,
donor_end=int(user.donor_end_at.timestamp()) if user.donor_end_at else 0,
creation_time=int(user.join_date.timestamp()),
latest_activity=int(user.last_visit.timestamp()) if user.last_visit else 0,
clan_id=0, # TODO: 实现战队系统
clan_priv=0,
preferred_mode=int(user.playmode),
preferred_type=0,
play_style=0,
custom_badge_enabled=0,
custom_badge_name="",
custom_badge_icon="",
custom_badge_color="white",
userpage_content=user.page.get("html", "") if user.page else "",
recentFailed=0,
social_discord=user.discord,
social_youtube=None,
social_twitter=user.twitter,
social_twitch=None,
social_github=None,
social_osu=None,
username_history=user.previous_usernames or [],
)

View File

@@ -99,6 +99,7 @@ async def get_team_ranking(
UserStatistics.mode == ruleset, UserStatistics.mode == ruleset,
UserStatistics.pp > 0, UserStatistics.pp > 0,
col(UserStatistics.user).has(col(User.team_membership).has(col(TeamMember.team_id) == team.id)), col(UserStatistics.user).has(col(User.team_membership).has(col(TeamMember.team_id) == team.id)),
~User.is_restricted_query(col(UserStatistics.user_id)),
) )
) )
).all() ).all()
@@ -249,6 +250,7 @@ async def get_country_ranking(
UserStatistics.pp > 0, UserStatistics.pp > 0,
col(UserStatistics.user).has(country_code=country), col(UserStatistics.user).has(country_code=country),
col(UserStatistics.user).has(is_active=True), col(UserStatistics.user).has(is_active=True),
~User.is_restricted_query(col(UserStatistics.user_id)),
) )
) )
).all() ).all()
@@ -363,7 +365,14 @@ async def get_user_ranking(
total_count = total_count_result.one() total_count = total_count_result.one()
statistics_list = await session.exec( statistics_list = await session.exec(
select(UserStatistics).where(*wheres).order_by(order_by).limit(50).offset(50 * (page - 1)) select(UserStatistics)
.where(
*wheres,
~User.is_restricted_query(col(UserStatistics.user_id)),
)
.order_by(order_by)
.limit(50)
.offset(50 * (page - 1))
) )
# 转换为响应格式 # 转换为响应格式

View File

@@ -10,7 +10,7 @@ from .router import router
from fastapi import HTTPException, Path, Query, Request, Security from fastapi import HTTPException, Path, Query, Request, Security
from pydantic import BaseModel from pydantic import BaseModel
from sqlmodel import exists, select from sqlmodel import col, exists, select
@router.get( @router.get(
@@ -63,6 +63,7 @@ async def get_relationship(
select(Relationship).where( select(Relationship).where(
Relationship.user_id == current_user.id, Relationship.user_id == current_user.id,
Relationship.type == relationship_type, Relationship.type == relationship_type,
~User.is_restricted_query(col(Relationship.target_id)),
) )
) )
if api_version >= 20241022 or relationship_type == RelationshipType.BLOCK: if api_version >= 20241022 or relationship_type == RelationshipType.BLOCK:
@@ -110,7 +111,11 @@ async def add_relationship(
target: Annotated[int, Query(description="目标用户 ID")], target: Annotated[int, Query(description="目标用户 ID")],
current_user: ClientUser, current_user: ClientUser,
): ):
if not (await db.exec(select(exists()).where(User.id == target))).first(): if await current_user.is_restricted(db):
raise HTTPException(403, "Your account is restricted and cannot perform this action.")
if not (
await db.exec(select(exists()).where((User.id == target) & ~User.is_restricted_query(col(User.id))))
).first():
raise HTTPException(404, "Target user not found") raise HTTPException(404, "Target user not found")
relationship_type = RelationshipType.FOLLOW if request.url.path.endswith("/friends") else RelationshipType.BLOCK relationship_type = RelationshipType.FOLLOW if request.url.path.endswith("/friends") else RelationshipType.BLOCK
@@ -179,7 +184,11 @@ async def delete_relationship(
target: Annotated[int, Path(..., description="目标用户 ID")], target: Annotated[int, Path(..., description="目标用户 ID")],
current_user: ClientUser, current_user: ClientUser,
): ):
if not (await db.exec(select(exists()).where(User.id == target))).first(): if await current_user.is_restricted(db):
raise HTTPException(403, "Your account is restricted and cannot perform this action.")
if not (
await db.exec(select(exists()).where((User.id == target) & ~User.is_restricted_query(col(User.id))))
).first():
raise HTTPException(404, "Target user not found") raise HTTPException(404, "Target user not found")
relationship_type = RelationshipType.BLOCK if "/blocks/" in request.url.path else RelationshipType.FOLLOW relationship_type = RelationshipType.BLOCK if "/blocks/" in request.url.path else RelationshipType.FOLLOW

View File

@@ -143,6 +143,9 @@ async def create_room(
current_user: ClientUser, current_user: ClientUser,
redis: Redis, redis: Redis,
): ):
if await current_user.is_restricted(db):
raise HTTPException(status_code=403, detail="Your account is restricted from multiplayer.")
user_id = current_user.id user_id = current_user.id
db_room = await create_playlist_room_from_api(db, room, user_id) db_room = await create_playlist_room_from_api(db, room, user_id)
await _participate_room(db_room.id, user_id, db_room, db, redis) await _participate_room(db_room.id, user_id, db_room, db, redis)
@@ -189,6 +192,9 @@ async def delete_room(
room_id: Annotated[int, Path(..., description="房间 ID")], room_id: Annotated[int, Path(..., description="房间 ID")],
current_user: ClientUser, current_user: ClientUser,
): ):
if await current_user.is_restricted(db):
raise HTTPException(status_code=403, detail="Your account is restricted from multiplayer.")
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first() db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if db_room is None: if db_room is None:
raise HTTPException(404, "Room not found") raise HTTPException(404, "Room not found")
@@ -211,6 +217,9 @@ async def add_user_to_room(
redis: Redis, redis: Redis,
current_user: ClientUser, current_user: ClientUser,
): ):
if await current_user.is_restricted(db):
raise HTTPException(status_code=403, detail="Your account is restricted from multiplayer.")
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first() db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if db_room is not None: if db_room is not None:
await _participate_room(room_id, user_id, db_room, db, redis) await _participate_room(room_id, user_id, db_room, db, redis)
@@ -235,6 +244,9 @@ async def remove_user_from_room(
current_user: ClientUser, current_user: ClientUser,
redis: Redis, redis: Redis,
): ):
if await current_user.is_restricted(db):
raise HTTPException(status_code=403, detail="Your account is restricted from multiplayer.")
db_room = (await db.exec(select(Room).where(Room.id == room_id))).first() db_room = (await db.exec(select(Room).where(Room.id == room_id))).first()
if db_room is not None: if db_room is not None:
participated_user = ( participated_user = (

View File

@@ -255,9 +255,6 @@ async def get_beatmap_scores(
] = LeaderboardType.GLOBAL, ] = LeaderboardType.GLOBAL,
limit: Annotated[int, Query(ge=1, le=200, description="返回条数 (1-200)")] = 50, limit: Annotated[int, Query(ge=1, le=200, description="返回条数 (1-200)")] = 50,
): ):
if legacy_only:
raise HTTPException(status_code=404, detail="this server only contains lazer scores")
all_scores, user_score, count = await get_leaderboard( all_scores, user_score, count = await get_leaderboard(
db, db,
beatmap_id, beatmap_id,
@@ -355,6 +352,7 @@ async def get_user_all_beatmap_scores(
Score.beatmap_id == beatmap_id, Score.beatmap_id == beatmap_id,
Score.user_id == user_id, Score.user_id == user_id,
col(Score.passed).is_(True), col(Score.passed).is_(True),
~User.is_restricted_query(col(Score.user_id)),
) )
.order_by(col(Score.total_score).desc()) .order_by(col(Score.total_score).desc())
) )
@@ -433,7 +431,9 @@ async def create_playlist_score(
current_user: ClientUser, current_user: ClientUser,
version_hash: Annotated[str, Form(description="谱面版本哈希")] = "", version_hash: Annotated[str, Form(description="谱面版本哈希")] = "",
): ):
# 立即获取用户ID避免懒加载问题 if await current_user.is_restricted(session):
raise HTTPException(status_code=403, detail="You are restricted from submitting multiplayer scores")
user_id = current_user.id user_id = current_user.id
room = await session.get(Room, room_id) room = await session.get(Room, room_id)
@@ -499,7 +499,9 @@ async def submit_playlist_score(
redis: Redis, redis: Redis,
fetcher: Fetcher, fetcher: Fetcher,
): ):
# 立即获取用户ID避免懒加载问题 if await current_user.is_restricted(session):
raise HTTPException(status_code=403, detail="You are restricted from submitting multiplayer scores")
user_id = current_user.id user_id = current_user.id
item = (await session.exec(select(Playlist).where(Playlist.id == playlist_id, Playlist.room_id == room_id))).first() item = (await session.exec(select(Playlist).where(Playlist.id == playlist_id, Playlist.room_id == room_id))).first()
@@ -574,6 +576,7 @@ async def index_playlist_scores(
PlaylistBestScore.playlist_id == playlist_id, PlaylistBestScore.playlist_id == playlist_id,
PlaylistBestScore.room_id == room_id, PlaylistBestScore.room_id == room_id,
PlaylistBestScore.total_score < cursor, PlaylistBestScore.total_score < cursor,
~User.is_restricted_query(col(PlaylistBestScore.user_id)),
) )
.order_by(col(PlaylistBestScore.total_score).desc()) .order_by(col(PlaylistBestScore.total_score).desc())
.limit(limit + 1) .limit(limit + 1)
@@ -641,6 +644,7 @@ async def show_playlist_score(
PlaylistBestScore.score_id == score_id, PlaylistBestScore.score_id == score_id,
PlaylistBestScore.playlist_id == playlist_id, PlaylistBestScore.playlist_id == playlist_id,
PlaylistBestScore.room_id == room_id, PlaylistBestScore.room_id == room_id,
~User.is_restricted_query(col(PlaylistBestScore.user_id)),
) )
) )
).first() ).first()
@@ -658,6 +662,7 @@ async def show_playlist_score(
select(PlaylistBestScore).where( select(PlaylistBestScore).where(
PlaylistBestScore.playlist_id == playlist_id, PlaylistBestScore.playlist_id == playlist_id,
PlaylistBestScore.room_id == room_id, PlaylistBestScore.room_id == room_id,
~User.is_restricted_query(col(PlaylistBestScore.user_id)),
) )
) )
).all() ).all()
@@ -702,6 +707,7 @@ async def get_user_playlist_score(
PlaylistBestScore.user_id == user_id, PlaylistBestScore.user_id == user_id,
PlaylistBestScore.playlist_id == playlist_id, PlaylistBestScore.playlist_id == playlist_id,
PlaylistBestScore.room_id == room_id, PlaylistBestScore.room_id == room_id,
~User.is_restricted_query(col(PlaylistBestScore.user_id)),
) )
) )
).first() ).first()

View File

@@ -58,6 +58,9 @@ async def vote_beatmap_tags(
session: Database, session: Database,
current_user: Annotated[User, Depends(get_client_user)], current_user: Annotated[User, Depends(get_client_user)],
): ):
if await current_user.is_restricted(session):
raise HTTPException(status_code=403, detail="Your account is restricted and cannot perform this action.")
try: try:
get_tag_by_id(tag_id) get_tag_by_id(tag_id)
beatmap = (await session.exec(select(exists()).where(Beatmap.id == beatmap_id))).first() beatmap = (await session.exec(select(exists()).where(Beatmap.id == beatmap_id))).first()
@@ -98,6 +101,9 @@ async def devote_beatmap_tags(
- **beatmap_id**: 谱面ID - **beatmap_id**: 谱面ID
- **tag_id**: 标签ID - **tag_id**: 标签ID
""" """
if await current_user.is_restricted(session):
raise HTTPException(status_code=403, detail="Your account is restricted and cannot perform this action.")
try: try:
tag = get_tag_by_id(tag_id) tag = get_tag_by_id(tag_id)
assert tag is not None assert tag is not None

View File

@@ -15,10 +15,10 @@ from app.database import (
from app.database.best_scores import BestScore from app.database.best_scores import BestScore
from app.database.events import Event from app.database.events import Event
from app.database.score import LegacyScoreResp, Score, ScoreResp, get_user_first_scores from app.database.score import LegacyScoreResp, Score, ScoreResp, get_user_first_scores
from app.database.user import SEARCH_INCLUDED from app.database.user import ALL_INCLUDED, SEARCH_INCLUDED
from app.dependencies.api_version import APIVersion from app.dependencies.api_version import APIVersion
from app.dependencies.database import Database, get_redis from app.dependencies.database import Database, get_redis
from app.dependencies.user import get_current_user from app.dependencies.user import get_current_user, get_optional_user
from app.helpers.asset_proxy_helper import asset_proxy_response from app.helpers.asset_proxy_helper import asset_proxy_response
from app.log import log from app.log import log
from app.models.mods import API_MODS from app.models.mods import API_MODS
@@ -52,6 +52,14 @@ def _get_difficulty_reduction_mods() -> set[str]:
return mods return mods
async def visible_to_current_user(user: User, current_user: User | None, session: Database) -> bool:
if user.id == BANCHOBOT_ID:
return False
if current_user and current_user.id == user.id:
return True
return not await user.is_restricted(session)
@router.get( @router.get(
"/users/", "/users/",
response_model=BatchUserResponse, response_model=BatchUserResponse,
@@ -90,7 +98,11 @@ async def get_users(
# 查询未缓存的用户 # 查询未缓存的用户
if uncached_user_ids: if uncached_user_ids:
searched_users = (await session.exec(select(User).where(col(User.id).in_(uncached_user_ids)))).all() searched_users = (
await session.exec(
select(User).where(col(User.id).in_(uncached_user_ids), ~User.is_restricted_query(col(User.id)))
)
).all()
# 将查询到的用户添加到缓存并返回 # 将查询到的用户添加到缓存并返回
for searched_user in searched_users: for searched_user in searched_users:
@@ -107,7 +119,9 @@ async def get_users(
response = BatchUserResponse(users=cached_users) response = BatchUserResponse(users=cached_users)
return response return response
else: else:
searched_users = (await session.exec(select(User).limit(50))).all() searched_users = (
await session.exec(select(User).limit(50).where(~User.is_restricted_query(col(User.id))))
).all()
users = [] users = []
for searched_user in searched_users: for searched_user in searched_users:
if searched_user.id == BANCHOBOT_ID: if searched_user.id == BANCHOBOT_ID:
@@ -139,7 +153,7 @@ async def get_user_events(
offset: Annotated[int | None, Query(description="活动日志的偏移量")] = None, offset: Annotated[int | None, Query(description="活动日志的偏移量")] = None,
): ):
db_user = await session.get(User, user_id) db_user = await session.get(User, user_id)
if db_user is None or db_user.id == BANCHOBOT_ID: if db_user is None or not await visible_to_current_user(db_user, None, session):
raise HTTPException(404, "User Not found") raise HTTPException(404, "User Not found")
events = ( events = (
await session.exec( await session.exec(
@@ -174,7 +188,7 @@ async def get_user_kudosu(
""" """
# 验证用户是否存在 # 验证用户是否存在
db_user = await session.get(User, user_id) db_user = await session.get(User, user_id)
if db_user is None or db_user.id == BANCHOBOT_ID: if db_user is None or not await visible_to_current_user(db_user, None, session):
raise HTTPException(404, "User not found") raise HTTPException(404, "User not found")
# TODO: 实现 kudosu 记录获取逻辑 # TODO: 实现 kudosu 记录获取逻辑
@@ -214,7 +228,7 @@ async def get_user_beatmaps_passed(
raise HTTPException(status_code=413, detail="beatmapset_ids cannot exceed 50 items") raise HTTPException(status_code=413, detail="beatmapset_ids cannot exceed 50 items")
user = await session.get(User, user_id) user = await session.get(User, user_id)
if not user or user.id == BANCHOBOT_ID: if user is None or not await visible_to_current_user(user, current_user, session):
raise HTTPException(404, detail="User not found") raise HTTPException(404, detail="User not found")
allowed_mode: GameMode | None = None allowed_mode: GameMode | None = None
@@ -282,7 +296,7 @@ async def get_user_info_ruleset(
background_task: BackgroundTasks, background_task: BackgroundTasks,
user_id: Annotated[str, Path(description="用户 ID 或用户名")], user_id: Annotated[str, Path(description="用户 ID 或用户名")],
ruleset: Annotated[GameMode | None, Path(description="指定 ruleset")], ruleset: Annotated[GameMode | None, Path(description="指定 ruleset")],
# current_user: User = Security(get_current_user, scopes=["public"]), current_user: User | None = Security(get_optional_user, scopes=["public"]),
): ):
redis = get_redis() redis = get_redis()
cache_service = get_user_cache_service(redis) cache_service = get_user_cache_service(redis)
@@ -303,11 +317,18 @@ async def get_user_info_ruleset(
).first() ).first()
if not searched_user or searched_user.id == BANCHOBOT_ID: if not searched_user or searched_user.id == BANCHOBOT_ID:
raise HTTPException(404, detail="User not found") raise HTTPException(404, detail="User not found")
searched_is_self = current_user is not None and current_user.id == searched_user.id
should_not_show = not searched_is_self and await searched_user.is_restricted(session)
if should_not_show:
raise HTTPException(404, detail="User not found")
include = SEARCH_INCLUDED
if searched_is_self:
include = ALL_INCLUDED
user_resp = await UserResp.from_db( user_resp = await UserResp.from_db(
searched_user, searched_user,
session, session,
include=SEARCH_INCLUDED, include=include,
ruleset=ruleset, ruleset=ruleset,
) )
@@ -331,7 +352,7 @@ async def get_user_info(
session: Database, session: Database,
request: Request, request: Request,
user_id: Annotated[str, Path(description="用户 ID 或用户名")], user_id: Annotated[str, Path(description="用户 ID 或用户名")],
# current_user: User = Security(get_current_user, scopes=["public"]), current_user: User | None = Security(get_optional_user, scopes=["public"]),
): ):
redis = get_redis() redis = get_redis()
cache_service = get_user_cache_service(redis) cache_service = get_user_cache_service(redis)
@@ -352,11 +373,18 @@ async def get_user_info(
).first() ).first()
if not searched_user or searched_user.id == BANCHOBOT_ID: if not searched_user or searched_user.id == BANCHOBOT_ID:
raise HTTPException(404, detail="User not found") raise HTTPException(404, detail="User not found")
searched_is_self = current_user is not None and current_user.id == searched_user.id
should_not_show = not searched_is_self and await searched_user.is_restricted(session)
if should_not_show:
raise HTTPException(404, detail="User not found")
include = SEARCH_INCLUDED
if searched_is_self:
include = ALL_INCLUDED
user_resp = await UserResp.from_db( user_resp = await UserResp.from_db(
searched_user, searched_user,
session, session,
include=SEARCH_INCLUDED, include=include,
) )
# 异步缓存结果 # 异步缓存结果
@@ -411,7 +439,7 @@ async def get_user_beatmapsets(
elif type == BeatmapsetType.FAVOURITE: elif type == BeatmapsetType.FAVOURITE:
user = await session.get(User, user_id) user = await session.get(User, user_id)
if not user: if user is None or not await visible_to_current_user(user, current_user, session):
raise HTTPException(404, detail="User not found") raise HTTPException(404, detail="User not found")
favourites = await user.awaitable_attrs.favourite_beatmapsets favourites = await user.awaitable_attrs.favourite_beatmapsets
resp = [ resp = [
@@ -419,6 +447,10 @@ async def get_user_beatmapsets(
] ]
elif type == BeatmapsetType.MOST_PLAYED: elif type == BeatmapsetType.MOST_PLAYED:
user = await session.get(User, user_id)
if user is None or not await visible_to_current_user(user, current_user, session):
raise HTTPException(404, detail="User not found")
most_played = await session.exec( most_played = await session.exec(
select(BeatmapPlaycounts) select(BeatmapPlaycounts)
.where(BeatmapPlaycounts.user_id == user_id) .where(BeatmapPlaycounts.user_id == user_id)
@@ -484,7 +516,7 @@ async def get_user_scores(
return cached_scores return cached_scores
db_user = await session.get(User, user_id) db_user = await session.get(User, user_id)
if not db_user or db_user.id == BANCHOBOT_ID: if db_user is None or not await visible_to_current_user(db_user, current_user, session):
raise HTTPException(404, detail="User not found") raise HTTPException(404, detail="User not found")
gamemode = mode or db_user.playmode gamemode = mode or db_user.playmode

View File

@@ -6,7 +6,8 @@ from typing import Final
from app.config import settings from app.config import settings
from app.database.score import Score from app.database.score import Score
from app.dependencies.database import get_redis from app.database.user import User
from app.dependencies.database import get_redis, with_db
from app.dependencies.fetcher import get_fetcher from app.dependencies.fetcher import get_fetcher
from app.dependencies.scheduler import get_scheduler from app.dependencies.scheduler import get_scheduler
from app.log import logger from app.log import logger
@@ -107,9 +108,6 @@ async def schedule_user_cache_warmup_task() -> None:
redis = get_redis() redis = get_redis()
cache_service = get_user_cache_service(redis) cache_service = get_user_cache_service(redis)
from app.dependencies.database import with_db
async with with_db() as session: async with with_db() as session:
from app.database.statistics import UserStatistics from app.database.statistics import UserStatistics
from app.models.score import GameMode from app.models.score import GameMode
@@ -119,7 +117,10 @@ async def schedule_user_cache_warmup_task() -> None:
top_users = ( top_users = (
await session.exec( await session.exec(
select(UserStatistics.user_id) select(UserStatistics.user_id)
.where(UserStatistics.mode == mode) .where(
UserStatistics.mode == mode,
~User.is_restricted_query(col(UserStatistics.user_id)),
)
.order_by(col(UserStatistics.pp).desc()) .order_by(col(UserStatistics.pp).desc())
.limit(100) .limit(100)
) )

View File

@@ -0,0 +1,41 @@
"""user: remove is_restricted
Revision ID: 425b91532cb4
Revises: ee13ad926584
Create Date: 2025-10-05 11:11:46.391414
"""
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
revision: str = "425b91532cb4"
down_revision: str | Sequence[str] | None = "ee13ad926584"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("lazer_users", "is_restricted")
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
# Step 1: Add the column as nullable
op.add_column(
"lazer_users", sa.Column("is_restricted", mysql.TINYINT(display_width=1), autoincrement=False, nullable=True)
)
# Step 2: Set a default value for all existing rows
op.execute("UPDATE lazer_users SET is_restricted = 0 WHERE is_restricted IS NULL")
# Step 3: Alter the column to be NOT NULL
op.alter_column("lazer_users", "is_restricted", nullable=False)
# ### end Alembic commands ###