net: Manage SMSG pubkeys in BSX.

This commit is contained in:
tecnovert
2025-07-16 22:53:03 +02:00
parent e73e084a6d
commit 10d6b13930
3 changed files with 99 additions and 86 deletions

View File

@@ -7545,7 +7545,7 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp):
self.log.debug("Ignoring expired offer.")
return
_ = self.expandMessageNets(offer_data.message_nets)
_ = self.expandMessageNets(offer_data.message_nets) # Decode to validate
offer_rate: int = ci_from.make_int(
offer_data.amount_to / offer_data.amount_from, r=1
@@ -7988,6 +7988,8 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp):
bid_rate: int = ci_from.make_int(bid_data.amount_to / bid_data.amount, r=1)
self.validateBidAmount(offer, bid_data.amount, bid_rate)
_ = self.expandMessageNets(bid_data.message_nets) # Decode to validate
network_type: str = msg.get("msg_net", "smsg")
network_type_received_on_id: int = networkTypeToID(network_type)
bid_message_nets: str = self.selectMessageNetString([network_type_received_on_id, ], bid_data.message_nets)
@@ -8449,6 +8451,8 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp):
if ci_to.curve_type() == Curves.ed25519:
ensure(len(bid_data.kbsf_dleag) <= 16000, "Invalid kbsf_dleag size")
_ = self.expandMessageNets(bid_data.message_nets) # Decode to validate
bid_id = bytes.fromhex(msg["msgid"])
network_type: str = msg.get("msg_net", "smsg")
@@ -9977,7 +9981,7 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp):
ensure(offer.swap_type == SwapTypes.XMR_SWAP, "Bid/offer swap type mismatch")
ensure(xmr_offer, f"Adaptor-sig offer not found: {self.log.id(offer_id)}.")
_ = self.expandMessageNets(bid_data.message_nets)
_ = self.expandMessageNets(bid_data.message_nets) # Decode to validate
ci_from = self.ci(offer.coin_to)
ci_to = self.ci(offer.coin_from)

View File

