This commit is contained in:
咕谷酱
2025-08-22 00:07:19 +08:00
parent bade8658ed
commit 80d4237c5d
22 changed files with 423 additions and 356 deletions

View File

@@ -4,7 +4,6 @@ import asyncio
from datetime import datetime
from typing import Literal
from app.config import settings
from app.database.lazer_user import User
from app.database.statistics import UserStatistics, UserStatisticsResp
from app.dependencies.database import Database, get_redis
@@ -56,7 +55,7 @@ class V1User(AllStrModel):
# 确保 user_id 不为 None
if db_user.id is None:
raise ValueError("User ID cannot be None")
ruleset = ruleset or db_user.playmode
current_statistics: UserStatistics | None = None
for i in await db_user.awaitable_attrs.statistics:
@@ -118,26 +117,28 @@ async def get_user(
):
redis = get_redis()
cache_service = get_user_cache_service(redis)
# 确定查询方式和用户ID
is_id_query = type == "id" or user.isdigit()
# 解析 ruleset
ruleset = GameMode.from_int_extra(ruleset_id) if ruleset_id else None
# 如果是 ID 查询,先尝试从缓存获取
cached_v1_user = None
user_id_for_cache = None
if is_id_query:
try:
user_id_for_cache = int(user)
cached_v1_user = await cache_service.get_v1_user_from_cache(user_id_for_cache, ruleset)
cached_v1_user = await cache_service.get_v1_user_from_cache(
user_id_for_cache, ruleset
)
if cached_v1_user:
return [V1User(**cached_v1_user)]
except (ValueError, TypeError):
pass # 不是有效的用户ID继续数据库查询
# 从数据库查询用户
db_user = (
await session.exec(
@@ -146,23 +147,23 @@ async def get_user(
)
)
).first()
if not db_user:
return []
try:
# 生成用户数据
v1_user = await V1User.from_db(session, db_user, ruleset)
# 异步缓存结果如果有用户ID
if db_user.id is not None:
user_data = v1_user.model_dump()
asyncio.create_task(
cache_service.cache_v1_user(user_data, db_user.id, ruleset)
)
return [v1_user]
except KeyError:
raise HTTPException(400, "Invalid request")
except ValueError as e:

View File

