Files
igncore/core/db.py
T
Minidodo c04f76c0db Added the option to !opt-in/opt-out [onlinebot only]
Fixed command & event threading
Events are now threaded by event_type (i.e. all buddy_logon events get ran in the same one)
Added default preferences
Fixed recipe loading for multiple installs (i.e. on different machines)
2021-08-27 13:58:47 +02:00

278 lines
11 KiB
Python

import os
import re
import sys
import threading
import time
import mariadb
# noinspection PyProtectedMember
from mariadb._mariadb import ConnectionPool, OperationalError
from mysql.connector.cursor import CursorBase
from pkg_resources import parse_version
from conf.config import BotConfig
from core.decorators import instance
from core.dict_object import DictObject
from core.logger import Logger
@instance()
class DB:
MYSQL = "mysql"
MARIADB = "mariadb-pool"
def __init__(self):
self.pool_size = 4
# noinspection PyTypeChecker
self.pool: ConnectionPool = None
self.enhanced_like_regex = re.compile(r"(\s+)(\S+)\s+<EXTENDED_LIKE=(\d+)>\s+\?(\s*)", re.IGNORECASE)
self.lastrowid = None
self.logger = Logger(__name__)
self.type = None
self.lock = threading.Semaphore(self.pool_size)
self.transaction_level = 0
# noinspection PyTypeChecker
self.shared: DB = None
mod = __import__(f'conf.{sys.argv[1]}', fromlist=['BotConfig'])
config: BotConfig = getattr(mod, 'BotConfig')
self.name = config.character
def connect_mariadb(self, host, port, username, password, database_name):
self.type = self.MARIADB
self.connect_detail = {'host': host, 'port': port, 'user': username,
'password': password, 'database': database_name, 'autocommit': True}
self.pool = mariadb.ConnectionPool(pool_name=database_name, pool_size=self.pool_size,
pool_reset_connection=False,
host=host, port=port, user=username, password=password,
database=database_name, autocommit=True)
self.exec("SET collation_connection = 'utf8_general_ci'")
self.exec("SET sql_mode = 'TRADITIONAL,ANSI'")
self.create_db_version_table()
def create_db_version_table(self):
self.exec("CREATE TABLE IF NOT EXISTS db_version ("
"file VARCHAR(255) NOT NULL, "
"version VARCHAR(255) NOT NULL, "
"verified SMALLINT NOT NULL, "
"bot varchar(32))")
def _execute_wrapper(self, sql, params, callback):
with self.lock:
start_time = time.time()
if self.pool.pool_size < self.pool_size - 1:
self.pool.add_connection(mariadb.connect(**self.connect_detail))
with self.pool.get_connection() as conn:
conn.auto_reconnect = True
conn.autocommit = True
with conn.cursor(dictionary=True) as cur:
try:
string: str = sql.upper()
if string.__contains__("UPDATE ") or string.__contains__("INSERT "):
cur.execute("START TRANSACTION;")
cur.execute(sql.replace("?", "%s"), params)
if string.__contains__("UPDATE ") or string.__contains__("INSERT "):
conn.commit()
except Exception as e:
raise SqlException( f"SQL Error: '{str(e)}' for '{sql}' "
f"[{', '.join(map(lambda x: str(x), params))}]") from e
elapsed = time.time() - start_time
result = callback(cur)
if elapsed > 5:
self.logger.warning(f"slow query ({elapsed:f}s) '{sql}' for params: {str(params)}")
return result
def query_single(self, sql, params=None, extended_like=False) -> DictObject:
if params is None:
params = []
if extended_like:
sql, params = self.handle_extended_like(sql, params)
sql, params = self.format_sql(sql, params)
def map_result(cur):
row = cur.fetchone()
return DictObject(row) if row else None
return self._execute_wrapper(sql, params, map_result)
def query(self, sql, params=None, extended_like=False) -> list[DictObject]:
if params is None:
params = []
if extended_like:
sql, params = self.handle_extended_like(sql, params)
sql, params = self.format_sql(sql, params)
def map_result(cur):
return list(map(lambda row: DictObject(row), cur.fetchall()))
return self._execute_wrapper(sql, params, map_result)
def exec(self, sql, params=None, extended_like=False) -> int:
if params is None:
params = []
if extended_like:
sql, params = self.handle_extended_like(sql, params)
sql, params = self.format_sql(sql, params)
def map_result(cur):
return [cur.rowcount, cur.lastrowid]
row_count, lastrowid = self._execute_wrapper(sql, params, map_result)
self.lastrowid = lastrowid
return row_count
def last_insert_id(self) -> int:
return self.lastrowid
def format_sql(self, sql, params=None) -> [str, list]:
return sql, params
def handle_extended_like(self, sql, params):
original_params = params.copy()
params = list(map(lambda x: [x], params))
for match in self.enhanced_like_regex.finditer(sql):
field = match.group(2)
index = int(match.group(3))
extra_sql, vals = self._get_extended_params(field, original_params[index].split(" "))
sql = self.enhanced_like_regex.sub(match.group(1) + "(" + " AND ".join(extra_sql) + ")" + match.group(4),
sql, 1)
# remove current param and add generated params in its place
del params[index]
params.insert(index, vals)
return sql, [item for sublist in params for item in sublist]
def _get_extended_params(self, field, params) -> [str, list]:
extra_sql = []
vals = []
for p in params:
if p.startswith("-") and p != "-":
vals.append("%" + p[1:] + "%")
extra_sql.append(field + " NOT LIKE ?")
else:
vals.append("%" + p + "%")
extra_sql.append(field + " LIKE ?")
return extra_sql, vals
def create_view(self, table) -> None:
if self.shared == self:
return
self.exec(f"DROP TABLE if exists {table};")
self.exec(f"CREATE OR REPLACE SQL SECURITY INVOKER VIEW {table} AS "
f"SELECT * FROM `{self.shared.pool.pool_name}`.{table};")
def load_sql_file(self, sql_file: str, force_update=False, per_bot=False, pre_optimized=False) -> None:
filename = sql_file.replace("\\", "/")
bot = "global"
if per_bot:
bot = self.name
db_version = self.shared.get_db_version(filename, bot)
file_version = self.get_file_version(filename)
if db_version:
if parse_version(file_version) > parse_version(db_version) or force_update:
self.logger.debug("loading sql file '%s'" % sql_file)
self._load_file(filename, pre_optimized)
self.exec("UPDATE db_version SET version = ?, verified = 1 WHERE file = ? and bot = ?",
[int(file_version), filename, bot])
else:
self.logger.debug("loading sql file '%s'" % sql_file)
self._load_file(filename, pre_optimized)
self.exec("INSERT INTO db_version (file, version, bot, verified) VALUES (?, ?, ?, 1)",
[filename, int(file_version), bot])
def get_file_version(self, filename) -> str:
return str(int(os.path.getmtime(filename)))
def get_db_version(self, filename, bot) -> int or None:
row = self.query_single("SELECT version FROM db_version WHERE file = ? and bot = ?", [filename, bot])
if row:
return row.version
else:
return None
def _load_optimized_file(self, filename):
start = time.time()
with open(filename, mode="r", encoding="UTF-8") as f:
with self.shared.pool.get_connection() as conn:
with conn.cursor() as cur:
for line in f.readlines():
line = line.strip()
if line != "":
if line.startswith("#") or line.startswith("--"):
continue
cur.execute(line)
print(f"Runtime: {time.time() - start: .2f} for {filename}")
def _load_file(self, filename, pre_optimized=False) -> None:
if pre_optimized:
self._load_optimized_file(filename)
return
start = time.time()
# Short version... instead of executing 90 000 inserts for the itemDB, just do one,
# while providing all values during one query
with open(filename, mode="r", encoding="UTF-8") as f:
insert_batches = []
inserts = []
others = []
stat = ""
for i in f.readlines():
i = i.strip()
if i == "" or i == " ":
continue
if i.startswith("INSERT INTO"):
match2 = re.match("(INSERT INTO .+? VALUES) (\(.+?\));", i)
if match2:
r2 = match2[2].replace("NULL", "None")
r2 = r2.replace("null", "None")
query = match2[1] + f" ({', ?' * len(eval(r2))})"
query = query.replace("(, ", "(")
if stat != query:
if stat != "" or len(inserts) != 0:
insert_batches.append([stat, inserts])
inserts = []
stat = query
inserts.append(eval(r2))
else:
if i.startswith("--"):
continue
others.append(i)
insert_batches.append([stat, inserts])
with self.shared.lock:
with self.shared.pool.get_connection() as conn:
with conn.cursor() as cur:
cur: CursorBase
if others:
for statement in others:
try:
cur.execute(statement)
except OperationalError:
pass
for sql, param in insert_batches:
if sql == "INSERT INTO trickle (id, group_name, name, amount_agility, " \
"amount_intelligence, amount_psychic, amount_stamina, " \
"amount_strength, amount_sense) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)":
for row in param:
cur.execute(sql, row)
continue
cur.executemany(sql, param)
print(f"Runtime: {time.time() - start: .2f} for {filename}")
def get_type(self) -> str:
return self.type
class SqlException(Exception):
def __init__(self, message):
super().__init__(message)