refactor(app): update database code
This commit is contained in:
@@ -11,7 +11,7 @@ from app.database.chat import (
|
||||
UserSilenceResp,
|
||||
)
|
||||
from app.database.lazer_user import User, UserResp
|
||||
from app.dependencies.database import get_db, get_redis
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.param import BodyOrForm
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.router.v2 import api_v2_router as router
|
||||
@@ -22,7 +22,6 @@ from fastapi import Depends, HTTPException, Path, Query, Security
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
class UpdateResponse(BaseModel):
|
||||
@@ -38,6 +37,7 @@ class UpdateResponse(BaseModel):
|
||||
tags=["聊天"],
|
||||
)
|
||||
async def get_update(
|
||||
session: Database,
|
||||
history_since: int | None = Query(
|
||||
None, description="获取自此禁言 ID 之后的禁言记录"
|
||||
),
|
||||
@@ -46,7 +46,6 @@ async def get_update(
|
||||
["presence", "silences"], alias="includes[]", description="要包含的更新类型"
|
||||
),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
resp = UpdateResponse()
|
||||
@@ -101,10 +100,10 @@ async def get_update(
|
||||
tags=["聊天"],
|
||||
)
|
||||
async def join_channel(
|
||||
session: Database,
|
||||
channel: str = Path(..., description="频道 ID/名称"),
|
||||
user: str = Path(..., description="用户 ID"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
db_channel = await ChatChannel.get(channel, session)
|
||||
|
||||
@@ -121,10 +120,10 @@ async def join_channel(
|
||||
tags=["聊天"],
|
||||
)
|
||||
async def leave_channel(
|
||||
session: Database,
|
||||
channel: str = Path(..., description="频道 ID/名称"),
|
||||
user: str = Path(..., description="用户 ID"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
db_channel = await ChatChannel.get(channel, session)
|
||||
|
||||
@@ -142,8 +141,8 @@ async def leave_channel(
|
||||
tags=["聊天"],
|
||||
)
|
||||
async def get_channel_list(
|
||||
session: Database,
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
channels = (
|
||||
@@ -181,9 +180,9 @@ class GetChannelResp(BaseModel):
|
||||
tags=["聊天"],
|
||||
)
|
||||
async def get_channel(
|
||||
session: Database,
|
||||
channel: str = Path(..., description="频道 ID/名称"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
db_channel = await ChatChannel.get(channel, session)
|
||||
@@ -250,9 +249,9 @@ class CreateChannelReq(BaseModel):
|
||||
tags=["聊天"],
|
||||
)
|
||||
async def create_channel(
|
||||
session: Database,
|
||||
req: CreateChannelReq = Depends(BodyOrForm(CreateChannelReq)),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.write_manage"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
if req.type == "PM":
|
||||
|
||||
@@ -11,7 +11,7 @@ from app.database.chat import (
|
||||
UserSilenceResp,
|
||||
)
|
||||
from app.database.lazer_user import User
|
||||
from app.dependencies.database import get_db, get_redis
|
||||
from app.dependencies.database import Database, get_redis
|
||||
from app.dependencies.param import BodyOrForm
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.router.v2 import api_v2_router as router
|
||||
@@ -23,7 +23,6 @@ from fastapi import Depends, HTTPException, Path, Query, Security
|
||||
from pydantic import BaseModel, Field
|
||||
from redis.asyncio import Redis
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
class KeepAliveResp(BaseModel):
|
||||
@@ -38,12 +37,12 @@ class KeepAliveResp(BaseModel):
|
||||
tags=["聊天"],
|
||||
)
|
||||
async def keep_alive(
|
||||
session: Database,
|
||||
history_since: int | None = Query(
|
||||
None, description="获取自此禁言 ID 之后的禁言记录"
|
||||
),
|
||||
since: int | None = Query(None, description="获取自此消息 ID 之后的禁言记录"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
resp = KeepAliveResp()
|
||||
if history_since:
|
||||
@@ -84,10 +83,10 @@ class MessageReq(BaseModel):
|
||||
tags=["聊天"],
|
||||
)
|
||||
async def send_message(
|
||||
session: Database,
|
||||
channel: str = Path(..., description="频道 ID/名称"),
|
||||
req: MessageReq = Depends(BodyOrForm(MessageReq)),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.write"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
db_channel = await ChatChannel.get(channel, session)
|
||||
if db_channel is None:
|
||||
@@ -125,12 +124,12 @@ async def send_message(
|
||||
tags=["聊天"],
|
||||
)
|
||||
async def get_message(
|
||||
session: Database,
|
||||
channel: str,
|
||||
limit: int = Query(50, ge=1, le=50, description="获取消息的数量"),
|
||||
since: int = Query(default=0, ge=0, description="获取自此消息 ID 之后的消息记录"),
|
||||
until: int | None = Query(None, description="获取自此消息 ID 之前的消息记录"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
db_channel = await ChatChannel.get(channel, session)
|
||||
if db_channel is None:
|
||||
@@ -158,10 +157,10 @@ async def get_message(
|
||||
tags=["聊天"],
|
||||
)
|
||||
async def mark_as_read(
|
||||
session: Database,
|
||||
channel: str = Path(..., description="频道 ID/名称"),
|
||||
message: int = Path(..., description="消息 ID"),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.read"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
):
|
||||
db_channel = await ChatChannel.get(channel, session)
|
||||
if db_channel is None:
|
||||
@@ -191,9 +190,9 @@ class NewPMResp(BaseModel):
|
||||
tags=["聊天"],
|
||||
)
|
||||
async def create_new_pm(
|
||||
session: Database,
|
||||
req: PMReq = Depends(BodyOrForm(PMReq)),
|
||||
current_user: User = Security(get_current_user, scopes=["chat.write"]),
|
||||
session: AsyncSession = Depends(get_db),
|
||||
redis: Redis = Depends(get_redis),
|
||||
):
|
||||
user_id = current_user.id
|
||||
|
||||
@@ -6,9 +6,9 @@ from app.database.chat import ChannelType, ChatChannel, ChatChannelResp, ChatMes
|
||||
from app.database.lazer_user import User
|
||||
from app.dependencies.database import (
|
||||
DBFactory,
|
||||
engine,
|
||||
get_db_factory,
|
||||
get_redis,
|
||||
with_db,
|
||||
)
|
||||
from app.dependencies.user import get_current_user
|
||||
from app.log import logger
|
||||
@@ -200,7 +200,7 @@ class ChatServer:
|
||||
)
|
||||
|
||||
async def join_room_channel(self, channel_id: int, user_id: int):
|
||||
async with AsyncSession(engine) as session:
|
||||
async with with_db() as session:
|
||||
channel = await ChatChannel.get(channel_id, session)
|
||||
if channel is None:
|
||||
return
|
||||
@@ -212,7 +212,7 @@ class ChatServer:
|
||||
await self.join_channel(user, channel, session)
|
||||
|
||||
async def leave_room_channel(self, channel_id: int, user_id: int):
|
||||
async with AsyncSession(engine) as session:
|
||||
async with with_db() as session:
|
||||
channel = await ChatChannel.get(channel_id, session)
|
||||
if channel is None:
|
||||
return
|
||||
@@ -268,7 +268,7 @@ async def chat_websocket(
|
||||
token = authorization[7:]
|
||||
if (
|
||||
user := await get_current_user(
|
||||
SecurityScopes(scopes=["chat.read"]), session, token_pw=token
|
||||
session, SecurityScopes(scopes=["chat.read"]), token_pw=token
|
||||
)
|
||||
) is None:
|
||||
await websocket.close(code=1008)
|
||||
|
||||
Reference in New Issue
Block a user