c04f76c0db
Fixed command & event threading Events are now threaded by event_type (i.e. all buddy_logon events get ran in the same one) Added default preferences Fixed recipe loading for multiple installs (i.e. on different machines)
278 lines
11 KiB
Python
278 lines
11 KiB
Python
import os
|
|
import re
|
|
import sys
|
|
import threading
|
|
import time
|
|
|
|
import mariadb
|
|
# noinspection PyProtectedMember
|
|
from mariadb._mariadb import ConnectionPool, OperationalError
|
|
from mysql.connector.cursor import CursorBase
|
|
from pkg_resources import parse_version
|
|
|
|
from conf.config import BotConfig
|
|
from core.decorators import instance
|
|
from core.dict_object import DictObject
|
|
from core.logger import Logger
|
|
|
|
|
|
@instance()
|
|
class DB:
|
|
MYSQL = "mysql"
|
|
MARIADB = "mariadb-pool"
|
|
|
|
def __init__(self):
|
|
self.pool_size = 4
|
|
# noinspection PyTypeChecker
|
|
self.pool: ConnectionPool = None
|
|
self.enhanced_like_regex = re.compile(r"(\s+)(\S+)\s+<EXTENDED_LIKE=(\d+)>\s+\?(\s*)", re.IGNORECASE)
|
|
self.lastrowid = None
|
|
self.logger = Logger(__name__)
|
|
self.type = None
|
|
self.lock = threading.Semaphore(self.pool_size)
|
|
self.transaction_level = 0
|
|
# noinspection PyTypeChecker
|
|
self.shared: DB = None
|
|
mod = __import__(f'conf.{sys.argv[1]}', fromlist=['BotConfig'])
|
|
config: BotConfig = getattr(mod, 'BotConfig')
|
|
self.name = config.character
|
|
|
|
def connect_mariadb(self, host, port, username, password, database_name):
|
|
self.type = self.MARIADB
|
|
self.connect_detail = {'host': host, 'port': port, 'user': username,
|
|
'password': password, 'database': database_name, 'autocommit': True}
|
|
self.pool = mariadb.ConnectionPool(pool_name=database_name, pool_size=self.pool_size,
|
|
pool_reset_connection=False,
|
|
host=host, port=port, user=username, password=password,
|
|
database=database_name, autocommit=True)
|
|
self.exec("SET collation_connection = 'utf8_general_ci'")
|
|
self.exec("SET sql_mode = 'TRADITIONAL,ANSI'")
|
|
self.create_db_version_table()
|
|
|
|
def create_db_version_table(self):
|
|
self.exec("CREATE TABLE IF NOT EXISTS db_version ("
|
|
"file VARCHAR(255) NOT NULL, "
|
|
"version VARCHAR(255) NOT NULL, "
|
|
"verified SMALLINT NOT NULL, "
|
|
"bot varchar(32))")
|
|
|
|
def _execute_wrapper(self, sql, params, callback):
|
|
with self.lock:
|
|
start_time = time.time()
|
|
if self.pool.pool_size < self.pool_size - 1:
|
|
self.pool.add_connection(mariadb.connect(**self.connect_detail))
|
|
|
|
with self.pool.get_connection() as conn:
|
|
conn.auto_reconnect = True
|
|
conn.autocommit = True
|
|
with conn.cursor(dictionary=True) as cur:
|
|
try:
|
|
string: str = sql.upper()
|
|
|
|
if string.__contains__("UPDATE ") or string.__contains__("INSERT "):
|
|
cur.execute("START TRANSACTION;")
|
|
cur.execute(sql.replace("?", "%s"), params)
|
|
if string.__contains__("UPDATE ") or string.__contains__("INSERT "):
|
|
conn.commit()
|
|
except Exception as e:
|
|
raise SqlException( f"SQL Error: '{str(e)}' for '{sql}' "
|
|
f"[{', '.join(map(lambda x: str(x), params))}]") from e
|
|
elapsed = time.time() - start_time
|
|
result = callback(cur)
|
|
if elapsed > 5:
|
|
self.logger.warning(f"slow query ({elapsed:f}s) '{sql}' for params: {str(params)}")
|
|
return result
|
|
|
|
def query_single(self, sql, params=None, extended_like=False) -> DictObject:
|
|
if params is None:
|
|
params = []
|
|
|
|
if extended_like:
|
|
sql, params = self.handle_extended_like(sql, params)
|
|
|
|
sql, params = self.format_sql(sql, params)
|
|
|
|
def map_result(cur):
|
|
row = cur.fetchone()
|
|
return DictObject(row) if row else None
|
|
|
|
return self._execute_wrapper(sql, params, map_result)
|
|
|
|
def query(self, sql, params=None, extended_like=False) -> list[DictObject]:
|
|
if params is None:
|
|
params = []
|
|
|
|
if extended_like:
|
|
sql, params = self.handle_extended_like(sql, params)
|
|
|
|
sql, params = self.format_sql(sql, params)
|
|
|
|
def map_result(cur):
|
|
return list(map(lambda row: DictObject(row), cur.fetchall()))
|
|
|
|
return self._execute_wrapper(sql, params, map_result)
|
|
|
|
def exec(self, sql, params=None, extended_like=False) -> int:
|
|
if params is None:
|
|
params = []
|
|
|
|
if extended_like:
|
|
sql, params = self.handle_extended_like(sql, params)
|
|
|
|
sql, params = self.format_sql(sql, params)
|
|
|
|
def map_result(cur):
|
|
return [cur.rowcount, cur.lastrowid]
|
|
|
|
row_count, lastrowid = self._execute_wrapper(sql, params, map_result)
|
|
self.lastrowid = lastrowid
|
|
return row_count
|
|
|
|
def last_insert_id(self) -> int:
|
|
return self.lastrowid
|
|
|
|
def format_sql(self, sql, params=None) -> [str, list]:
|
|
return sql, params
|
|
|
|
def handle_extended_like(self, sql, params):
|
|
original_params = params.copy()
|
|
params = list(map(lambda x: [x], params))
|
|
|
|
for match in self.enhanced_like_regex.finditer(sql):
|
|
field = match.group(2)
|
|
index = int(match.group(3))
|
|
extra_sql, vals = self._get_extended_params(field, original_params[index].split(" "))
|
|
|
|
sql = self.enhanced_like_regex.sub(match.group(1) + "(" + " AND ".join(extra_sql) + ")" + match.group(4),
|
|
sql, 1)
|
|
# remove current param and add generated params in its place
|
|
del params[index]
|
|
params.insert(index, vals)
|
|
return sql, [item for sublist in params for item in sublist]
|
|
|
|
def _get_extended_params(self, field, params) -> [str, list]:
|
|
extra_sql = []
|
|
vals = []
|
|
for p in params:
|
|
if p.startswith("-") and p != "-":
|
|
vals.append("%" + p[1:] + "%")
|
|
extra_sql.append(field + " NOT LIKE ?")
|
|
else:
|
|
vals.append("%" + p + "%")
|
|
extra_sql.append(field + " LIKE ?")
|
|
return extra_sql, vals
|
|
|
|
def create_view(self, table) -> None:
|
|
if self.shared == self:
|
|
return
|
|
self.exec(f"DROP TABLE if exists {table};")
|
|
self.exec(f"CREATE OR REPLACE SQL SECURITY INVOKER VIEW {table} AS "
|
|
f"SELECT * FROM `{self.shared.pool.pool_name}`.{table};")
|
|
|
|
def load_sql_file(self, sql_file: str, force_update=False, per_bot=False, pre_optimized=False) -> None:
|
|
filename = sql_file.replace("\\", "/")
|
|
bot = "global"
|
|
if per_bot:
|
|
bot = self.name
|
|
db_version = self.shared.get_db_version(filename, bot)
|
|
file_version = self.get_file_version(filename)
|
|
if db_version:
|
|
if parse_version(file_version) > parse_version(db_version) or force_update:
|
|
self.logger.debug("loading sql file '%s'" % sql_file)
|
|
self._load_file(filename, pre_optimized)
|
|
self.exec("UPDATE db_version SET version = ?, verified = 1 WHERE file = ? and bot = ?",
|
|
[int(file_version), filename, bot])
|
|
else:
|
|
self.logger.debug("loading sql file '%s'" % sql_file)
|
|
self._load_file(filename, pre_optimized)
|
|
self.exec("INSERT INTO db_version (file, version, bot, verified) VALUES (?, ?, ?, 1)",
|
|
[filename, int(file_version), bot])
|
|
|
|
def get_file_version(self, filename) -> str:
|
|
return str(int(os.path.getmtime(filename)))
|
|
|
|
def get_db_version(self, filename, bot) -> int or None:
|
|
|
|
row = self.query_single("SELECT version FROM db_version WHERE file = ? and bot = ?", [filename, bot])
|
|
if row:
|
|
return row.version
|
|
else:
|
|
return None
|
|
|
|
def _load_optimized_file(self, filename):
|
|
start = time.time()
|
|
with open(filename, mode="r", encoding="UTF-8") as f:
|
|
with self.shared.pool.get_connection() as conn:
|
|
with conn.cursor() as cur:
|
|
for line in f.readlines():
|
|
line = line.strip()
|
|
|
|
if line != "":
|
|
if line.startswith("#") or line.startswith("--"):
|
|
continue
|
|
cur.execute(line)
|
|
print(f"Runtime: {time.time() - start: .2f} for {filename}")
|
|
|
|
def _load_file(self, filename, pre_optimized=False) -> None:
|
|
if pre_optimized:
|
|
self._load_optimized_file(filename)
|
|
return
|
|
start = time.time()
|
|
# Short version... instead of executing 90 000 inserts for the itemDB, just do one,
|
|
# while providing all values during one query
|
|
with open(filename, mode="r", encoding="UTF-8") as f:
|
|
insert_batches = []
|
|
inserts = []
|
|
others = []
|
|
stat = ""
|
|
for i in f.readlines():
|
|
i = i.strip()
|
|
if i == "" or i == " ":
|
|
continue
|
|
if i.startswith("INSERT INTO"):
|
|
match2 = re.match("(INSERT INTO .+? VALUES) (\(.+?\));", i)
|
|
if match2:
|
|
r2 = match2[2].replace("NULL", "None")
|
|
r2 = r2.replace("null", "None")
|
|
query = match2[1] + f" ({', ?' * len(eval(r2))})"
|
|
query = query.replace("(, ", "(")
|
|
if stat != query:
|
|
if stat != "" or len(inserts) != 0:
|
|
insert_batches.append([stat, inserts])
|
|
inserts = []
|
|
stat = query
|
|
inserts.append(eval(r2))
|
|
else:
|
|
if i.startswith("--"):
|
|
continue
|
|
others.append(i)
|
|
insert_batches.append([stat, inserts])
|
|
with self.shared.lock:
|
|
with self.shared.pool.get_connection() as conn:
|
|
with conn.cursor() as cur:
|
|
cur: CursorBase
|
|
if others:
|
|
for statement in others:
|
|
try:
|
|
cur.execute(statement)
|
|
except OperationalError:
|
|
pass
|
|
for sql, param in insert_batches:
|
|
if sql == "INSERT INTO trickle (id, group_name, name, amount_agility, " \
|
|
"amount_intelligence, amount_psychic, amount_stamina, " \
|
|
"amount_strength, amount_sense) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)":
|
|
for row in param:
|
|
cur.execute(sql, row)
|
|
continue
|
|
cur.executemany(sql, param)
|
|
|
|
print(f"Runtime: {time.time() - start: .2f} for {filename}")
|
|
|
|
def get_type(self) -> str:
|
|
return self.type
|
|
|
|
|
|
class SqlException(Exception):
|
|
def __init__(self, message):
|
|
super().__init__(message)
|