Merge branch 'main' of https://github.com/GooGuTeam/osu_lazer_api
This commit is contained in:
@@ -1,6 +1,8 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from .user import User
|
from app.models.user import User as APIUser
|
||||||
|
|
||||||
|
from .user import User as DBUser
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlmodel import (
|
from sqlmodel import (
|
||||||
@@ -41,14 +43,14 @@ class Relationship(SQLModel, table=True):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False)
|
type: RelationshipType = Field(default=RelationshipType.FOLLOW, nullable=False)
|
||||||
target: "User" = SQLRelationship(
|
target: DBUser = SQLRelationship(
|
||||||
sa_relationship_kwargs={"foreign_keys": "[Relationship.target_id]"}
|
sa_relationship_kwargs={"foreign_keys": "[Relationship.target_id]"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class RelationshipResp(BaseModel):
|
class RelationshipResp(BaseModel):
|
||||||
target_id: int
|
target_id: int
|
||||||
# FIXME: target: User
|
target: APIUser
|
||||||
mutual: bool = False
|
mutual: bool = False
|
||||||
type: RelationshipType
|
type: RelationshipType
|
||||||
|
|
||||||
@@ -56,6 +58,8 @@ class RelationshipResp(BaseModel):
|
|||||||
async def from_db(
|
async def from_db(
|
||||||
cls, session: AsyncSession, relationship: Relationship
|
cls, session: AsyncSession, relationship: Relationship
|
||||||
) -> "RelationshipResp":
|
) -> "RelationshipResp":
|
||||||
|
from app.utils import convert_db_user_to_api_user
|
||||||
|
|
||||||
target_relationship = (
|
target_relationship = (
|
||||||
await session.exec(
|
await session.exec(
|
||||||
select(Relationship).where(
|
select(Relationship).where(
|
||||||
@@ -71,7 +75,7 @@ class RelationshipResp(BaseModel):
|
|||||||
)
|
)
|
||||||
return cls(
|
return cls(
|
||||||
target_id=relationship.target_id,
|
target_id=relationship.target_id,
|
||||||
# target=relationship.target,
|
target=await convert_db_user_to_api_user(relationship.target),
|
||||||
mutual=mutual,
|
mutual=mutual,
|
||||||
type=relationship.type,
|
type=relationship.type,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -7,7 +7,8 @@ from .team import TeamMember
|
|||||||
|
|
||||||
from sqlalchemy import DECIMAL, JSON, Column, Date, DateTime, Text
|
from sqlalchemy import DECIMAL, JSON, Column, Date, DateTime, Text
|
||||||
from sqlalchemy.dialects.mysql import VARCHAR
|
from sqlalchemy.dialects.mysql import VARCHAR
|
||||||
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel
|
from sqlalchemy.orm import joinedload, selectinload
|
||||||
|
from sqlmodel import BigInteger, Field, ForeignKey, Relationship, SQLModel, select
|
||||||
|
|
||||||
|
|
||||||
class User(SQLModel, table=True):
|
class User(SQLModel, table=True):
|
||||||
@@ -100,6 +101,30 @@ class User(SQLModel, table=True):
|
|||||||
back_populates="user"
|
back_populates="user"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def all_select_option(cls):
|
||||||
|
return (
|
||||||
|
joinedload(cls.lazer_profile), # pyright: ignore[reportArgumentType]
|
||||||
|
joinedload(cls.lazer_counts), # pyright: ignore[reportArgumentType]
|
||||||
|
joinedload(cls.daily_challenge_stats), # pyright: ignore[reportArgumentType]
|
||||||
|
joinedload(cls.avatar), # pyright: ignore[reportArgumentType]
|
||||||
|
selectinload(cls.lazer_statistics), # pyright: ignore[reportArgumentType]
|
||||||
|
selectinload(cls.lazer_achievements), # pyright: ignore[reportArgumentType]
|
||||||
|
selectinload(cls.lazer_profile_sections), # pyright: ignore[reportArgumentType]
|
||||||
|
selectinload(cls.statistics), # pyright: ignore[reportArgumentType]
|
||||||
|
joinedload(cls.team_membership), # pyright: ignore[reportArgumentType]
|
||||||
|
selectinload(cls.rank_history), # pyright: ignore[reportArgumentType]
|
||||||
|
selectinload(cls.active_banners), # pyright: ignore[reportArgumentType]
|
||||||
|
selectinload(cls.lazer_badges), # pyright: ignore[reportArgumentType]
|
||||||
|
selectinload(cls.lazer_monthly_playcounts), # pyright: ignore[reportArgumentType]
|
||||||
|
selectinload(cls.lazer_previous_usernames), # pyright: ignore[reportArgumentType]
|
||||||
|
selectinload(cls.lazer_replays_watched), # pyright: ignore[reportArgumentType]
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def all_select_clause(cls):
|
||||||
|
return select(cls).options(*cls.all_select_option())
|
||||||
|
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# Lazer API 专用表模型
|
# Lazer API 专用表模型
|
||||||
|
|||||||
@@ -9,8 +9,6 @@ from .database import get_db
|
|||||||
|
|
||||||
from fastapi import Depends, HTTPException
|
from fastapi import Depends, HTTPException
|
||||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||||
from sqlalchemy.orm import joinedload, selectinload
|
|
||||||
from sqlmodel import select
|
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
security = HTTPBearer()
|
security = HTTPBearer()
|
||||||
@@ -35,25 +33,7 @@ async def get_current_user_by_token(token: str, db: AsyncSession) -> DBUser | No
|
|||||||
return None
|
return None
|
||||||
user = (
|
user = (
|
||||||
await db.exec(
|
await db.exec(
|
||||||
select(DBUser)
|
DBUser.all_select_clause().where(DBUser.id == token_record.user_id)
|
||||||
.options(
|
|
||||||
joinedload(DBUser.lazer_profile), # pyright: ignore[reportArgumentType]
|
|
||||||
joinedload(DBUser.lazer_counts), # pyright: ignore[reportArgumentType]
|
|
||||||
joinedload(DBUser.daily_challenge_stats), # pyright: ignore[reportArgumentType]
|
|
||||||
joinedload(DBUser.avatar), # pyright: ignore[reportArgumentType]
|
|
||||||
selectinload(DBUser.lazer_statistics), # pyright: ignore[reportArgumentType]
|
|
||||||
selectinload(DBUser.lazer_achievements), # pyright: ignore[reportArgumentType]
|
|
||||||
selectinload(DBUser.lazer_profile_sections), # pyright: ignore[reportArgumentType]
|
|
||||||
selectinload(DBUser.statistics), # pyright: ignore[reportArgumentType]
|
|
||||||
joinedload(DBUser.team_membership), # pyright: ignore[reportArgumentType]
|
|
||||||
selectinload(DBUser.rank_history), # pyright: ignore[reportArgumentType]
|
|
||||||
selectinload(DBUser.active_banners), # pyright: ignore[reportArgumentType]
|
|
||||||
selectinload(DBUser.lazer_badges), # pyright: ignore[reportArgumentType]
|
|
||||||
selectinload(DBUser.lazer_monthly_playcounts), # pyright: ignore[reportArgumentType]
|
|
||||||
selectinload(DBUser.lazer_previous_usernames), # pyright: ignore[reportArgumentType]
|
|
||||||
selectinload(DBUser.lazer_replays_watched), # pyright: ignore[reportArgumentType]
|
|
||||||
)
|
|
||||||
.where(DBUser.id == token_record.user_id)
|
|
||||||
)
|
)
|
||||||
).first()
|
).first()
|
||||||
return user
|
return user
|
||||||
|
|||||||
@@ -2,16 +2,15 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
from app.database import (
|
|
||||||
LazerUserAchievement,
|
|
||||||
Team as Team,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .score import GameMode
|
from .score import GameMode
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.database import LazerUserAchievement, Team
|
||||||
|
|
||||||
|
|
||||||
class PlayStyle(str, Enum):
|
class PlayStyle(str, Enum):
|
||||||
MOUSE = "mouse"
|
MOUSE = "mouse"
|
||||||
@@ -83,7 +82,11 @@ class UserAchievement(BaseModel):
|
|||||||
achievement_id: int
|
achievement_id: int
|
||||||
|
|
||||||
# 添加数据库模型转换方法
|
# 添加数据库模型转换方法
|
||||||
def to_db_model(self, user_id: int) -> LazerUserAchievement:
|
def to_db_model(self, user_id: int) -> "LazerUserAchievement":
|
||||||
|
from app.database import (
|
||||||
|
LazerUserAchievement,
|
||||||
|
)
|
||||||
|
|
||||||
return LazerUserAchievement(
|
return LazerUserAchievement(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
achievement_id=self.achievement_id,
|
achievement_id=self.achievement_id,
|
||||||
@@ -207,5 +210,5 @@ class User(BaseModel):
|
|||||||
rank_history: RankHistory | None = None
|
rank_history: RankHistory | None = None
|
||||||
rankHistory: RankHistory | None = None # 兼容性别名
|
rankHistory: RankHistory | None = None # 兼容性别名
|
||||||
replays_watched_counts: list[dict] = []
|
replays_watched_counts: list[dict] = []
|
||||||
team: Team | None = None
|
team: "Team | None" = None
|
||||||
user_achievements: list[UserAchievement] = []
|
user_achievements: list[UserAchievement] = []
|
||||||
|
|||||||
@@ -2,43 +2,42 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
import re
|
import re
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from app.auth import (
|
from app.auth import (
|
||||||
authenticate_user,
|
authenticate_user,
|
||||||
create_access_token,
|
create_access_token,
|
||||||
generate_refresh_token,
|
generate_refresh_token,
|
||||||
|
get_password_hash,
|
||||||
get_token_by_refresh_token,
|
get_token_by_refresh_token,
|
||||||
store_token,
|
store_token,
|
||||||
get_password_hash,
|
|
||||||
)
|
)
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.dependencies import get_db
|
|
||||||
from app.models.oauth import TokenResponse, OAuthErrorResponse, RegistrationErrorResponse, UserRegistrationErrors, RegistrationRequestErrors
|
|
||||||
from app.database import User as DBUser
|
from app.database import User as DBUser
|
||||||
|
from app.dependencies import get_db
|
||||||
|
from app.models.oauth import (
|
||||||
|
OAuthErrorResponse,
|
||||||
|
RegistrationRequestErrors,
|
||||||
|
TokenResponse,
|
||||||
|
UserRegistrationErrors,
|
||||||
|
)
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Form, HTTPException
|
from fastapi import APIRouter, Depends, Form
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
|
||||||
def create_oauth_error_response(error: str, description: str, hint: str, status_code: int = 400):
|
def create_oauth_error_response(
|
||||||
|
error: str, description: str, hint: str, status_code: int = 400
|
||||||
|
):
|
||||||
"""创建标准的 OAuth 错误响应"""
|
"""创建标准的 OAuth 错误响应"""
|
||||||
error_data = OAuthErrorResponse(
|
error_data = OAuthErrorResponse(
|
||||||
error=error,
|
error=error, error_description=description, hint=hint, message=description
|
||||||
error_description=description,
|
|
||||||
hint=hint,
|
|
||||||
message=description
|
|
||||||
)
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=status_code,
|
|
||||||
content=error_data.model_dump()
|
|
||||||
)
|
)
|
||||||
|
return JSONResponse(status_code=status_code, content=error_data.model_dump())
|
||||||
|
|
||||||
|
|
||||||
|
def validate_username(username: str) -> list[str]:
|
||||||
def validate_username(username: str) -> List[str]:
|
|
||||||
"""验证用户名"""
|
"""验证用户名"""
|
||||||
errors = []
|
errors = []
|
||||||
|
|
||||||
@@ -53,8 +52,10 @@ def validate_username(username: str) -> List[str]:
|
|||||||
errors.append("Username must be at most 15 characters long")
|
errors.append("Username must be at most 15 characters long")
|
||||||
|
|
||||||
# 检查用户名格式(只允许字母、数字、下划线、连字符)
|
# 检查用户名格式(只允许字母、数字、下划线、连字符)
|
||||||
if not re.match(r'^[a-zA-Z0-9_-]+$', username):
|
if not re.match(r"^[a-zA-Z0-9_-]+$", username):
|
||||||
errors.append("Username can only contain letters, numbers, underscores, and hyphens")
|
errors.append(
|
||||||
|
"Username can only contain letters, numbers, underscores, and hyphens"
|
||||||
|
)
|
||||||
|
|
||||||
# 检查是否以数字开头
|
# 检查是否以数字开头
|
||||||
if username[0].isdigit():
|
if username[0].isdigit():
|
||||||
@@ -63,7 +64,7 @@ def validate_username(username: str) -> List[str]:
|
|||||||
return errors
|
return errors
|
||||||
|
|
||||||
|
|
||||||
def validate_email(email: str) -> List[str]:
|
def validate_email(email: str) -> list[str]:
|
||||||
"""验证邮箱"""
|
"""验证邮箱"""
|
||||||
errors = []
|
errors = []
|
||||||
|
|
||||||
@@ -72,14 +73,14 @@ def validate_email(email: str) -> List[str]:
|
|||||||
return errors
|
return errors
|
||||||
|
|
||||||
# 基本的邮箱格式验证
|
# 基本的邮箱格式验证
|
||||||
email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
|
email_pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
|
||||||
if not re.match(email_pattern, email):
|
if not re.match(email_pattern, email):
|
||||||
errors.append("Please enter a valid email address")
|
errors.append("Please enter a valid email address")
|
||||||
|
|
||||||
return errors
|
return errors
|
||||||
|
|
||||||
|
|
||||||
def validate_password(password: str) -> List[str]:
|
def validate_password(password: str) -> list[str]:
|
||||||
"""验证密码"""
|
"""验证密码"""
|
||||||
errors = []
|
errors = []
|
||||||
|
|
||||||
@@ -109,16 +110,12 @@ async def register_user(
|
|||||||
email_errors = validate_email(user_email)
|
email_errors = validate_email(user_email)
|
||||||
password_errors = validate_password(user_password)
|
password_errors = validate_password(user_password)
|
||||||
|
|
||||||
result = await db.exec(
|
result = await db.exec(select(DBUser).where(DBUser.name == user_username))
|
||||||
select(DBUser).where(DBUser.name == user_username)
|
|
||||||
)
|
|
||||||
existing_user = result.first()
|
existing_user = result.first()
|
||||||
if existing_user:
|
if existing_user:
|
||||||
username_errors.append("Username is already taken")
|
username_errors.append("Username is already taken")
|
||||||
|
|
||||||
result = await db.exec(
|
result = await db.exec(select(DBUser).where(DBUser.email == user_email))
|
||||||
select(DBUser).where(DBUser.email == user_email)
|
|
||||||
)
|
|
||||||
existing_email = result.first()
|
existing_email = result.first()
|
||||||
if existing_email:
|
if existing_email:
|
||||||
email_errors.append("Email is already taken")
|
email_errors.append("Email is already taken")
|
||||||
@@ -128,13 +125,12 @@ async def register_user(
|
|||||||
user=UserRegistrationErrors(
|
user=UserRegistrationErrors(
|
||||||
username=username_errors,
|
username=username_errors,
|
||||||
user_email=email_errors,
|
user_email=email_errors,
|
||||||
password=password_errors
|
password=password_errors,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=422,
|
status_code=422, content={"form_error": errors.model_dump()}
|
||||||
content={"form_error": errors.model_dump()}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -194,11 +190,13 @@ async def register_user(
|
|||||||
if user_id <= 2:
|
if user_id <= 2:
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
errors = RegistrationRequestErrors(
|
errors = RegistrationRequestErrors(
|
||||||
message="Failed to create account with valid ID. Please contact support."
|
message=(
|
||||||
|
"Failed to create account with valid ID. "
|
||||||
|
"Please contact support."
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=500,
|
status_code=500, content={"form_error": errors.model_dump()}
|
||||||
content={"form_error": errors.model_dump()}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as fix_error:
|
except Exception as fix_error:
|
||||||
@@ -208,12 +206,12 @@ async def register_user(
|
|||||||
message="Failed to create account with valid ID. Please try again."
|
message="Failed to create account with valid ID. Please try again."
|
||||||
)
|
)
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=500,
|
status_code=500, content={"form_error": errors.model_dump()}
|
||||||
content={"form_error": errors.model_dump()}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建默认的 lazer_profile
|
# 创建默认的 lazer_profile
|
||||||
from app.database.user import LazerUserProfile
|
from app.database.user import LazerUserProfile
|
||||||
|
|
||||||
lazer_profile = LazerUserProfile(
|
lazer_profile = LazerUserProfile(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
is_active=True,
|
is_active=True,
|
||||||
@@ -240,10 +238,7 @@ async def register_user(
|
|||||||
# 返回成功响应
|
# 返回成功响应
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=201,
|
status_code=201,
|
||||||
content={
|
content={"message": "Account created successfully", "user_id": user_id},
|
||||||
"message": "Account created successfully",
|
|
||||||
"user_id": user_id
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -251,6 +246,7 @@ async def register_user(
|
|||||||
# 打印详细错误信息用于调试
|
# 打印详细错误信息用于调试
|
||||||
print(f"Registration error: {e}")
|
print(f"Registration error: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
# 返回通用错误
|
# 返回通用错误
|
||||||
@@ -259,8 +255,7 @@ async def register_user(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=500,
|
status_code=500, content={"form_error": errors.model_dump()}
|
||||||
content={"form_error": errors.model_dump()}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -283,9 +278,13 @@ async def oauth_token(
|
|||||||
):
|
):
|
||||||
return create_oauth_error_response(
|
return create_oauth_error_response(
|
||||||
error="invalid_client",
|
error="invalid_client",
|
||||||
description="Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method).",
|
description=(
|
||||||
|
"Client authentication failed (e.g., unknown client, "
|
||||||
|
"no client authentication included, "
|
||||||
|
"or unsupported authentication method)."
|
||||||
|
),
|
||||||
hint="Invalid client credentials",
|
hint="Invalid client credentials",
|
||||||
status_code=401
|
status_code=401,
|
||||||
)
|
)
|
||||||
|
|
||||||
if grant_type == "password":
|
if grant_type == "password":
|
||||||
@@ -293,8 +292,12 @@ async def oauth_token(
|
|||||||
if not username or not password:
|
if not username or not password:
|
||||||
return create_oauth_error_response(
|
return create_oauth_error_response(
|
||||||
error="invalid_request",
|
error="invalid_request",
|
||||||
description="The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed.",
|
description=(
|
||||||
hint="Username and password required"
|
"The request is missing a required parameter, includes an "
|
||||||
|
"invalid parameter value, "
|
||||||
|
"includes a parameter more than once, or is otherwise malformed."
|
||||||
|
),
|
||||||
|
hint="Username and password required",
|
||||||
)
|
)
|
||||||
|
|
||||||
# 验证用户
|
# 验证用户
|
||||||
@@ -302,8 +305,14 @@ async def oauth_token(
|
|||||||
if not user:
|
if not user:
|
||||||
return create_oauth_error_response(
|
return create_oauth_error_response(
|
||||||
error="invalid_grant",
|
error="invalid_grant",
|
||||||
description="The provided authorization grant (e.g., authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, does not match the redirection URI used in the authorization request, or was issued to another client.",
|
description=(
|
||||||
hint="Incorrect sign in"
|
"The provided authorization grant (e.g., authorization code, "
|
||||||
|
"resource owner credentials) "
|
||||||
|
"or refresh token is invalid, expired, revoked, "
|
||||||
|
"does not match the redirection URI used in "
|
||||||
|
"the authorization request, or was issued to another client."
|
||||||
|
),
|
||||||
|
hint="Incorrect sign in",
|
||||||
)
|
)
|
||||||
|
|
||||||
# 生成令牌
|
# 生成令牌
|
||||||
@@ -335,8 +344,12 @@ async def oauth_token(
|
|||||||
if not refresh_token:
|
if not refresh_token:
|
||||||
return create_oauth_error_response(
|
return create_oauth_error_response(
|
||||||
error="invalid_request",
|
error="invalid_request",
|
||||||
description="The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed.",
|
description=(
|
||||||
hint="Refresh token required"
|
"The request is missing a required parameter, "
|
||||||
|
"includes an invalid parameter value, "
|
||||||
|
"includes a parameter more than once, or is otherwise malformed."
|
||||||
|
),
|
||||||
|
hint="Refresh token required",
|
||||||
)
|
)
|
||||||
|
|
||||||
# 验证刷新令牌
|
# 验证刷新令牌
|
||||||
@@ -344,8 +357,14 @@ async def oauth_token(
|
|||||||
if not token_record:
|
if not token_record:
|
||||||
return create_oauth_error_response(
|
return create_oauth_error_response(
|
||||||
error="invalid_grant",
|
error="invalid_grant",
|
||||||
description="The provided authorization grant (e.g., authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, does not match the redirection URI used in the authorization request, or was issued to another client.",
|
description=(
|
||||||
hint="Invalid refresh token"
|
"The provided authorization grant (e.g., authorization code, "
|
||||||
|
"resource owner credentials) or refresh token is "
|
||||||
|
"invalid, expired, revoked, "
|
||||||
|
"does not match the redirection URI used "
|
||||||
|
"in the authorization request, or was issued to another client."
|
||||||
|
),
|
||||||
|
hint="Invalid refresh token",
|
||||||
)
|
)
|
||||||
|
|
||||||
# 生成新的访问令牌
|
# 生成新的访问令牌
|
||||||
@@ -375,6 +394,9 @@ async def oauth_token(
|
|||||||
else:
|
else:
|
||||||
return create_oauth_error_response(
|
return create_oauth_error_response(
|
||||||
error="unsupported_grant_type",
|
error="unsupported_grant_type",
|
||||||
description="The authorization grant type is not supported by the authorization server.",
|
description=(
|
||||||
hint="Unsupported grant type"
|
"The authorization grant type is not supported "
|
||||||
|
"by the authorization server."
|
||||||
|
),
|
||||||
|
hint="Unsupported grant type",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from app.dependencies.user import get_current_user
|
|||||||
from .api_router import router
|
from .api_router import router
|
||||||
|
|
||||||
from fastapi import Depends, HTTPException, Query, Request
|
from fastapi import Depends, HTTPException, Query, Request
|
||||||
|
from sqlalchemy.orm import joinedload
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
@@ -25,7 +26,9 @@ async def get_relationship(
|
|||||||
else RelationshipType.BLOCK
|
else RelationshipType.BLOCK
|
||||||
)
|
)
|
||||||
relationships = await db.exec(
|
relationships = await db.exec(
|
||||||
select(Relationship).where(
|
select(Relationship)
|
||||||
|
.options(joinedload(Relationship.target).options(*DBUser.all_select_option())) # pyright: ignore[reportArgumentType]
|
||||||
|
.where(
|
||||||
Relationship.user_id == current_user.id,
|
Relationship.user_id == current_user.id,
|
||||||
Relationship.type == relationship_type,
|
Relationship.type == relationship_type,
|
||||||
)
|
)
|
||||||
@@ -79,8 +82,20 @@ async def add_relationship(
|
|||||||
if target_relationship and target_relationship.type == RelationshipType.FOLLOW:
|
if target_relationship and target_relationship.type == RelationshipType.FOLLOW:
|
||||||
await db.delete(target_relationship)
|
await db.delete(target_relationship)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
await db.refresh(relationship)
|
|
||||||
if relationship.type == RelationshipType.FOLLOW:
|
if relationship.type == RelationshipType.FOLLOW:
|
||||||
|
relationship = (
|
||||||
|
await db.exec(
|
||||||
|
select(Relationship)
|
||||||
|
.where(
|
||||||
|
Relationship.user_id == current_user.id,
|
||||||
|
Relationship.target_id == target,
|
||||||
|
)
|
||||||
|
.options(
|
||||||
|
joinedload(Relationship.target).options(*DBUser.all_select_option()) # pyright: ignore[reportArgumentType]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
assert relationship, "Relationship should exist after commit"
|
||||||
return await RelationshipResp.from_db(db, relationship)
|
return await RelationshipResp.from_db(db, relationship)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,17 +3,22 @@ from __future__ import annotations
|
|||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from app.database import User as DBUser
|
from app.database import User as DBUser
|
||||||
from app.dependencies import get_db, get_current_user
|
from app.dependencies.database import get_db
|
||||||
|
from app.dependencies import get_current_user
|
||||||
|
from app.models.score import INT_TO_MODE
|
||||||
from app.models.user import User as ApiUser
|
from app.models.user import User as ApiUser
|
||||||
from app.utils import convert_db_user_to_api_user
|
from app.utils import convert_db_user_to_api_user
|
||||||
|
|
||||||
from .api_router import router
|
from .api_router import router
|
||||||
|
|
||||||
from fastapi import Depends, HTTPException, Query
|
from fastapi import Depends, HTTPException, Query
|
||||||
|
from pydantic import BaseModel
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
from sqlmodel.sql.expression import col
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- Shared Utility ----------
|
||||||
async def get_user_by_lookup(
|
async def get_user_by_lookup(
|
||||||
db: AsyncSession,
|
db: AsyncSession,
|
||||||
lookup: str,
|
lookup: str,
|
||||||
@@ -38,7 +43,39 @@ async def get_user_by_lookup(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- Batch Users ----------
|
||||||
|
class BatchUserResponse(BaseModel):
|
||||||
|
users: list[ApiUser]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/users", response_model=BatchUserResponse)
|
||||||
|
@router.get("/users/lookup", response_model=BatchUserResponse)
|
||||||
|
async def get_users(
|
||||||
|
user_ids: list[int] = Query(default_factory=list, alias="ids[]"),
|
||||||
|
include_variant_statistics: bool = Query(default=False), # TODO: future use
|
||||||
|
session: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
if user_ids:
|
||||||
|
searched_users = (
|
||||||
|
await session.exec(
|
||||||
|
DBUser.all_select_clause().limit(50).where(col(DBUser.id).in_(user_ids))
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
else:
|
||||||
|
searched_users = (
|
||||||
|
await session.exec(DBUser.all_select_clause().limit(50))
|
||||||
|
).all()
|
||||||
|
return BatchUserResponse(
|
||||||
|
users=[
|
||||||
|
await convert_db_user_to_api_user(
|
||||||
|
searched_user, ruleset=INT_TO_MODE[searched_user.preferred_mode].value
|
||||||
|
)
|
||||||
|
for searched_user in searched_users
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- Individual User ----------
|
||||||
@router.get("/users/{user_lookup}/{mode}", response_model=ApiUser)
|
@router.get("/users/{user_lookup}/{mode}", response_model=ApiUser)
|
||||||
@router.get("/users/{user_lookup}/{mode}/", response_model=ApiUser)
|
@router.get("/users/{user_lookup}/{mode}/", response_model=ApiUser)
|
||||||
async def get_user_with_mode(
|
async def get_user_with_mode(
|
||||||
@@ -53,8 +90,7 @@ async def get_user_with_mode(
|
|||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(status_code=404, detail="User not found")
|
raise HTTPException(status_code=404, detail="User not found")
|
||||||
|
|
||||||
api_user = await convert_db_user_to_api_user(user, mode)
|
return await convert_db_user_to_api_user(user, mode)
|
||||||
return api_user
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/users/{user_lookup}", response_model=ApiUser)
|
@router.get("/users/{user_lookup}", response_model=ApiUser)
|
||||||
@@ -70,5 +106,24 @@ async def get_user_default(
|
|||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(status_code=404, detail="User not found")
|
raise HTTPException(status_code=404, detail="User not found")
|
||||||
|
|
||||||
api_user = await convert_db_user_to_api_user(user, "osu")
|
return await convert_db_user_to_api_user(user, "osu")
|
||||||
return api_user
|
|
||||||
|
|
||||||
|
@router.get("/users/{user}/{ruleset}", response_model=ApiUser)
|
||||||
|
async def get_user_info(
|
||||||
|
user: str,
|
||||||
|
ruleset: Literal["osu", "taiko", "fruits", "mania"] = "osu",
|
||||||
|
session: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
searched_user = (
|
||||||
|
await session.exec(
|
||||||
|
DBUser.all_select_clause().where(
|
||||||
|
DBUser.id == int(user)
|
||||||
|
if user.isdigit()
|
||||||
|
else DBUser.name == user.removeprefix("@")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
if not searched_user:
|
||||||
|
raise HTTPException(404, detail="User not found")
|
||||||
|
return await convert_db_user_to_api_user(searched_user, ruleset=ruleset)
|
||||||
|
|||||||
4
main.py
4
main.py
@@ -4,12 +4,16 @@ from contextlib import asynccontextmanager
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
|
from app.database import Team # noqa: F401
|
||||||
from app.dependencies.database import create_tables, engine
|
from app.dependencies.database import create_tables, engine
|
||||||
from app.dependencies.fetcher import get_fetcher
|
from app.dependencies.fetcher import get_fetcher
|
||||||
|
from app.models.user import User
|
||||||
from app.router import api_router, auth_router, fetcher_router, signalr_router
|
from app.router import api_router, auth_router, fetcher_router, signalr_router
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
User.model_rebuild()
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
|
|||||||
Reference in New Issue
Block a user