fix: make database async

This commit is contained in:
beerpsi
2024-11-14 12:36:22 +07:00
parent 1331d473c9
commit bc7524c8fc
9 changed files with 297 additions and 154 deletions

View File

@@ -1,8 +1,14 @@
from __future__ import with_statement
from alembic import context
from sqlalchemy import engine_from_config, pool
import asyncio
import threading
from logging.config import fileConfig
from alembic import context
from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import async_engine_from_config
from core.data.schema.base import metadata
# this is the Alembic Config object, which provides
@@ -37,20 +43,29 @@ def run_migrations_offline():
script output.
"""
raise Exception('Not implemented or configured!')
raise Exception("Not implemented or configured!")
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url, target_metadata=target_metadata, literal_binds=True)
context.configure(url=url, target_metadata=target_metadata, literal_binds=True)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online():
"""Run migrations in 'online' mode.
def do_run_migrations(connection: Connection) -> None:
context.configure(
connection=connection,
target_metadata=target_metadata,
compare_type=True,
compare_server_default=True,
)
In this scenario we need to create an Engine
with context.begin_transaction():
context.run_migrations()
async def run_async_migrations() -> None:
"""In this scenario we need to create an Engine
and associate a connection with the context.
"""
@@ -59,21 +74,32 @@ def run_migrations_online():
for override in overrides:
ini_section[override] = overrides[override]
connectable = engine_from_config(
ini_section,
prefix='sqlalchemy.',
poolclass=pool.NullPool)
connectable = async_engine_from_config(
ini_section, prefix="sqlalchemy.", poolclass=pool.NullPool
)
with connectable.connect() as connection:
context.configure(
connection=connection,
target_metadata=target_metadata,
compare_type=True,
compare_server_default=True,
)
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
await connectable.dispose()
def run_migrations_online():
try:
loop = asyncio.get_running_loop()
except RuntimeError:
# there's no event loop
asyncio.run(run_async_migrations())
else:
# there's currently an event loop and trying to wait for a coroutine
# to finish without using `await` is pretty wormy. nested event loops
# are explicitly forbidden by asyncio.
#
# take the easy way out, spawn it in another thread.
thread = threading.Thread(target=asyncio.run, args=(run_async_migrations(),))
thread.start()
thread.join()
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()

View File

