diff --git a/basicswap/basicswap.py b/basicswap/basicswap.py index 9384236..7ce680a 100644 --- a/basicswap/basicswap.py +++ b/basicswap/basicswap.py @@ -7545,7 +7545,7 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp): self.log.debug("Ignoring expired offer.") return - _ = self.expandMessageNets(offer_data.message_nets) + _ = self.expandMessageNets(offer_data.message_nets) # Decode to validate offer_rate: int = ci_from.make_int( offer_data.amount_to / offer_data.amount_from, r=1 @@ -7988,6 +7988,8 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp): bid_rate: int = ci_from.make_int(bid_data.amount_to / bid_data.amount, r=1) self.validateBidAmount(offer, bid_data.amount, bid_rate) + _ = self.expandMessageNets(bid_data.message_nets) # Decode to validate + network_type: str = msg.get("msg_net", "smsg") network_type_received_on_id: int = networkTypeToID(network_type) bid_message_nets: str = self.selectMessageNetString([network_type_received_on_id, ], bid_data.message_nets) @@ -8449,6 +8451,8 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp): if ci_to.curve_type() == Curves.ed25519: ensure(len(bid_data.kbsf_dleag) <= 16000, "Invalid kbsf_dleag size") + _ = self.expandMessageNets(bid_data.message_nets) # Decode to validate + bid_id = bytes.fromhex(msg["msgid"]) network_type: str = msg.get("msg_net", "smsg") @@ -9977,7 +9981,7 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp): ensure(offer.swap_type == SwapTypes.XMR_SWAP, "Bid/offer swap type mismatch") ensure(xmr_offer, f"Adaptor-sig offer not found: {self.log.id(offer_id)}.") - _ = self.expandMessageNets(bid_data.message_nets) + _ = self.expandMessageNets(bid_data.message_nets) # Decode to validate ci_from = self.ci(offer.coin_to) ci_to = self.ci(offer.coin_from) diff --git a/basicswap/network/bsx_network.py b/basicswap/network/bsx_network.py index 9887adc..317dc88 100644 --- a/basicswap/network/bsx_network.py +++ b/basicswap/network/bsx_network.py @@ -14,6 +14,9 @@ from basicswap.basicswap_util import ( MessageNetworks, MessageTypes, ) +from basicswap.chainparams import ( + Coins, +) from basicswap.db import DirectMessageRoute from basicswap.messages_npb import ( MessagePortalOffer, @@ -29,6 +32,9 @@ from basicswap.network.simplex import ( sendSimplexMsg, ) from basicswap.util import ensure +from basicswap.util.address import ( + b58decode, +) from basicswap.util.logging import LogCategories as LC from basicswap.util.smsg import smsgGetID @@ -213,9 +219,8 @@ class BSXNetwork: # TODO: Ensure smsg is enabled for the active wallet. if self._smsg_plaintext_version >= 2: - self.log.debug("TODO: disable addReceivedPubkeys") - # ro = self.callrpc("smsgoptions", ["set", "addReceivedPubkeys", False]) - # self.log.debug("smsgoptions {ro}") + ro = self.callrpc("smsgoptions", ["set", "addReceivedPubkeys", False]) + self.log.debug("smsgoptions {ro}") def add_connection(self, host, port, peer_pubkey): self.log.info(f"add_connection {host} {port} {peer_pubkey.hex()}.") @@ -226,6 +231,66 @@ class BSXNetwork: return {"Error": "Not Initialised"} return self._network.get_info() + def getPrivkeyForAddress(self, cursor, addr: str) -> bytes: + ci_part = self.ci(Coins.PART) + try: + return ci_part.decodeKey( + self.callrpc( + "smsgdumpprivkey", + [ + addr, + ], + ) + ) + except Exception as e: # noqa: F841 + pass + try: + return ci_part.decodeKey( + ci_part.rpc_wallet( + "dumpprivkey", + [ + addr, + ], + ) + ) + except Exception as e: # noqa: F841 + pass + raise ValueError("key not found") + + def getPubkeyForAddress(self, cursor, addr: str) -> bytes: + if addr == self.network_addr: + return bytes.fromhex(self.network_pubkey) + + use_cursor = self.openDB(cursor) + try: + query: str = "SELECT pk_from FROM offers WHERE addr_from = :addr_to LIMIT 1" + rows = use_cursor.execute(query, {"addr_to": addr}).fetchall() + if len(rows) > 0: + return rows[0][0] + query: str = "SELECT pk_bid_addr FROM bids WHERE bid_addr = :addr_to LIMIT 1" + rows = use_cursor.execute(query, {"addr_to": addr}).fetchall() + if len(rows) > 0: + return rows[0][0] + query: str = "SELECT pubkey FROM smsgaddresses WHERE addr = :addr LIMIT 1" + rows = use_cursor.execute(query, {"addr": addr}).fetchall() + if len(rows) > 0: + return bytes.fromhex(rows[0][0]) + finally: + if cursor is None: + self.closeDB(use_cursor, commit=False) + if self._have_smsg_rpc: + try: + rv = self.callrpc( + "smsggetpubkey", + [ + addr, + ], + ) + return b58decode(rv["publickey"]) + except Exception as e: # noqa: F841 + pass + raise ValueError(f"Could not get public key for address: {addr}") + def addMessageNetworkLink( self, linked_type, linked_id, link_type, network_id, cursor ) -> None: @@ -469,7 +534,7 @@ class BSXNetwork: self.forwardSmsg(smsg_msg) else: net_message_id, smsg_msg = self.sendSmsg( - addr_from, addr_to, payload_hex, msg_valid, return_msg=True + addr_from, addr_to, payload_hex, msg_valid, return_msg=True, cursor=cursor ) elif network_type == MessageNetworks.SIMPLEX: if smsg_msg: @@ -561,20 +626,24 @@ class BSXNetwork: payload_hex: bytes, msg_valid: int, return_msg: bool = False, + cursor=None, ) -> bytes: - options = {"decodehex": True, "ttl_is_seconds": True} if self._smsg_plaintext_version >= 2: options["plaintext_format_version"] = 2 options["compression"] = 0 + send_to = self.getPubkeyForAddress(cursor, addr_to).hex() + else: + send_to = addr_to if self._smsg_add_to_outbox is False: options["add_to_outbox"] = False if return_msg: options["returnmsg"] = True + try: ro = self.callrpc( "smsgsend", - [addr_from, addr_to, payload_hex, False, msg_valid, False, options], + [addr_from, send_to, payload_hex, False, msg_valid, False, options], ) self.num_smsg_messages_sent += 1 if return_msg: @@ -692,7 +761,7 @@ class BSXNetwork: msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, portal.time_valid) if network_from_id == MessageNetworks.SMSG: - net_message_id = self.sendSmsg(addr_portal, addr_to, payload_hex, msg_valid) + net_message_id = self.sendSmsg(addr_portal, addr_to, payload_hex, msg_valid, cursor=cursor) elif network_from_id == MessageNetworks.SIMPLEX: network = self.getActiveNetwork(MessageNetworks.SIMPLEX) @@ -741,15 +810,13 @@ class BSXNetwork: payload_hex = ( str.format("{:02x}", MessageTypes.PORTAL_OFFER) + msg_buf.to_bytes().hex() ) - - msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, portal.time_valid) - if portal.network_from == MessageNetworks.SMSG: - net_message_id = self.sendSmsg(addr_portal, addr_to, payload_hex, msg_valid) - elif portal.network_from == MessageNetworks.SIMPLEX: - network = self.getActiveNetwork(MessageNetworks.SIMPLEX) - - cursor = self.openDB() - try: + cursor = self.openDB() + try: + msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, portal.time_valid) + if portal.network_from == MessageNetworks.SMSG: + net_message_id = self.sendSmsg(addr_portal, addr_to, payload_hex, msg_valid, cursor=cursor) + elif portal.network_from == MessageNetworks.SIMPLEX: + network = self.getActiveNetwork(MessageNetworks.SIMPLEX) net_message_id = sendSimplexMsg( self, network, @@ -760,10 +827,10 @@ class BSXNetwork: cursor, portal.time_start, ) - finally: - self.closeDB(cursor) - else: - raise RuntimeError(f"Unknown network id {portal.network_from}") + else: + raise RuntimeError(f"Unknown network id {portal.network_from}") + finally: + self.closeDB(cursor) portal.time_start = now self.logD( @@ -784,7 +851,7 @@ class BSXNetwork: addr_to: str = portal.address_from if portal.network_from == MessageNetworks.SMSG: - net_message_id = self.sendSmsg(addr_from, addr_to, payload_hex, msg_valid) + net_message_id = self.sendSmsg(addr_from, addr_to, payload_hex, msg_valid, cursor=cursor) elif portal.network_from == MessageNetworks.SIMPLEX: network = self.getActiveNetwork(MessageNetworks.SIMPLEX) @@ -958,6 +1025,8 @@ class BSXNetwork: if now - self._last_checked_smsg >= self.check_smsg_seconds: self._last_checked_smsg = now options = {"encoding": "hex", "setread": True} + if self._smsg_plaintext_version >= 2: + options["pubkey_from"] = True msgs = self.callrpc("smsginbox", ["unread", "", options]) for msg in msgs["messages"]: self.processMsg(msg) diff --git a/basicswap/network/simplex.py b/basicswap/network/simplex.py index bc6b0ca..6972c15 100644 --- a/basicswap/network/simplex.py +++ b/basicswap/network/simplex.py @@ -23,7 +23,6 @@ from basicswap.chainparams import ( Coins, ) from basicswap.util.address import ( - b58decode, decodeWif, ) from basicswap.basicswap_util import AddressTypes @@ -173,65 +172,6 @@ def waitForConnected(ws_thread, delay_event): raise ValueError("waitForConnected timed-out.") -def getPrivkeyForAddress(self, cursor, addr: str) -> bytes: - ci_part = self.ci(Coins.PART) - try: - return ci_part.decodeKey( - self.callrpc( - "smsgdumpprivkey", - [ - addr, - ], - ) - ) - except Exception as e: # noqa: F841 - pass - try: - return ci_part.decodeKey( - ci_part.rpc_wallet( - "dumpprivkey", - [ - addr, - ], - ) - ) - except Exception as e: # noqa: F841 - pass - raise ValueError("key not found") - - -def getPubkeyForAddress(self, cursor, addr: str) -> bytes: - if self._have_smsg_rpc: - try: - rv = self.callrpc( - "smsggetpubkey", - [ - addr, - ], - ) - return b58decode(rv["publickey"]) - except Exception as e: # noqa: F841 - pass - use_cursor = self.openDB(cursor) - try: - query: str = "SELECT pk_from FROM offers WHERE addr_from = :addr_to LIMIT 1" - rows = use_cursor.execute(query, {"addr_to": addr}).fetchall() - if len(rows) > 0: - return rows[0][0] - query: str = "SELECT pk_bid_addr FROM bids WHERE bid_addr = :addr_to LIMIT 1" - rows = use_cursor.execute(query, {"addr_to": addr}).fetchall() - if len(rows) > 0: - return rows[0][0] - query: str = "SELECT pubkey FROM smsgaddresses WHERE addr = :addr LIMIT 1" - rows = use_cursor.execute(query, {"addr": addr}).fetchall() - if len(rows) > 0: - return bytes.fromhex(rows[0][0]) - raise ValueError(f"Could not get public key for address: {addr}") - finally: - if cursor is None: - self.closeDB(use_cursor, commit=False) - - def encryptMsg( self, addr_from: str, @@ -245,8 +185,8 @@ def encryptMsg( ) -> bytes: self.log.debug("encryptMsg") - pubkey_to = getPubkeyForAddress(self, cursor, addr_to) - privkey_from = getPrivkeyForAddress(self, cursor, addr_from) + pubkey_to = self.getPubkeyForAddress(cursor, addr_to) + privkey_from = self.getPrivkeyForAddress(cursor, addr_from) smsg_msg: bytes = smsgEncrypt( privkey_from, @@ -370,7 +310,7 @@ def decryptSimplexMsg(self, msg_data): for row in addr_rows: addr = row[0] try: - vk_addr = getPrivkeyForAddress(self, cursor, addr) + vk_addr = self.getPrivkeyForAddress(cursor, addr) decrypted = smsgDecrypt(vk_addr, msg_data, output_dict=True) decrypted["from"] = ci_part.pubkey_to_address( bytes.fromhex(decrypted["pubkey_from"])