Files
igncore/core/db.py
T

275 lines
11 KiB
Python

import os
import re
import sys
import threading
import time
import mysql.connector
# noinspection PyProtectedMember
from pkg_resources import parse_version
from conf.config import BotConfig
from core.decorators import instance
from core.dict_object import DictObject
from core.logger import Logger
@instance()
class DB:
MYSQL = "mysql"
MARIADB = "mariadb-pool"
def __init__(self):
self.pool_size = 4
# noinspection PyTypeChecker
self.pool = None
self.enhanced_like_regex = re.compile(r"(\s+)(\S+)\s+<EXTENDED_LIKE=(\d+)>\s+\?(\s*)", re.IGNORECASE)
self.lastrowid = None
self.logger = Logger(__name__)
self.type = None
self.lock = threading.Semaphore(self.pool_size)
self.transaction_level = 0
# noinspection PyTypeChecker
self.shared: DB = None
mod = __import__(f'conf.{sys.argv[1]}', fromlist=['BotConfig'])
config: BotConfig = getattr(mod, 'BotConfig')
self.name = config.character
def connect_mariadb(self, host, port, username, password, database_name):
self.type = self.MARIADB
self.connect_detail = {'host': host, 'port': port, 'user': username,
'password': password, 'database': database_name, 'autocommit': True}
self.pool = mysql.connector.pooling.MySQLConnectionPool(pool_name=database_name, pool_size=self.pool_size,
host=host, port=port, user=username, password=password,
database=database_name, autocommit=True
)
# self.pool = mariadb.ConnectionPool(pool_name=database_name, pool_size=self.pool_size,
# pool_reset_connection=False,
# host=host, port=port, user=username, password=password,
# database=database_name, autocommit=True)
self.exec("SET collation_connection = 'utf8_general_ci'")
self.exec("SET sql_mode = 'TRADITIONAL,ANSI'")
self.create_db_version_table()
def create_db_version_table(self):
self.exec("CREATE TABLE IF NOT EXISTS db_version ("
"file VARCHAR(255) NOT NULL, "
"version VARCHAR(255) NOT NULL, "
"verified SMALLINT NOT NULL, "
"bot varchar(32))")
def _execute_wrapper(self, sql, params, callback):
with self.lock:
start_time = time.time()
with self.pool.get_connection() as conn:
conn.auto_reconnect = True
conn.autocommit = True
with conn.cursor(dictionary=True) as cur:
try:
string: str = sql.upper()
if string.__contains__("UPDATE ") or string.__contains__("INSERT "):
cur.execute("START TRANSACTION;")
cur.execute(sql.replace("?", "%s"), params)
if string.__contains__("UPDATE ") or string.__contains__("INSERT "):
conn.commit()
except Exception as e:
print(sql, params)
raise SqlException(f"SQL Error: '{str(e)}' for '{sql}' "
f"[{', '.join(map(lambda x: str(x), params))}]") from e
elapsed = time.time() - start_time
result = callback(cur)
if elapsed > 5:
self.logger.warning(f"slow query ({elapsed:f}s) '{sql}' for params: {str(params)}")
return result
def query_single(self, sql, params=None, extended_like=False) -> DictObject:
if params is None:
params = []
if extended_like:
sql, params = self.handle_extended_like(sql, params)
sql, params = self.format_sql(sql, params)
def map_result(cur):
row = cur.fetchone()
return DictObject(row) if row else None
return self._execute_wrapper(sql, params, map_result)
def query(self, sql, params=None, extended_like=False) -> list[DictObject]:
if params is None:
params = []
if extended_like:
sql, params = self.handle_extended_like(sql, params)
sql, params = self.format_sql(sql, params)
def map_result(cur):
return list(map(lambda row: DictObject(row), cur.fetchall()))
return self._execute_wrapper(sql, params, map_result)
def exec(self, sql, params=None, extended_like=False) -> int:
if params is None:
params = []
if extended_like:
sql, params = self.handle_extended_like(sql, params)
sql, params = self.format_sql(sql, params)
def map_result(cur):
return [cur.rowcount, cur.lastrowid]
row_count, lastrowid = self._execute_wrapper(sql, params, map_result)
self.lastrowid = lastrowid
return row_count
def last_insert_id(self) -> int:
return self.lastrowid
def format_sql(self, sql, params=None) -> [str, list]:
return sql, params
def handle_extended_like(self, sql, params):
original_params = params.copy()
params = list(map(lambda x: [x], params))
for match in self.enhanced_like_regex.finditer(sql):
field = match.group(2)
index = int(match.group(3))
extra_sql, vals = self._get_extended_params(field, original_params[index].split(" "))
sql = self.enhanced_like_regex.sub(match.group(1) + "(" + " AND ".join(extra_sql) + ")" + match.group(4),
sql, 1)
# remove current param and add generated params in its place
del params[index]
params.insert(index, vals)
return sql, [item for sublist in params for item in sublist]
def _get_extended_params(self, field, params) -> [str, list]:
extra_sql = []
vals = []
for p in params:
if p.startswith("-") and p != "-":
vals.append("%" + p[1:] + "%")
extra_sql.append(field + " NOT LIKE ?")
else:
vals.append("%" + p + "%")
extra_sql.append(field + " LIKE ?")
return extra_sql, vals
def create_view(self, table) -> None:
if self.shared == self:
return
self.exec(f"DROP TABLE if exists {table};")
self.exec(f"CREATE OR REPLACE SQL SECURITY INVOKER VIEW {table} AS "
f"SELECT * FROM `{self.shared.pool.pool_name}`.{table};")
def load_sql_file(self, sql_file: str, force_update=False, per_bot=False, pre_optimized=False) -> None:
filename = sql_file.replace("\\", "/")
bot = "global"
if per_bot:
bot = self.name
db_version = self.shared.get_db_version(filename, bot)
file_version = self.get_file_version(filename)
if db_version:
if parse_version(file_version) > parse_version(db_version) or force_update:
self.logger.debug("loading sql file '%s'" % sql_file)
self._load_file(filename, pre_optimized)
self.exec("UPDATE db_version SET version = ?, verified = 1 WHERE file = ? and bot = ?",
[int(file_version), filename, bot])
else:
self.logger.debug("loading sql file '%s'" % sql_file)
self._load_file(filename, pre_optimized)
self.exec("INSERT INTO db_version (file, version, bot, verified) VALUES (?, ?, ?, 1)",
[filename, int(file_version), bot])
def get_file_version(self, filename) -> str:
return str(int(os.path.getmtime(filename)))
def get_db_version(self, filename, bot) -> int or None:
row = self.query_single("SELECT version FROM db_version WHERE file = ? and bot = ?", [filename, bot])
if row:
return row.version
else:
return None
def _load_optimized_file(self, filename):
start = time.time()
with open(filename, mode="r", encoding="UTF-8") as f:
with self.shared.pool.get_connection() as conn:
with conn.cursor() as cur:
for line in f.readlines():
line = line.strip()
if line != "":
if line.startswith("#") or line.startswith("--"):
continue
cur.execute(line)
print(f"Runtime: {time.time() - start: .2f} for {filename}")
def _load_file(self, filename, pre_optimized=False) -> None:
if pre_optimized:
self._load_optimized_file(filename)
return
raise SqlException(f"SQL File not optimized: {filename}")
# start = time.time()
# Short version... instead of executing 90 000 inserts for the itemDB, just do one,
# while providing all values during one query
# with open(filename, mode="r", encoding="UTF-8") as f:
# insert_batches = []
# inserts = []
# others = []
# stat = ""
# for i in f.readlines():
# i = i.strip()
# if i == "" or i == " ":
# continue
# if i.startswith("INSERT INTO"):
# match2 = re.match("(INSERT INTO .+? VALUES) (\(.+?\));", i)
# if match2:
# r2 = match2[2].replace("NULL", "None")
# r2 = r2.replace("null", "None")
# query = match2[1] + f" ({', ?' * len(eval(r2))})"
# query = query.replace("(, ", "(")
# if stat != query:
# if stat != "" or len(inserts) != 0:
# insert_batches.append([stat, inserts])
# inserts = []
# stat = query
# inserts.append(eval(r2))
# else:
# if i.startswith("--"):
# continue
# others.append(i)
# insert_batches.append([stat, inserts])
# with self.shared.lock:
# with self.shared.pool.get_connection() as conn:
# with conn.cursor() as cur:
# cur: CursorBase
# if others:
# for statement in others:
# cur.execute(statement)
# for sql, param in insert_batches:
# if sql == "INSERT INTO trickle (id, group_name, name, amount_agility, " \
# "amount_intelligence, amount_psychic, amount_stamina, " \
# "amount_strength, amount_sense) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)":
# for row in param:
# cur.execute(sql, row)
# continue
# cur.executemany(sql, param)
# print(f"Runtime: {time.time() - start: .2f} for {filename}")
def get_type(self) -> str:
return self.type
class SqlException(Exception):
def __init__(self, message):
super().__init__(message)