@@ -1,54 +1,65 @@
import logging, coloredlogs
from typing import Optional
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy import create_engine
from logging.handlers import TimedRotatingFileHandler
import asyncio
import logging
import os
import secrets, string
import bcrypt
import secrets
import string
import warnings
from hashlib import sha256
from logging.handlers import TimedRotatingFileHandler
from typing import ClassVar, Optional
import alembic.config
import glob
import bcrypt
import coloredlogs
import pymysql.err
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_scoped_session,
create_async_engine,
)
from sqlalchemy.orm import sessionmaker
from core.config import CoreConfig
from core.data.schema import *
from core.utils import Utils
from core.data.schema import ArcadeData, BaseData, CardData, UserData, metadata
from core.utils import MISSING, Utils
class Data:
engine = None
session = None
user = None
arcade = None
card = None
base = None
engine: ClassVar[AsyncEngine] = MISSING
session: ClassVar[AsyncSession] = MISSING
user: ClassVar[UserData] = MISSING
arcade: ClassVar[ArcadeData] = MISSING
card: ClassVar[CardData] = MISSING
base: ClassVar[BaseData] = MISSING
def __init__(self, cfg: CoreConfig) -> None:
self.config = cfg
if self.config.database.sha2_password:
passwd = sha256(self.config.database.password.encode()).digest()
self.__url = f"{self.config.database.protocol}://{self.config.database.username}:{passwd.hex()}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4&ssl={str(self.config.database.ssl_enabled).lower()}"
self.__url = f"{self.config.database.protocol}+aiomysql://{self.config.database.username}:{passwd.hex()}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4&ssl={str(self.config.database.ssl_enabled).lower()}"
else:
self.__url = f"{self.config.database.protocol}://{self.config.database.username}:{self.config.database.password}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4&ssl={str(self.config.database.ssl_enabled).lower()}"
self.__url = f"{self.config.database.protocol}+aiomysql://{self.config.database.username}:{self.config.database.password}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4&ssl={str(self.config.database.ssl_enabled).lower()}"
if Data.engine is None:
Data.engine = create_engine(self.__url, pool_recycle=3600)
if Data.engine is MISSING:
Data.engine = create_async_engine(self.__url, pool_recycle=3600, isolation_level="AUTOCOMMIT")
self.__engine = Data.engine
if Data.session is None:
s = sessionmaker(bind=Data.engine, autoflush=True, autocommit=True)
Data.session = scoped_session(s)
if Data.session is MISSING:
s = sessionmaker(Data.engine, expire_on_commit=False, class_=AsyncSession)
Data.session = async_scoped_session(s, asyncio.current_task)
if Data.user is None:
if Data.user is MISSING:
Data.user = UserData(self.config, self.session)
if Data.arcade is None:
if Data.arcade is MISSING:
Data.arcade = ArcadeData(self.config, self.session)
if Data.card is None:
if Data.card is MISSING:
Data.card = CardData(self.config, self.session)
if Data.base is None:
if Data.base is MISSING:
Data.base = BaseData(self.config, self.session)
self.logger = logging.getLogger("database")
@@ -94,40 +105,73 @@ class Data:
alembic.config.main(argv=alembicArgs)
os.chdir(old_dir)
def create_database(self):
async def create_database(self):
self.logger.info("Creating databases...")
metadata.create_all(
self.engine,
checkfirst=True,
)
for _, mod in Utils.get_all_titles().items():
if hasattr(mod, "database"):
mod.database(self.config)
metadata.create_all(
self.engine,
checkfirst=True,
)
with warnings.catch_warnings():
# SQLAlchemy will generate a nice primary key constraint name, but in
# MySQL/MariaDB the constraint name is always PRIMARY. Every time a
# custom primary key name is generated, a warning is emitted from pymysql,
# which we don't care about. Other warnings may be helpful though, don't
# suppress everything.
warnings.filterwarnings(
action="ignore",
message=r"Name '(.+)' ignored for PRIMARY key\.",
category=pymysql.err.Warning,
)
# Stamp the end revision as if alembic had created it, so it can take off after this.
self.__alembic_cmd(
"stamp",
"head",
)
async with self.engine.begin() as conn:
await conn.run_sync(metadata.create_all, checkfirst=True)
def schema_upgrade(self, ver: str = None):
self.__alembic_cmd(
"upgrade",
"head" if not ver else ver,
)
for _, mod in Utils.get_all_titles().items():
if hasattr(mod, "database"):
mod.database(self.config)
await conn.run_sync(metadata.create_all, checkfirst=True)
# Stamp the end revision as if alembic had created it, so it can take off after this.
self.__alembic_cmd(
"stamp",
"head",
)
def schema_upgrade(self, ver: Optional[str] = None):
with warnings.catch_warnings():
# SQLAlchemy will generate a nice primary key constraint name, but in
# MySQL/MariaDB the constraint name is always PRIMARY. Every time a
# custom primary key name is generated, a warning is emitted from pymysql,
# which we don't care about. Other warnings may be helpful though, don't
# suppress everything.
warnings.filterwarnings(
action="ignore",
message=r"Name '(.+)' ignored for PRIMARY key\.",
category=pymysql.err.Warning,
)
self.__alembic_cmd(
"upgrade",
"head" if not ver else ver,
)
def schema_downgrade(self, ver: str):
self.__alembic_cmd(
"downgrade",
ver,
)
with warnings.catch_warnings():
# SQLAlchemy will generate a nice primary key constraint name, but in
# MySQL/MariaDB the constraint name is always PRIMARY. Every time a
# custom primary key name is generated, a warning is emitted from pymysql,
# which we don't care about. Other warnings may be helpful though, don't
# suppress everything.
warnings.filterwarnings(
action="ignore",
message=r"Name '(.+)' ignored for PRIMARY key\.",
category=pymysql.err.Warning,
)
async def create_owner(self, email: Optional[str] = None, code: Optional[str] = "00000000000000000000") -> None:
self.__alembic_cmd(
"downgrade",
ver,
)
async def create_owner(self, email: Optional[str] = None, code: str = "00000000000000000000") -> None:
pw = "".join(
secrets.choice(string.ascii_letters + string.digits) for i in range(20)
)
@@ -150,12 +194,12 @@ class Data:
async def migrate(self) -> None:
exist = await self.base.execute("SELECT * FROM alembic_version")
if exist is not None:
self.logger.warn("No need to migrate as you have already migrated to alembic. If you are trying to upgrade the schema, use `upgrade` instead!")
self.logger.warning("No need to migrate as you have already migrated to alembic. If you are trying to upgrade the schema, use `upgrade` instead!")
return
self.logger.info("Upgrading to latest with legacy system")
if not await self.legacy_upgrade():
self.logger.warn("No need to migrate as you have already deleted the old schema_versions system. If you are trying to upgrade the schema, use `upgrade` instead!")
self.logger.warning("No need to migrate as you have already deleted the old schema_versions system. If you are trying to upgrade the schema, use `upgrade` instead!")
return
self.logger.info("Done")

