245 lines
12 KiB
Python
245 lines
12 KiB
Python
import base64
|
|
import json
|
|
import threading
|
|
|
|
from cryptography.fernet import Fernet
|
|
from cryptography.hazmat.primitives import hashes
|
|
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
|
|
|
from core.decorators import instance, timerevent
|
|
from core.dict_object import DictObject
|
|
from core.logger import Logger
|
|
from core.setting_service import SettingService
|
|
from core.setting_types import ColorSettingType, TextSettingType, HiddenSettingType, BooleanSettingType
|
|
from modules.standard.datanet.ws_worker import WebsocketRelayWorker
|
|
|
|
|
|
@instance()
|
|
class RelayController:
|
|
MESSAGE_SOURCE = "websocket_relay"
|
|
|
|
def __init__(self):
|
|
self.dthread = None
|
|
self.queue = []
|
|
self.logger = Logger(__name__)
|
|
self.worker = None
|
|
self.encrypter = None
|
|
|
|
def inject(self, registry):
|
|
self.bot = registry.get_instance("bot")
|
|
self.db = registry.get_instance("db")
|
|
self.util = registry.get_instance("util")
|
|
self.setting_service: SettingService = registry.get_instance("setting_service")
|
|
self.event_service = registry.get_instance("event_service")
|
|
self.character_service = registry.get_instance("character_service")
|
|
self.pork_service = registry.get_instance("pork_service")
|
|
self.online_controller = registry.get_instance("online_controller")
|
|
self.public_channel_service = registry.get_instance("public_channel_service")
|
|
self.message_hub_service = registry.get_instance("message_hub_service")
|
|
|
|
def pre_start(self):
|
|
self.message_hub_service.register_message_source(self.MESSAGE_SOURCE)
|
|
|
|
def start(self):
|
|
self.message_hub_service.register_message_destination(self.MESSAGE_SOURCE,
|
|
self.handle_message_from_hub,
|
|
["private_channel", "org_channel"],
|
|
[self.MESSAGE_SOURCE])
|
|
|
|
self.setting_service.register(self.module_name, "websocket_relay_enabled", False, BooleanSettingType(),
|
|
"Enable the websocket relay")
|
|
self.setting_service.register(self.module_name, "websocket_relay_server_address",
|
|
"ws://localhost/subscribe/relay",
|
|
TextSettingType(["ws://localhost/subscribe/relay"]),
|
|
"The address of the websocket relay server",
|
|
"All bots on the relay must connect to the same server and channel. "
|
|
"If using the public relay server, use a unique channel name.")
|
|
self.setting_service.register(self.module_name, "websocket_relay_channel_color", "#FFFF00",
|
|
ColorSettingType(), "Color of the channel in websocket relay messages")
|
|
self.setting_service.register(self.module_name, "websocket_relay_message_color", "#FCA712",
|
|
ColorSettingType(),
|
|
"Color of the message content in websocket relay messages")
|
|
self.setting_service.register(self.module_name, "websocket_relay_sender_color", "#00DE42",
|
|
ColorSettingType(), "Color of the sender in websocket relay messages")
|
|
self.setting_service.register(self.module_name, "websocket_encryption_key", "",
|
|
HiddenSettingType(allow_empty=True),
|
|
"An encryption key used to encrypt messages over a public websocket relay")
|
|
self.setting_service.register(self.module_name, "ws_relay_prefix", "", TextSettingType(allow_empty=True),
|
|
"Name of this relay (if you don't want to use org or bot name)")
|
|
self.setting_service.register(self.module_name, "ws_msg_relay_prefix", "||",
|
|
TextSettingType(["!", "#", "*", "@", "$", "+", "-"]),
|
|
"Prefix for Messages which should get relayed")
|
|
self.setting_service.register(self.module_name, "ws_relay_type", "with_symbol",
|
|
TextSettingType(["with_symbol", "unless_symbol", "always"]),
|
|
"Relay Messages", )
|
|
self.setting_service.register(self.module_name, "ws_relay_salt", "IGNCore",
|
|
TextSettingType(["IGNCore", "IgnCore"]),
|
|
"Salt for relayed messages (used together with the encryption key)")
|
|
self.initialize_encrypter(self.setting_service.get("websocket_encryption_key").get_value())
|
|
|
|
self.setting_service.register_change_listener("websocket_relay_enabled", self.websocket_relay_update)
|
|
self.setting_service.register_change_listener("websocket_relay_server_address", self.websocket_relay_update)
|
|
self.setting_service.register_change_listener("websocket_encryption_key", self.websocket_relay_update)
|
|
|
|
def get_org_channel_prefix(self):
|
|
return self.setting_service.get_value("ws_relay_prefix") or self.public_channel_service.get_org_name() or self.bot.get_char_name()
|
|
|
|
def initialize_encrypter(self, password):
|
|
if password:
|
|
salt = bytes(self.setting_service.get("ws_relay_salt").get_value(), encoding="utf-8")
|
|
kdf = PBKDF2HMAC(
|
|
algorithm=hashes.SHA256(),
|
|
length=32,
|
|
salt=salt,
|
|
iterations=10000, )
|
|
key = base64.urlsafe_b64encode(kdf.derive(password.encode("utf-8")))
|
|
self.encrypter = Fernet(key)
|
|
else:
|
|
self.encrypter = None
|
|
|
|
@timerevent(budatime="1s", description="Relay messages from Text relay to the internal message hub", is_hidden=True,
|
|
is_enabled=False)
|
|
def handle_queue_event(self, _, _1):
|
|
while self.queue:
|
|
obj = self.queue.pop(0)
|
|
if obj.type == "message":
|
|
payload = obj.payload
|
|
self.process_relay_message(obj.client_id, payload)
|
|
elif obj.type == "ping":
|
|
return_obj = json.dumps({"type": "ping", "payload": obj.payload})
|
|
self.worker.send_message(return_obj)
|
|
|
|
@timerevent(budatime="1m", description="Ensure the bot is connected to Text relay", is_hidden=True,
|
|
is_enabled=False, run_at_startup=True)
|
|
def handle_connect_event(self, _, _1):
|
|
if not self.worker or not self.dthread.is_alive():
|
|
self.connect()
|
|
|
|
def process_relay_message(self, _, message):
|
|
if self.encrypter:
|
|
message = self.encrypter.decrypt(message.encode('utf-8'))
|
|
obj = DictObject(json.loads(message))
|
|
|
|
if obj.type == "message":
|
|
channel = self.get_channel_name(obj.source)
|
|
|
|
message = ""
|
|
message += "[%s] " % self.setting_service.get("websocket_relay_channel_color").format_text(channel)
|
|
if obj.user:
|
|
message += "%s: " % self.setting_service.get("websocket_relay_sender_color").format_text(obj.user.name)
|
|
message += self.setting_service.get("websocket_relay_message_color").format_text(obj.message)
|
|
|
|
self.message_hub_service.send_message(self.MESSAGE_SOURCE, obj.get("user", None), obj.message, message)
|
|
|
|
def send_relay_event(self, char_id, event_type, source):
|
|
char_name = self.character_service.resolve_char_to_name(char_id)
|
|
obj = {"user": {"id": char_id,
|
|
"name": char_name},
|
|
"type": event_type,
|
|
"source": self.create_source_obj(source)}
|
|
self.send_relay_message(obj)
|
|
|
|
def send_relay_message(self, message):
|
|
if self.worker:
|
|
message = json.dumps(message)
|
|
if self.encrypter:
|
|
message = self.encrypter.encrypt(message.encode('utf-8')).decode('utf-8')
|
|
obj = json.dumps({"type": "message", "payload": message})
|
|
self.worker.send_message(obj)
|
|
|
|
def handle_message_from_hub(self, ctx):
|
|
if not ctx.sender:
|
|
return
|
|
if self.worker:
|
|
method = self.setting_service.get_value("ws_relay_type")
|
|
symbol = self.setting_service.get_value("ws_msg_relay_prefix")
|
|
plain_msg = ctx.message or ctx.formatted_message
|
|
if method == "unless_symbol" and len(plain_msg) > len(symbol) and plain_msg[:len(symbol)] == symbol:
|
|
return
|
|
elif method == "with_symbol":
|
|
if len(plain_msg) < len(symbol) or plain_msg[:len(symbol)] != symbol:
|
|
return
|
|
else:
|
|
plain_msg = plain_msg[len(symbol):]
|
|
elif method == "always":
|
|
trim = len(ctx.formatted_message) - len(plain_msg)
|
|
if not ctx.sender and ctx.message != ctx.formatted_message[trim:]:
|
|
return
|
|
|
|
obj = {"user": self.create_user_obj(ctx.sender), "message": plain_msg.strip(),
|
|
"type": "message",
|
|
"source": self.create_source_obj(ctx.source)}
|
|
self.send_relay_message(obj)
|
|
|
|
def connect(self):
|
|
|
|
self.worker = WebsocketRelayWorker(self.queue,
|
|
self.setting_service.get("websocket_relay_server_address").get_value(), True)
|
|
self.dthread = threading.Thread(target=self.worker.run, daemon=True)
|
|
self.dthread.start()
|
|
|
|
if self.worker:
|
|
self.worker.close()
|
|
self.worker = None
|
|
self.dthread.join()
|
|
self.dthread = None
|
|
|
|
def websocket_relay_update(self, setting_name, _, new_value):
|
|
if setting_name == "websocket_relay_enabled":
|
|
event_handlers = [self.handle_connect_event, self.handle_queue_event]
|
|
for handler in event_handlers:
|
|
event_handler = self.util.get_handler_name(handler)
|
|
event_base_type, event_sub_type = self.event_service.get_event_type_parts(handler.event.event_type)
|
|
self.event_service.update_event_status(event_base_type, event_sub_type, event_handler,
|
|
1 if new_value else 0)
|
|
|
|
if not new_value:
|
|
self.disconnect()
|
|
elif setting_name == "websocket_relay_server_address":
|
|
if self.setting_service.get("websocket_relay_enabled").get_value():
|
|
self.connect()
|
|
elif setting_name == "websocket_encryption_key":
|
|
self.initialize_encrypter(new_value)
|
|
if self.setting_service.get("websocket_relay_enabled").get_value():
|
|
self.connect()
|
|
|
|
def get_channel_name(self, source):
|
|
channel_name = source.label or source.name
|
|
if source.channel:
|
|
channel_name += " " + source.channel
|
|
return channel_name
|
|
|
|
def create_user_obj(self, sender):
|
|
if sender:
|
|
return {
|
|
"id": sender.get("char_id", None),
|
|
"name": sender.name
|
|
}
|
|
else:
|
|
return None
|
|
|
|
def create_source_obj(self, source):
|
|
org_name = self.public_channel_service.get_org_name() or self.get_org_channel_prefix()
|
|
if source == "private_channel":
|
|
if org_name:
|
|
channel = "Guest"
|
|
else:
|
|
channel = ""
|
|
elif org_name and source == "org_channel":
|
|
channel = ""
|
|
else:
|
|
channel = source.capitalize()
|
|
|
|
channel_type = source
|
|
if source == "private_channel":
|
|
channel_type = "priv"
|
|
elif source == "org_channel":
|
|
channel_type = "org"
|
|
return {
|
|
"name": org_name or self.bot.get_char_name(),
|
|
"label": self.setting_service.get("ws_relay_prefix").get_value() or "",
|
|
"channel": channel,
|
|
"type": channel_type,
|
|
"server": self.bot.dimension
|
|
}
|