69 lines
1.9 KiB
Python
69 lines
1.9 KiB
Python
from core.db import DB
|
|
from core.logger import Logger
|
|
from core.registry import Registry
|
|
|
|
db = Registry.get_instance("db")
|
|
logger = Logger("core.upgrade")
|
|
|
|
|
|
def table_info(table_name):
|
|
if db.type == DB.MARIADB:
|
|
data = db.query("DESCRIBE %s" % table_name)
|
|
|
|
def normalize_table_info(row):
|
|
row.name = row.Field
|
|
row.type = row.Type.upper()
|
|
return row
|
|
|
|
return list(map(normalize_table_info, data))
|
|
else:
|
|
raise Exception("Unknown database type '%s'" % db.type)
|
|
|
|
|
|
def table_exists(table_name):
|
|
# noinspection PyBroadException
|
|
try:
|
|
db.query(f"SELECT * FROM {table_name} LIMIT 1")
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
def column_exists(table_name, column_name):
|
|
# noinspection PyBroadException
|
|
try:
|
|
db.query(f"SELECT {column_name} FROM {table_name} LIMIT 1")
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
def update_version(v):
|
|
v += 1
|
|
logger.info("Upgrading db to version '%d'" % v)
|
|
db.exec("UPDATE db_version SET version = ? WHERE file = 'db_version' and bot =?", [v, db.name])
|
|
return v
|
|
|
|
|
|
def get_version():
|
|
row = db.query_single("SELECT version FROM db_version WHERE file = 'db_version' and bot=?", [db.name])
|
|
if row:
|
|
return int(row.version)
|
|
else:
|
|
return 0
|
|
|
|
|
|
def run_upgrades():
|
|
version = get_version()
|
|
logger.info("Database at version '%d'" % version)
|
|
|
|
if version == 0:
|
|
db.exec("INSERT INTO db_version (file, version, bot, verified) VALUES ('db_version', ?, ?, 1)", [0, db.name])
|
|
version = update_version(version)
|
|
db.create_view("db_version")
|
|
if version == 1:
|
|
if table_exists("account"):
|
|
if not column_exists("account", "auto_invite"):
|
|
db.exec("ALTER TABLE account ADD COLUMN auto_invite INT(2) default 0")
|
|
version = update_version(version)
|