@@ -14,6 +14,9 @@ from basicswap.basicswap_util import (
MessageNetworks,
MessageTypes,
)
from basicswap.chainparams import (
Coins,
)
from basicswap.db import DirectMessageRoute
from basicswap.messages_npb import (
MessagePortalOffer,
@@ -29,6 +32,9 @@ from basicswap.network.simplex import (
sendSimplexMsg,
)
from basicswap.util import ensure
from basicswap.util.address import (
b58decode,
)
from basicswap.util.logging import LogCategories as LC
from basicswap.util.smsg import smsgGetID
@@ -213,9 +219,8 @@ class BSXNetwork:
# TODO: Ensure smsg is enabled for the active wallet.
if self._smsg_plaintext_version >= 2:
self.log.debug("TODO: disable addReceivedPubkeys")
# ro = self.callrpc("smsgoptions", ["set", "addReceivedPubkeys", False])
# self.log.debug("smsgoptions {ro}")
ro = self.callrpc("smsgoptions", ["set", "addReceivedPubkeys", False])
self.log.debug("smsgoptions {ro}")
def add_connection(self, host, port, peer_pubkey):
self.log.info(f"add_connection {host} {port} {peer_pubkey.hex()}.")
@@ -226,6 +231,66 @@ class BSXNetwork:
return {"Error": "Not Initialised"}
return self._network.get_info()
def getPrivkeyForAddress(self, cursor, addr: str) -> bytes:
ci_part = self.ci(Coins.PART)
try:
return ci_part.decodeKey(
self.callrpc(
"smsgdumpprivkey",
[
addr,
],
)
)
except Exception as e: # noqa: F841
pass
try:
return ci_part.decodeKey(
ci_part.rpc_wallet(
"dumpprivkey",
[
addr,
],
)
)
except Exception as e: # noqa: F841
pass
raise ValueError("key not found")
def getPubkeyForAddress(self, cursor, addr: str) -> bytes:
if addr == self.network_addr:
return bytes.fromhex(self.network_pubkey)
use_cursor = self.openDB(cursor)
try:
query: str = "SELECT pk_from FROM offers WHERE addr_from = :addr_to LIMIT 1"
rows = use_cursor.execute(query, {"addr_to": addr}).fetchall()
if len(rows) > 0:
return rows[0][0]
query: str = "SELECT pk_bid_addr FROM bids WHERE bid_addr = :addr_to LIMIT 1"
rows = use_cursor.execute(query, {"addr_to": addr}).fetchall()
if len(rows) > 0:
return rows[0][0]
query: str = "SELECT pubkey FROM smsgaddresses WHERE addr = :addr LIMIT 1"
rows = use_cursor.execute(query, {"addr": addr}).fetchall()
if len(rows) > 0:
return bytes.fromhex(rows[0][0])
finally:
if cursor is None:
self.closeDB(use_cursor, commit=False)
if self._have_smsg_rpc:
try:
rv = self.callrpc(
"smsggetpubkey",
[
addr,
],
)
return b58decode(rv["publickey"])
except Exception as e: # noqa: F841
pass
raise ValueError(f"Could not get public key for address: {addr}")
def addMessageNetworkLink(
self, linked_type, linked_id, link_type, network_id, cursor
) -> None:
@@ -469,7 +534,7 @@ class BSXNetwork:
self.forwardSmsg(smsg_msg)
else:
net_message_id, smsg_msg = self.sendSmsg(
addr_from, addr_to, payload_hex, msg_valid, return_msg=True
addr_from, addr_to, payload_hex, msg_valid, return_msg=True, cursor=cursor
)
elif network_type == MessageNetworks.SIMPLEX:
if smsg_msg:
@@ -561,20 +626,24 @@ class BSXNetwork:
payload_hex: bytes,
msg_valid: int,
return_msg: bool = False,
cursor=None,
) -> bytes:
options = {"decodehex": True, "ttl_is_seconds": True}
if self._smsg_plaintext_version >= 2:
options["plaintext_format_version"] = 2
options["compression"] = 0
send_to = self.getPubkeyForAddress(cursor, addr_to).hex()
else:
send_to = addr_to
if self._smsg_add_to_outbox is False:
options["add_to_outbox"] = False
if return_msg:
options["returnmsg"] = True
try:
ro = self.callrpc(
"smsgsend",
[addr_from, addr_to, payload_hex, False, msg_valid, False, options],
[addr_from, send_to, payload_hex, False, msg_valid, False, options],
)
self.num_smsg_messages_sent += 1
if return_msg:
@@ -692,7 +761,7 @@ class BSXNetwork:
msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, portal.time_valid)
if network_from_id == MessageNetworks.SMSG:
net_message_id = self.sendSmsg(addr_portal, addr_to, payload_hex, msg_valid)
net_message_id = self.sendSmsg(addr_portal, addr_to, payload_hex, msg_valid, cursor=cursor)
elif network_from_id == MessageNetworks.SIMPLEX:
network = self.getActiveNetwork(MessageNetworks.SIMPLEX)
@@ -741,15 +810,13 @@ class BSXNetwork:
payload_hex = (
str.format("{:02x}", MessageTypes.PORTAL_OFFER) + msg_buf.to_bytes().hex()
)
msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, portal.time_valid)
if portal.network_from == MessageNetworks.SMSG:
net_message_id = self.sendSmsg(addr_portal, addr_to, payload_hex, msg_valid)
elif portal.network_from == MessageNetworks.SIMPLEX:
network = self.getActiveNetwork(MessageNetworks.SIMPLEX)
cursor = self.openDB()
try:
msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, portal.time_valid)
if portal.network_from == MessageNetworks.SMSG:
net_message_id = self.sendSmsg(addr_portal, addr_to, payload_hex, msg_valid, cursor=cursor)
elif portal.network_from == MessageNetworks.SIMPLEX:
network = self.getActiveNetwork(MessageNetworks.SIMPLEX)
net_message_id = sendSimplexMsg(
self,
network,
@@ -760,10 +827,10 @@ class BSXNetwork:
cursor,
portal.time_start,
)
finally:
self.closeDB(cursor)
else:
raise RuntimeError(f"Unknown network id {portal.network_from}")
finally:
self.closeDB(cursor)
portal.time_start = now
self.logD(
@@ -784,7 +851,7 @@ class BSXNetwork:
addr_to: str = portal.address_from
if portal.network_from == MessageNetworks.SMSG:
net_message_id = self.sendSmsg(addr_from, addr_to, payload_hex, msg_valid)
net_message_id = self.sendSmsg(addr_from, addr_to, payload_hex, msg_valid, cursor=cursor)
elif portal.network_from == MessageNetworks.SIMPLEX:
network = self.getActiveNetwork(MessageNetworks.SIMPLEX)
@@ -958,6 +1025,8 @@ class BSXNetwork:
if now - self._last_checked_smsg >= self.check_smsg_seconds:
self._last_checked_smsg = now
options = {"encoding": "hex", "setread": True}
if self._smsg_plaintext_version >= 2:
options["pubkey_from"] = True
msgs = self.callrpc("smsginbox", ["unread", "", options])
for msg in msgs["messages"]:
self.processMsg(msg)

View File

@@ -23,7 +23,6 @@ from basicswap.chainparams import (
Coins,
)
from basicswap.util.address import (
b58decode,
decodeWif,
)
from basicswap.basicswap_util import AddressTypes
@@ -173,65 +172,6 @@ def waitForConnected(ws_thread, delay_event):
raise ValueError("waitForConnected timed-out.")
def getPrivkeyForAddress(self, cursor, addr: str) -> bytes:
ci_part = self.ci(Coins.PART)
try:
return ci_part.decodeKey(
self.callrpc(
"smsgdumpprivkey",
[
addr,
],
)
)
except Exception as e: # noqa: F841
pass
try:
return ci_part.decodeKey(
ci_part.rpc_wallet(
"dumpprivkey",
[
addr,
],
)
)
except Exception as e: # noqa: F841
pass
raise ValueError("key not found")
def getPubkeyForAddress(self, cursor, addr: str) -> bytes:
if self._have_smsg_rpc:
try:
rv = self.callrpc(
"smsggetpubkey",
[
addr,
],
)
return b58decode(rv["publickey"])
except Exception as e: # noqa: F841
pass
use_cursor = self.openDB(cursor)
try:
query: str = "SELECT pk_from FROM offers WHERE addr_from = :addr_to LIMIT 1"
rows = use_cursor.execute(query, {"addr_to": addr}).fetchall()
if len(rows) > 0:
return rows[0][0]
query: str = "SELECT pk_bid_addr FROM bids WHERE bid_addr = :addr_to LIMIT 1"
rows = use_cursor.execute(query, {"addr_to": addr}).fetchall()
if len(rows) > 0:
return rows[0][0]
query: str = "SELECT pubkey FROM smsgaddresses WHERE addr = :addr LIMIT 1"
rows = use_cursor.execute(query, {"addr": addr}).fetchall()
if len(rows) > 0:
return bytes.fromhex(rows[0][0])
raise ValueError(f"Could not get public key for address: {addr}")
finally:
if cursor is None:
self.closeDB(use_cursor, commit=False)
def encryptMsg(
self,
addr_from: str,
@@ -245,8 +185,8 @@ def encryptMsg(
) -> bytes:
self.log.debug("encryptMsg")
pubkey_to = getPubkeyForAddress(self, cursor, addr_to)
privkey_from = getPrivkeyForAddress(self, cursor, addr_from)
pubkey_to = self.getPubkeyForAddress(cursor, addr_to)
privkey_from = self.getPrivkeyForAddress(cursor, addr_from)
smsg_msg: bytes = smsgEncrypt(
privkey_from,
@@ -370,7 +310,7 @@ def decryptSimplexMsg(self, msg_data):
for row in addr_rows:
addr = row[0]
try:
vk_addr = getPrivkeyForAddress(self, cursor, addr)
vk_addr = self.getPrivkeyForAddress(cursor, addr)
decrypted = smsgDecrypt(vk_addr, msg_data, output_dict=True)
decrypted["from"] = ci_part.pubkey_to_address(
bytes.fromhex(decrypted["pubkey_from"])