@@ -2,15 +2,15 @@
缓存管理和监控接口
提供缓存统计、清理和预热功能
"""
from __future__ import annotations
from app.dependencies.database import get_redis
from app.dependencies.user import get_current_user
from app.service.user_cache_service import get_user_cache_service
from .router import router
from fastapi import Depends, HTTPException, Security
from fastapi import Depends, HTTPException
from pydantic import BaseModel
from redis.asyncio import Redis
@@ -34,7 +34,7 @@ async def get_cache_stats(
try:
cache_service = get_user_cache_service(redis)
user_cache_stats = await cache_service.get_cache_stats()
# 获取 Redis 基本信息
redis_info = await redis.info()
redis_stats = {
@@ -47,20 +47,17 @@ async def get_cache_stats(
"evicted_keys": redis_info.get("evicted_keys", 0),
"expired_keys": redis_info.get("expired_keys", 0),
}
# 计算缓存命中率
hits = redis_stats["keyspace_hits"]
misses = redis_stats["keyspace_misses"]
hit_rate = hits / (hits + misses) * 100 if (hits + misses) > 0 else 0
redis_stats["cache_hit_rate_percent"] = round(hit_rate, 2)
return CacheStatsResponse(
user_cache=user_cache_stats,
redis_info=redis_stats
)
return CacheStatsResponse(user_cache=user_cache_stats, redis_info=redis_stats)
except Exception as e:
raise HTTPException(500, f"Failed to get cache stats: {str(e)}")
raise HTTPException(500, f"Failed to get cache stats: {e!s}")
@router.post(
@@ -80,7 +77,7 @@ async def invalidate_user_cache(
await cache_service.invalidate_v1_user_cache(user_id)
return {"message": f"Cache invalidated for user {user_id}"}
except Exception as e:
raise HTTPException(500, f"Failed to invalidate cache: {str(e)}")
raise HTTPException(500, f"Failed to invalidate cache: {e!s}")
@router.post(
@@ -98,15 +95,15 @@ async def clear_all_user_cache(
user_keys = await redis.keys("user:*")
v1_user_keys = await redis.keys("v1_user:*")
all_keys = user_keys + v1_user_keys
if all_keys:
await redis.delete(*all_keys)
return {"message": f"Cleared {len(all_keys)} cache entries"}
else:
return {"message": "No cache entries found"}
except Exception as e:
raise HTTPException(500, f"Failed to clear cache: {str(e)}")
raise HTTPException(500, f"Failed to clear cache: {e!s}")
class CacheWarmupRequest(BaseModel):
@@ -127,18 +124,22 @@ async def warmup_cache(
):
try:
cache_service = get_user_cache_service(redis)
if request.user_ids:
# 预热指定用户
from app.dependencies.database import with_db
async with with_db() as session:
await cache_service.preload_user_cache(session, request.user_ids)
return {"message": f"Warmed up cache for {len(request.user_ids)} users"}
else:
# 预热活跃用户
from app.scheduler.user_cache_scheduler import schedule_user_cache_preload_task
from app.scheduler.user_cache_scheduler import (
schedule_user_cache_preload_task,
)
await schedule_user_cache_preload_task()
return {"message": f"Warmed up cache for top {request.limit} active users"}
except Exception as e:
raise HTTPException(500, f"Failed to warmup cache: {str(e)}")
raise HTTPException(500, f"Failed to warmup cache: {e!s}")

View File

@@ -45,24 +45,24 @@ async def get_country_ranking(
# 获取 Redis 连接和缓存服务
redis = get_redis()
cache_service = get_ranking_cache_service(redis)
# 尝试从缓存获取数据
cached_data = await cache_service.get_cached_country_ranking(ruleset, page)
if cached_data:
# 从缓存返回数据
return CountryResponse(
ranking=[CountryStatistics.model_validate(item) for item in cached_data]
)
# 缓存未命中,从数据库查询
response = CountryResponse(ranking=[])
countries = (await session.exec(select(User.country_code).distinct())).all()
for country in countries:
if not country: # 跳过空的国家代码
continue
statistics = (
await session.exec(
select(UserStatistics).where(
@@ -73,10 +73,10 @@ async def get_country_ranking(
)
)
).all()
if not statistics: # 跳过没有数据的国家
continue
pp = 0
country_stats = CountryStatistics(
code=country,
@@ -92,27 +92,28 @@ async def get_country_ranking(
pp += stat.pp
country_stats.performance = round(pp)
response.ranking.append(country_stats)
response.ranking.sort(key=lambda x: x.performance, reverse=True)
# 分页处理
page_size = 50
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
# 获取当前页的数据
current_page_data = response.ranking[start_idx:end_idx]
# 异步缓存数据(不等待完成)
cache_data = [item.model_dump() for item in current_page_data]
cache_task = cache_service.cache_country_ranking(
ruleset, cache_data, page, ttl=settings.ranking_cache_expire_minutes * 60
)
# 创建后台任务来缓存数据
import asyncio
asyncio.create_task(cache_task)
# 返回当前页的结果
response.ranking = current_page_data
return response
@@ -142,20 +143,16 @@ async def get_user_ranking(
# 获取 Redis 连接和缓存服务
redis = get_redis()
cache_service = get_ranking_cache_service(redis)
# 尝试从缓存获取数据
cached_data = await cache_service.get_cached_ranking(
ruleset, type, country, page
)
cached_data = await cache_service.get_cached_ranking(ruleset, type, country, page)
if cached_data:
# 从缓存返回数据
return TopUsersResponse(
ranking=[
UserStatisticsResp.model_validate(item) for item in cached_data
]
ranking=[UserStatisticsResp.model_validate(item) for item in cached_data]
)
# 缓存未命中,从数据库查询
wheres = [
col(UserStatistics.mode) == ruleset,
@@ -170,7 +167,7 @@ async def get_user_ranking(
order_by = col(UserStatistics.ranked_score).desc()
if country:
wheres.append(col(UserStatistics.user).has(country_code=country.upper()))
statistics_list = await session.exec(
select(UserStatistics)
.where(*wheres)
@@ -178,7 +175,7 @@ async def get_user_ranking(
.limit(50)
.offset(50 * (page - 1))
)
# 转换为响应格式
ranking_data = []
for statistics in statistics_list:
@@ -186,18 +183,24 @@ async def get_user_ranking(
statistics, session, None, include
)
ranking_data.append(user_stats_resp)
# 异步缓存数据(不等待完成)
# 使用配置文件中的TTL设置
cache_data = [item.model_dump() for item in ranking_data]
cache_task = cache_service.cache_ranking(
ruleset, type, cache_data, country, page, ttl=settings.ranking_cache_expire_minutes * 60
ruleset,
type,
cache_data,
country,
page,
ttl=settings.ranking_cache_expire_minutes * 60,
)
# 创建后台任务来缓存数据
import asyncio
asyncio.create_task(cache_task)
resp = TopUsersResponse(ranking=ranking_data)
return resp
@@ -328,4 +331,4 @@ async def get_ranking_cache_stats(
cache_service = get_ranking_cache_service(redis)
stats = await cache_service.get_cache_stats()
return stats """
return stats """

View File

@@ -183,7 +183,7 @@ async def submit_score(
}
db.add(rank_event)
await db.commit()
# 成绩提交后刷新用户缓存
try:
user_cache_service = get_user_cache_service(redis)
@@ -193,7 +193,7 @@ async def submit_score(
)
except Exception as e:
logger.error(f"Failed to refresh user cache after score submit: {e}")
background_task.add_task(process_user_achievement, resp.id)
return resp

View File

@@ -57,25 +57,27 @@ async def get_users(
):
redis = get_redis()
cache_service = get_user_cache_service(redis)
if user_ids:
# 先尝试从缓存获取
cached_users = []
uncached_user_ids = []
for user_id in user_ids[:50]: # 限制50个
cached_user = await cache_service.get_user_from_cache(user_id)
if cached_user:
cached_users.append(cached_user)
else:
uncached_user_ids.append(user_id)
# 查询未缓存的用户
if uncached_user_ids:
searched_users = (
await session.exec(select(User).where(col(User.id).in_(uncached_user_ids)))
await session.exec(
select(User).where(col(User.id).in_(uncached_user_ids))
)
).all()
# 将查询到的用户添加到缓存并返回
for searched_user in searched_users:
if searched_user.id != BANCHOBOT_ID:
@@ -87,7 +89,7 @@ async def get_users(
cached_users.append(user_resp)
# 异步缓存,不阻塞响应
asyncio.create_task(cache_service.cache_user(user_resp))
return BatchUserResponse(users=cached_users)
else:
searched_users = (await session.exec(select(User).limit(50))).all()
@@ -102,7 +104,7 @@ async def get_users(
users.append(user_resp)
# 异步缓存
asyncio.create_task(cache_service.cache_user(user_resp))
return BatchUserResponse(users=users)
@@ -121,14 +123,14 @@ async def get_user_info_ruleset(
):
redis = get_redis()
cache_service = get_user_cache_service(redis)
# 如果是数字ID先尝试从缓存获取
if user_id.isdigit():
user_id_int = int(user_id)
cached_user = await cache_service.get_user_from_cache(user_id_int, ruleset)
if cached_user:
return cached_user
searched_user = (
await session.exec(
select(User).where(
@@ -140,17 +142,17 @@ async def get_user_info_ruleset(
).first()
if not searched_user or searched_user.id == BANCHOBOT_ID:
raise HTTPException(404, detail="User not found")
user_resp = await UserResp.from_db(
searched_user,
session,
include=SEARCH_INCLUDED,
ruleset=ruleset,
)
# 异步缓存结果
asyncio.create_task(cache_service.cache_user(user_resp, ruleset))
return user_resp
@@ -169,14 +171,14 @@ async def get_user_info(
):
redis = get_redis()
cache_service = get_user_cache_service(redis)
# 如果是数字ID先尝试从缓存获取
if user_id.isdigit():
user_id_int = int(user_id)
cached_user = await cache_service.get_user_from_cache(user_id_int)
if cached_user:
return cached_user
searched_user = (
await session.exec(
select(User).where(
@@ -188,16 +190,16 @@ async def get_user_info(
).first()
if not searched_user or searched_user.id == BANCHOBOT_ID:
raise HTTPException(404, detail="User not found")
user_resp = await UserResp.from_db(
searched_user,
session,
include=SEARCH_INCLUDED,
)
# 异步缓存结果
asyncio.create_task(cache_service.cache_user(user_resp))
return user_resp
@@ -218,7 +220,7 @@ async def get_user_beatmapsets(
):
redis = get_redis()
cache_service = get_user_cache_service(redis)
# 先尝试从缓存获取
cached_result = await cache_service.get_user_beatmapsets_from_cache(
user_id, type.value, limit, offset
@@ -229,7 +231,7 @@ async def get_user_beatmapsets(
return [BeatmapPlaycountsResp(**item) for item in cached_result]
else:
return [BeatmapsetResp(**item) for item in cached_result]
user = await session.get(User, user_id)
if not user or user.id == BANCHOBOT_ID:
raise HTTPException(404, detail="User not found")
@@ -275,10 +277,14 @@ async def get_user_beatmapsets(
# 异步缓存结果
async def cache_beatmapsets():
try:
await cache_service.cache_user_beatmapsets(user_id, type.value, resp, limit, offset)
await cache_service.cache_user_beatmapsets(
user_id, type.value, resp, limit, offset
)
except Exception as e:
logger.error(f"Error caching user beatmapsets for user {user_id}, type {type.value}: {e}")
logger.error(
f"Error caching user beatmapsets for user {user_id}, type {type.value}: {e}"
)
asyncio.create_task(cache_beatmapsets())
return resp
@@ -311,7 +317,7 @@ async def get_user_scores(
):
redis = get_redis()
cache_service = get_user_cache_service(redis)
# 先尝试从缓存获取对于recent类型使用较短的缓存时间
cache_expire = 30 if type == "recent" else settings.user_scores_cache_expire_seconds
cached_scores = await cache_service.get_user_scores_from_cache(
@@ -319,7 +325,7 @@ async def get_user_scores(
)
if cached_scores is not None:
return cached_scores
db_user = await session.get(User, user_id)
if not db_user or db_user.id == BANCHOBOT_ID:
raise HTTPException(404, detail="User not found")
@@ -355,7 +361,7 @@ async def get_user_scores(
).all()
if not scores:
return []
score_responses = [
await ScoreResp.from_db(
session,
@@ -363,14 +369,14 @@ async def get_user_scores(
)
for score in scores
]
# 异步缓存结果
asyncio.create_task(
cache_service.cache_user_scores(
user_id, type, score_responses, mode, limit, offset, cache_expire
)
)
return score_responses