View File

@@ -1,16 +1,16 @@
from typing import Optional, Dict, List
from sqlalchemy import Table, Column, and_, or_
from sqlalchemy.sql.schema import ForeignKey, PrimaryKeyConstraint
from sqlalchemy.types import Integer, String, Boolean, JSON
from sqlalchemy.sql import func, select
import re
from typing import List, Optional
from sqlalchemy import Column, Table, and_, or_
from sqlalchemy.dialects.mysql import insert
from sqlalchemy.engine import Row
import re
from sqlalchemy.sql import func, select
from sqlalchemy.sql.schema import ForeignKey, PrimaryKeyConstraint
from sqlalchemy.types import JSON, Boolean, Integer, String
from core.data.schema.base import BaseData, metadata
from core.const import *
arcade = Table(
arcade: Table = Table(
"arcade",
metadata,
Column("id", Integer, primary_key=True, nullable=False),
@@ -26,7 +26,7 @@ arcade = Table(
mysql_charset="utf8mb4",
)
machine = Table(
machine: Table = Table(
"machine",
metadata,
Column("id", Integer, primary_key=True, nullable=False),
@@ -47,7 +47,7 @@ machine = Table(
mysql_charset="utf8mb4",
)
arcade_owner = Table(
arcade_owner: Table = Table(
"arcade_owner",
metadata,
Column(
@@ -69,7 +69,7 @@ arcade_owner = Table(
class ArcadeData(BaseData):
async def get_machine(self, serial: str = None, id: int = None) -> Optional[Row]:
async def get_machine(self, serial: Optional[str] = None, id: Optional[int] = None) -> Optional[Row]:
if serial is not None:
serial = serial.replace("-", "")
if len(serial) == 11:
@@ -98,8 +98,8 @@ class ArcadeData(BaseData):
self,
arcade_id: int,
serial: str = "",
board: str = None,
game: str = None,
board: Optional[str] = None,
game: Optional[str] = None,
is_cab: bool = False,
) -> Optional[int]:
if not arcade_id:
@@ -150,8 +150,8 @@ class ArcadeData(BaseData):
async def create_arcade(
self,
name: str = None,
nickname: str = None,
name: Optional[str] = None,
nickname: Optional[str] = None,
country: str = "JPN",
country_id: int = 1,
state: str = "",

View File

@@ -1,22 +1,23 @@
import asyncio
import json
import logging
from random import randrange
from typing import Any, Optional, Dict, List
from typing import Any, Dict, List, Optional
from sqlalchemy import Column, MetaData, Table
from sqlalchemy.engine import Row
from sqlalchemy.engine.cursor import CursorResult
from sqlalchemy.engine.base import Connection
from sqlalchemy.sql import text, func, select
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy import MetaData, Table, Column
from sqlalchemy.types import Integer, String, TIMESTAMP, JSON, INTEGER, TEXT
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.schema import ForeignKey
from sqlalchemy.dialects.mysql import insert
from sqlalchemy.sql import func, text
from sqlalchemy.types import INTEGER, JSON, TEXT, TIMESTAMP, Integer, String
from core.config import CoreConfig
metadata = MetaData()
event_log = Table(
event_log: Table = Table(
"event_log",
metadata,
Column("id", Integer, primary_key=True, nullable=False),
@@ -37,7 +38,7 @@ event_log = Table(
class BaseData:
def __init__(self, cfg: CoreConfig, conn: Connection) -> None:
def __init__(self, cfg: CoreConfig, conn: AsyncSession) -> None:
self.config = cfg
self.conn = conn
self.logger = logging.getLogger("database")
@@ -47,7 +48,7 @@ class BaseData:
try:
self.logger.debug(f"SQL Execute: {''.join(str(sql).splitlines())}")
res = self.conn.execute(text(sql), opts)
res = await self.conn.execute(text(sql), opts)
except SQLAlchemyError as e:
self.logger.error(f"SQLAlchemy error {e}")
@@ -59,7 +60,7 @@ class BaseData:
except Exception:
try:
res = self.conn.execute(sql, opts)
res = await self.conn.execute(sql, opts)
except SQLAlchemyError as e:
self.logger.error(f"SQLAlchemy error {e}")
@@ -83,7 +84,7 @@ class BaseData:
async def log_event(
self, system: str, type: str, severity: int, message: str, details: Dict = {}, user: int = None,
arcade: int = None, machine: int = None, ip: str = None, game: str = None, version: str = None
arcade: int = None, machine: int = None, ip: Optional[str] = None, game: Optional[str] = None, version: Optional[str] = None
) -> Optional[int]:
sql = event_log.insert().values(
system=system,

View File

@@ -1,13 +1,14 @@
from typing import Dict, List, Optional
from sqlalchemy import Table, Column, UniqueConstraint
from sqlalchemy.types import Integer, String, Boolean, TIMESTAMP, BIGINT, VARCHAR
from sqlalchemy.sql.schema import ForeignKey
from sqlalchemy.sql import func
from sqlalchemy import Column, Table, UniqueConstraint
from sqlalchemy.engine import Row
from sqlalchemy.sql import func
from sqlalchemy.sql.schema import ForeignKey
from sqlalchemy.types import BIGINT, TIMESTAMP, VARCHAR, Boolean, Integer, String
from core.data.schema.base import BaseData, metadata
aime_card = Table(
aime_card: Table = Table(
"aime_card",
metadata,
Column("id", Integer, primary_key=True, nullable=False),

View File

@@ -1,15 +1,15 @@
from typing import Optional, List
from sqlalchemy import Table, Column
from sqlalchemy.types import Integer, String, TIMESTAMP
from sqlalchemy.sql import func
from sqlalchemy.dialects.mysql import insert
from sqlalchemy.sql import func, select
from sqlalchemy.engine import Row
from typing import List, Optional
import bcrypt
from sqlalchemy import Column, Table
from sqlalchemy.dialects.mysql import insert
from sqlalchemy.engine import Row
from sqlalchemy.sql import func, select
from sqlalchemy.types import TIMESTAMP, Integer, String
from core.data.schema.base import BaseData, metadata
aime_user = Table(
aime_user: Table = Table(
"aime_user",
metadata,
Column("id", Integer, nullable=False, primary_key=True, autoincrement=True),
@@ -26,10 +26,10 @@ aime_user = Table(
class UserData(BaseData):
async def create_user(
self,
id: int = None,
username: str = None,
email: str = None,
password: str = None,
id: Optional[int] = None,
username: Optional[str] = None,
email: Optional[str] = None,
password: Optional[str] = None,
permission: int = 1,
) -> Optional[int]:
if id is None: