refactor(database): migrate to sqlmodel
This commit is contained in:
282
app/utils.py
282
app/utils.py
@@ -1,8 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from app.database import User as DBUser
|
||||
from app.database import (
|
||||
LazerUserCounts,
|
||||
LazerUserProfile,
|
||||
LazerUserStatistics,
|
||||
User as DBUser,
|
||||
)
|
||||
from app.models.user import (
|
||||
Country,
|
||||
Cover,
|
||||
@@ -14,7 +19,6 @@ from app.models.user import (
|
||||
RankHighest,
|
||||
RankHistory,
|
||||
Statistics,
|
||||
Team,
|
||||
User,
|
||||
UserAchievement,
|
||||
)
|
||||
@@ -37,25 +41,28 @@ def convert_db_user_to_api_user(
|
||||
profile = db_user.lazer_profile
|
||||
if not profile:
|
||||
# 如果没有 lazer 资料,使用默认值
|
||||
profile = create_default_profile(db_user)
|
||||
profile = LazerUserProfile(
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# 获取 Lazer 用户计数
|
||||
# 获取 Lazer 用户计数 - 使用正确的 lazer_counts 关系
|
||||
lzrcnt = db_user.lazer_counts
|
||||
|
||||
lzrcnt = db_user.lazer_statistics
|
||||
if not lzrcnt:
|
||||
# 如果没有 lazer 计数,使用默认值
|
||||
lzrcnt = create_default_counts()
|
||||
lzrcnt = LazerUserCounts(user_id=user_id)
|
||||
|
||||
# 获取指定模式的统计信息
|
||||
user_stats = None
|
||||
for stat in db_user.lazer_statistics:
|
||||
if stat.mode == ruleset:
|
||||
user_stats = stat
|
||||
break
|
||||
if db_user.lazer_statistics:
|
||||
for stat in db_user.lazer_statistics:
|
||||
if stat.mode == ruleset:
|
||||
user_stats = stat
|
||||
break
|
||||
|
||||
if not user_stats:
|
||||
# 如果没有找到指定模式的统计,创建默认统计
|
||||
user_stats = create_default_lazer_statistics(ruleset)
|
||||
user_stats = LazerUserStatistics(user_id=user_id)
|
||||
|
||||
# 获取国家信息
|
||||
country_code = db_user.country_code if db_user.country_code is not None else "XX"
|
||||
@@ -66,7 +73,7 @@ def convert_db_user_to_api_user(
|
||||
kudosu = Kudosu(available=0, total=0)
|
||||
|
||||
# 获取计数信息
|
||||
counts = create_default_counts()
|
||||
counts = LazerUserCounts(user_id=user_id)
|
||||
|
||||
# 转换统计信息
|
||||
statistics = Statistics(
|
||||
@@ -199,12 +206,7 @@ def convert_db_user_to_api_user(
|
||||
team = None
|
||||
if db_user.team_membership:
|
||||
team_member = db_user.team_membership[0] # 假设用户只属于一个团队
|
||||
team = Team(
|
||||
flag_url=team_member.team.flag_url or "",
|
||||
id=team_member.team.id,
|
||||
name=team_member.team.name,
|
||||
short_name=team_member.team.short_name,
|
||||
)
|
||||
team = team_member.team
|
||||
|
||||
# 创建用户对象
|
||||
# 从db_user获取基本字段值
|
||||
@@ -229,27 +231,25 @@ def convert_db_user_to_api_user(
|
||||
avatar_url = str(db_user.avatar.r2_original_url)
|
||||
|
||||
# 如果还是没有找到,通过查询获取
|
||||
if db_session and avatar_url is None:
|
||||
try:
|
||||
# 导入UserAvatar模型
|
||||
from app.database import UserAvatar
|
||||
# if db_session and avatar_url is None:
|
||||
# try:
|
||||
# # 导入UserAvatar模型
|
||||
|
||||
# 尝试查找用户的头像记录
|
||||
avatar_record = (
|
||||
db_session.query(UserAvatar)
|
||||
.filter_by(user_id=user_id, is_active=True)
|
||||
.first()
|
||||
)
|
||||
if avatar_record is not None:
|
||||
if avatar_record.r2_game_url is not None:
|
||||
# 优先使用游戏用的头像URL
|
||||
avatar_url = str(avatar_record.r2_game_url)
|
||||
elif avatar_record.r2_original_url is not None:
|
||||
# 其次使用原始头像URL
|
||||
avatar_url = str(avatar_record.r2_original_url)
|
||||
except Exception as e:
|
||||
print(f"获取用户头像时出错: {e}")
|
||||
print(f"最终头像URL: {avatar_url}")
|
||||
# # 尝试查找用户的头像记录
|
||||
# statement = select(UserAvatar).where(
|
||||
# UserAvatar.user_id == user_id, UserAvatar.is_active == True
|
||||
# )
|
||||
# avatar_record = db_session.exec(statement).first()
|
||||
# if avatar_record is not None:
|
||||
# if avatar_record.r2_game_url is not None:
|
||||
# # 优先使用游戏用的头像URL
|
||||
# avatar_url = str(avatar_record.r2_game_url)
|
||||
# elif avatar_record.r2_original_url is not None:
|
||||
# # 其次使用原始头像URL
|
||||
# avatar_url = str(avatar_record.r2_original_url)
|
||||
# except Exception as e:
|
||||
# print(f"获取用户头像时出错: {e}")
|
||||
# print(f"最终头像URL: {avatar_url}")
|
||||
# 如果仍然没有找到头像URL,则使用默认URL
|
||||
if avatar_url is None:
|
||||
avatar_url = "https://a.gu-osu.gmoe.cc/api/users/avatar/1"
|
||||
@@ -265,15 +265,12 @@ def convert_db_user_to_api_user(
|
||||
"kudosu",
|
||||
]
|
||||
if profile and profile.profile_order:
|
||||
profile_order = profile.profile_order
|
||||
profile_order = profile.profile_order.split(",")
|
||||
|
||||
# 在convert_db_user_to_api_user函数中添加active_tournament_banners处理
|
||||
active_tournament_banners = []
|
||||
if (
|
||||
hasattr(db_user, "lazer_tournament_banners")
|
||||
and db_user.lazer_tournament_banners
|
||||
):
|
||||
for banner in db_user.lazer_tournament_banners:
|
||||
if db_user.active_banners:
|
||||
for banner in db_user.active_banners:
|
||||
active_tournament_banners.append(
|
||||
{
|
||||
"tournament_id": banner.tournament_id,
|
||||
@@ -284,7 +281,7 @@ def convert_db_user_to_api_user(
|
||||
|
||||
# 在convert_db_user_to_api_user函数中添加badges处理
|
||||
badges = []
|
||||
if hasattr(db_user, "lazer_badges") and db_user.lazer_badges:
|
||||
if db_user.lazer_badges:
|
||||
for badge in db_user.lazer_badges:
|
||||
badges.append(
|
||||
{
|
||||
@@ -298,10 +295,7 @@ def convert_db_user_to_api_user(
|
||||
|
||||
# 在convert_db_user_to_api_user函数中添加monthly_playcounts处理
|
||||
monthly_playcounts = []
|
||||
if (
|
||||
hasattr(db_user, "lazer_monthly_playcounts")
|
||||
and db_user.lazer_monthly_playcounts
|
||||
):
|
||||
if db_user.lazer_monthly_playcounts:
|
||||
for playcount in db_user.lazer_monthly_playcounts:
|
||||
monthly_playcounts.append(
|
||||
{
|
||||
@@ -314,10 +308,7 @@ def convert_db_user_to_api_user(
|
||||
|
||||
# 在convert_db_user_to_api_user函数中添加previous_usernames处理
|
||||
previous_usernames = []
|
||||
if (
|
||||
hasattr(db_user, "lazer_previous_usernames")
|
||||
and db_user.lazer_previous_usernames
|
||||
):
|
||||
if db_user.lazer_previous_usernames:
|
||||
for username in db_user.lazer_previous_usernames:
|
||||
previous_usernames.append(
|
||||
{
|
||||
@@ -348,22 +339,22 @@ def convert_db_user_to_api_user(
|
||||
avatar_url=avatar_url,
|
||||
country_code=str(country_code),
|
||||
default_group=profile.default_group if profile else "default",
|
||||
is_active=profile.is_active if profile else True,
|
||||
is_bot=profile.is_bot if profile else False,
|
||||
is_deleted=profile.is_deleted if profile else False,
|
||||
is_online=profile.is_online if profile else True,
|
||||
is_supporter=profile.is_supporter if profile else False,
|
||||
is_restricted=profile.is_restricted if profile else False,
|
||||
last_visit=db_user.last_visit if db_user.last_visit else None,
|
||||
pm_friends_only=profile.pm_friends_only if profile else False,
|
||||
profile_colour=profile.profile_colour if profile else None,
|
||||
is_active=profile.is_active,
|
||||
is_bot=profile.is_bot,
|
||||
is_deleted=profile.is_deleted,
|
||||
is_online=profile.is_online,
|
||||
is_supporter=profile.is_supporter,
|
||||
is_restricted=profile.is_restricted,
|
||||
last_visit=db_user.last_visit,
|
||||
pm_friends_only=profile.pm_friends_only,
|
||||
profile_colour=profile.profile_colour,
|
||||
cover_url=profile.cover_url
|
||||
if profile and profile.cover_url
|
||||
else "https://assets.ppy.sh/user-profile-covers/default.jpeg",
|
||||
discord=profile.discord if profile else None,
|
||||
has_supported=profile.has_supported if profile else False,
|
||||
interests=profile.interests if profile else None,
|
||||
join_date=profile.join_date,
|
||||
join_date=profile.join_date if profile.join_date else datetime.now(UTC),
|
||||
location=profile.location if profile else None,
|
||||
max_blocks=profile.max_blocks if profile and profile.max_blocks else 100,
|
||||
max_friends=profile.max_friends if profile and profile.max_friends else 500,
|
||||
@@ -408,7 +399,7 @@ def convert_db_user_to_api_user(
|
||||
daily_challenge_user_stats=None,
|
||||
groups=[],
|
||||
monthly_playcounts=monthly_playcounts,
|
||||
page=Page(html=profile.page_html, raw=profile.page_raw)
|
||||
page=Page(html=profile.page_html or "", raw=profile.page_raw or "")
|
||||
if profile.page_html or profile.page_raw
|
||||
else Page(),
|
||||
previous_usernames=previous_usernames,
|
||||
@@ -439,164 +430,3 @@ def get_country_name(country_code: str) -> str:
|
||||
# 可以添加更多国家
|
||||
}
|
||||
return country_names.get(country_code, "Unknown")
|
||||
|
||||
|
||||
def create_default_profile(db_user: DBUser):
|
||||
"""创建默认的用户资料"""
|
||||
|
||||
# 完善 MockProfile 类定义
|
||||
class MockProfile:
|
||||
def __init__(self):
|
||||
self.is_active = True
|
||||
self.is_bot = False
|
||||
self.is_deleted = False
|
||||
self.is_online = True
|
||||
self.is_supporter = False
|
||||
self.is_restricted = False
|
||||
self.session_verified = False
|
||||
self.has_supported = False
|
||||
self.pm_friends_only = False
|
||||
self.default_group = "default"
|
||||
self.last_visit = None
|
||||
self.join_date = db_user.join_date if db_user else datetime.utcnow()
|
||||
self.profile_colour = None
|
||||
self.profile_hue = None
|
||||
self.avatar_url = None
|
||||
self.cover_url = None
|
||||
self.discord = None
|
||||
self.twitter = None
|
||||
self.website = None
|
||||
self.title = None
|
||||
self.title_url = None
|
||||
self.interests = None
|
||||
self.location = None
|
||||
self.occupation = None
|
||||
self.playmode = "osu"
|
||||
self.support_level = 0
|
||||
self.max_blocks = 100
|
||||
self.max_friends = 500
|
||||
self.post_count = 0
|
||||
# 添加profile_order字段
|
||||
self.profile_order = [
|
||||
"me",
|
||||
"recent_activity",
|
||||
"top_ranks",
|
||||
"medals",
|
||||
"historical",
|
||||
"beatmaps",
|
||||
"kudosu",
|
||||
]
|
||||
self.page_html = ""
|
||||
self.page_raw = ""
|
||||
# 在MockProfile类中添加active_tournament_banners字段
|
||||
self.active_tournament_banners = (
|
||||
MockLazerTournamentBanner.create_default_banners()
|
||||
)
|
||||
self.active_tournament_banners = [] # 默认空列表
|
||||
|
||||
return MockProfile()
|
||||
|
||||
|
||||
def create_default_lazer_statistics(mode: str):
|
||||
"""创建默认的 Lazer 统计信息"""
|
||||
|
||||
class MockLazerStatistics:
|
||||
def __init__(self, mode: str):
|
||||
self.mode = mode
|
||||
self.count_100 = 0
|
||||
self.count_300 = 0
|
||||
self.count_50 = 0
|
||||
self.count_miss = 0
|
||||
self.level_current = 1
|
||||
self.level_progress = 0
|
||||
self.global_rank = None
|
||||
self.global_rank_exp = None
|
||||
self.pp = 0.0
|
||||
self.pp_exp = 0.0
|
||||
self.ranked_score = 0
|
||||
self.hit_accuracy = 0.0
|
||||
self.total_score = 0
|
||||
self.total_hits = 0
|
||||
self.maximum_combo = 0
|
||||
self.play_count = 0
|
||||
self.play_time = 0
|
||||
self.replays_watched_by_others = 0
|
||||
self.is_ranked = False
|
||||
self.grade_ss = 0
|
||||
self.grade_ssh = 0
|
||||
self.grade_s = 0
|
||||
self.grade_sh = 0
|
||||
self.grade_a = 0
|
||||
self.country_rank = None
|
||||
self.rank_highest = None
|
||||
self.rank_highest_updated_at = None
|
||||
|
||||
return MockLazerStatistics(mode)
|
||||
|
||||
|
||||
def create_default_country(country_code: str):
|
||||
"""创建默认的国家信息"""
|
||||
|
||||
class MockCountry:
|
||||
def __init__(self, code: str):
|
||||
self.code = code
|
||||
self.name = get_country_name(code)
|
||||
|
||||
return MockCountry(country_code)
|
||||
|
||||
|
||||
def create_default_kudosu():
|
||||
"""创建默认的 Kudosu 信息"""
|
||||
|
||||
class MockKudosu:
|
||||
def __init__(self):
|
||||
self.available = 0
|
||||
self.total = 0
|
||||
|
||||
return MockKudosu()
|
||||
|
||||
|
||||
def create_default_counts():
|
||||
"""创建默认的计数信息"""
|
||||
|
||||
class MockCounts:
|
||||
def __init__(self):
|
||||
self.recent_scores_count = None
|
||||
self.beatmap_playcounts_count = 0
|
||||
self.scores_first_count = 0
|
||||
self.scores_pinned_count = 0
|
||||
self.comments_count = 0
|
||||
self.favourite_beatmapset_count = 0
|
||||
self.follower_count = 0
|
||||
self.graveyard_beatmapset_count = 0
|
||||
self.guest_beatmapset_count = 0
|
||||
self.loved_beatmapset_count = 0
|
||||
self.mapping_follower_count = 0
|
||||
self.nominated_beatmapset_count = 0
|
||||
self.pending_beatmapset_count = 0
|
||||
self.ranked_beatmapset_count = 0
|
||||
self.ranked_and_approved_beatmapset_count = 0
|
||||
self.unranked_beatmapset_count = 0
|
||||
self.scores_best_count = 0
|
||||
self.scores_first_count = 0
|
||||
self.scores_pinned_count = 0
|
||||
self.scores_recent_count = 0
|
||||
|
||||
return MockCounts()
|
||||
|
||||
|
||||
class MockLazerTournamentBanner:
|
||||
def __init__(self, tournament_id: int, image_url: str, is_active: bool = True):
|
||||
self.tournament_id = tournament_id
|
||||
self.image_url = image_url
|
||||
self.is_active = is_active
|
||||
|
||||
@staticmethod
|
||||
def create_default_banners():
|
||||
"""创建默认的锦标赛横幅配置"""
|
||||
return [
|
||||
MockLazerTournamentBanner(1, "https://example.com/banner1.jpg", True),
|
||||
MockLazerTournamentBanner(2, "https://example.com/banner2.jpg", False),
|
||||
]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user