import os import re import sys import threading import time import mysql.connector from mysql.connector.cursor import CursorBase # noinspection PyProtectedMember from pkg_resources import parse_version from conf.config import BotConfig from core.decorators import instance from core.dict_object import DictObject from core.logger import Logger @instance() class DB: MYSQL = "mysql" MARIADB = "mariadb-pool" def __init__(self): self.pool_size = 4 # noinspection PyTypeChecker self.pool = None self.enhanced_like_regex = re.compile(r"(\s+)(\S+)\s+\s+\?(\s*)", re.IGNORECASE) self.lastrowid = None self.logger = Logger(__name__) self.type = None self.lock = threading.Semaphore(self.pool_size) self.transaction_level = 0 # noinspection PyTypeChecker self.shared: DB = None mod = __import__(f'conf.{sys.argv[1]}', fromlist=['BotConfig']) config: BotConfig = getattr(mod, 'BotConfig') self.name = config.character def connect_mariadb(self, host, port, username, password, database_name): self.type = self.MARIADB self.connect_detail = {'host': host, 'port': port, 'user': username, 'password': password, 'database': database_name, 'autocommit': True} self.pool = mysql.connector.pooling.MySQLConnectionPool(pool_name=database_name, pool_size=self.pool_size, host=host, port=port, user=username, password=password, database=database_name, autocommit=True ) # self.pool = mariadb.ConnectionPool(pool_name=database_name, pool_size=self.pool_size, # pool_reset_connection=False, # host=host, port=port, user=username, password=password, # database=database_name, autocommit=True) self.exec("SET collation_connection = 'utf8_general_ci'") self.exec("SET sql_mode = 'TRADITIONAL,ANSI'") self.create_db_version_table() def create_db_version_table(self): self.exec("CREATE TABLE IF NOT EXISTS db_version (" "file VARCHAR(255) NOT NULL, " "version VARCHAR(255) NOT NULL, " "verified SMALLINT NOT NULL, " "bot varchar(32))") def _execute_wrapper(self, sql, params, callback): with self.lock: start_time = time.time() with self.pool.get_connection() as conn: conn.auto_reconnect = True conn.autocommit = True with conn.cursor(dictionary=True, buffered=True) as cur: cur: CursorBase try: string: str = sql.upper() if string.__contains__("UPDATE ") or string.__contains__("INSERT "): cur.execute("START TRANSACTION;") cur.execute(sql.replace("?", "%s"), params) if string.__contains__("UPDATE ") or string.__contains__("INSERT "): conn.commit() except Exception as e: print(sql, params) raise SqlException(f"SQL Error: '{str(e)}' for '{sql}' " f"[{', '.join(map(lambda x: str(x), params))}]") from e elapsed = time.time() - start_time result = callback(cur) if elapsed > 5: self.logger.warning(f"slow query ({elapsed:f}s) '{sql}' for params: {str(params)}") return result def query_single(self, sql, params=None, extended_like=False) -> DictObject: if params is None: params = [] if extended_like: sql, params = self.handle_extended_like(sql, params) sql, params = self.format_sql(sql, params) def map_result(cur): row = cur.fetchone() return DictObject(row) if row else None return self._execute_wrapper(sql, params, map_result) def query(self, sql, params=None, extended_like=False) -> list[DictObject]: if params is None: params = [] if extended_like: sql, params = self.handle_extended_like(sql, params) sql, params = self.format_sql(sql, params) def map_result(cur): return list(map(lambda row: DictObject(row), cur.fetchall())) return self._execute_wrapper(sql, params, map_result) def exec(self, sql, params=None, extended_like=False) -> int: if params is None: params = [] if extended_like: sql, params = self.handle_extended_like(sql, params) sql, params = self.format_sql(sql, params) def map_result(cur): return [cur.rowcount, cur.lastrowid] row_count, lastrowid = self._execute_wrapper(sql, params, map_result) self.lastrowid = lastrowid return row_count def last_insert_id(self) -> int: return self.lastrowid def format_sql(self, sql, params=None) -> [str, list]: return sql, params def handle_extended_like(self, sql, params): original_params = params.copy() params = list(map(lambda x: [x], params)) for match in self.enhanced_like_regex.finditer(sql): field = match.group(2) index = int(match.group(3)) extra_sql, vals = self._get_extended_params(field, original_params[index].split(" ")) sql = self.enhanced_like_regex.sub(match.group(1) + "(" + " AND ".join(extra_sql) + ")" + match.group(4), sql, 1) # remove current param and add generated params in its place del params[index] params.insert(index, vals) return sql, [item for sublist in params for item in sublist] def _get_extended_params(self, field, params) -> [str, list]: extra_sql = [] vals = [] for p in params: if p.startswith("-") and p != "-": vals.append("%" + p[1:] + "%") extra_sql.append(field + " NOT LIKE ?") else: vals.append("%" + p + "%") extra_sql.append(field + " LIKE ?") return extra_sql, vals def create_view(self, table) -> None: if self.shared == self: return self.exec(f"DROP TABLE if exists {table};") self.exec(f"CREATE OR REPLACE SQL SECURITY INVOKER VIEW {table} AS " f"SELECT * FROM `{self.shared.pool.pool_name}`.{table};") def load_sql_file(self, sql_file: str, force_update=False, per_bot=False, pre_optimized=False) -> None: filename = sql_file.replace("\\", "/") bot = "global" if per_bot: bot = self.name db_version = self.shared.get_db_version(filename, bot) file_version = self.get_file_version(filename) if db_version: if parse_version(file_version) > parse_version(db_version) or force_update: self.logger.debug("loading sql file '%s'" % sql_file) self._load_file(filename, pre_optimized) self.exec("UPDATE db_version SET version = ?, verified = 1 WHERE file = ? and bot = ?", [int(file_version), filename, bot]) else: self.logger.debug("loading sql file '%s'" % sql_file) self._load_file(filename, pre_optimized) self.exec("INSERT INTO db_version (file, version, bot, verified) VALUES (?, ?, ?, 1)", [filename, int(file_version), bot]) def get_file_version(self, filename) -> str: return str(int(os.path.getmtime(filename))) def get_db_version(self, filename, bot) -> int or None: row = self.query_single("SELECT version FROM db_version WHERE file = ? and bot = ?", [filename, bot]) if row: return row.version else: return None def _load_optimized_file(self, filename): start = time.time() with open(filename, mode="r", encoding="UTF-8") as f: with self.shared.pool.get_connection() as conn: with conn.cursor() as cur: for line in f.readlines(): line = line.strip() if line != "": if line.startswith("#") or line.startswith("--"): continue cur.execute(line) print(f"Runtime: {time.time() - start: .2f} for {filename}") def _load_file(self, filename, pre_optimized=False) -> None: if pre_optimized: self._load_optimized_file(filename) return raise SqlException(f"SQL File not optimized: {filename}") # start = time.time() # Short version... instead of executing 90 000 inserts for the itemDB, just do one, # while providing all values during one query # with open(filename, mode="r", encoding="UTF-8") as f: # insert_batches = [] # inserts = [] # others = [] # stat = "" # for i in f.readlines(): # i = i.strip() # if i == "" or i == " ": # continue # if i.startswith("INSERT INTO"): # match2 = re.match("(INSERT INTO .+? VALUES) (\(.+?\));", i) # if match2: # r2 = match2[2].replace("NULL", "None") # r2 = r2.replace("null", "None") # query = match2[1] + f" ({', ?' * len(eval(r2))})" # query = query.replace("(, ", "(") # if stat != query: # if stat != "" or len(inserts) != 0: # insert_batches.append([stat, inserts]) # inserts = [] # stat = query # inserts.append(eval(r2)) # else: # if i.startswith("--"): # continue # others.append(i) # insert_batches.append([stat, inserts]) # with self.shared.lock: # with self.shared.pool.get_connection() as conn: # with conn.cursor() as cur: # cur: CursorBase # if others: # for statement in others: # cur.execute(statement) # for sql, param in insert_batches: # if sql == "INSERT INTO trickle (id, group_name, name, amount_agility, " \ # "amount_intelligence, amount_psychic, amount_stamina, " \ # "amount_strength, amount_sense) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)": # for row in param: # cur.execute(sql, row) # continue # cur.executemany(sql, param) # print(f"Runtime: {time.time() - start: .2f} for {filename}") def get_type(self) -> str: return self.type class SqlException(Exception): def __init__(self, message): super().__init__(message)