zmq: Use recv_multipart and set server keypair in prepare script.

This commit is contained in:
tecnovert
2025-07-29 02:00:08 +02:00
parent 53fc673e71
commit 6d4200f871
3 changed files with 32 additions and 16 deletions

View File

@@ -10406,9 +10406,7 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp):
None, None,
) )
def processZmqHashwtx(self) -> None: def processZmqHashwtx(self, message) -> None:
self.zmqSubscriber.recv()
try: try:
if Coins.PART not in self.coin_clients: if Coins.PART not in self.coin_clients:
return return
@@ -10561,11 +10559,13 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp):
if self._zmq_queue_enabled: if self._zmq_queue_enabled:
try: try:
if self._read_zmq_queue: if self._read_zmq_queue:
message = self.zmqSubscriber.recv(flags=zmq.NOBLOCK) topic, message, seq = self.zmqSubscriber.recv_multipart(
if message == b"smsg": flags=zmq.NOBLOCK
self.processZmqSmsg() )
elif message == b"hashwtx": if topic == b"smsg":
self.processZmqHashwtx() self.processZmqSmsg(message)
elif topic == b"hashwtx":
self.processZmqHashwtx(message)
except zmq.Again as e: # noqa: F841 except zmq.Again as e: # noqa: F841
pass pass
except Exception as e: except Exception as e:

View File

@@ -6,6 +6,7 @@
# Distributed under the MIT software license, see the accompanying # Distributed under the MIT software license, see the accompanying
# file LICENSE or http://www.opensource.org/licenses/mit-license.php. # file LICENSE or http://www.opensource.org/licenses/mit-license.php.
import base64
import contextlib import contextlib
import gnupg import gnupg
import hashlib import hashlib
@@ -26,6 +27,7 @@ import threading
import time import time
import urllib.parse import urllib.parse
import zipfile import zipfile
import zmq
from urllib.request import urlopen from urllib.request import urlopen
@@ -1374,6 +1376,13 @@ def prepareDataDir(coin, settings, chain, particl_mnemonic, extra_opts={}):
COINS_RPCBIND_IP, settings["zmqport"] COINS_RPCBIND_IP, settings["zmqport"]
) )
) )
zmqsecret = extra_opts.get("zmqsecret", None)
if zmqsecret:
try:
_ = base64.b64decode(zmqsecret)
except Exception as e: # noqa: F841
raise ValueError("zmqsecret must be base64 encoded")
fp.write(f"serverkeyzmq={zmqsecret}\n")
fp.write("spentindex=1\n") fp.write("spentindex=1\n")
fp.write("txindex=1\n") fp.write("txindex=1\n")
fp.write("staking=0\n") fp.write("staking=0\n")
@@ -3211,6 +3220,9 @@ def main():
for c in with_coins: for c in with_coins:
withchainclients[c] = chainclients[c] withchainclients[c] = chainclients[c]
zmq_public_key, zmq_secret_key = zmq.curve_keypair()
extra_opts["zmqsecret"] = base64.b64encode(zmq_secret_key).decode("utf-8")
settings = { settings = {
"debug": True, "debug": True,
"zmqhost": f"tcp://{PART_RPC_HOST}", "zmqhost": f"tcp://{PART_RPC_HOST}",
@@ -3226,6 +3238,7 @@ def main():
"check_watched_seconds": 60, "check_watched_seconds": 60,
"check_expired_seconds": 60, "check_expired_seconds": 60,
"wallet_update_timeout": 10, # Seconds to wait for wallet page update "wallet_update_timeout": 10, # Seconds to wait for wallet page update
"zmq_server_key": base64.b64encode(zmq_public_key).decode("utf-8"),
} }
wshost: str = extra_opts.get("wshost", htmlhost) wshost: str = extra_opts.get("wshost", htmlhost)

View File

@@ -4,6 +4,7 @@
# Distributed under the MIT software license, see the accompanying # Distributed under the MIT software license, see the accompanying
# file LICENSE or http://www.opensource.org/licenses/mit-license.php. # file LICENSE or http://www.opensource.org/licenses/mit-license.php.
import base64
import json import json
import random import random
import zmq import zmq
@@ -89,6 +90,7 @@ class BSXNetwork:
self._poll_smsg = self.settings.get("poll_smsg", False) self._poll_smsg = self.settings.get("poll_smsg", False)
self.zmqContext = None self.zmqContext = None
self.zmqSubscriber = None self.zmqSubscriber = None
self.zmq_server_key = self.settings.get("zmq_server_key", None)
self.SMSG_SECONDS_IN_HOUR = ( self.SMSG_SECONDS_IN_HOUR = (
60 * 60 60 * 60
@@ -145,12 +147,17 @@ class BSXNetwork:
if self._zmq_queue_enabled: if self._zmq_queue_enabled:
self.zmqContext = zmq.Context() self.zmqContext = zmq.Context()
self.zmqSubscriber = self.zmqContext.socket(zmq.SUB) self.zmqSubscriber = self.zmqContext.socket(zmq.SUB)
if self.zmq_server_key is not None:
zmq_server_key = base64.b64decode(self.zmq_server_key)
public_key, secret_key = zmq.curve_keypair()
self.zmqSubscriber.setsockopt(zmq.CURVE_PUBLICKEY, public_key)
self.zmqSubscriber.setsockopt(zmq.CURVE_SECRETKEY, secret_key)
self.zmqSubscriber.setsockopt(zmq.CURVE_SERVERKEY, zmq_server_key)
self.zmqSubscriber.setsockopt_string(zmq.SUBSCRIBE, "smsg")
self.zmqSubscriber.setsockopt_string(zmq.SUBSCRIBE, "hashwtx")
self.zmqSubscriber.connect( self.zmqSubscriber.connect(
self.settings["zmqhost"] + ":" + str(self.settings["zmqport"]) self.settings["zmqhost"] + ":" + str(self.settings["zmqport"])
) )
self.zmqSubscriber.setsockopt_string(zmq.SUBSCRIBE, "smsg")
self.zmqSubscriber.setsockopt_string(zmq.SUBSCRIBE, "hashwtx")
ro = self.callrpc("smsglocalkeys") ro = self.callrpc("smsglocalkeys")
found = False found = False
@@ -725,11 +732,7 @@ class BSXNetwork:
return bytes.fromhex(msg["hex"][2:-2]) return bytes.fromhex(msg["hex"][2:-2])
return bytes.fromhex(msg["hex"][2:]) return bytes.fromhex(msg["hex"][2:])
def processZmqSmsg(self) -> None: def processZmqSmsg(self, message) -> None:
message = self.zmqSubscriber.recv()
# Clear
_ = self.zmqSubscriber.recv()
if message[0] == 3: # Paid smsg if message[0] == 3: # Paid smsg
return # TODO: Switch to paid? return # TODO: Switch to paid?