net: Manage SMSG pubkeys in BSX.

This commit is contained in:
tecnovert
2025-07-16 22:53:03 +02:00
parent e73e084a6d
commit 10d6b13930
3 changed files with 99 additions and 86 deletions

View File

@@ -14,6 +14,9 @@ from basicswap.basicswap_util import (
MessageNetworks,
MessageTypes,
)
from basicswap.chainparams import (
Coins,
)
from basicswap.db import DirectMessageRoute
from basicswap.messages_npb import (
MessagePortalOffer,
@@ -29,6 +32,9 @@ from basicswap.network.simplex import (
sendSimplexMsg,
)
from basicswap.util import ensure
from basicswap.util.address import (
b58decode,
)
from basicswap.util.logging import LogCategories as LC
from basicswap.util.smsg import smsgGetID
@@ -213,9 +219,8 @@ class BSXNetwork:
# TODO: Ensure smsg is enabled for the active wallet.
if self._smsg_plaintext_version >= 2:
self.log.debug("TODO: disable addReceivedPubkeys")
# ro = self.callrpc("smsgoptions", ["set", "addReceivedPubkeys", False])
# self.log.debug("smsgoptions {ro}")
ro = self.callrpc("smsgoptions", ["set", "addReceivedPubkeys", False])
self.log.debug("smsgoptions {ro}")
def add_connection(self, host, port, peer_pubkey):
self.log.info(f"add_connection {host} {port} {peer_pubkey.hex()}.")
@@ -226,6 +231,66 @@ class BSXNetwork:
return {"Error": "Not Initialised"}
return self._network.get_info()
def getPrivkeyForAddress(self, cursor, addr: str) -> bytes:
ci_part = self.ci(Coins.PART)
try:
return ci_part.decodeKey(
self.callrpc(
"smsgdumpprivkey",
[
addr,
],
)
)
except Exception as e: # noqa: F841
pass
try:
return ci_part.decodeKey(
ci_part.rpc_wallet(
"dumpprivkey",
[
addr,
],
)
)
except Exception as e: # noqa: F841
pass
raise ValueError("key not found")
def getPubkeyForAddress(self, cursor, addr: str) -> bytes:
if addr == self.network_addr:
return bytes.fromhex(self.network_pubkey)
use_cursor = self.openDB(cursor)
try:
query: str = "SELECT pk_from FROM offers WHERE addr_from = :addr_to LIMIT 1"
rows = use_cursor.execute(query, {"addr_to": addr}).fetchall()
if len(rows) > 0:
return rows[0][0]
query: str = "SELECT pk_bid_addr FROM bids WHERE bid_addr = :addr_to LIMIT 1"
rows = use_cursor.execute(query, {"addr_to": addr}).fetchall()
if len(rows) > 0:
return rows[0][0]
query: str = "SELECT pubkey FROM smsgaddresses WHERE addr = :addr LIMIT 1"
rows = use_cursor.execute(query, {"addr": addr}).fetchall()
if len(rows) > 0:
return bytes.fromhex(rows[0][0])
finally:
if cursor is None:
self.closeDB(use_cursor, commit=False)
if self._have_smsg_rpc:
try:
rv = self.callrpc(
"smsggetpubkey",
[
addr,
],
)
return b58decode(rv["publickey"])
except Exception as e: # noqa: F841
pass
raise ValueError(f"Could not get public key for address: {addr}")
def addMessageNetworkLink(
self, linked_type, linked_id, link_type, network_id, cursor
) -> None:
@@ -469,7 +534,7 @@ class BSXNetwork:
self.forwardSmsg(smsg_msg)
else:
net_message_id, smsg_msg = self.sendSmsg(
addr_from, addr_to, payload_hex, msg_valid, return_msg=True
addr_from, addr_to, payload_hex, msg_valid, return_msg=True, cursor=cursor
)
elif network_type == MessageNetworks.SIMPLEX:
if smsg_msg:
@@ -561,20 +626,24 @@ class BSXNetwork:
payload_hex: bytes,
msg_valid: int,
return_msg: bool = False,
cursor=None,
) -> bytes:
options = {"decodehex": True, "ttl_is_seconds": True}
if self._smsg_plaintext_version >= 2:
options["plaintext_format_version"] = 2
options["compression"] = 0
send_to = self.getPubkeyForAddress(cursor, addr_to).hex()
else:
send_to = addr_to
if self._smsg_add_to_outbox is False:
options["add_to_outbox"] = False
if return_msg:
options["returnmsg"] = True
try:
ro = self.callrpc(
"smsgsend",
[addr_from, addr_to, payload_hex, False, msg_valid, False, options],
[addr_from, send_to, payload_hex, False, msg_valid, False, options],
)
self.num_smsg_messages_sent += 1
if return_msg:
@@ -692,7 +761,7 @@ class BSXNetwork:
msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, portal.time_valid)
if network_from_id == MessageNetworks.SMSG:
net_message_id = self.sendSmsg(addr_portal, addr_to, payload_hex, msg_valid)
net_message_id = self.sendSmsg(addr_portal, addr_to, payload_hex, msg_valid, cursor=cursor)
elif network_from_id == MessageNetworks.SIMPLEX:
network = self.getActiveNetwork(MessageNetworks.SIMPLEX)
@@ -741,15 +810,13 @@ class BSXNetwork:
payload_hex = (
str.format("{:02x}", MessageTypes.PORTAL_OFFER) + msg_buf.to_bytes().hex()
)
msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, portal.time_valid)
if portal.network_from == MessageNetworks.SMSG:
net_message_id = self.sendSmsg(addr_portal, addr_to, payload_hex, msg_valid)
elif portal.network_from == MessageNetworks.SIMPLEX:
network = self.getActiveNetwork(MessageNetworks.SIMPLEX)
cursor = self.openDB()
try:
cursor = self.openDB()
try:
msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, portal.time_valid)
if portal.network_from == MessageNetworks.SMSG:
net_message_id = self.sendSmsg(addr_portal, addr_to, payload_hex, msg_valid, cursor=cursor)
elif portal.network_from == MessageNetworks.SIMPLEX:
network = self.getActiveNetwork(MessageNetworks.SIMPLEX)
net_message_id = sendSimplexMsg(
self,
network,
@@ -760,10 +827,10 @@ class BSXNetwork:
cursor,
portal.time_start,
)
finally:
self.closeDB(cursor)
else:
raise RuntimeError(f"Unknown network id {portal.network_from}")
else:
raise RuntimeError(f"Unknown network id {portal.network_from}")
finally:
self.closeDB(cursor)
portal.time_start = now
self.logD(
@@ -784,7 +851,7 @@ class BSXNetwork:
addr_to: str = portal.address_from
if portal.network_from == MessageNetworks.SMSG:
net_message_id = self.sendSmsg(addr_from, addr_to, payload_hex, msg_valid)
net_message_id = self.sendSmsg(addr_from, addr_to, payload_hex, msg_valid, cursor=cursor)
elif portal.network_from == MessageNetworks.SIMPLEX:
network = self.getActiveNetwork(MessageNetworks.SIMPLEX)
@@ -958,6 +1025,8 @@ class BSXNetwork:
if now - self._last_checked_smsg >= self.check_smsg_seconds:
self._last_checked_smsg = now
options = {"encoding": "hex", "setread": True}
if self._smsg_plaintext_version >= 2:
options["pubkey_from"] = True
msgs = self.callrpc("smsginbox", ["unread", "", options])
for msg in msgs["messages"]:
self.processMsg(msg)