Files
igncore/modules/standard/datanet/relay_controller.py
T

246 lines
12 KiB
Python

import base64
import json
import threading
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from core.decorators import instance, timerevent
from core.dict_object import DictObject
from core.logger import Logger
from core.setting_service import SettingService
from core.setting_types import ColorSettingType, TextSettingType, HiddenSettingType, BooleanSettingType
from modules.standard.datanet.ws_worker import WebsocketRelayWorker
@instance()
class RelayController:
MESSAGE_SOURCE = "websocket_relay"
def __init__(self):
self.dthread = None
self.queue = []
self.logger = Logger(__name__)
self.worker = None
self.encrypter = None
def inject(self, registry):
self.bot = registry.get_instance("bot")
self.db = registry.get_instance("db")
self.util = registry.get_instance("util")
self.setting_service: SettingService = registry.get_instance("setting_service")
self.event_service = registry.get_instance("event_service")
self.character_service = registry.get_instance("character_service")
self.pork_service = registry.get_instance("pork_service")
self.online_controller = registry.get_instance("online_controller")
self.public_channel_service = registry.get_instance("public_channel_service")
self.message_hub_service = registry.get_instance("message_hub_service")
def pre_start(self):
self.message_hub_service.register_message_source(self.MESSAGE_SOURCE)
def start(self):
self.message_hub_service.register_message_destination(self.MESSAGE_SOURCE,
self.handle_message_from_hub,
["private_channel", "org_channel",
"discord", "tell_relay"],
[self.MESSAGE_SOURCE])
self.setting_service.register_new(self.module_name, "websocket_relay_enabled", False, BooleanSettingType(),
"Enable the websocket relay")
self.setting_service.register_new(self.module_name, "websocket_relay_server_address",
"ws://localhost/subscribe/relay",
TextSettingType(["ws://localhost/subscribe/relay"]),
"The address of the websocket relay server",
"All bots on the relay must connect to the same server and channel. "
"If using the public relay server, use a unique channel name.")
self.setting_service.register_new(self.module_name, "websocket_relay_channel_color", "#FFFF00",
ColorSettingType(), "Color of the channel in websocket relay messages")
self.setting_service.register_new(self.module_name, "websocket_relay_message_color", "#FCA712",
ColorSettingType(),
"Color of the message content in websocket relay messages")
self.setting_service.register_new(self.module_name, "websocket_relay_sender_color", "#00DE42",
ColorSettingType(), "Color of the sender in websocket relay messages")
self.setting_service.register_new(self.module_name, "websocket_encryption_key", "",
HiddenSettingType(allow_empty=True),
"An encryption key used to encrypt messages over a public websocket relay")
self.setting_service.register_new(self.module_name, "ws_relay_prefix", "", TextSettingType(allow_empty=True),
"Name of this relay (if you don't want to use org or bot name)")
self.setting_service.register_new(self.module_name, "ws_msg_relay_prefix", "||",
TextSettingType(["!", "#", "*", "@", "$", "+", "-"]),
"Prefix for Messages which should get relayed")
self.setting_service.register_new(self.module_name, "ws_relay_type", "with_symbol",
TextSettingType(["with_symbol", "unless_symbol", "always"]),
"Relay Messages", )
self.initialize_encrypter(self.setting_service.get("websocket_encryption_key").get_value())
self.setting_service.register_change_listener("websocket_relay_enabled", self.websocket_relay_update)
self.setting_service.register_change_listener("websocket_relay_server_address", self.websocket_relay_update)
self.setting_service.register_change_listener("websocket_encryption_key", self.websocket_relay_update)
def get_org_channel_prefix(self):
return self.setting_service.get_value(
"ws_relay_prefix") or self.public_channel_service.get_org_name() or self.bot.get_char_name()
def initialize_encrypter(self, password):
if password:
# using hard-coded salt is less secure as it nullifies the
# function of the salt and allows for rainbow attacks
salt = b"tyrbot"
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=10000, )
key = base64.urlsafe_b64encode(kdf.derive(password.encode("utf-8")))
self.encrypter = Fernet(key)
else:
self.encrypter = None
@timerevent(budatime="1s", description="Relay messages from Text relay to the internal message hub", is_hidden=True,
is_enabled=False)
def handle_queue_event(self, _, _1):
while self.queue:
obj = self.queue.pop(0)
if obj.type == "message":
payload = obj.payload
self.process_relay_message(obj.client_id, payload)
elif obj.type == "ping":
return_obj = json.dumps({"type": "ping", "payload": obj.payload})
self.worker.send_message(return_obj)
@timerevent(budatime="1m", description="Ensure the bot is connected to Text relay", is_hidden=True,
is_enabled=False, run_at_startup=True)
def handle_connect_event(self, _, _1):
if not self.worker or not self.dthread.is_alive():
self.connect()
def process_relay_message(self, _, message):
if self.encrypter:
message = self.encrypter.decrypt(message.encode('utf-8'))
obj = DictObject(json.loads(message))
if obj.type == "message":
channel = self.get_channel_name(obj.source)
message = ""
message += "[%s] " % self.setting_service.get("websocket_relay_channel_color").format_text(channel)
if obj.user:
message += "%s: " % self.setting_service.get("websocket_relay_sender_color").format_text(obj.user.name)
message += self.setting_service.get("websocket_relay_message_color").format_text(obj.message)
self.message_hub_service.send_message(self.MESSAGE_SOURCE, obj.get("user", None), obj.message, message)
def send_relay_event(self, char_id, event_type, source):
char_name = self.character_service.resolve_char_to_name(char_id)
obj = {"user": {"id": char_id,
"name": char_name},
"type": event_type,
"source": self.create_source_obj(source)}
self.send_relay_message(obj)
def send_relay_message(self, message):
if self.worker:
message = json.dumps(message)
if self.encrypter:
message = self.encrypter.encrypt(message.encode('utf-8')).decode('utf-8')
obj = json.dumps({"type": "message", "payload": message})
self.worker.send_message(obj)
def handle_message_from_hub(self, ctx):
if not ctx.sender:
return
if self.worker:
method = self.setting_service.get_value("ws_relay_type")
symbol = self.setting_service.get_value("ws_msg_relay_prefix")
plain_msg = ctx.message or ctx.formatted_message
if method == "unless_symbol" and len(plain_msg) > len(symbol) and plain_msg[:len(symbol)] == symbol:
return
elif method == "with_symbol":
if len(plain_msg) < len(symbol) or plain_msg[:len(symbol)] != symbol:
return
else:
plain_msg = plain_msg[len(symbol):]
elif method == "always":
trim = len(ctx.formatted_message) - len(plain_msg)
if not ctx.sender and ctx.message != ctx.formatted_message[trim:]:
return
obj = {"user": self.create_user_obj(ctx.sender), "message": plain_msg.strip(),
"type": "message",
"source": self.create_source_obj(ctx.source)}
self.send_relay_message(obj)
def connect(self):
self.worker = WebsocketRelayWorker(self.queue,
self.setting_service.get("websocket_relay_server_address").get_value(), True)
self.dthread = threading.Thread(target=self.worker.run, daemon=True)
self.dthread.start()
if self.worker:
self.worker.close()
self.worker = None
self.dthread.join()
self.dthread = None
def websocket_relay_update(self, setting_name, _, new_value):
if setting_name == "websocket_relay_enabled":
event_handlers = [self.handle_connect_event, self.handle_queue_event]
for handler in event_handlers:
event_handler = self.util.get_handler_name(handler)
event_base_type, event_sub_type = self.event_service.get_event_type_parts(handler.event.event_type)
self.event_service.update_event_status(event_base_type, event_sub_type, event_handler,
1 if new_value else 0)
if not new_value:
self.disconnect()
elif setting_name == "websocket_relay_server_address":
if self.setting_service.get("websocket_relay_enabled").get_value():
self.connect()
elif setting_name == "websocket_encryption_key":
self.initialize_encrypter(new_value)
if self.setting_service.get("websocket_relay_enabled").get_value():
self.connect()
def get_channel_name(self, source):
channel_name = source.label or source.name
if source.channel:
channel_name += " " + source.channel
return channel_name
def create_user_obj(self, sender):
if sender:
return {
"id": sender.get("char_id", None),
"name": sender.name
}
else:
return None
def create_source_obj(self, source):
org_name = self.public_channel_service.get_org_name() or self.get_org_channel_prefix()
if source == "private_channel":
if org_name:
channel = "Guest"
else:
channel = ""
elif org_name and source == "org_channel":
channel = ""
else:
channel = source.capitalize()
channel_type = source
if source == "private_channel":
channel_type = "priv"
elif source == "org_channel":
channel_type = "org"
return {
"name": org_name or self.bot.get_char_name(),
"label": self.setting_service.get("ws_relay_prefix").get_value() or "",
"channel": channel,
"type": channel_type,
"server": self.bot.dimension
}