277 lines
12 KiB
Python
277 lines
12 KiB
Python
import os
|
|
import re
|
|
import sys
|
|
import threading
|
|
import time
|
|
|
|
import mysql.connector
|
|
from mysql.connector.cursor import CursorBase
|
|
# noinspection PyProtectedMember
|
|
from pkg_resources import parse_version
|
|
|
|
from conf.config import BotConfig
|
|
from core.decorators import instance
|
|
from core.dict_object import DictObject
|
|
from core.logger import Logger
|
|
|
|
|
|
@instance()
|
|
class DB:
|
|
MYSQL = "mysql"
|
|
MARIADB = "mariadb-pool"
|
|
|
|
def __init__(self):
|
|
self.pool_size = 4
|
|
# noinspection PyTypeChecker
|
|
self.pool = None
|
|
self.enhanced_like_regex = re.compile(r"(\s+)(\S+)\s+<EXTENDED_LIKE=(\d+)>\s+\?(\s*)", re.IGNORECASE)
|
|
self.lastrowid = None
|
|
self.logger = Logger(__name__)
|
|
self.type = None
|
|
self.lock = threading.Semaphore(self.pool_size)
|
|
self.transaction_level = 0
|
|
# noinspection PyTypeChecker
|
|
self.shared: DB = None
|
|
mod = __import__(f'conf.{sys.argv[1]}', fromlist=['BotConfig'])
|
|
config: BotConfig = getattr(mod, 'BotConfig')
|
|
self.name = config.character
|
|
|
|
def connect_mariadb(self, host, port, username, password, database_name):
|
|
self.type = self.MARIADB
|
|
self.connect_detail = {'host': host, 'port': port, 'user': username,
|
|
'password': password, 'database': database_name, 'autocommit': True}
|
|
self.pool = mysql.connector.pooling.MySQLConnectionPool(pool_name=database_name, pool_size=self.pool_size,
|
|
host=host, port=port, user=username, password=password,
|
|
database=database_name, autocommit=True
|
|
)
|
|
# self.pool = mariadb.ConnectionPool(pool_name=database_name, pool_size=self.pool_size,
|
|
# pool_reset_connection=False,
|
|
# host=host, port=port, user=username, password=password,
|
|
# database=database_name, autocommit=True)
|
|
self.exec("SET collation_connection = 'utf8_general_ci'")
|
|
self.exec("SET sql_mode = 'TRADITIONAL,ANSI'")
|
|
self.create_db_version_table()
|
|
|
|
def create_db_version_table(self):
|
|
self.exec("CREATE TABLE IF NOT EXISTS db_version ("
|
|
"file VARCHAR(255) NOT NULL, "
|
|
"version VARCHAR(255) NOT NULL, "
|
|
"verified SMALLINT NOT NULL, "
|
|
"bot varchar(32))")
|
|
|
|
def _execute_wrapper(self, sql, params, callback):
|
|
with self.lock:
|
|
start_time = time.time()
|
|
with self.pool.get_connection() as conn:
|
|
conn.auto_reconnect = True
|
|
conn.autocommit = True
|
|
with conn.cursor(dictionary=True, buffered=True) as cur:
|
|
cur: CursorBase
|
|
try:
|
|
string: str = sql.upper()
|
|
|
|
if string.__contains__("UPDATE ") or string.__contains__("INSERT "):
|
|
cur.execute("START TRANSACTION;")
|
|
cur.execute(sql.replace("?", "%s"), params)
|
|
if string.__contains__("UPDATE ") or string.__contains__("INSERT "):
|
|
conn.commit()
|
|
except Exception as e:
|
|
print(sql, params)
|
|
raise SqlException(f"SQL Error: '{str(e)}' for '{sql}' "
|
|
f"[{', '.join(map(lambda x: str(x), params))}]") from e
|
|
elapsed = time.time() - start_time
|
|
result = callback(cur)
|
|
if elapsed > 5:
|
|
self.logger.warning(f"slow query ({elapsed:f}s) '{sql}' for params: {str(params)}")
|
|
return result
|
|
|
|
def query_single(self, sql, params=None, extended_like=False) -> DictObject:
|
|
if params is None:
|
|
params = []
|
|
|
|
if extended_like:
|
|
sql, params = self.handle_extended_like(sql, params)
|
|
|
|
sql, params = self.format_sql(sql, params)
|
|
|
|
def map_result(cur):
|
|
row = cur.fetchone()
|
|
return DictObject(row) if row else None
|
|
|
|
return self._execute_wrapper(sql, params, map_result)
|
|
|
|
def query(self, sql, params=None, extended_like=False) -> list[DictObject]:
|
|
if params is None:
|
|
params = []
|
|
|
|
if extended_like:
|
|
sql, params = self.handle_extended_like(sql, params)
|
|
|
|
sql, params = self.format_sql(sql, params)
|
|
|
|
def map_result(cur):
|
|
return list(map(lambda row: DictObject(row), cur.fetchall()))
|
|
|
|
return self._execute_wrapper(sql, params, map_result)
|
|
|
|
def exec(self, sql, params=None, extended_like=False) -> int:
|
|
if params is None:
|
|
params = []
|
|
|
|
if extended_like:
|
|
sql, params = self.handle_extended_like(sql, params)
|
|
|
|
sql, params = self.format_sql(sql, params)
|
|
|
|
def map_result(cur):
|
|
return [cur.rowcount, cur.lastrowid]
|
|
|
|
row_count, lastrowid = self._execute_wrapper(sql, params, map_result)
|
|
self.lastrowid = lastrowid
|
|
return row_count
|
|
|
|
def last_insert_id(self) -> int:
|
|
return self.lastrowid
|
|
|
|
def format_sql(self, sql, params=None) -> [str, list]:
|
|
return sql, params
|
|
|
|
def handle_extended_like(self, sql, params):
|
|
original_params = params.copy()
|
|
params = list(map(lambda x: [x], params))
|
|
|
|
for match in self.enhanced_like_regex.finditer(sql):
|
|
field = match.group(2)
|
|
index = int(match.group(3))
|
|
extra_sql, vals = self._get_extended_params(field, original_params[index].split(" "))
|
|
|
|
sql = self.enhanced_like_regex.sub(match.group(1) + "(" + " AND ".join(extra_sql) + ")" + match.group(4),
|
|
sql, 1)
|
|
# remove current param and add generated params in its place
|
|
del params[index]
|
|
params.insert(index, vals)
|
|
return sql, [item for sublist in params for item in sublist]
|
|
|
|
def _get_extended_params(self, field, params) -> [str, list]:
|
|
extra_sql = []
|
|
vals = []
|
|
for p in params:
|
|
if p.startswith("-") and p != "-":
|
|
vals.append("%" + p[1:] + "%")
|
|
extra_sql.append(field + " NOT LIKE ?")
|
|
else:
|
|
vals.append("%" + p + "%")
|
|
extra_sql.append(field + " LIKE ?")
|
|
return extra_sql, vals
|
|
|
|
def create_view(self, table) -> None:
|
|
if self.shared == self:
|
|
return
|
|
self.exec(f"DROP TABLE if exists {table};")
|
|
self.exec(f"CREATE OR REPLACE SQL SECURITY INVOKER VIEW {table} AS "
|
|
f"SELECT * FROM `{self.shared.pool.pool_name}`.{table};")
|
|
|
|
def load_sql_file(self, sql_file: str, force_update=False, per_bot=False, pre_optimized=False) -> None:
|
|
filename = sql_file.replace("\\", "/")
|
|
bot = "global"
|
|
if per_bot:
|
|
bot = self.name
|
|
db_version = self.shared.get_db_version(filename, bot)
|
|
file_version = self.get_file_version(filename)
|
|
if db_version:
|
|
if parse_version(file_version) > parse_version(db_version) or force_update:
|
|
self.logger.debug("loading sql file '%s'" % sql_file)
|
|
self._load_file(filename, pre_optimized)
|
|
self.exec("UPDATE db_version SET version = ?, verified = 1 WHERE file = ? and bot = ?",
|
|
[int(file_version), filename, bot])
|
|
else:
|
|
self.logger.debug("loading sql file '%s'" % sql_file)
|
|
self._load_file(filename, pre_optimized)
|
|
self.exec("INSERT INTO db_version (file, version, bot, verified) VALUES (?, ?, ?, 1)",
|
|
[filename, int(file_version), bot])
|
|
|
|
def get_file_version(self, filename) -> str:
|
|
return str(int(os.path.getmtime(filename)))
|
|
|
|
def get_db_version(self, filename, bot) -> int or None:
|
|
|
|
row = self.query_single("SELECT version FROM db_version WHERE file = ? and bot = ?", [filename, bot])
|
|
if row:
|
|
return row.version
|
|
else:
|
|
return None
|
|
|
|
def _load_optimized_file(self, filename):
|
|
start = time.time()
|
|
with open(filename, mode="r", encoding="UTF-8") as f:
|
|
with self.shared.pool.get_connection() as conn:
|
|
with conn.cursor() as cur:
|
|
for line in f.readlines():
|
|
line = line.strip()
|
|
|
|
if line != "":
|
|
if line.startswith("#") or line.startswith("--"):
|
|
continue
|
|
cur.execute(line)
|
|
print(f"Runtime: {time.time() - start: .2f} for {filename}")
|
|
|
|
def _load_file(self, filename, pre_optimized=False) -> None:
|
|
if pre_optimized:
|
|
self._load_optimized_file(filename)
|
|
return
|
|
raise SqlException(f"SQL File not optimized: {filename}")
|
|
# start = time.time()
|
|
# Short version... instead of executing 90 000 inserts for the itemDB, just do one,
|
|
# while providing all values during one query
|
|
# with open(filename, mode="r", encoding="UTF-8") as f:
|
|
# insert_batches = []
|
|
# inserts = []
|
|
# others = []
|
|
# stat = ""
|
|
# for i in f.readlines():
|
|
# i = i.strip()
|
|
# if i == "" or i == " ":
|
|
# continue
|
|
# if i.startswith("INSERT INTO"):
|
|
# match2 = re.match("(INSERT INTO .+? VALUES) (\(.+?\));", i)
|
|
# if match2:
|
|
# r2 = match2[2].replace("NULL", "None")
|
|
# r2 = r2.replace("null", "None")
|
|
# query = match2[1] + f" ({', ?' * len(eval(r2))})"
|
|
# query = query.replace("(, ", "(")
|
|
# if stat != query:
|
|
# if stat != "" or len(inserts) != 0:
|
|
# insert_batches.append([stat, inserts])
|
|
# inserts = []
|
|
# stat = query
|
|
# inserts.append(eval(r2))
|
|
# else:
|
|
# if i.startswith("--"):
|
|
# continue
|
|
# others.append(i)
|
|
# insert_batches.append([stat, inserts])
|
|
# with self.shared.lock:
|
|
# with self.shared.pool.get_connection() as conn:
|
|
# with conn.cursor() as cur:
|
|
# cur: CursorBase
|
|
# if others:
|
|
# for statement in others:
|
|
# cur.execute(statement)
|
|
# for sql, param in insert_batches:
|
|
# if sql == "INSERT INTO trickle (id, group_name, name, amount_agility, " \
|
|
# "amount_intelligence, amount_psychic, amount_stamina, " \
|
|
# "amount_strength, amount_sense) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)":
|
|
# for row in param:
|
|
# cur.execute(sql, row)
|
|
# continue
|
|
# cur.executemany(sql, param)
|
|
# print(f"Runtime: {time.time() - start: .2f} for {filename}")
|
|
|
|
def get_type(self) -> str:
|
|
return self.type
|
|
|
|
|
|
class SqlException(Exception):
|
|
def __init__(self, message):
|
|
super().__init__(message)
|