diff --git a/basicswap/basicswap.py b/basicswap/basicswap.py index c43e240..7c71c43 100644 --- a/basicswap/basicswap.py +++ b/basicswap/basicswap.py @@ -10406,9 +10406,7 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp): None, ) - def processZmqHashwtx(self) -> None: - self.zmqSubscriber.recv() - + def processZmqHashwtx(self, message) -> None: try: if Coins.PART not in self.coin_clients: return @@ -10561,11 +10559,13 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp): if self._zmq_queue_enabled: try: if self._read_zmq_queue: - message = self.zmqSubscriber.recv(flags=zmq.NOBLOCK) - if message == b"smsg": - self.processZmqSmsg() - elif message == b"hashwtx": - self.processZmqHashwtx() + topic, message, seq = self.zmqSubscriber.recv_multipart( + flags=zmq.NOBLOCK + ) + if topic == b"smsg": + self.processZmqSmsg(message) + elif topic == b"hashwtx": + self.processZmqHashwtx(message) except zmq.Again as e: # noqa: F841 pass except Exception as e: diff --git a/basicswap/bin/prepare.py b/basicswap/bin/prepare.py index b056db3..beb38d1 100755 --- a/basicswap/bin/prepare.py +++ b/basicswap/bin/prepare.py @@ -6,6 +6,7 @@ # Distributed under the MIT software license, see the accompanying # file LICENSE or http://www.opensource.org/licenses/mit-license.php. +import base64 import contextlib import gnupg import hashlib @@ -26,6 +27,7 @@ import threading import time import urllib.parse import zipfile +import zmq from urllib.request import urlopen @@ -1374,6 +1376,13 @@ def prepareDataDir(coin, settings, chain, particl_mnemonic, extra_opts={}): COINS_RPCBIND_IP, settings["zmqport"] ) ) + zmqsecret = extra_opts.get("zmqsecret", None) + if zmqsecret: + try: + _ = base64.b64decode(zmqsecret) + except Exception as e: # noqa: F841 + raise ValueError("zmqsecret must be base64 encoded") + fp.write(f"serverkeyzmq={zmqsecret}\n") fp.write("spentindex=1\n") fp.write("txindex=1\n") fp.write("staking=0\n") @@ -3211,6 +3220,9 @@ def main(): for c in with_coins: withchainclients[c] = chainclients[c] + zmq_public_key, zmq_secret_key = zmq.curve_keypair() + extra_opts["zmqsecret"] = base64.b64encode(zmq_secret_key).decode("utf-8") + settings = { "debug": True, "zmqhost": f"tcp://{PART_RPC_HOST}", @@ -3226,6 +3238,7 @@ def main(): "check_watched_seconds": 60, "check_expired_seconds": 60, "wallet_update_timeout": 10, # Seconds to wait for wallet page update + "zmq_server_key": base64.b64encode(zmq_public_key).decode("utf-8"), } wshost: str = extra_opts.get("wshost", htmlhost) diff --git a/basicswap/network/bsx_network.py b/basicswap/network/bsx_network.py index 51bcdfc..ea160df 100644 --- a/basicswap/network/bsx_network.py +++ b/basicswap/network/bsx_network.py @@ -4,6 +4,7 @@ # Distributed under the MIT software license, see the accompanying # file LICENSE or http://www.opensource.org/licenses/mit-license.php. +import base64 import json import random import zmq @@ -89,6 +90,7 @@ class BSXNetwork: self._poll_smsg = self.settings.get("poll_smsg", False) self.zmqContext = None self.zmqSubscriber = None + self.zmq_server_key = self.settings.get("zmq_server_key", None) self.SMSG_SECONDS_IN_HOUR = ( 60 * 60 @@ -145,12 +147,17 @@ class BSXNetwork: if self._zmq_queue_enabled: self.zmqContext = zmq.Context() self.zmqSubscriber = self.zmqContext.socket(zmq.SUB) - + if self.zmq_server_key is not None: + zmq_server_key = base64.b64decode(self.zmq_server_key) + public_key, secret_key = zmq.curve_keypair() + self.zmqSubscriber.setsockopt(zmq.CURVE_PUBLICKEY, public_key) + self.zmqSubscriber.setsockopt(zmq.CURVE_SECRETKEY, secret_key) + self.zmqSubscriber.setsockopt(zmq.CURVE_SERVERKEY, zmq_server_key) + self.zmqSubscriber.setsockopt_string(zmq.SUBSCRIBE, "smsg") + self.zmqSubscriber.setsockopt_string(zmq.SUBSCRIBE, "hashwtx") self.zmqSubscriber.connect( self.settings["zmqhost"] + ":" + str(self.settings["zmqport"]) ) - self.zmqSubscriber.setsockopt_string(zmq.SUBSCRIBE, "smsg") - self.zmqSubscriber.setsockopt_string(zmq.SUBSCRIBE, "hashwtx") ro = self.callrpc("smsglocalkeys") found = False @@ -725,11 +732,7 @@ class BSXNetwork: return bytes.fromhex(msg["hex"][2:-2]) return bytes.fromhex(msg["hex"][2:]) - def processZmqSmsg(self) -> None: - message = self.zmqSubscriber.recv() - # Clear - _ = self.zmqSubscriber.recv() - + def processZmqSmsg(self, message) -> None: if message[0] == 3: # Paid smsg return # TODO: Switch to paid?