diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9774c69..e54c881 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -38,6 +38,7 @@ jobs: sudo apt-get install -y firefox fi python -m pip install --upgrade pip + pip install python-gnupg pip install -e .[dev] pip install -r requirements.txt --require-hashes - name: Install diff --git a/basicswap/base.py b/basicswap/base.py index 8b6cabf..a834aee 100644 --- a/basicswap/base.py +++ b/basicswap/base.py @@ -71,6 +71,7 @@ class BaseApp(DBMethods): self.default_socket = socket.socket self.default_socket_timeout = socket.getdefaulttimeout() self.default_socket_getaddrinfo = socket.getaddrinfo + self._force_db_upgrade = False def __del__(self): if self.fp: diff --git a/basicswap/basicswap.py b/basicswap/basicswap.py index e565da2..12be57a 100644 --- a/basicswap/basicswap.py +++ b/basicswap/basicswap.py @@ -27,13 +27,50 @@ import zmq from typing import Optional -from .interface.base import Curves -from .interface.part import PARTInterface, PARTInterfaceAnon, PARTInterfaceBlind - from . import __version__ +from .base import BaseApp +from .basicswap_util import ( + ActionTypes, + AddressTypes, + AutomationOverrideOptions, + BidStates, + canAcceptBidState, + ConnectionRequestTypes, + DebugTypes, + describeEventEntry, + EventLogTypes, + fiatTicker, + get_api_key_setting, + getLastBidState, + getVoutByAddress, + getVoutByScriptPubKey, + inactive_states, + isActiveBidState, + KeyTypes, + MessageNetworks, + MessageTypes, + NotificationTypes as NT, + OfferStates, + strBidState, + SwapTypes, + TxLockTypes, + TxStates, + TxTypes, + VisibilityOverrideOptions, + XmrSplitMsgTypes, +) +from .chainparams import ( + Coins, + chainparams, + Fiat, + ticker_map, +) +from .db_upgrades import upgradeDatabase, upgradeDatabaseData +from .db_util import remove_expired_data from .rpc import escape_rpcauth from .rpc_xmr import make_xmr_rpc2_func from .ui.util import getCoinName +from .ui.app import UIApp from .util import ( AutomationConstraint, AutomationConstraintTemporary, @@ -56,28 +93,22 @@ from .util.address import ( decodeAddress, pubkeyToAddress, ) -from .util.crypto import ( - sha256, -) -from basicswap.util.network import is_private_ip_address -from .chainparams import ( - Coins, - chainparams, - Fiat, - ticker_map, -) +from .util.crypto import sha256 +from .util.network import is_private_ip_address +from .util.smsg import smsgGetID +from .interface.base import Curves +from .interface.part import PARTInterface, PARTInterfaceAnon, PARTInterfaceBlind from .explorers import ( default_chart_api_key, default_coingecko_api_key, ) -from .script import ( - OpCodes, -) +from .script import OpCodes from .messages_npb import ( ADSBidIntentAcceptMessage, ADSBidIntentMessage, BidAcceptMessage, BidMessage, + ConnectReqMessage, OfferMessage, OfferRevokeMessage, XmrBidAcceptMessage, @@ -95,6 +126,8 @@ from .db import ( Concepts, create_db, CURRENT_DB_VERSION, + DirectMessageRoute, + DirectMessageRouteLink, EventLog, getOrderByStr, KnownIdentity, @@ -112,17 +145,19 @@ from .db import ( XmrSplitData, XmrSwap, ) -from .db_upgrades import upgradeDatabase, upgradeDatabaseData -from .base import BaseApp from .explorers import ( ExplorerInsight, ExplorerBitAps, ExplorerChainz, ) from .network.simplex import ( + closeSimplexChat, + encryptMsg, + getJoinedSimplexLink, + getResponseData, initialiseSimplexNetwork, - sendSimplexMsg, readSimplexMsgs, + sendSimplexMsg, ) from .network.util import ( getMsgPubkey, @@ -131,37 +166,7 @@ import basicswap.config as cfg import basicswap.network.network as bsn import basicswap.protocols.atomic_swap_1 as atomic_swap_1 import basicswap.protocols.xmr_swap_1 as xmr_swap_1 -from .basicswap_util import ( - ActionTypes, - AddressTypes, - AutomationOverrideOptions, - BidStates, - DebugTypes, - EventLogTypes, - fiatTicker, - get_api_key_setting, - KeyTypes, - MessageTypes, - NotificationTypes as NT, - OfferStates, - SwapTypes, - TxLockTypes, - TxStates, - TxTypes, - VisibilityOverrideOptions, - XmrSplitMsgTypes, - canAcceptBidState, - describeEventEntry, - getLastBidState, - getVoutByAddress, - getVoutByScriptPubKey, - inactive_states, - isActiveBidState, - strBidState, -) -from basicswap.db_util import ( - remove_expired_data, -) + PROTOCOL_VERSION_SECRET_HASH = 5 MINPROTO_VERSION_SECRET_HASH = 4 @@ -277,7 +282,7 @@ class WatchedTransaction: self.swap_type = swap_type -class BasicSwap(BaseApp): +class BasicSwap(BaseApp, UIApp): ws_server = None _read_zmq_queue: bool = True protocolInterfaces = { @@ -358,9 +363,14 @@ class BasicSwap(BaseApp): self._expire_db_records_after = self.get_int_setting( "expire_db_records_after", 7 * 86400, 0, 31 * 86400 ) # Seconds + self._expire_message_routes_after = self._expire_db_records_after = ( + self.get_int_setting( + "expire_message_routes_after", 48 * 3600, 10 * 60, 31 * 86400 + ) + ) # Seconds self._max_logfile_bytes = self.settings.get( "max_logfile_size", 100 - ) # In MB 0 to disable truncation + ) # In MB. Set to 0 to disable truncation if self._max_logfile_bytes > 0: self._max_logfile_bytes *= 1024 * 1024 self._max_logfiles = self.get_int_setting("max_logfiles", 10, 1, 100) @@ -369,6 +379,13 @@ class BasicSwap(BaseApp): self._is_encrypted = None self._is_locked = None + self.num_group_simplex_messages_received = 0 + self.num_direct_simplex_messages_received = 0 + + self._max_transient_errors = self.settings.get( + "max_transient_errors", 100 + ) # Number of retries before a bid will stop when encountering transient errors. + # Keep sensitive info out of the log file (WIP) self.log.safe_logs = self.settings.get("safe_logs", False) if self.log.safe_logs and self.debug: @@ -378,7 +395,7 @@ class BasicSwap(BaseApp): self.log.warning("Safe log enabled.") if "safe_logs_prefix" in self.settings: self.log.safe_logs_prefix = self.settings["safe_logs_prefix"].encode( - encoding="utf-8" + encoding="UTF-8" ) else: self.log.warning('Using random "safe_logs_prefix".') @@ -429,14 +446,11 @@ class BasicSwap(BaseApp): "restrict_unknown_seed_wallets", True ) self._max_check_loop_blocks = self.settings.get("max_check_loop_blocks", 100000) - self._bid_expired_leeway = 5 + self._use_direct_message_routes = True self.swaps_in_progress = dict() - self.dleag_split_size_init = 16000 - self.dleag_split_size = 17000 - self.SMSG_SECONDS_IN_HOUR = ( 60 * 60 ) # Note: Set smsgsregtestadjust=0 for regtest @@ -486,6 +500,7 @@ class BasicSwap(BaseApp): self.with_coins_override = extra_opts.get("with_coins", set()) self.without_coins_override = extra_opts.get("without_coins", set()) + self._force_db_upgrade = extra_opts.get("force_db_upgrade", False) for c in Coins: if c in chainparams: @@ -992,7 +1007,7 @@ class BasicSwap(BaseApp): pass else: with open(pidfilepath, "rb") as fp: - datadir_pid = int(fp.read().decode("utf-8")) + datadir_pid = int(fp.read().decode("UTF-8")) assert datadir_pid == cc["pid"], "Mismatched pid" assert os.path.exists(authcookiepath) break @@ -1006,7 +1021,7 @@ class BasicSwap(BaseApp): ): # Litecoin on windows doesn't write a pid file assert datadir_pid == cc["pid"], "Mismatched pid" with open(authcookiepath, "rb") as fp: - cc["rpcauth"] = escape_rpcauth(fp.read().decode("utf-8")) + cc["rpcauth"] = escape_rpcauth(fp.read().decode("UTF-8")) except Exception as e: self.log.error( "Unable to read authcookie for %s, %s, datadir pid %d, daemon pid %s. Error: %s", @@ -1033,6 +1048,10 @@ class BasicSwap(BaseApp): self.log.info(f"Python version: {platform.python_version()}") self.log.info(f"SQLite version: {sqlite3.sqlite_version}") self.log.debug(f"Timezone offset: {time.timezone} ({time.tzname[0]})") + gil_status: bool = True + if sys.version_info >= (3, 13): + gil_status = sys._is_gil_enabled() + self.log.debug(f"GIL enabled: {gil_status}") MIN_SQLITE_VERSION = (3, 35, 0) # Upsert if sqlite3.sqlite_version_info < MIN_SQLITE_VERSION: @@ -1543,6 +1562,25 @@ class BasicSwap(BaseApp): if cursor is None: self.closeDB(use_cursor) + def getMessageRoute( + self, network_id: int, address_from: str, address_to: str, cursor=None + ): + try: + use_cursor = self.openDB(cursor) + route = self.queryOne( + DirectMessageRoute, + use_cursor, + { + "network_id": network_id, + "smsg_addr_local": address_from, + "smsg_addr_remote": address_to, + }, + ) + return route + finally: + if cursor is None: + self.closeDB(use_cursor) + def activateBid(self, cursor, bid) -> None: if bid.bid_id in self.swaps_in_progress: self.log.debug(f"Bid {self.log.id(bid.bid_id)} is already in progress") @@ -1759,10 +1797,87 @@ class BasicSwap(BaseApp): bid_valid = (bid.expire_at - now) + 10 * 60 # Add 10 minute buffer return max(smsg_min_valid, min(smsg_max_valid, bid_valid)) + def getActiveNetwork(self, network_id: int): + # TODO: Add more network types + for network in self.active_networks: + if network["type"] == "simplex": + return network + raise RuntimeError("Network not found.") + + def getActiveNetworkInterface(self, network_id: int): + network = self.getActiveNetwork(network_id) + return network["ws_thread"] + def sendMessage( - self, addr_from: str, addr_to: str, payload_hex: bytes, msg_valid: int, cursor + self, + addr_from: str, + addr_to: str, + payload_hex: bytes, + msg_valid: int, + cursor, + linked_type=None, + linked_id=None, + timestamp=None, + deterministic=False, ) -> bytes: message_id: bytes = None + + message_route = self.getMessageRoute(1, addr_from, addr_to, cursor=cursor) + if message_route: + raise RuntimeError("Trying to send through an unestablished direct route.") + + message_route = self.getMessageRoute(2, addr_from, addr_to, cursor=cursor) + if message_route: + network = self.getActiveNetwork(2) + net_i = network["ws_thread"] + + remote_name = None + route_data = json.loads(message_route.route_data.decode("UTF-8")) + if "localDisplayName" in route_data: + remote_name = route_data["localDisplayName"] + else: + pccConnId = route_data["pccConnId"] + self.log.debug(f"Finding name for Simplex chat, ID: {pccConnId}") + cmd_id = net_i.send_command("/chats") + response = net_i.wait_for_command_response(cmd_id) + for chat in getResponseData(response, "chats"): + if ( + "chatInfo" not in chat + or "type" not in chat["chatInfo"] + or chat["chatInfo"]["type"] != "direct" + ): + continue + try: + if ( + chat["chatInfo"]["contact"]["activeConn"]["connId"] + == pccConnId + ): + remote_name = chat["chatInfo"]["contact"][ + "localDisplayName" + ] + break + except Exception as e: + self.log.debug(f"Error parsing chat: {e}") + + if remote_name is None: + raise RuntimeError( + f"Unable to find remote name for simplex direct chat, pccConnId: {pccConnId}" + ) + + message_id = sendSimplexMsg( + self, + network, + addr_from, + addr_to, + bytes.fromhex(payload_hex), + msg_valid, + cursor, + timestamp, + deterministic, + to_user_name=remote_name, + ) + return message_id + # First network in list will set message_id for network in self.active_networks: net_message_id = None @@ -1779,6 +1894,8 @@ class BasicSwap(BaseApp): bytes.fromhex(payload_hex), msg_valid, cursor, + timestamp, + deterministic, ) else: raise ValueError("Unknown network: {}".format(network["type"])) @@ -2000,7 +2117,7 @@ class BasicSwap(BaseApp): ) query_data: dict = {} - address = filters.get("address", None) + address: str = filters.get("address", None) if address is not None: query_str += " AND address = :address " query_data["address"] = address @@ -2215,7 +2332,7 @@ class BasicSwap(BaseApp): try: cursor = self.openDB() self.checkCoinsReady(coin_from_t, coin_to_t) - offer_addr = self.prepareSMSGAddress( + offer_addr: str = self.prepareSMSGAddress( addr_send_from, AddressTypes.OFFER, cursor ) @@ -2648,7 +2765,7 @@ class BasicSwap(BaseApp): self.callcoinrpc(Coins.PART, "extkey", ["info", evkey, path])[ "key_info" ]["result"], - "utf-8", + "UTF-8", ) ) @@ -2830,7 +2947,7 @@ class BasicSwap(BaseApp): _, is_locked = self.getLockedState() if is_locked is False: self.log.warning( - f"Setting seed id for coin {ci.coin_name()} from master key." + f"Setting seed ID for coin {ci.coin_name()} from master key." ) root_key = self.getWalletKey(c, 1) self.storeSeedIDForCoin(root_key, c) @@ -3254,68 +3371,93 @@ class BasicSwap(BaseApp): cursor = self.openDB() self.checkCoinsReady(coin_from, coin_to) - msg_buf = BidMessage() - msg_buf.protocol_version = PROTOCOL_VERSION_SECRET_HASH - msg_buf.offer_msg_id = offer_id - msg_buf.time_valid = valid_for_seconds - msg_buf.amount = amount # amount of coin_from - msg_buf.amount_to = amount_to - now: int = self.getTime() + encoded_proof_utxos = None if offer.swap_type == SwapTypes.SELLER_FIRST: proof_addr, proof_sig, proof_utxos = self.getProofOfFunds( coin_to, amount_to, offer_id ) - msg_buf.proof_address = proof_addr - msg_buf.proof_signature = proof_sig - if len(proof_utxos) > 0: - msg_buf.proof_utxos = ci_to.encodeProofUtxos(proof_utxos) - - contract_count = self.getNewContractId(cursor) - contract_pubkey = self.getContractPubkey( - dt.datetime.fromtimestamp(now).date(), contract_count - ) - msg_buf.pkhash_buyer = ci_from.pkh(contract_pubkey) - pkhash_buyer_to = ci_to.pkh(contract_pubkey) - if pkhash_buyer_to != msg_buf.pkhash_buyer: - # Different pubkey hash - msg_buf.pkhash_buyer_to = pkhash_buyer_to + encoded_proof_utxos = ci_to.encodeProofUtxos(proof_utxos) else: raise ValueError("TODO") - bid_bytes = msg_buf.to_bytes() - payload_hex = str.format("{:02x}", MessageTypes.BID) + bid_bytes.hex() + bid_addr: str = self.prepareSMSGAddress( + addr_send_from, AddressTypes.BID, cursor + ) + request_data = { + "offer_id": offer_id.hex(), + "amount_from": amount, + "amount_to": amount_to, + } + route_id, route_established = self.prepareMessageRoute( + MessageNetworks.SIMPLEX, + request_data, + bid_addr, + offer.addr_from, + cursor, + valid_for_seconds, + ) - bid_addr = self.prepareSMSGAddress(addr_send_from, AddressTypes.BID, cursor) - msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, valid_for_seconds) - bid_id = self.sendMessage( - bid_addr, offer.addr_from, payload_hex, msg_valid, cursor + contract_count = self.getNewContractId(cursor) + contract_pubkey = self.getContractPubkey( + dt.datetime.fromtimestamp(now).date(), contract_count ) bid = Bid( - protocol_version=msg_buf.protocol_version, + protocol_version=PROTOCOL_VERSION_SECRET_HASH, active_ind=1, - bid_id=bid_id, offer_id=offer_id, - amount=msg_buf.amount, - amount_to=msg_buf.amount_to, + amount=amount, # amount of coin_from + amount_to=amount_to, rate=bid_rate, - pkhash_buyer=msg_buf.pkhash_buyer, - proof_address=msg_buf.proof_address, - proof_utxos=msg_buf.proof_utxos, + pkhash_buyer=ci_from.pkh(contract_pubkey), + proof_address=proof_addr, + proof_signature=proof_sig, + proof_utxos=encoded_proof_utxos, created_at=now, contract_count=contract_count, - expire_at=now + msg_buf.time_valid, + expire_at=now + valid_for_seconds, bid_addr=bid_addr, was_sent=True, chain_a_height_start=ci_from.getChainHeight(), chain_b_height_start=ci_to.getChainHeight(), ) - bid.setState(BidStates.BID_SENT) - if len(msg_buf.pkhash_buyer_to) > 0: - bid.pkhash_buyer_to = msg_buf.pkhash_buyer_to + pkhash_buyer_to = ci_to.pkh(contract_pubkey) + if pkhash_buyer_to != bid.pkhash_buyer: + # Different pubkey hash + bid.pkhash_buyer_to = pkhash_buyer_to + + if route_id and route_established is False: + msg_buf = self.getBidMessage(bid, offer) + msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, valid_for_seconds) + encrypted_msg = encryptMsg( + self, + bid.bid_addr, + offer.addr_from, + bytes((MessageTypes.BID,)) + msg_buf.to_bytes(), + msg_valid, + cursor, + timestamp=bid.created_at, + deterministic=True, + ) + bid_id = smsgGetID(encrypted_msg) + bid.setState(BidStates.CONNECT_REQ_SENT) + else: + bid_id = self.sendBidMessage(bid, offer, cursor) + bid.setState(BidStates.BID_SENT) + if route_id: + message_route_link = DirectMessageRouteLink( + active_ind=2 if route_established else 1, + direct_message_route_id=route_id, + linked_type=Concepts.BID, + linked_id=bid_id, + created_at=bid.created_at, + ) + self.add(message_route_link, cursor) + + bid.bid_id = bid_id self.saveBidInSession(bid_id, bid, cursor) @@ -3628,7 +3770,7 @@ class BasicSwap(BaseApp): msg_valid: int = self.getAcceptBidMsgValidTime(bid) accept_msg_id = self.sendMessage( - offer.addr_from, bid.bid_addr, payload_hex, msg_valid, cursor + offer.addr_from, bid.bid_addr, payload_hex, msg_valid, use_cursor ) self.addMessageLink( @@ -3654,20 +3796,21 @@ class BasicSwap(BaseApp): msg_type, addr_from: str, addr_to: str, - bid_id: bytes, + xmr_swap, dleag: bytes, msg_valid: int, bid_msg_ids, cursor, ) -> None: - sent_bytes = self.dleag_split_size_init + dleag_split_size_init, dleag_split_size = xmr_swap.getMsgSplitInfo() + sent_bytes = dleag_split_size_init num_sent = 1 while sent_bytes < len(dleag): - size_to_send: int = min(self.dleag_split_size, len(dleag) - sent_bytes) + size_to_send: int = min(dleag_split_size, len(dleag) - sent_bytes) msg_buf = XmrSplitMessage( - msg_id=bid_id, + msg_id=xmr_swap.bid_id, msg_type=msg_type, sequence=num_sent, dleag=dleag[sent_bytes : sent_bytes + size_to_send], @@ -3682,6 +3825,223 @@ class BasicSwap(BaseApp): num_sent += 1 sent_bytes += size_to_send + def getADSBidIntentMessage(self, bid, offer) -> bytes: + valid_for_seconds: int = bid.expire_at - bid.created_at + msg_buf = ADSBidIntentMessage() + msg_buf.protocol_version = bid.protocol_version + msg_buf.offer_msg_id = bid.offer_id + msg_buf.time_valid = valid_for_seconds + msg_buf.amount_from = bid.amount_to + msg_buf.amount_to = bid.amount + + return msg_buf + + def sendADSBidIntentMessage(self, bid, offer, cursor) -> bytes: + valid_for_seconds: int = bid.expire_at - bid.created_at + msg_buf = self.getADSBidIntentMessage(bid, offer) + msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, valid_for_seconds) + payload_hex = ( + str.format("{:02x}", MessageTypes.ADS_BID_LF) + msg_buf.to_bytes().hex() + ) + return self.sendMessage( + bid.bid_addr, + offer.addr_from, + payload_hex, + msg_valid, + cursor, + timestamp=bid.created_at, + deterministic=(False if bid.bid_id is None else True), + ) + + def getXmrBidMessage(self, bid, xmr_swap, offer) -> XmrBidMessage: + valid_for_seconds: int = bid.expire_at - bid.created_at + msg_buf = XmrBidMessage() + msg_buf.protocol_version = PROTOCOL_VERSION_ADAPTOR_SIG + msg_buf.offer_msg_id = bid.offer_id + msg_buf.time_valid = valid_for_seconds + msg_buf.amount = bid.amount + msg_buf.amount_to = bid.amount_to + + msg_buf.dest_af = xmr_swap.dest_af + msg_buf.pkaf = xmr_swap.pkaf + msg_buf.kbvf = xmr_swap.vkbvf + + dleag_split_size_init, _ = xmr_swap.getMsgSplitInfo() + if len(xmr_swap.kbsf_dleag) > dleag_split_size_init: + msg_buf.kbsf_dleag = xmr_swap.kbsf_dleag[:dleag_split_size_init] + else: + msg_buf.kbsf_dleag = xmr_swap.kbsf_dleag + + return msg_buf + + def sendXmrBidMessage(self, bid, xmr_swap, offer, cursor) -> bytes: + valid_for_seconds: int = bid.expire_at - bid.created_at + + ci_to = self.ci(offer.coin_to) + + msg_buf = self.getXmrBidMessage(bid, xmr_swap, offer) + + payload_hex = ( + str.format("{:02x}", MessageTypes.XMR_BID_FL) + msg_buf.to_bytes().hex() + ) + msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, valid_for_seconds) + + bid_msg_id = self.sendMessage( + bid.bid_addr, + offer.addr_from, + payload_hex, + msg_valid, + cursor, + timestamp=bid.created_at, + deterministic=(False if bid.bid_id is None else True), + ) + bid_id = bid_msg_id + if bid.bid_id and bid_msg_id != bid.bid_id: + self.log.warning( + f"sendXmrBidMessage: Mismatched bid ids: {bid.bid_id.hex()}, {bid_msg_id.hex()}." + ) + + bid_msg_ids = {} + if xmr_swap.bid_id is None: + xmr_swap.bid_id = bid_id + if ci_to.curve_type() == Curves.ed25519: + self.sendXmrSplitMessages( + XmrSplitMsgTypes.BID, + bid.bid_addr, + offer.addr_from, + xmr_swap, + xmr_swap.kbsf_dleag, + msg_valid, + bid_msg_ids, + cursor, + ) + for k, msg_id in bid_msg_ids.items(): + self.addMessageLink( + Concepts.BID, + bid_id, + MessageTypes.BID, + msg_id, + msg_sequence=k, + cursor=cursor, + ) + + return bid_msg_id + + def getBidMessage(self, bid, offer) -> BidMessage: + valid_for_seconds: int = bid.expire_at - bid.created_at + msg_buf = BidMessage() + msg_buf.protocol_version = bid.protocol_version + msg_buf.offer_msg_id = bid.offer_id + msg_buf.time_valid = valid_for_seconds + msg_buf.amount = bid.amount + msg_buf.amount_to = bid.amount_to + + msg_buf.pkhash_buyer = bid.pkhash_buyer + if bid.pkhash_buyer_to: + msg_buf.pkhash_buyer_to = bid.pkhash_buyer_to + + msg_buf.proof_address = bid.proof_address + msg_buf.proof_signature = bid.proof_signature + + if bid.proof_utxos: + msg_buf.proof_utxos = bid.proof_utxos + + return msg_buf + + def sendBidMessage(self, bid, offer, cursor) -> bytes: + valid_for_seconds: int = bid.expire_at - bid.created_at + + msg_buf = self.getBidMessage(bid, offer) + + payload_hex = str.format("{:02x}", MessageTypes.BID) + msg_buf.to_bytes().hex() + msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, valid_for_seconds) + + bid_msg_id = self.sendMessage( + bid.bid_addr, + offer.addr_from, + payload_hex, + msg_valid, + cursor, + timestamp=bid.created_at, + deterministic=(False if bid.bid_id is None else True), + ) + if bid.bid_id and bid_msg_id != bid.bid_id: + self.log.warning( + f"sendBidMessage: Mismatched bid ids: {bid.bid_id.hex()}, {bid_msg_id.hex()}." + ) + return bid_msg_id + + def prepareMessageRoute( + self, + network_id, + req_data, + addr_from: str, + addr_to: str, + cursor, + valid_for_seconds, + ) -> (int, bool): + if self._use_direct_message_routes is False: + return None, False + + try: + net_i = self.getActiveNetworkInterface(2) + except Exception as e: # noqa: F841 + return None, False + + # Look for active route + message_route = self.getMessageRoute(1, addr_from, addr_to, cursor=cursor) + self.log.debug(f"Using active message route: {message_route}") + if message_route: + return message_route.record_id, True + + # Look for route being established + message_route = self.getMessageRoute(2, addr_from, addr_to, cursor=cursor) + self.log.debug(f"Waiting for message route: {message_route}") + if message_route: + return message_route.record_id, False + + cmd_id = net_i.send_command("/connect") + response = net_i.wait_for_command_response(cmd_id) + connReqInvitation = getJoinedSimplexLink(response) + pccConnId = getResponseData(response, "connection")["pccConnId"] + req_data["bsx_address"] = addr_from + req_data["connection_req"] = connReqInvitation + + msg_buf = ConnectReqMessage() + msg_buf.network_type = MessageNetworks.SIMPLEX + msg_buf.network_data = b"bsx" + msg_buf.request_type = ConnectionRequestTypes.BID + msg_buf.request_data = json.dumps(req_data).encode("UTF-8") + + bid_bytes = msg_buf.to_bytes() + payload_hex = str.format("{:02x}", MessageTypes.CONNECT_REQ) + bid_bytes.hex() + + msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, valid_for_seconds) + connect_req_msgid = self.sendMessage( + addr_from, addr_to, payload_hex, msg_valid, cursor + ) + + now: int = self.getTime() + message_route = DirectMessageRoute( + active_ind=2, + network_id=2, + linked_type=Concepts.OFFER, + smsg_addr_local=addr_from, + smsg_addr_remote=addr_to, + route_data=json.dumps( + { + "connection_req": connReqInvitation, + "connect_req_msgid": connect_req_msgid.hex(), + "pccConnId": pccConnId, + } + ).encode("UTF-8"), + created_at=now, + ) + message_route_id = self.add(message_route, cursor) + + self.log.info(f"Sent CONNECT_REQ {self.logIDB(connect_req_msgid)}") + return message_route_id, False + def postXmrBid( self, offer_id: bytes, amount: int, addr_send_from: str = None, extra_options={} ) -> bytes: @@ -3732,51 +4092,77 @@ class BasicSwap(BaseApp): ci_to, offer.swap_type, int(amount_to), estimated_fee, for_offer=False ) + bid_addr: str = self.prepareSMSGAddress( + addr_send_from, AddressTypes.BID, cursor + ) + + # return id of route waiting to be established + request_data = { + "offer_id": offer_id.hex(), + "amount_from": amount, + "amount_to": amount_to, + } + route_id, route_established = self.prepareMessageRoute( + MessageNetworks.SIMPLEX, + request_data, + bid_addr, + offer.addr_from, + cursor, + valid_for_seconds, + ) + reverse_bid: bool = self.is_reverse_ads_bid(coin_from, coin_to) if reverse_bid: reversed_rate: int = ci_to.make_int(amount / amount_to, r=1) - msg_buf = ADSBidIntentMessage() - msg_buf.protocol_version = PROTOCOL_VERSION_ADAPTOR_SIG - msg_buf.offer_msg_id = offer_id - msg_buf.time_valid = valid_for_seconds - msg_buf.amount_from = amount - msg_buf.amount_to = amount_to - - bid_bytes = msg_buf.to_bytes() - payload_hex = ( - str.format("{:02x}", MessageTypes.ADS_BID_LF) + bid_bytes.hex() - ) - xmr_swap = XmrSwap() xmr_swap.contract_count = self.getNewContractId(cursor) - - bid_addr = self.prepareSMSGAddress( - addr_send_from, AddressTypes.BID, cursor - ) - - msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, valid_for_seconds) - xmr_swap.bid_id = self.sendMessage( - bid_addr, offer.addr_from, payload_hex, msg_valid, cursor - ) + self.setMsgSplitInfo(xmr_swap) bid = Bid( - protocol_version=msg_buf.protocol_version, + protocol_version=PROTOCOL_VERSION_ADAPTOR_SIG, active_ind=1, - bid_id=xmr_swap.bid_id, offer_id=offer_id, - amount=msg_buf.amount_to, - amount_to=msg_buf.amount_from, + amount=amount_to, + amount_to=amount, rate=reversed_rate, created_at=bid_created_at, contract_count=xmr_swap.contract_count, - expire_at=bid_created_at + msg_buf.time_valid, + expire_at=bid_created_at + valid_for_seconds, bid_addr=bid_addr, was_sent=True, was_received=False, ) - bid.setState(BidStates.BID_REQUEST_SENT) + if route_id and route_established is False: + msg_buf = self.getADSBidIntentMessage(bid, offer) + msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, valid_for_seconds) + encrypted_msg = encryptMsg( + self, + bid.bid_addr, + offer.addr_from, + bytes((MessageTypes.ADS_BID_LF,)) + msg_buf.to_bytes(), + msg_valid, + cursor, + timestamp=bid.created_at, + deterministic=True, + ) + bid_id = smsgGetID(encrypted_msg) + bid.setState(BidStates.CONNECT_REQ_SENT) + else: + bid_id = self.sendADSBidIntentMessage(bid, offer, cursor) + bid.setState(BidStates.BID_REQUEST_SENT) + if route_id: + message_route_link = DirectMessageRouteLink( + active_ind=2 if route_established else 1, + direct_message_route_id=route_id, + linked_type=Concepts.BID, + linked_id=bid_id, + created_at=bid_created_at, + ) + self.add(message_route_link, cursor) + bid.bid_id = bid_id + xmr_swap.bid_id = bid.bid_id self.saveBidInSession(xmr_swap.bid_id, bid, cursor, xmr_swap) self.commitDB() @@ -3784,25 +4170,18 @@ class BasicSwap(BaseApp): self.log.info(f"Sent ADS_BID_LF {self.logIDB(xmr_swap.bid_id)}") return xmr_swap.bid_id - msg_buf = XmrBidMessage() - msg_buf.protocol_version = PROTOCOL_VERSION_ADAPTOR_SIG - msg_buf.offer_msg_id = offer_id - msg_buf.time_valid = valid_for_seconds - msg_buf.amount = int(amount) # Amount of coin_from - msg_buf.amount_to = amount_to + xmr_swap = XmrSwap() + xmr_swap.contract_count = self.getNewContractId(cursor) + self.setMsgSplitInfo(xmr_swap) address_out = self.getReceiveAddressFromPool( coin_from, offer_id, TxTypes.XMR_SWAP_A_LOCK, cursor=cursor ) if coin_from in (Coins.PART_BLIND,): addrinfo = ci_from.rpc("getaddressinfo", [address_out]) - msg_buf.dest_af = bytes.fromhex(addrinfo["pubkey"]) + xmr_swap.dest_af = bytes.fromhex(addrinfo["pubkey"]) else: - msg_buf.dest_af = ci_from.decodeAddress(address_out) - - xmr_swap = XmrSwap() - xmr_swap.contract_count = self.getNewContractId(cursor) - xmr_swap.dest_af = msg_buf.dest_af + xmr_swap.dest_af = ci_from.decodeAddress(address_out) for_ed25519: bool = True if ci_to.curve_type() == Curves.ed25519 else False kbvf = self.getPathKey( @@ -3839,7 +4218,6 @@ class BasicSwap(BaseApp): if ci_to.curve_type() == Curves.ed25519: xmr_swap.kbsf_dleag = ci_to.proveDLEAG(kbsf) xmr_swap.pkasf = xmr_swap.kbsf_dleag[0:33] - msg_buf.kbsf_dleag = xmr_swap.kbsf_dleag[: self.dleag_split_size_init] elif ci_to.curve_type() == Curves.secp256k1: for i in range(10): xmr_swap.kbsf_dleag = ci_to.signRecoverable( @@ -3853,54 +4231,54 @@ class BasicSwap(BaseApp): self.log.debug("kbsl recovered pubkey mismatch, retrying.") assert pk_recovered == xmr_swap.pkbsf xmr_swap.pkasf = xmr_swap.pkbsf - msg_buf.kbsf_dleag = xmr_swap.kbsf_dleag else: raise ValueError("Unknown curve") assert xmr_swap.pkasf == ci_from.getPubkey(kbsf) - msg_buf.pkaf = xmr_swap.pkaf - msg_buf.kbvf = kbvf - - bid_bytes = msg_buf.to_bytes() - payload_hex = ( - str.format("{:02x}", MessageTypes.XMR_BID_FL) + bid_bytes.hex() - ) - - bid_addr = self.prepareSMSGAddress(addr_send_from, AddressTypes.BID, cursor) - - msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, valid_for_seconds) - xmr_swap.bid_id = self.sendMessage( - bid_addr, offer.addr_from, payload_hex, msg_valid, cursor - ) - - bid_msg_ids = {} - if ci_to.curve_type() == Curves.ed25519: - self.sendXmrSplitMessages( - XmrSplitMsgTypes.BID, - bid_addr, - offer.addr_from, - xmr_swap.bid_id, - xmr_swap.kbsf_dleag, - msg_valid, - bid_msg_ids, - cursor, - ) - bid = Bid( - protocol_version=msg_buf.protocol_version, + protocol_version=PROTOCOL_VERSION_ADAPTOR_SIG, active_ind=1, - bid_id=xmr_swap.bid_id, offer_id=offer_id, - amount=msg_buf.amount, - amount_to=msg_buf.amount_to, + amount=amount, + amount_to=amount_to, rate=bid_rate, created_at=bid_created_at, contract_count=xmr_swap.contract_count, - expire_at=bid_created_at + msg_buf.time_valid, + expire_at=bid_created_at + valid_for_seconds, bid_addr=bid_addr, was_sent=True, ) + if route_id and route_established is False: + msg_buf = self.getXmrBidMessage(bid, xmr_swap, offer) + msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, valid_for_seconds) + encrypted_msg = encryptMsg( + self, + bid.bid_addr, + offer.addr_from, + bytes((MessageTypes.XMR_BID_FL,)) + msg_buf.to_bytes(), + msg_valid, + cursor, + timestamp=bid.created_at, + deterministic=True, + ) + bid_id = smsgGetID(encrypted_msg) + bid.setState(BidStates.CONNECT_REQ_SENT) + else: + bid_id = self.sendXmrBidMessage(bid, xmr_swap, offer, cursor) + bid.setState(BidStates.BID_SENT) + if route_id: + message_route_link = DirectMessageRouteLink( + active_ind=2 if route_established else 1, + direct_message_route_id=route_id, + linked_type=Concepts.BID, + linked_id=bid_id, + created_at=bid_created_at, + ) + self.add(message_route_link, cursor) + bid.bid_id = bid_id + xmr_swap.bid_id = bid.bid_id + bid.chain_a_height_start = ci_from.getChainHeight() bid.chain_b_height_start = ci_to.getChainHeight() @@ -3911,19 +4289,7 @@ class BasicSwap(BaseApp): f"Adaptor-sig swap restore height clamped to {wallet_restore_height}" ) - bid.setState(BidStates.BID_SENT) - - self.saveBidInSession(xmr_swap.bid_id, bid, cursor, xmr_swap) - for k, msg_id in bid_msg_ids.items(): - self.addMessageLink( - Concepts.BID, - xmr_swap.bid_id, - MessageTypes.BID, - msg_id, - msg_sequence=k, - cursor=cursor, - ) - + self.saveBidInSession(bid.bid_id, bid, cursor, xmr_swap) self.log.info(f"Sent XMR_BID_FL {self.logIDB(xmr_swap.bid_id)}") return xmr_swap.bid_id finally: @@ -4160,9 +4526,10 @@ class BasicSwap(BaseApp): msg_buf.pkal = xmr_swap.pkal msg_buf.kbvl = kbvl + dleag_split_size_init, _ = xmr_swap.getMsgSplitInfo() if ci_to.curve_type() == Curves.ed25519: xmr_swap.kbsl_dleag = ci_to.proveDLEAG(kbsl) - msg_buf.kbsl_dleag = xmr_swap.kbsl_dleag[: self.dleag_split_size_init] + msg_buf.kbsl_dleag = xmr_swap.kbsl_dleag[:dleag_split_size_init] elif ci_to.curve_type() == Curves.secp256k1: for i in range(10): xmr_swap.kbsl_dleag = ci_to.signRecoverable( @@ -4206,7 +4573,7 @@ class BasicSwap(BaseApp): XmrSplitMsgTypes.BID_ACCEPT, addr_from, addr_to, - xmr_swap.bid_id, + xmr_swap, xmr_swap.kbsl_dleag, msg_valid, bid_msg_ids, @@ -4325,6 +4692,7 @@ class BasicSwap(BaseApp): xmr_swap_1.setDLEAG(xmr_swap, ci_to, kbsf) assert xmr_swap.pkasf == ci_from.getPubkey(kbsf) + dleag_split_size_init, _ = xmr_swap.getMsgSplitInfo() msg_buf = ADSBidIntentAcceptMessage() msg_buf.bid_msg_id = bid_id msg_buf.dest_af = xmr_swap.dest_af @@ -4332,8 +4700,8 @@ class BasicSwap(BaseApp): msg_buf.kbvf = kbvf msg_buf.kbsf_dleag = ( xmr_swap.kbsf_dleag - if len(xmr_swap.kbsf_dleag) < self.dleag_split_size_init - else xmr_swap.kbsf_dleag[: self.dleag_split_size_init] + if len(xmr_swap.kbsf_dleag) < dleag_split_size_init + else xmr_swap.kbsf_dleag[:dleag_split_size_init] ) bid_bytes = msg_buf.to_bytes() @@ -4354,7 +4722,7 @@ class BasicSwap(BaseApp): XmrSplitMsgTypes.BID, addr_from, addr_to, - xmr_swap.bid_id, + xmr_swap, xmr_swap.kbsf_dleag, msg_valid, bid_msg_ids, @@ -6776,9 +7144,80 @@ class BasicSwap(BaseApp): last_height_checked += 1 self.updateCheckedBlock(ci, c, block) + def expireMessageRoutes(self) -> None: + if self._is_locked is True: + self.log.debug("Not expiring message routes while system is locked") + return + + num_removed: int = 0 + now: int = self.getTime() + cursor = self.openDB() + try: + query_str = ( + "SELECT record_id, network_id, created_at, active_ind, route_data FROM direct_message_routes " + + "WHERE 1 = 1 " + ) + rows = cursor.execute(query_str).fetchall() + for row in rows: + record_id, network_id, created_at, active_ind, route_data = row + + route_data = json.loads(route_data.decode("UTF-8")) + + if now - created_at < self._expire_message_routes_after: + continue + + # unestablished routes + if active_ind == 2: + pass + else: + query_str = ( + "SELECT MAX(created_at) FROM direct_message_route_links " + + "WHERE direct_message_route_id = :message_route_id " + ) + max_link_created_at = cursor.execute( + query_str, {"message_route_id": record_id} + ).fetchone()[0] + + if now - max_link_created_at < self._expire_message_routes_after: + continue + + query_str = ( + "SELECT COUNT(*) FROM direct_message_route_links rl " + + "INNER JOIN bids b ON b.bid_id = rl.linked_id " + + "INNER JOIN bidstates s ON s.state_id = b.state " + + "WHERE rl.direct_message_route_id = :message_route_id AND rl.linked_type = :link_type_bid " + + "AND (b.in_progress OR s.in_progress OR (s.swap_ended = 0 AND b.expire_at > :now))" + ) + num_active_bids = cursor.execute( + query_str, + { + "message_route_id": record_id, + "link_type_bid": Concepts.BID, + "now": now, + }, + ).fetchone()[0] + if num_active_bids > 0: + self.log.warning( + f"Not expiring message route {record_id} with {num_active_bids} active bids." + ) + continue + + self.closeMessageRoute(record_id, network_id, route_data, cursor) + num_removed += 1 + finally: + self.closeDB(cursor) + + if num_removed > 0: + self.log.info( + "Expired {} message route{}.".format( + num_removed, + "s" if num_removed != 1 else "", + ) + ) + def expireMessages(self) -> None: if self._is_locked is True: - self.log.debug("Not expiring messages while system locked") + self.log.debug("Not expiring messages while system is locked") return self.mxDB.acquire() @@ -7288,7 +7727,7 @@ class BasicSwap(BaseApp): ) return - signature_enc = base64.b64encode(msg_data.signature).decode("utf-8") + signature_enc = base64.b64encode(msg_data.signature).decode("UTF-8") passed = self.callcoinrpc( Coins.PART, @@ -7392,7 +7831,7 @@ class BasicSwap(BaseApp): use_cursor, {"active_ind": 1, "record_id": link.strategy_id}, ) - opts = json.loads(strategy.data.decode("utf-8")) + opts = json.loads(strategy.data.decode("UTF-8")) bid_amount: int = bid.amount bid_rate: int = bid.rate @@ -7504,6 +7943,35 @@ class BasicSwap(BaseApp): if cursor is None: self.closeDB(use_cursor) + def addRecvBidNetworkLink(self, msg, bid_id): + if "chat_type" not in msg or msg["chat_type"] != "direct": + return + conn_id = msg["conn_id"] + query_str = ( + "SELECT record_id, network_id, route_data FROM direct_message_routes" + ) + try: + cursor = self.openDB() + + rows = cursor.execute(query_str).fetchall() + + for row in rows: + record_id, network_id, route_data = row + route_data = json.loads(route_data.decode("UTF-8")) + + if conn_id == route_data["pccConnId"]: + message_route_link = DirectMessageRouteLink( + active_ind=2, + direct_message_route_id=record_id, + linked_type=Concepts.BID, + linked_id=bid_id, + created_at=self.getTime(), + ) + self.add(message_route_link, cursor) + break + finally: + self.closeDB(cursor) + def processBid(self, msg) -> None: self.log.debug("Processing bid msg {}.".format(self.log.id(msg["msgid"]))) now: int = self.getTime() @@ -7598,6 +8066,7 @@ class BasicSwap(BaseApp): bid.proof_address = bid_data.proof_address bid.setState(BidStates.BID_RECEIVED) + self.addRecvBidNetworkLink(msg, bid_id) self.saveBid(bid_id, bid) self.notify( @@ -8009,6 +8478,7 @@ class BasicSwap(BaseApp): pkbvf=ci_to.getPubkey(bid_data.kbvf), kbsf_dleag=bid_data.kbsf_dleag, ) + self.setMsgSplitInfo(xmr_swap) wallet_restore_height = self.getWalletRestoreHeight(ci_to) if bid.chain_b_height_start < wallet_restore_height: bid.chain_b_height_start = wallet_restore_height @@ -8025,6 +8495,7 @@ class BasicSwap(BaseApp): bid.was_received = True bid.setState(BidStates.BID_RECEIVING) + self.addRecvBidNetworkLink(msg, bid_id) self.log.info( f"Receiving adaptor-sig bid {self.log.id(bid_id)} for offer {self.log.id(bid_data.offer_msg_id)}." @@ -8210,7 +8681,7 @@ class BasicSwap(BaseApp): ) # TODO: Split BID_ACCEPTED into received and sent ensure( bid.state in allowed_states, - "Invalid state for bid {}".format(bid.state), + f"Invalid state for bid {bid.state}", ) bid.setState(BidStates.BID_RECEIVING_ACC) self.saveBid(bid.bid_id, bid, xmr_swap=xmr_swap) @@ -8618,10 +9089,10 @@ class BasicSwap(BaseApp): bid, EventLogTypes.FAILED_TX_B_LOCK_PUBLISH, cursor ) if num_retries > 0: - error_msg += ", retry no. {}".format(num_retries) + error_msg += f", retry no. {num_retries} / {self._max_transient_errors}" self.log.error(error_msg) - if num_retries < 5 and ( + if num_retries < self._max_transient_errors and ( ci_to.is_transient_error(ex) or self.is_transient_error(ex) ): delay = self.get_delay_retry_seconds() @@ -8935,10 +9406,10 @@ class BasicSwap(BaseApp): bid, EventLogTypes.FAILED_TX_B_SPEND, cursor ) if num_retries > 0: - error_msg += ", retry no. {}".format(num_retries) + error_msg += f", retry no. {num_retries} / {self._max_transient_errors}" self.log.error(error_msg) - if num_retries < 100 and ( + if num_retries < self._max_transient_errors and ( ci_to.is_transient_error(ex) or self.is_transient_error(ex) ): delay = self.get_delay_retry_seconds() @@ -9049,11 +9520,11 @@ class BasicSwap(BaseApp): bid, EventLogTypes.FAILED_TX_B_REFUND, cursor ) if num_retries > 0: - error_msg += ", retry no. {}".format(num_retries) + error_msg += f", retry no. {num_retries} / {self._max_transient_errors}" self.log.error(error_msg) str_error = str(ex) - if num_retries < 100 and ( + if num_retries < self._max_transient_errors and ( ci_to.is_transient_error(ex) or self.is_transient_error(ex) ): delay = self.get_delay_retry_seconds() @@ -9522,6 +9993,7 @@ class BasicSwap(BaseApp): xmr_swap = XmrSwap( bid_id=bid_id, ) + self.setMsgSplitInfo(xmr_swap) wallet_restore_height = self.getWalletRestoreHeight(ci_to) if bid.chain_b_height_start < wallet_restore_height: bid.chain_b_height_start = wallet_restore_height @@ -9538,6 +10010,7 @@ class BasicSwap(BaseApp): bid.was_received = True bid.setState(BidStates.BID_RECEIVED) # BID_REQUEST_RECEIVED + self.addRecvBidNetworkLink(msg, bid_id) self.log.info( f"Received reverse adaptor-sig bid {self.log.id(bid_id)} for offer {self.log.id(bid_data.offer_msg_id)}." @@ -9648,6 +10121,203 @@ class BasicSwap(BaseApp): finally: self.closeDB(cursor) + def processConnectRequest(self, msg) -> None: + self.log.debug( + "Processing connection request msg {}.".format(self.log.id(msg["msgid"])) + ) + msg_bytes = bytes.fromhex(msg["hex"][2:-2]) + msg_data = ConnectReqMessage(init_all=False) + msg_data.from_bytes(msg_bytes) + + req_data = json.loads(msg_data.request_data) + + offer_id = bytes.fromhex(req_data["offer_id"]) + bidder_addr = req_data["bsx_address"] + + net_i = self.getActiveNetworkInterface(2) + try: + cursor = self.openDB() + offer = self.getOffer(offer_id, cursor) + ensure(offer, f"Offer not found: {self.log.id(offer_id)}.") + ensure(offer.expire_at > self.getTime(), "Offer has expired") + ensure(msg["from"] == bidder_addr, "Mismatched from address") + ensure(msg["to"] == offer.addr_from, "Mismatched to address") + + self.log.debug( + f"Opening direct message route from {offer.addr_from} to {bidder_addr}" + ) + message_route = self.getMessageRoute( + 2, bidder_addr, offer.addr_from, cursor=cursor + ) + if message_route: + raise ValueError("Direct message route already exists") + + connReqInvitation = req_data["connection_req"] + cmd_id = net_i.send_command(f"/connect {connReqInvitation}") + response = net_i.wait_for_command_response(cmd_id) + pccConnId = getResponseData(response, "connection")["pccConnId"] + + now: int = self.getTime() + message_route = DirectMessageRoute( + active_ind=2, + network_id=2, + linked_type=Concepts.OFFER, + smsg_addr_local=offer.addr_from, + smsg_addr_remote=bidder_addr, + route_data=json.dumps( + {"connection_req": connReqInvitation, "pccConnId": pccConnId} + ).encode("UTF-8"), + created_at=now, + ) + message_route_id = self.add(message_route, cursor) + + message_route_link = DirectMessageRouteLink( + active_ind=1, + direct_message_route_id=message_route_id, + linked_type=Concepts.OFFER, + linked_id=offer_id, + created_at=now, + ) + self.add(message_route_link, cursor) + + finally: + self.closeDB(cursor) + + def routeEstablishedForBid(self, bid_id: bytes, cursor): + self.log.info(f"Route established for bid {self.log.id(bid_id)}") + + bid, offer = self.getBidAndOffer(bid_id, cursor) + ensure(bid, "Bid not found") + ensure(offer, "Offer not found") + + coin_from = Coins(offer.coin_from) + coin_to = Coins(offer.coin_to) + + if offer.swap_type == SwapTypes.XMR_SWAP: + xmr_swap = self.queryOne(XmrSwap, cursor, {"bid_id": bid.bid_id}) + + reverse_bid: bool = self.is_reverse_ads_bid(coin_from, coin_to) + if reverse_bid: + bid_id = self.sendADSBidIntentMessage(bid, offer, cursor) + bid.setState(BidStates.BID_REQUEST_SENT) + self.log.info(f"Sent ADS_BID_LF {self.logIDB(xmr_swap.bid_id)}") + else: + bid_id = self.sendXmrBidMessage(bid, xmr_swap, offer, cursor) + bid.setState(BidStates.BID_SENT) + self.log.info(f"Sent XMR_BID_FL {self.logIDB(xmr_swap.bid_id)}") + self.saveBidInSession(bid.bid_id, bid, cursor, xmr_swap) + else: + bid_id = self.sendBidMessage(bid, offer, cursor) + bid.setState(BidStates.BID_SENT) + self.log.info(f"Sent BID {self.log.id(bid_id)}") + self.saveBidInSession(bid_id, bid, cursor) + + def processContactConnected(self, event_data) -> None: + contact_data = getResponseData(event_data, "contact") + connId = contact_data["activeConn"]["connId"] + localDisplayName = contact_data["localDisplayName"] + self.log.debug( + f"Processing Contact Connected event, ID: {connId}, contact name: {localDisplayName}." + ) + + try: + cursor = self.openDB() + + query_str = ( + "SELECT record_id, network_id, smsg_addr_local, smsg_addr_remote, route_data FROM direct_message_routes " + + "WHERE active_ind = 2" + ) + rows = cursor.execute(query_str).fetchall() + + found_direct_message_route = None + for row in rows: + record_id, network_id, smsg_addr_local, smsg_addr_remote, route_data = ( + row + ) + route_data = json.loads(route_data.decode("UTF-8")) + + if connId == route_data["pccConnId"]: + self.log.debug( + f"Direct message route established local: {smsg_addr_local}, remote: {smsg_addr_remote}." + ) + # route_data["localDisplayName"] = localDisplayName + + cursor.execute(query_str) + # query = "UPDATE direct_message_routes SET active_ind = 1, route_data = :route_data WHERE record_id = :record_id " + query = "UPDATE direct_message_routes SET active_ind = 1 WHERE record_id = :record_id " + cursor.execute(query, {"record_id": record_id}) + found_direct_message_route = record_id + break + + if found_direct_message_route: + query_str = ( + "SELECT record_id, linked_type, linked_id FROM direct_message_route_links " + + "WHERE active_ind = 1" + ) + rows = cursor.execute(query_str).fetchall() + for row in rows: + record_id, linked_type, linked_id = row + + if linked_type == Concepts.BID: + self.routeEstablishedForBid(linked_id, cursor) + query = "UPDATE direct_message_route_links SET active_ind = 2 WHERE record_id = :record_id " + cursor.execute(query, {"record_id": record_id}) + elif linked_type == Concepts.OFFER: + pass + else: + self.log.warning( + f"Unknown direct_message_route_link type: {linked_type}, {self.log.id(linked_id)}." + ) + else: + self.log.warning( + f"Unknown direct message route connected, connId: {connId}" + ) + finally: + self.closeDB(cursor) + + def processContactDisconnected(self, event_data) -> None: + net_i = self.getActiveNetworkInterface(2) + connId = getResponseData(event_data, "contact")["activeConn"]["connId"] + self.log.info(f"Direct message route disconnected, connId: {connId}") + closeSimplexChat(self, net_i, connId) + + query_str = "SELECT record_id, network_id, smsg_addr_local, smsg_addr_remote, route_data FROM direct_message_routes" + try: + cursor = self.openDB() + + rows = cursor.execute(query_str).fetchall() + + for row in rows: + record_id, network_id, smsg_addr_local, smsg_addr_remote, route_data = ( + row + ) + route_data = json.loads(route_data.decode("UTF-8")) + + if connId == route_data["pccConnId"]: + self.log.debug(f"Removing direct message route: {record_id}.") + cursor.execute( + "DELETE FROM direct_message_routes WHERE record_id = :record_id ", + {"record_id": record_id}, + ) + break + finally: + self.closeDB(cursor) + + def closeMessageRoute(self, record_id, network_id, route_data, cursor): + net_i = self.getActiveNetworkInterface(2) + + connId = route_data["pccConnId"] + + self.log.info(f"Closing Simplex chat, id: {connId}") + closeSimplexChat(self, net_i, connId) + + self.log.debug(f"Removing direct message route: {record_id}.") + cursor.execute( + "DELETE FROM direct_message_routes WHERE record_id = :record_id ", + {"record_id": record_id}, + ) + self.commitDB() + def processMsg(self, msg) -> None: try: if "hex" not in msg: @@ -9685,6 +10355,8 @@ class BasicSwap(BaseApp): self.processADSBidReversed(msg) elif msg_type == MessageTypes.ADS_BID_ACCEPT_FL: self.processADSBidReversedAccept(msg) + elif msg_type == MessageTypes.CONNECT_REQ: + self.processConnectRequest(msg) except InactiveCoin as ex: self.log.debug( @@ -9920,6 +10592,7 @@ class BasicSwap(BaseApp): if now - self._last_checked_expired >= self.check_expired_seconds: self.expireMessages() + self.expireMessageRoutes() self.expireDBRecords() self.checkAcceptedBids() self._last_checked_expired = now @@ -10105,7 +10778,7 @@ class BasicSwap(BaseApp): settings_changed = True else: # Encode value as hex to avoid escaping - new_value = new_value.encode("utf-8").hex() + new_value = new_value.encode("UTF-8").hex() if settings_copy.get("chart_api_key_enc", "") != new_value: settings_copy["chart_api_key_enc"] = new_value if "chart_api_key" in settings_copy: @@ -10127,7 +10800,7 @@ class BasicSwap(BaseApp): settings_changed = True else: # Encode value as hex to avoid escaping - new_value = new_value.encode("utf-8").hex() + new_value = new_value.encode("UTF-8").hex() if settings_copy.get("coingecko_api_key_enc", "") != new_value: settings_copy["coingecko_api_key_enc"] = new_value if "coingecko_api_key" in settings_copy: @@ -10963,7 +11636,7 @@ class BasicSwap(BaseApp): AutomationStrategy, cursor, {"record_id": strategy_id} ) if "data" in data: - strategy.data = json.dumps(data["data"]).encode("utf-8") + strategy.data = json.dumps(data["data"]).encode("UTF-8") self.log.debug("data {}".format(data["data"])) if "note" in data: strategy.note = data["note"] @@ -10981,10 +11654,10 @@ class BasicSwap(BaseApp): strategy_data = ( {} if strategy.data is None - else json.loads(strategy.data.decode("utf-8")) + else json.loads(strategy.data.decode("UTF-8")) ) strategy_data["max_concurrent_bids"] = new_max_concurrent_bids - strategy.data = json.dumps(strategy_data).encode("utf-8") + strategy.data = json.dumps(strategy_data).encode("UTF-8") self.updateDB(strategy, cursor, ["record_id"]) finally: @@ -11267,7 +11940,7 @@ class BasicSwap(BaseApp): def isOfferRevoked(self, offer_id: bytes, offer_addr_from) -> bool: for pair in self._possibly_revoked_offers: if offer_id == pair[0]: - signature_enc = base64.b64encode(pair[1]).decode("utf-8") + signature_enc = base64.b64encode(pair[1]).decode("UTF-8") passed = self.callcoinrpc( Coins.PART, "verifymessage", @@ -11545,6 +12218,13 @@ class BasicSwap(BaseApp): return rv + def setMsgSplitInfo(self, xmr_swap) -> None: + for network in self.active_networks: + if network["type"] == "simplex": + xmr_swap.msg_split_info = "9000:11000" + return + xmr_swap.msg_split_info = "16000:17000" + def setFilters(self, prefix, filters): key_str = "saved_filters_" + prefix value_str = json.dumps(filters) diff --git a/basicswap/basicswap_util.py b/basicswap/basicswap_util.py index 39403d0..0687a90 100644 --- a/basicswap/basicswap_util.py +++ b/basicswap/basicswap_util.py @@ -36,6 +36,11 @@ class KeyTypes(IntEnum): KAF = 6 +class MessageNetworks(IntEnum): + SMSG = auto() + SIMPLEX = auto() + + class MessageTypes(IntEnum): OFFER = auto() BID = auto() @@ -53,6 +58,8 @@ class MessageTypes(IntEnum): ADS_BID_LF = auto() ADS_BID_ACCEPT_FL = auto() + CONNECT_REQ = auto() + class AddressTypes(IntEnum): OFFER = auto() @@ -111,6 +118,7 @@ class BidStates(IntEnum): BID_EXPIRED = 31 BID_AACCEPT_DELAY = 32 BID_AACCEPT_FAIL = 33 + CONNECT_REQ_SENT = 34 class TxStates(IntEnum): @@ -228,6 +236,10 @@ class NotificationTypes(IntEnum): BID_ACCEPTED = auto() +class ConnectionRequestTypes(IntEnum): + BID = 1 + + class AutomationOverrideOptions(IntEnum): DEFAULT = 0 ALWAYS_ACCEPT = 1 @@ -339,6 +351,8 @@ def strBidState(state): return "Auto accept delay" if state == BidStates.BID_AACCEPT_FAIL: return "Auto accept failed" + if state == BidStates.CONNECT_REQ_SENT: + return "Connect request sent" return "Unknown" + " " + str(state) diff --git a/basicswap/bin/run.py b/basicswap/bin/run.py index 2835ff0..4085d66 100755 --- a/basicswap/bin/run.py +++ b/basicswap/bin/run.py @@ -17,12 +17,13 @@ import traceback import basicswap.config as cfg from basicswap import __version__ -from basicswap.ui.util import getCoinName from basicswap.basicswap import BasicSwap from basicswap.chainparams import chainparams, Coins, isKnownCoinName -from basicswap.http_server import HttpThread from basicswap.contrib.websocket_server import WebsocketServer - +from basicswap.http_server import HttpThread +from basicswap.network.simplex_chat import startSimplexClient +from basicswap.ui.util import getCoinName +from basicswap.util.daemon import Daemon initial_logger = logging.getLogger() initial_logger.level = logging.DEBUG @@ -31,16 +32,6 @@ if not len(initial_logger.handlers): logger = initial_logger swap_client = None -with_coins = set() -without_coins = set() - - -class Daemon: - __slots__ = ("handle", "files") - - def __init__(self, handle, files): - self.handle = handle - self.files = files def signal_handler(sig, frame): @@ -131,6 +122,7 @@ def startDaemon(node_dir, bin_dir, daemon_bin, opts=[], extra_config={}): cwd=datadir_path, ), opened_files, + os.path.basename(daemon_bin), ) @@ -161,6 +153,7 @@ def startXmrDaemon(node_dir, bin_dir, daemon_bin, opts=[]): cwd=datadir_path, ), [file_stdout, file_stderr], + os.path.basename(daemon_bin), ) @@ -224,6 +217,7 @@ def startXmrWalletDaemon(node_dir, bin_dir, wallet_bin, opts=[]): cwd=data_dir, ), [wallet_stdout, wallet_stderr], + os.path.basename(wallet_bin), ) @@ -284,8 +278,32 @@ def getCoreBinArgs(coin_id: int, coin_settings, prepare=False, use_tor_proxy=Fal return extra_args +def mainLoop(daemons, update: bool = True): + while not swap_client.delay_event.wait(0.5): + if update: + swap_client.update() + else: + pass + + for daemon in daemons: + if daemon.running is False: + continue + poll = daemon.handle.poll() + if poll is None: + pass # Process is running + else: + daemon.running = False + swap_client.log.error( + f"Process {daemon.handle.pid} for {daemon.name} terminated unexpectedly returning {poll}." + ) + + def runClient( - data_dir: str, chain: str, start_only_coins: bool, log_prefix: str = "BasicSwap" + data_dir: str, + chain: str, + start_only_coins: bool, + log_prefix: str = "BasicSwap", + extra_opts=dict(), ) -> int: global swap_client, logger daemons = [] @@ -311,13 +329,6 @@ def runClient( with open(settings_path) as fs: settings = json.load(fs) - extra_opts = dict() - if len(with_coins) > 0: - with_coins.add("particl") - extra_opts["with_coins"] = with_coins - if len(without_coins) > 0: - extra_opts["without_coins"] = without_coins - swap_client = BasicSwap( data_dir, settings, chain, log_name=log_prefix, extra_opts=extra_opts ) @@ -334,12 +345,46 @@ def runClient( # Settings may have been modified settings = swap_client.settings + try: # Try start daemons + for network in settings.get("networks", []): + network_type = network.get("type", "unknown") + if network_type == "simplex": + simplex_dir = os.path.join(data_dir, "simplex") + + log_level = "debug" if swap_client.debug else "info" + + socks_proxy = None + if "socks_proxy_override" in network: + socks_proxy = network["socks_proxy_override"] + elif swap_client.use_tor_proxy: + socks_proxy = ( + f"{swap_client.tor_proxy_host}:{swap_client.tor_proxy_port}" + ) + + daemons.append( + startSimplexClient( + network["client_path"], + simplex_dir, + network["server_address"], + network["ws_port"], + logger, + swap_client.delay_event, + socks_proxy=socks_proxy, + log_level=log_level, + ) + ) + pid = daemons[-1].handle.pid + swap_client.log.info(f"Started Simplex client {pid}") + for c, v in settings["chainclients"].items(): if len(start_only_coins) > 0 and c not in start_only_coins: continue - if (len(with_coins) > 0 and c not in with_coins) or c in without_coins: + if ( + len(swap_client.with_coins_override) > 0 + and c not in swap_client.with_coins_override + ) or c in swap_client.without_coins_override: if v.get("manage_daemon", False) or v.get( "manage_wallet_daemon", False ): @@ -497,8 +542,7 @@ def runClient( logger.info( f"Only running {start_only_coins}. Manually exit with Ctrl + c when ready." ) - while not swap_client.delay_event.wait(0.5): - pass + mainLoop(daemons, update=False) else: swap_client.start() if "htmlhost" in settings: @@ -536,8 +580,7 @@ def runClient( swap_client.ws_server.run_forever(threaded=True) logger.info("Exit with Ctrl + c.") - while not swap_client.delay_event.wait(0.5): - swap_client.update() + mainLoop(daemons) except Exception as e: # noqa: F841 traceback.print_exc() @@ -560,13 +603,13 @@ def runClient( closed_pids = [] for d in daemons: - swap_client.log.info(f"Interrupting {d.handle.pid}") + swap_client.log.info(f"Interrupting {d.name} {d.handle.pid}") try: d.handle.send_signal( signal.CTRL_C_EVENT if os.name == "nt" else signal.SIGINT ) except Exception as e: - swap_client.log.info(f"Interrupting {d.handle.pid}, error {e}") + swap_client.log.info(f"Interrupting {d.name} {d.handle.pid}, error {e}") for d in daemons: try: d.handle.wait(timeout=120) @@ -623,6 +666,9 @@ def printHelp(): "--startonlycoin Only start the provides coin daemon/s, use this if a chain requires extra processing." ) print("--logprefix Specify log prefix.") + print( + "--forcedbupgrade Recheck database against schema regardless of version." + ) def main(): @@ -630,6 +676,9 @@ def main(): chain = "mainnet" start_only_coins = set() log_prefix: str = "BasicSwap" + options = dict() + with_coins = set() + without_coins = set() for v in sys.argv[1:]: if len(v) < 2 or v[0] != "-": @@ -665,6 +714,9 @@ def main(): ensure_coin_valid(coin) without_coins.add(coin) continue + if name == "forcedbupgrade": + options["force_db_upgrade"] = True + continue if len(s) == 2: if name == "datadir": data_dir = os.path.expanduser(s[1]) @@ -693,8 +745,14 @@ def main(): if not os.path.exists(data_dir): os.makedirs(data_dir) + if len(with_coins) > 0: + with_coins.add("particl") + options["with_coins"] = with_coins + if len(without_coins) > 0: + options["without_coins"] = without_coins + logger.info(os.path.basename(sys.argv[0]) + ", version: " + __version__ + "\n\n") - fail_code = runClient(data_dir, chain, start_only_coins, log_prefix) + fail_code = runClient(data_dir, chain, start_only_coins, log_prefix, options) print("Done.") return fail_code diff --git a/basicswap/db.py b/basicswap/db.py index 1867c6e..25e2e94 100644 --- a/basicswap/db.py +++ b/basicswap/db.py @@ -13,7 +13,7 @@ from enum import IntEnum, auto from typing import Optional -CURRENT_DB_VERSION = 28 +CURRENT_DB_VERSION = 29 CURRENT_DB_DATA_VERSION = 6 @@ -219,6 +219,7 @@ class Bid(Table): bid_addr = Column("string") pk_bid_addr = Column("blob") proof_address = Column("string") + proof_signature = Column("blob") proof_utxos = Column("blob") # Address to spend lock tx to - address from wallet if empty TODO withdraw_to_addr = Column("string") @@ -485,6 +486,14 @@ class XmrSwap(Table): b_lock_tx_id = Column("blob") + msg_split_info = Column("string") + + def getMsgSplitInfo(self): + if self.msg_split_info is None: + return 16000, 17000 + msg_split_info = self.msg_split_info.split(":") + return int(msg_split_info[0]), int(msg_split_info[1]) + class XmrSplitData(Table): __tablename__ = "xmr_split_data" @@ -658,10 +667,44 @@ class CoinRates(Table): last_updated = Column("integer") -def create_db_(con, log) -> None: - c = con.cursor() +class MessageNetworks(Table): + __tablename__ = "message_networks" + record_id = Column("integer", primary_key=True, autoincrement=True) + active_ind = Column("integer") + name = Column("string") + created_at = Column("integer") + + +class DirectMessageRoute(Table): + __tablename__ = "direct_message_routes" + + record_id = Column("integer", primary_key=True, autoincrement=True) + active_ind = Column("integer") + network_id = Column("integer") + linked_type = Column("integer") + linked_id = Column("blob") + smsg_addr_local = Column("string") + smsg_addr_remote = Column("string") + # smsg_addr_id_local = Column("integer") # SmsgAddress + # smsg_addr_id_remote = Column("integer") # KnownIdentity + route_data = Column("blob") + created_at = Column("integer") + + +class DirectMessageRouteLink(Table): + __tablename__ = "direct_message_route_links" + record_id = Column("integer", primary_key=True, autoincrement=True) + active_ind = Column("integer") + direct_message_route_id = Column("integer") + linked_type = Column("integer") + linked_id = Column("blob") + created_at = Column("integer") + + +def extract_schema() -> dict: g = globals().copy() + tables = {} for name, obj in g.items(): if not inspect.isclass(obj): continue @@ -671,15 +714,13 @@ def create_db_(con, log) -> None: continue table_name: str = obj.__tablename__ - query: str = f"CREATE TABLE {table_name} (" - + table = {} + columns = {} primary_key = None constraints = [] indices = [] - num_columns: int = 0 for m in inspect.getmembers(obj): m_name, m_obj = m - if hasattr(m_obj, "__sqlite3_primary_key__"): primary_key = m_obj continue @@ -690,46 +731,103 @@ def create_db_(con, log) -> None: indices.append(m_obj) continue if hasattr(m_obj, "__sqlite3_column__"): - if num_columns > 0: - query += "," - col_type: str = m_obj.column_type.upper() if col_type == "BOOL": col_type = "INTEGER" - query += f" {m_name} {col_type} " - - if m_obj.primary_key: - query += "PRIMARY KEY ASC " - if m_obj.unique: - query += "UNIQUE " - num_columns += 1 - + columns[m_name] = { + "type": col_type, + "primary_key": m_obj.primary_key, + "unique": m_obj.unique, + } + table["columns"] = columns if primary_key is not None: - query += f", PRIMARY KEY ({primary_key.column_1}" + table["primary_key"] = {"column_1": primary_key.column_1} if primary_key.column_2: - query += f", {primary_key.column_2}" + table["primary_key"]["column_2"] = primary_key.column_2 if primary_key.column_3: - query += f", {primary_key.column_3}" - query += ") " + table["primary_key"]["column_3"] = primary_key.column_3 for constraint in constraints: - query += f", UNIQUE ({constraint.column_1}" + if "constraints" not in table: + table["constraints"] = [] + table_constraint = {"column_1": constraint.column_1} if constraint.column_2: - query += f", {constraint.column_2}" + table_constraint["column_2"] = constraint.column_2 if constraint.column_3: - query += f", {constraint.column_3}" - query += ") " + table_constraint["column_3"] = constraint.column_3 + table["constraints"].append(table_constraint) + for i in indices: + if "indices" not in table: + table["indices"] = [] + table_index = {"index_name": i.name, "column_1": i.column_1} + if i.column_2 is not None: + table_index["column_2"] = i.column_2 + if i.column_3 is not None: + table_index["column_3"] = i.column_3 + table["indices"].append(table_index) + + tables[table_name] = table + return tables + + +def create_table(c, table_name, table) -> None: + query: str = f"CREATE TABLE {table_name} (" + + for i, (colname, column) in enumerate(table["columns"].items()): + col_type = column["type"] + query += ("," if i > 0 else "") + f" {colname} {col_type} " + if column["primary_key"]: + query += "PRIMARY KEY ASC " + if column["unique"]: + query += "UNIQUE " + + if "primary_key" in table: + column_1 = table["primary_key"]["column_1"] + column_2 = table["primary_key"].get("column_2", None) + column_3 = table["primary_key"].get("column_3", None) + query += f", PRIMARY KEY ({column_1}" + if column_2: + query += f", {column_2}" + if column_3: + query += f", {column_3}" + query += ") " + + constraints = table.get("constraints", []) + for constraint in constraints: + column_1 = constraint["column_1"] + column_2 = constraint.get("column_2", None) + column_3 = constraint.get("column_3", None) + query += f", UNIQUE ({column_1}" + if column_2: + query += f", {column_2}" + if column_3: + query += f", {column_3}" + query += ") " + + query += ")" + c.execute(query) + + indices = table.get("indices", []) + for index in indices: + index_name = index["index_name"] + column_1 = index["column_1"] + column_2 = index.get("column_2", None) + column_3 = index.get("column_3", None) + query: str = f"CREATE INDEX {index_name} ON {table_name} ({column_1}" + if column_2: + query += f", {column_2}" + if column_3: + query += f", {column_3}" query += ")" c.execute(query) - for i in indices: - query: str = f"CREATE INDEX {i.name} ON {table_name} ({i.column_1}" - if i.column_2 is not None: - query += f", {i.column_2}" - if i.column_3 is not None: - query += f", {i.column_3}" - query += ")" - c.execute(query) + + +def create_db_(con, log) -> None: + db_schema = extract_schema() + c = con.cursor() + for table_name, table in db_schema.items(): + create_table(c, table_name, table) def create_db(db_path: str, log) -> None: @@ -915,6 +1013,7 @@ class DBMethods: query += f"{key}=:{key}" cursor.execute(query, values) + return cursor.lastrowid def query( self, diff --git a/basicswap/db_upgrades.py b/basicswap/db_upgrades.py index 2ed8357..4385780 100644 --- a/basicswap/db_upgrades.py +++ b/basicswap/db_upgrades.py @@ -12,8 +12,10 @@ from .db import ( AutomationStrategy, BidState, Concepts, + create_table, CURRENT_DB_DATA_VERSION, CURRENT_DB_VERSION, + extract_schema, ) from .basicswap_util import ( @@ -49,10 +51,9 @@ def upgradeDatabaseData(self, data_version): return self.log.info( - "Upgrading database records from version %d to %d.", - data_version, - CURRENT_DB_DATA_VERSION, + f"Upgrading database records from version {data_version} to {CURRENT_DB_DATA_VERSION}." ) + cursor = self.openDB() try: now = int(time.time()) @@ -138,313 +139,137 @@ def upgradeDatabaseData(self, data_version): self.db_data_version = CURRENT_DB_DATA_VERSION self.setIntKV("db_data_version", self.db_data_version, cursor) self.commitDB() - self.log.info( - "Upgraded database records to version {}".format(self.db_data_version) - ) + self.log.info(f"Upgraded database records to version {self.db_data_version}") finally: self.closeDB(cursor, commit=False) def upgradeDatabase(self, db_version): - if db_version >= CURRENT_DB_VERSION: + if self._force_db_upgrade is False and db_version >= CURRENT_DB_VERSION: return self.log.info( f"Upgrading database from version {db_version} to {CURRENT_DB_VERSION}." ) - while True: - try: - cursor = self.openDB() + # db_version, tablename, oldcolumnname, newcolumnname + rename_columns = [ + (13, "actions", "event_id", "action_id"), + (13, "actions", "event_type", "action_type"), + (13, "actions", "event_data", "action_data"), + ( + 14, + "xmr_swaps", + "coin_a_lock_refund_spend_tx_msg_id", + "coin_a_lock_spend_tx_msg_id", + ), + ] - current_version = db_version - if current_version == 6: - cursor.execute("ALTER TABLE bids ADD COLUMN security_token BLOB") - cursor.execute("ALTER TABLE offers ADD COLUMN security_token BLOB") - db_version += 1 - elif current_version == 7: - cursor.execute("ALTER TABLE transactions ADD COLUMN block_hash BLOB") + expect_schema = extract_schema() + have_tables = {} + try: + cursor = self.openDB() + + for rename_column in rename_columns: + dbv, table_name, colname_from, colname_to = rename_column + if db_version < dbv: cursor.execute( - "ALTER TABLE transactions ADD COLUMN block_height INTEGER" - ) - cursor.execute("ALTER TABLE transactions ADD COLUMN block_time INTEGER") - db_version += 1 - elif current_version == 8: - cursor.execute( - """ - CREATE TABLE wallets ( - record_id INTEGER NOT NULL, - coin_id INTEGER, - wallet_name VARCHAR, - wallet_data VARCHAR, - balance_type INTEGER, - created_at BIGINT, - PRIMARY KEY (record_id))""" - ) - db_version += 1 - elif current_version == 9: - cursor.execute("ALTER TABLE wallets ADD COLUMN wallet_data VARCHAR") - db_version += 1 - elif current_version == 10: - cursor.execute( - "ALTER TABLE smsgaddresses ADD COLUMN active_ind INTEGER" - ) - cursor.execute( - "ALTER TABLE smsgaddresses ADD COLUMN created_at INTEGER" - ) - cursor.execute("ALTER TABLE smsgaddresses ADD COLUMN note VARCHAR") - cursor.execute("ALTER TABLE smsgaddresses ADD COLUMN pubkey VARCHAR") - cursor.execute( - "UPDATE smsgaddresses SET active_ind = 1, created_at = 1" + f"ALTER TABLE {table_name} RENAME COLUMN {colname_from} TO {colname_to}" ) - cursor.execute("ALTER TABLE offers ADD COLUMN addr_to VARCHAR") - cursor.execute(f'UPDATE offers SET addr_to = "{self.network_addr}"') - db_version += 1 - elif current_version == 11: - cursor.execute( - "ALTER TABLE bids ADD COLUMN chain_a_height_start INTEGER" - ) - cursor.execute( - "ALTER TABLE bids ADD COLUMN chain_b_height_start INTEGER" - ) - cursor.execute("ALTER TABLE bids ADD COLUMN protocol_version INTEGER") - cursor.execute("ALTER TABLE offers ADD COLUMN protocol_version INTEGER") - cursor.execute("ALTER TABLE transactions ADD COLUMN tx_data BLOB") - db_version += 1 - elif current_version == 12: - cursor.execute( - """ - CREATE TABLE knownidentities ( - record_id INTEGER NOT NULL, - address VARCHAR, - label VARCHAR, - publickey BLOB, - num_sent_bids_successful INTEGER, - num_recv_bids_successful INTEGER, - num_sent_bids_rejected INTEGER, - num_recv_bids_rejected INTEGER, - num_sent_bids_failed INTEGER, - num_recv_bids_failed INTEGER, - note VARCHAR, - updated_at BIGINT, - created_at BIGINT, - PRIMARY KEY (record_id))""" - ) - cursor.execute("ALTER TABLE bids ADD COLUMN reject_code INTEGER") - cursor.execute("ALTER TABLE bids ADD COLUMN rate INTEGER") - cursor.execute( - "ALTER TABLE offers ADD COLUMN amount_negotiable INTEGER" - ) - cursor.execute("ALTER TABLE offers ADD COLUMN rate_negotiable INTEGER") - db_version += 1 - elif current_version == 13: - db_version += 1 - cursor.execute( - """ - CREATE TABLE automationstrategies ( - record_id INTEGER NOT NULL, - active_ind INTEGER, - label VARCHAR, - type_ind INTEGER, - only_known_identities INTEGER, - num_concurrent INTEGER, - data BLOB, - - note VARCHAR, - created_at BIGINT, - PRIMARY KEY (record_id))""" - ) - - cursor.execute( - """ - CREATE TABLE automationlinks ( - record_id INTEGER NOT NULL, - active_ind INTEGER, - - linked_type INTEGER, - linked_id BLOB, - strategy_id INTEGER, - - data BLOB, - repeat_limit INTEGER, - repeat_count INTEGER, - - note VARCHAR, - created_at BIGINT, - PRIMARY KEY (record_id))""" - ) - - cursor.execute( - """ - CREATE TABLE history ( - record_id INTEGER NOT NULL, - concept_type INTEGER, - concept_id INTEGER, - changed_data BLOB, - - note VARCHAR, - created_at BIGINT, - PRIMARY KEY (record_id))""" - ) - - cursor.execute( - """ - CREATE TABLE bidstates ( - record_id INTEGER NOT NULL, - active_ind INTEGER, - state_id INTEGER, - label VARCHAR, - in_progress INTEGER, - - note VARCHAR, - created_at BIGINT, - PRIMARY KEY (record_id))""" - ) - - cursor.execute("ALTER TABLE wallets ADD COLUMN active_ind INTEGER") - cursor.execute( - "ALTER TABLE knownidentities ADD COLUMN active_ind INTEGER" - ) - cursor.execute("ALTER TABLE eventqueue RENAME TO actions") - cursor.execute( - "ALTER TABLE actions RENAME COLUMN event_id TO action_id" - ) - cursor.execute( - "ALTER TABLE actions RENAME COLUMN event_type TO action_type" - ) - cursor.execute( - "ALTER TABLE actions RENAME COLUMN event_data TO action_data" - ) - elif current_version == 14: - db_version += 1 - cursor.execute( - "ALTER TABLE xmr_swaps ADD COLUMN coin_a_lock_release_msg_id BLOB" - ) - cursor.execute( - "ALTER TABLE xmr_swaps RENAME COLUMN coin_a_lock_refund_spend_tx_msg_id TO coin_a_lock_spend_tx_msg_id" - ) - elif current_version == 15: - db_version += 1 - cursor.execute( - """ - CREATE TABLE notifications ( - record_id INTEGER NOT NULL, - active_ind INTEGER, - event_type INTEGER, - event_data BLOB, - created_at BIGINT, - PRIMARY KEY (record_id))""" - ) - elif current_version == 16: - db_version += 1 - cursor.execute( - """ - CREATE TABLE prefunded_transactions ( - record_id INTEGER NOT NULL, - active_ind INTEGER, - created_at BIGINT, - linked_type INTEGER, - linked_id BLOB, - tx_type INTEGER, - tx_data BLOB, - used_by BLOB, - PRIMARY KEY (record_id))""" - ) - elif current_version == 17: - db_version += 1 - cursor.execute( - "ALTER TABLE knownidentities ADD COLUMN automation_override INTEGER" - ) - cursor.execute( - "ALTER TABLE knownidentities ADD COLUMN visibility_override INTEGER" - ) - cursor.execute("ALTER TABLE knownidentities ADD COLUMN data BLOB") - cursor.execute("UPDATE knownidentities SET active_ind = 1") - elif current_version == 18: - db_version += 1 - cursor.execute("ALTER TABLE xmr_split_data ADD COLUMN addr_from STRING") - cursor.execute("ALTER TABLE xmr_split_data ADD COLUMN addr_to STRING") - elif current_version == 19: - db_version += 1 - cursor.execute("ALTER TABLE bidstates ADD COLUMN in_error INTEGER") - cursor.execute("ALTER TABLE bidstates ADD COLUMN swap_failed INTEGER") - cursor.execute("ALTER TABLE bidstates ADD COLUMN swap_ended INTEGER") - elif current_version == 20: - db_version += 1 - cursor.execute( - """ - CREATE TABLE message_links ( - record_id INTEGER NOT NULL, - active_ind INTEGER, - created_at BIGINT, - - linked_type INTEGER, - linked_id BLOB, - - msg_type INTEGER, - msg_sequence INTEGER, - msg_id BLOB, - PRIMARY KEY (record_id))""" - ) - cursor.execute("ALTER TABLE offers ADD COLUMN bid_reversed INTEGER") - elif current_version == 21: - db_version += 1 - cursor.execute("ALTER TABLE offers ADD COLUMN proof_utxos BLOB") - cursor.execute("ALTER TABLE bids ADD COLUMN proof_utxos BLOB") - elif current_version == 22: - db_version += 1 - cursor.execute("ALTER TABLE offers ADD COLUMN amount_to INTEGER") - elif current_version == 23: - db_version += 1 - cursor.execute( - """ - CREATE TABLE checkedblocks ( - record_id INTEGER NOT NULL, - created_at BIGINT, - coin_type INTEGER, - block_height INTEGER, - block_hash BLOB, - block_time INTEGER, - PRIMARY KEY (record_id))""" - ) - cursor.execute("ALTER TABLE bids ADD COLUMN pkhash_buyer_to BLOB") - elif current_version == 24: - db_version += 1 - cursor.execute("ALTER TABLE bidstates ADD COLUMN can_accept INTEGER") - elif current_version == 25: - db_version += 1 - cursor.execute( - """ - CREATE TABLE coinrates ( - record_id INTEGER NOT NULL, - currency_from INTEGER, - currency_to INTEGER, - rate VARCHAR, - source VARCHAR, - last_updated INTEGER, - PRIMARY KEY (record_id))""" - ) - elif current_version == 26: - db_version += 1 - cursor.execute("ALTER TABLE offers ADD COLUMN auto_accept_type INTEGER") - elif current_version == 27: - db_version += 1 - cursor.execute("ALTER TABLE offers ADD COLUMN pk_from BLOB") - cursor.execute("ALTER TABLE bids ADD COLUMN pk_bid_addr BLOB") - - if current_version != db_version: - self.db_version = db_version - self.setIntKV("db_version", db_version, cursor) - self.commitDB() - self.log.info("Upgraded database to version {}".format(self.db_version)) + query = "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;" + tables = cursor.execute(query).fetchall() + for table in tables: + table_name = table[0] + if table_name in ("sqlite_sequence",): continue - except Exception as e: - self.log.error("Upgrade failed {}".format(e)) - self.rollbackDB() - finally: - self.closeDB(cursor, commit=False) - break - if db_version != CURRENT_DB_VERSION: - raise ValueError("Unable to upgrade database.") + have_table = {} + have_columns = {} + query = "SELECT * FROM PRAGMA_TABLE_INFO(:table_name) ORDER BY cid DESC;" + columns = cursor.execute(query, {"table_name": table_name}).fetchall() + for column in columns: + cid, name, data_type, notnull, default_value, primary_key = column + have_columns[name] = {"type": data_type, "primary_key": primary_key} + + have_table["columns"] = have_columns + + cursor.execute(f"PRAGMA INDEX_LIST('{table_name}');") + indices = cursor.fetchall() + for index in indices: + seq, index_name, unique, origin, partial = index + + if origin == "pk": # Created by a PRIMARY KEY constraint + continue + + cursor.execute(f"PRAGMA INDEX_INFO('{index_name}');") + index_info = cursor.fetchall() + + add_index = {"index_name": index_name} + for index_columns in index_info: + seqno, cid, name = index_columns + if origin == "u": # Created by a UNIQUE constraint + have_columns[name]["unique"] = 1 + else: + if "column_1" not in add_index: + add_index["column_1"] = name + elif "column_2" not in add_index: + add_index["column_2"] = name + elif "column_3" not in add_index: + add_index["column_3"] = name + else: + raise RuntimeError("Add more index columns.") + if origin == "c": + if "indices" not in table: + have_table["indices"] = [] + have_table["indices"].append(add_index) + + have_tables[table_name] = have_table + + for table_name, table in expect_schema.items(): + if table_name not in have_tables: + self.log.info(f"Creating table {table_name}.") + create_table(cursor, table_name, table) + continue + + have_table = have_tables[table_name] + have_columns = have_table["columns"] + for colname, column in table["columns"].items(): + if colname not in have_columns: + col_type = column["type"] + self.log.info(f"Adding column {colname} to table {table_name}.") + cursor.execute( + f"ALTER TABLE {table_name} ADD COLUMN {colname} {col_type}" + ) + indices = table.get("indices", []) + have_indices = have_table.get("indices", []) + for index in indices: + index_name = index["index_name"] + if not any( + have_idx.get("index_name") == index_name + for have_idx in have_indices + ): + self.log.info(f"Adding index {index_name} to table {table_name}.") + column_1 = index["column_1"] + column_2 = index.get("column_2", None) + column_3 = index.get("column_3", None) + query: str = ( + f"CREATE INDEX {index_name} ON {table_name} ({column_1}" + ) + if column_2: + query += f", {column_2}" + if column_3: + query += f", {column_3}" + query += ")" + cursor.execute(query) + + if CURRENT_DB_VERSION != db_version: + self.db_version = CURRENT_DB_VERSION + self.setIntKV("db_version", CURRENT_DB_VERSION, cursor) + self.log.info(f"Upgraded database to version {self.db_version}") + self.commitDB() + except Exception as e: + self.log.error(f"Upgrade failed {e}") + self.rollbackDB() + finally: + self.closeDB(cursor, commit=False) diff --git a/basicswap/db_util.py b/basicswap/db_util.py index 2cf4ec9..73ef7c4 100644 --- a/basicswap/db_util.py +++ b/basicswap/db_util.py @@ -76,6 +76,10 @@ def remove_expired_data(self, time_offset: int = 0): "DELETE FROM message_links WHERE linked_type = :type_ind AND linked_id = :linked_id", {"type_ind": int(Concepts.BID), "linked_id": bid_row[0]}, ) + cursor.execute( + "DELETE FROM direct_message_route_links WHERE linked_type = :type_ind AND linked_id = :linked_id", + {"type_ind": int(Concepts.BID), "linked_id": bid_row[0]}, + ) cursor.execute( "DELETE FROM eventlog WHERE eventlog.linked_type = :type_ind AND eventlog.linked_id = :offer_id", diff --git a/basicswap/js_server.py b/basicswap/js_server.py index a1288c3..6761c3a 100644 --- a/basicswap/js_server.py +++ b/basicswap/js_server.py @@ -858,7 +858,7 @@ def js_automationstrategies(self, url_split, post_string: str, is_json: bool) -> "label": strat_data.label, "type_ind": strat_data.type_ind, "only_known_identities": strat_data.only_known_identities, - "data": json.loads(strat_data.data.decode("utf-8")), + "data": json.loads(strat_data.data.decode("UTF-8")), "note": "" if strat_data.note is None else strat_data.note, } return bytes(json.dumps(rv), "UTF-8") @@ -992,7 +992,7 @@ def js_unlock(self, url_split, post_string, is_json) -> bytes: swap_client = self.server.swap_client post_data = getFormData(post_string, is_json) - password = get_data_entry(post_data, "password") + password: str = get_data_entry(post_data, "password") if have_data_entry(post_data, "coin"): coin = getCoinType(str(get_data_entry(post_data, "coin"))) @@ -1167,6 +1167,49 @@ def js_coinprices(self, url_split, post_string, is_json) -> bytes: ) +def js_messageroutes(self, url_split, post_string, is_json) -> bytes: + swap_client = self.server.swap_client + post_data = {} if post_string == "" else getFormData(post_string, is_json) + + filters = { + "page_no": 1, + "limit": PAGE_LIMIT, + "sort_by": "created_at", + "sort_dir": "desc", + } + + if have_data_entry(post_data, "sort_by"): + sort_by = get_data_entry(post_data, "sort_by") + ensure( + sort_by + in [ + "created_at", + ], + "Invalid sort by", + ) + filters["sort_by"] = sort_by + if have_data_entry(post_data, "sort_dir"): + sort_dir = get_data_entry(post_data, "sort_dir") + ensure(sort_dir in ["asc", "desc"], "Invalid sort dir") + filters["sort_dir"] = sort_dir + + if have_data_entry(post_data, "offset"): + filters["offset"] = int(get_data_entry(post_data, "offset")) + if have_data_entry(post_data, "limit"): + filters["limit"] = int(get_data_entry(post_data, "limit")) + ensure(filters["limit"] > 0, "Invalid limit") + + if have_data_entry(post_data, "address_from"): + filters["address_from"] = get_data_entry(post_data, "address_from") + if have_data_entry(post_data, "address_to"): + filters["address_to"] = get_data_entry(post_data, "address_to") + + action = get_data_entry_or(post_data, "action", None) + + message_routes = swap_client.listMessageRoutes(filters, action) + return bytes(json.dumps(message_routes), "UTF-8") + + endpoints = { "coins": js_coins, "wallets": js_wallets, @@ -1194,6 +1237,7 @@ endpoints = { "readurl": js_readurl, "active": js_active, "coinprices": js_coinprices, + "messageroutes": js_messageroutes, } diff --git a/basicswap/messages_npb.py b/basicswap/messages_npb.py index ef2c98e..1c5dafd 100644 --- a/basicswap/messages_npb.py +++ b/basicswap/messages_npb.py @@ -2,6 +2,7 @@ # -*- coding: utf-8 -*- # Copyright (c) 2024 tecnovert +# Copyright (c) 2025 The Basicswap developers # Distributed under the MIT software license, see the accompanying # file LICENSE or http://www.opensource.org/licenses/mit-license.php. @@ -23,6 +24,13 @@ protobuf ParseFromString would reset the whole object, from_bytes won't. from basicswap.util.integer import encode_varint, decode_varint +NPBW_INT = 0 +NPBW_BYTES = 2 + +NPBF_STR = 1 +NPBF_BOOL = 2 + + class NonProtobufClass: def __init__(self, init_all: bool = True, **kwargs): for key, value in kwargs.items(): @@ -34,7 +42,7 @@ class NonProtobufClass: found_field = True break if found_field is False: - raise ValueError(f"got an unexpected keyword argument '{key}'") + raise ValueError(f"Got an unexpected keyword argument '{key}'") if init_all: self.init_fields() @@ -117,151 +125,160 @@ class NonProtobufClass: class OfferMessage(NonProtobufClass): _map = { - 1: ("protocol_version", 0, 0), - 2: ("coin_from", 0, 0), - 3: ("coin_to", 0, 0), - 4: ("amount_from", 0, 0), - 5: ("amount_to", 0, 0), - 6: ("min_bid_amount", 0, 0), - 7: ("time_valid", 0, 0), - 8: ("lock_type", 0, 0), - 9: ("lock_value", 0, 0), - 10: ("swap_type", 0, 0), - 11: ("proof_address", 2, 1), - 12: ("proof_signature", 2, 1), - 13: ("pkhash_seller", 2, 0), - 14: ("secret_hash", 2, 0), - 15: ("fee_rate_from", 0, 0), - 16: ("fee_rate_to", 0, 0), - 17: ("amount_negotiable", 0, 2), - 18: ("rate_negotiable", 0, 2), - 19: ("proof_utxos", 2, 0), + 1: ("protocol_version", NPBW_INT, 0), + 2: ("coin_from", NPBW_INT, 0), + 3: ("coin_to", NPBW_INT, 0), + 4: ("amount_from", NPBW_INT, 0), + 5: ("amount_to", NPBW_INT, 0), + 6: ("min_bid_amount", NPBW_INT, 0), + 7: ("time_valid", NPBW_INT, 0), + 8: ("lock_type", NPBW_INT, 0), + 9: ("lock_value", NPBW_INT, 0), + 10: ("swap_type", NPBW_INT, 0), + 11: ("proof_address", NPBW_BYTES, NPBF_STR), + 12: ("proof_signature", NPBW_BYTES, NPBF_STR), + 13: ("pkhash_seller", NPBW_BYTES, 0), + 14: ("secret_hash", NPBW_BYTES, 0), + 15: ("fee_rate_from", NPBW_INT, 0), + 16: ("fee_rate_to", NPBW_INT, 0), + 17: ("amount_negotiable", NPBW_INT, NPBF_BOOL), + 18: ("rate_negotiable", NPBW_INT, NPBF_BOOL), + 19: ("proof_utxos", NPBW_BYTES, 0), 20: ("auto_accept_type", 0, 0), } class BidMessage(NonProtobufClass): _map = { - 1: ("protocol_version", 0, 0), - 2: ("offer_msg_id", 2, 0), - 3: ("time_valid", 0, 0), - 4: ("amount", 0, 0), - 5: ("amount_to", 0, 0), - 6: ("pkhash_buyer", 2, 0), - 7: ("proof_address", 2, 1), - 8: ("proof_signature", 2, 1), - 9: ("proof_utxos", 2, 0), - 10: ("pkhash_buyer_to", 2, 0), + 1: ("protocol_version", NPBW_INT, 0), + 2: ("offer_msg_id", NPBW_BYTES, 0), + 3: ("time_valid", NPBW_INT, 0), + 4: ("amount", NPBW_INT, 0), + 5: ("amount_to", NPBW_INT, 0), + 6: ("pkhash_buyer", NPBW_BYTES, 0), + 7: ("proof_address", NPBW_BYTES, NPBF_STR), + 8: ("proof_signature", NPBW_BYTES, NPBF_STR), + 9: ("proof_utxos", NPBW_BYTES, 0), + 10: ("pkhash_buyer_to", NPBW_BYTES, 0), } class BidAcceptMessage(NonProtobufClass): # Step 3, seller -> buyer _map = { - 1: ("bid_msg_id", 2, 0), - 2: ("initiate_txid", 2, 0), - 3: ("contract_script", 2, 0), - 4: ("pkhash_seller", 2, 0), + 1: ("bid_msg_id", NPBW_BYTES, 0), + 2: ("initiate_txid", NPBW_BYTES, 0), + 3: ("contract_script", NPBW_BYTES, 0), + 4: ("pkhash_seller", NPBW_BYTES, 0), } class OfferRevokeMessage(NonProtobufClass): _map = { - 1: ("offer_msg_id", 2, 0), - 2: ("signature", 2, 0), + 1: ("offer_msg_id", NPBW_BYTES, 0), + 2: ("signature", NPBW_BYTES, 0), } class BidRejectMessage(NonProtobufClass): _map = { - 1: ("bid_msg_id", 2, 0), - 2: ("reject_code", 0, 0), + 1: ("bid_msg_id", NPBW_BYTES, 0), + 2: ("reject_code", NPBW_INT, 0), } class XmrBidMessage(NonProtobufClass): # MSG1L, F -> L _map = { - 1: ("protocol_version", 0, 0), - 2: ("offer_msg_id", 2, 0), - 3: ("time_valid", 0, 0), - 4: ("amount", 0, 0), - 5: ("amount_to", 0, 0), - 6: ("pkaf", 2, 0), - 7: ("kbvf", 2, 0), - 8: ("kbsf_dleag", 2, 0), - 9: ("dest_af", 2, 0), + 1: ("protocol_version", NPBW_INT, 0), + 2: ("offer_msg_id", NPBW_BYTES, 0), + 3: ("time_valid", NPBW_INT, 0), + 4: ("amount", NPBW_INT, 0), + 5: ("amount_to", NPBW_INT, 0), + 6: ("pkaf", NPBW_BYTES, 0), + 7: ("kbvf", NPBW_BYTES, 0), + 8: ("kbsf_dleag", NPBW_BYTES, 0), + 9: ("dest_af", NPBW_BYTES, 0), } class XmrSplitMessage(NonProtobufClass): _map = { - 1: ("msg_id", 2, 0), - 2: ("msg_type", 0, 0), - 3: ("sequence", 0, 0), - 4: ("dleag", 2, 0), + 1: ("msg_id", NPBW_BYTES, 0), + 2: ("msg_type", NPBW_INT, 0), + 3: ("sequence", NPBW_INT, 0), + 4: ("dleag", NPBW_BYTES, 0), } class XmrBidAcceptMessage(NonProtobufClass): _map = { - 1: ("bid_msg_id", 2, 0), - 2: ("pkal", 2, 0), - 3: ("kbvl", 2, 0), - 4: ("kbsl_dleag", 2, 0), + 1: ("bid_msg_id", NPBW_BYTES, 0), + 2: ("pkal", NPBW_BYTES, 0), + 3: ("kbvl", NPBW_BYTES, 0), + 4: ("kbsl_dleag", NPBW_BYTES, 0), # MSG2F - 5: ("a_lock_tx", 2, 0), - 6: ("a_lock_tx_script", 2, 0), - 7: ("a_lock_refund_tx", 2, 0), - 8: ("a_lock_refund_tx_script", 2, 0), - 9: ("a_lock_refund_spend_tx", 2, 0), - 10: ("al_lock_refund_tx_sig", 2, 0), + 5: ("a_lock_tx", NPBW_BYTES, 0), + 6: ("a_lock_tx_script", NPBW_BYTES, 0), + 7: ("a_lock_refund_tx", NPBW_BYTES, 0), + 8: ("a_lock_refund_tx_script", NPBW_BYTES, 0), + 9: ("a_lock_refund_spend_tx", NPBW_BYTES, 0), + 10: ("al_lock_refund_tx_sig", NPBW_BYTES, 0), } class XmrBidLockTxSigsMessage(NonProtobufClass): # MSG3L _map = { - 1: ("bid_msg_id", 2, 0), - 2: ("af_lock_refund_spend_tx_esig", 2, 0), - 3: ("af_lock_refund_tx_sig", 2, 0), + 1: ("bid_msg_id", NPBW_BYTES, 0), + 2: ("af_lock_refund_spend_tx_esig", NPBW_BYTES, 0), + 3: ("af_lock_refund_tx_sig", NPBW_BYTES, 0), } class XmrBidLockSpendTxMessage(NonProtobufClass): # MSG4F _map = { - 1: ("bid_msg_id", 2, 0), - 2: ("a_lock_spend_tx", 2, 0), - 3: ("kal_sig", 2, 0), + 1: ("bid_msg_id", NPBW_BYTES, 0), + 2: ("a_lock_spend_tx", NPBW_BYTES, 0), + 3: ("kal_sig", NPBW_BYTES, 0), } class XmrBidLockReleaseMessage(NonProtobufClass): # MSG5F _map = { - 1: ("bid_msg_id", 2, 0), - 2: ("al_lock_spend_tx_esig", 2, 0), + 1: ("bid_msg_id", NPBW_BYTES, 0), + 2: ("al_lock_spend_tx_esig", NPBW_BYTES, 0), } class ADSBidIntentMessage(NonProtobufClass): # L -> F Sent from bidder, construct a reverse bid _map = { - 1: ("protocol_version", 0, 0), - 2: ("offer_msg_id", 2, 0), - 3: ("time_valid", 0, 0), - 4: ("amount_from", 0, 0), - 5: ("amount_to", 0, 0), + 1: ("protocol_version", NPBW_INT, 0), + 2: ("offer_msg_id", NPBW_BYTES, 0), + 3: ("time_valid", NPBW_INT, 0), + 4: ("amount_from", NPBW_INT, 0), + 5: ("amount_to", NPBW_INT, 0), } class ADSBidIntentAcceptMessage(NonProtobufClass): # F -> L Sent from offerer, construct a reverse bid _map = { - 1: ("bid_msg_id", 2, 0), - 2: ("pkaf", 2, 0), - 3: ("kbvf", 2, 0), - 4: ("kbsf_dleag", 2, 0), - 5: ("dest_af", 2, 0), + 1: ("bid_msg_id", NPBW_BYTES, 0), + 2: ("pkaf", NPBW_BYTES, 0), + 3: ("kbvf", NPBW_BYTES, 0), + 4: ("kbsf_dleag", NPBW_BYTES, 0), + 5: ("dest_af", NPBW_BYTES, 0), + } + + +class ConnectReqMessage(NonProtobufClass): + _map = { + 1: ("network_type", NPBW_INT, 0), + 2: ("network_data", NPBW_BYTES, 0), + 3: ("request_type", NPBW_INT, 0), + 4: ("request_data", NPBW_BYTES, 0), } diff --git a/basicswap/network/simplex.py b/basicswap/network/simplex.py index ad404ff..3841ca8 100644 --- a/basicswap/network/simplex.py +++ b/basicswap/network/simplex.py @@ -8,6 +8,7 @@ import base64 import json import threading +import traceback import websocket @@ -25,9 +26,6 @@ from basicswap.util.address import ( b58decode, decodeWif, ) -from basicswap.basicswap_util import ( - BidStates, -) def encode_base64(data: bytes) -> str: @@ -52,6 +50,20 @@ class WebSocketThread(threading.Thread): self.recv_queue = Queue() self.cmd_recv_queue = Queue() + self.delayed_events_queue = Queue() + + self.ignore_events: bool = False + + self.num_messages_received: int = 0 + + def disable_debug_mode(self): + self.ignore_events = False + for i in range(100): + try: + message = self.delayed_events_queue.get(block=False) + except Empty: + break + self.recv_queue.put(message) def on_message(self, ws, message): if self.logger: @@ -62,6 +74,7 @@ class WebSocketThread(threading.Thread): if message.startswith('{"corrId"'): self.cmd_recv_queue.put(message) else: + self.num_messages_received += 1 self.recv_queue.put(message) def queue_get(self): @@ -106,6 +119,20 @@ class WebSocketThread(threading.Thread): self.ws.send(cmd) return self.corrId + def wait_for_command_response(self, cmd_id, num_tries: int = 200): + cmd_id = str(cmd_id) + for i in range(num_tries): + message = self.cmd_queue_get() + if message is not None: + data = json.loads(message) + if "corrId" in data: + if data["corrId"] == cmd_id: + return data + self.delay_event.wait(0.5) + raise ValueError( + f"wait_for_command_response timed-out waiting for ID: {cmd_id}" + ) + def run(self): self.ws = websocket.WebSocketApp( self.url, @@ -126,16 +153,15 @@ class WebSocketThread(threading.Thread): def waitForResponse(ws_thread, sent_id, delay_event): sent_id = str(sent_id) - for i in range(100): + for i in range(200): message = ws_thread.cmd_queue_get() if message is not None: data = json.loads(message) - # print(f"json: {json.dumps(data, indent=4)}") if "corrId" in data: if data["corrId"] == sent_id: return data delay_event.wait(0.5) - raise ValueError(f"waitForResponse timed-out waiting for id: {sent_id}") + raise ValueError(f"waitForResponse timed-out waiting for ID: {sent_id}") def waitForConnected(ws_thread, delay_event): @@ -174,10 +200,17 @@ def getPrivkeyForAddress(self, addr) -> bytes: raise ValueError("key not found") -def sendSimplexMsg( - self, network, addr_from: str, addr_to: str, payload: bytes, msg_valid: int, cursor +def encryptMsg( + self, + addr_from: str, + addr_to: str, + payload: bytes, + msg_valid: int, + cursor, + timestamp=None, + deterministic=False, ) -> bytes: - self.log.debug("sendSimplexMsg") + self.log.debug("encryptMsg") try: rv = self.callrpc( @@ -210,14 +243,40 @@ def sendSimplexMsg( privkey_from = getPrivkeyForAddress(self, addr_from) payload += bytes((0,)) # Include null byte to match smsg - smsg_msg: bytes = smsgEncrypt(privkey_from, pubkey_to, payload) + smsg_msg: bytes = smsgEncrypt( + privkey_from, pubkey_to, payload, timestamp, deterministic + ) + return smsg_msg + + +def sendSimplexMsg( + self, + network, + addr_from: str, + addr_to: str, + payload: bytes, + msg_valid: int, + cursor, + timestamp: int = None, + deterministic: bool = False, + to_user_name: str = None, +) -> bytes: + self.log.debug("sendSimplexMsg") + + smsg_msg: bytes = encryptMsg( + self, addr_from, addr_to, payload, msg_valid, cursor, timestamp, deterministic + ) smsg_id = smsgGetID(smsg_msg) ws_thread = network["ws_thread"] - sent_id = ws_thread.send_command("#bsx " + encode_base64(smsg_msg)) + if to_user_name is not None: + to = "@" + to_user_name + " " + else: + to = "#bsx " + sent_id = ws_thread.send_command(to + encode_base64(smsg_msg)) response = waitForResponse(ws_thread, sent_id, self.delay_event) - if response["resp"]["type"] != "newChatItems": + if getResponseData(response, "type") != "newChatItems": json_str = json.dumps(response, indent=4) self.log.debug(f"Response {json_str}") raise ValueError("Send failed") @@ -243,8 +302,10 @@ def decryptSimplexMsg(self, msg_data): # Try with all active bid/offer addresses query: str = """SELECT DISTINCT address FROM ( - SELECT bid_addr AS address FROM bids WHERE active_ind = 1 - AND (in_progress = 1 OR (state > :bid_received AND state < :bid_completed) OR (state IN (:bid_received, :bid_sent) AND expire_at > :now)) + SELECT b.bid_addr AS address FROM bids b + JOIN bidstates s ON b.state = s.state_id + WHERE b.active_ind = 1 + AND (s.in_progress OR (s.swap_ended = 0 AND b.expire_at > :now)) UNION SELECT addr_from AS address FROM offers WHERE active_ind = 1 AND expire_at > :now )""" @@ -253,15 +314,7 @@ def decryptSimplexMsg(self, msg_data): try: cursor = self.openDB() - addr_rows = cursor.execute( - query, - { - "bid_received": int(BidStates.BID_RECEIVED), - "bid_completed": int(BidStates.SWAP_COMPLETED), - "bid_sent": int(BidStates.BID_SENT), - "now": now, - }, - ).fetchall() + addr_rows = cursor.execute(query, {"now": now}).fetchall() finally: self.closeDB(cursor, commit=False) @@ -283,46 +336,125 @@ def decryptSimplexMsg(self, msg_data): return decrypted +def parseSimplexMsg(self, chat_item): + item_status = chat_item["chatItem"]["meta"]["itemStatus"] + dir_type = item_status["type"] + if dir_type not in ("sndRcvd", "rcvNew"): + return None + + snd_progress = item_status.get("sndProgress", None) + if snd_progress and snd_progress != "complete": + item_id = chat_item["chatItem"]["meta"]["itemId"] + self.log.debug(f"simplex chat item {item_id} {snd_progress}") + return None + + conn_id = None + msg_dir: str = "recv" if dir_type == "rcvNew" else "sent" + chat_type: str = chat_item["chatInfo"]["type"] + if chat_type == "group": + chat_name = chat_item["chatInfo"]["groupInfo"]["localDisplayName"] + conn_id = chat_item["chatInfo"]["groupInfo"]["groupId"] + self.num_group_simplex_messages_received += 1 + elif chat_type == "direct": + chat_name = chat_item["chatInfo"]["contact"]["localDisplayName"] + conn_id = chat_item["chatInfo"]["contact"]["activeConn"]["connId"] + self.num_direct_simplex_messages_received += 1 + else: + return None + + msg_content = chat_item["chatItem"]["content"]["msgContent"]["text"] + try: + msg_data: bytes = decode_base64(msg_content) + decrypted_msg = decryptSimplexMsg(self, msg_data) + if decrypted_msg is None: + return None + decrypted_msg["chat_type"] = chat_type + decrypted_msg["chat_name"] = chat_name + decrypted_msg["conn_id"] = conn_id + decrypted_msg["msg_dir"] = msg_dir + return decrypted_msg + except Exception as e: # noqa: F841 + # self.log.debug(f"decryptSimplexMsg error: {e}") + self.log.debug(f"decryptSimplexMsg error: {e}") + pass + return None + + +def processEvent(self, ws_thread, msg_type: str, data) -> bool: + if ws_thread.ignore_events: + if msg_type not in ("contactConnected", "contactDeletedByContact"): + return False + ws_thread.delayed_events_queue.put(json.dumps(data)) + return True + + if msg_type == "contactConnected": + self.processContactConnected(data) + elif msg_type == "contactDeletedByContact": + self.processContactDisconnected(data) + else: + return False + return True + + def readSimplexMsgs(self, network): ws_thread = network["ws_thread"] - for i in range(100): message = ws_thread.queue_get() if message is None: break + if self.delay_event.is_set(): + break data = json.loads(message) - # self.log.debug(f"message 1: {json.dumps(data, indent=4)}") + # self.log.debug(f"Message: {json.dumps(data, indent=4)}") try: - if data["resp"]["type"] in ("chatItemsStatusesUpdated", "newChatItems"): - for chat_item in data["resp"]["chatItems"]: - item_status = chat_item["chatItem"]["meta"]["itemStatus"] - if item_status["type"] in ("sndRcvd", "rcvNew"): - snd_progress = item_status.get("sndProgress", None) - if snd_progress: - if snd_progress != "complete": - item_id = chat_item["chatItem"]["meta"]["itemId"] - self.log.debug( - f"simplex chat item {item_id} {snd_progress}" - ) - continue - try: - msg_data: bytes = decode_base64( - chat_item["chatItem"]["content"]["msgContent"]["text"] - ) - decrypted_msg = decryptSimplexMsg(self, msg_data) - if decrypted_msg is None: - continue - self.processMsg(decrypted_msg) - except Exception as e: # noqa: F841 - # self.log.debug(f"decryptSimplexMsg error: {e}") - pass + msg_type: str = getResponseData(data, "type") + if msg_type in ("chatItemsStatusesUpdated", "newChatItems"): + for chat_item in getResponseData(data, "chatItems"): + decrypted_msg = parseSimplexMsg(self, chat_item) + if decrypted_msg is None: + continue + self.processMsg(decrypted_msg) + elif msg_type == "chatError": + # self.log.debug(f"chatError Message: {json.dumps(data, indent=4)}") + pass + elif processEvent(self, ws_thread, msg_type, data): + pass + else: + self.log.debug(f"Unknown msg_type: {msg_type}") + # self.log.debug(f"Message: {json.dumps(data, indent=4)}") except Exception as e: self.log.debug(f"readSimplexMsgs error: {e}") + if self.debug: + self.log.error(traceback.format_exc()) self.delay_event.wait(0.05) +def getResponseData(data, tag=None): + if "Right" in data["resp"]: + if tag: + return data["resp"]["Right"][tag] + return data["resp"]["Right"] + if tag: + return data["resp"][tag] + return data["resp"] + + +def getNewSimplexLink(data): + response_data = getResponseData(data) + if "connLinkContact" in response_data: + return response_data["connLinkContact"]["connFullLink"] + return response_data["connReqContact"] + + +def getJoinedSimplexLink(data): + response_data = getResponseData(data) + if "connLinkInvitation" in response_data: + return response_data["connLinkInvitation"]["connFullLink"] + return response_data["connReqInvitation"] + + def initialiseSimplexNetwork(self, network_config) -> None: self.log.debug("initialiseSimplexNetwork") @@ -337,10 +469,10 @@ def initialiseSimplexNetwork(self, network_config) -> None: sent_id = ws_thread.send_command("/groups") response = waitForResponse(ws_thread, sent_id, self.delay_event) - if len(response["resp"]["groups"]) < 1: + if len(getResponseData(response, "groups")) < 1: sent_id = ws_thread.send_command("/c " + network_config["group_link"]) response = waitForResponse(ws_thread, sent_id, self.delay_event) - assert "groupLinkId" in response["resp"]["connection"] + assert "groupLinkId" in getResponseData(response, "connection") network = { "type": "simplex", @@ -348,3 +480,44 @@ def initialiseSimplexNetwork(self, network_config) -> None: } self.active_networks.append(network) + + +def closeSimplexChat(self, net_i, connId) -> bool: + try: + cmd_id = net_i.send_command("/chats") + response = net_i.wait_for_command_response(cmd_id, num_tries=500) + remote_name = None + for chat in getResponseData(response, "chats"): + if ( + "chatInfo" not in chat + or "type" not in chat["chatInfo"] + or chat["chatInfo"]["type"] != "direct" + ): + continue + try: + if chat["chatInfo"]["contact"]["activeConn"]["connId"] == connId: + remote_name = chat["chatInfo"]["contact"]["localDisplayName"] + break + except Exception as e: + self.log.debug(f"Error parsing chat: {e}") + + if remote_name is None: + self.log.warning( + f"Unable to find remote name for simplex direct chat, ID: {connId}" + ) + return False + + self.log.debug(f"Deleting simplex chat @{remote_name}, connID {connId}") + cmd_id = net_i.send_command(f"/delete @{remote_name}") + cmd_response = net_i.wait_for_command_response(cmd_id) + + if getResponseData(cmd_response, "type") != "contactDeleted": + self.log.warning(f"Failed to delete simplex chat, ID: {connId}") + self.log.debug( + "cmd_response: {}".format(json.dumps(cmd_response, indent=4)) + ) + return False + except Exception as e: + self.log.warning(f"Error deleting simplex chat, ID: {connId} - {e}") + return False + return True diff --git a/basicswap/network/simplex_chat.py b/basicswap/network/simplex_chat.py index df1951c..1a8b0c5 100644 --- a/basicswap/network/simplex_chat.py +++ b/basicswap/network/simplex_chat.py @@ -7,10 +7,11 @@ import os import select +import sqlite3 import subprocess import time -from basicswap.bin.run import Daemon +from basicswap.util.daemon import Daemon def initSimplexClient(args, logger, delay_event): @@ -29,7 +30,7 @@ def initSimplexClient(args, logger, delay_event): def readOutput(): buf = os.read(pipe_r, 1024).decode("utf-8") response = None - # logging.debug(f"simplex-chat output: {buf}") + # logger.debug(f"simplex-chat output: {buf}") if "display name:" in buf: logger.debug("Setting display name") response = b"user\n" @@ -45,7 +46,7 @@ def initSimplexClient(args, logger, delay_event): max_wait_seconds: int = 60 while p.poll() is None: if time.time() > start_time + max_wait_seconds: - raise ValueError("Timed out") + raise RuntimeError("Timed out") if os.name == "nt": readOutput() delay_event.wait(0.1) @@ -70,22 +71,45 @@ def startSimplexClient( websocket_port: int, logger, delay_event, + socks_proxy=None, + log_level: str = "debug", ) -> Daemon: logger.info("Starting Simplex client") if not os.path.exists(data_path): os.makedirs(data_path) - db_path = os.path.join(data_path, "simplex_client_data") + simplex_data_prefix = os.path.join(data_path, "simplex_client_data") + simplex_db_path = simplex_data_prefix + "_chat.db" + args = [bin_path, "-d", simplex_data_prefix, "-p", str(websocket_port)] - args = [bin_path, "-d", db_path, "-s", server_address, "-p", str(websocket_port)] + if socks_proxy: + args += ["--socks-proxy", socks_proxy] - if not os.path.exists(db_path): + if not os.path.exists(simplex_db_path): # Need to set initial profile through CLI # TODO: Must be a better way? - init_args = args + ["-e", "/help"] # Run command ro exit client + init_args = args + ["-e", "/help"] # Run command to exit client + init_args += ["-s", server_address] initSimplexClient(init_args, logger, delay_event) + else: + # Workaround to avoid error: + # SQLite3 returned ErrorConstraint while attempting to perform step: UNIQUE constraint failed: protocol_servers.user_id, protocol_servers.host, protocol_servers.port + # TODO: Remove? + with sqlite3.connect(simplex_db_path) as con: + c = con.cursor() + if ":" in server_address: + host, port = server_address.split(":") + else: + host = server_address + port = "" + query: str = ( + "SELECT COUNT(*) FROM protocol_servers WHERE host = :host and port = :port" + ) + q = c.execute(query, {"host": host, "port": port}).fetchone() + if q[0] < 1: + args += ["-s", server_address] - args += ["-l", "debug"] + args += ["-l", log_level] opened_files = [] stdout_dest = open( @@ -104,4 +128,5 @@ def startSimplexClient( cwd=data_path, ), opened_files, + "simplex-chat", ) diff --git a/basicswap/ui/app.py b/basicswap/ui/app.py new file mode 100644 index 0000000..c78b79c --- /dev/null +++ b/basicswap/ui/app.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2025 The Basicswap developers +# Distributed under the MIT software license, see the accompanying +# file LICENSE or http://www.opensource.org/licenses/mit-license.php. + +import json + +from basicswap.db import getOrderByStr + + +class UIApp: + def listMessageRoutes(self, filters={}, action=None): + cursor = self.openDB() + try: + rv = [] + + query_data: dict = {} + filter_query_str: str = "" + address_from: str = filters.get("address_from", None) + if address_from is not None: + filter_query_str += " AND smsg_addr_local = :address_from " + query_data["address_from"] = address_from + + address_to: str = filters.get("address_to", None) + if address_from is not None: + filter_query_str += " AND smsg_addr_remote = :address_to " + query_data["address_to"] = address_to + + if action is None: + pass + elif action == "clear": + self.log.info("Clearing message routes") + query_str: str = ( + "SELECT record_id, network_id, route_data" + + " FROM direct_message_routes " + ) + query_str += filter_query_str + rows = cursor.execute(query_str, query_data).fetchall() + for row in rows: + record_id, network_id, route_data = row + route_data = json.loads(route_data.decode("UTF-8")) + self.closeMessageRoute(record_id, network_id, route_data, cursor) + + else: + raise ValueError("Unknown action") + + query_str: str = ( + "SELECT record_id, network_id, active_ind, linked_type, linked_id, " + + " smsg_addr_local, smsg_addr_remote, route_data, created_at" + + " FROM direct_message_routes " + + " WHERE active_ind > 0 " + ) + + query_str += filter_query_str + query_str += getOrderByStr(filters) + + limit = filters.get("limit", None) + if limit is not None: + query_str += " LIMIT :limit" + query_data["limit"] = limit + offset = filters.get("offset", None) + if offset is not None: + query_str += " OFFSET :offset" + query_data["offset"] = offset + + q = cursor.execute(query_str, query_data) + rv = [] + for row in q: + ( + record_id, + network_id, + active_ind, + linked_type, + linked_id, + smsg_addr_local, + smsg_addr_remote, + route_data, + created_at, + ) = row + rv.append( + { + "record_id": record_id, + "network_id": network_id, + "active_ind": active_ind, + "smsg_addr_local": smsg_addr_local, + "smsg_addr_remote": smsg_addr_remote, + "route_data": json.loads(route_data.decode("UTF-8")), + } + ) + + return rv + finally: + self.closeDB(cursor, commit=False) diff --git a/basicswap/util/daemon.py b/basicswap/util/daemon.py new file mode 100644 index 0000000..f760e74 --- /dev/null +++ b/basicswap/util/daemon.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2025 The Basicswap developers +# Distributed under the MIT software license, see the accompanying +# file LICENSE or http://www.opensource.org/licenses/mit-license.php. + + +class Daemon: + __slots__ = ("handle", "files", "name", "running") + + def __init__(self, handle, files, name): + self.handle = handle + self.files = files + self.name = name + self.running = True diff --git a/basicswap/util/smsg.py b/basicswap/util/smsg.py index 6ab1fba..dd30884 100644 --- a/basicswap/util/smsg.py +++ b/basicswap/util/smsg.py @@ -83,19 +83,35 @@ def smsgGetID(smsg_message: bytes) -> bytes: return smsg_timestamp.to_bytes(8, byteorder="big") + ripemd160(smsg_message[8:]) -def smsgEncrypt(privkey_from: bytes, pubkey_to: bytes, payload: bytes) -> bytes: +def smsgEncrypt( + privkey_from: bytes, + pubkey_to: bytes, + payload: bytes, + smsg_timestamp: int = None, + deterministic: bool = False, +) -> bytes: # assert len(payload) < 128 # Requires lz4 if payload > 128 bytes # TODO: Add lz4 to match core smsg - smsg_timestamp = int(time.time()) - r = getSecretInt().to_bytes(32, byteorder="big") + if deterministic: + assert smsg_timestamp is not None + h = hashlib.sha256(b"smsg") + h.update(privkey_from) + h.update(pubkey_to) + h.update(payload) + h.update(smsg_timestamp.to_bytes(8, byteorder="big")) + r = h.digest() + smsg_iv: bytes = hashlib.sha256(b"smsg_iv" + r).digest()[:16] + else: + r = getSecretInt().to_bytes(32, byteorder="big") + smsg_iv: bytes = secrets.token_bytes(16) + if smsg_timestamp is None: + smsg_timestamp = int(time.time()) R = PublicKey.from_secret(r).format() p = PrivateKey(r).ecdh(pubkey_to) H = hashlib.sha512(p).digest() key_e: bytes = H[:32] key_m: bytes = H[32:] - smsg_iv: bytes = secrets.token_bytes(16) - payload_hash: bytes = sha256(sha256(payload)) signature: bytes = PrivateKey(privkey_from).sign_recoverable( payload_hash, hasher=None diff --git a/doc/notes.md b/doc/notes.md index 3b9c3fa..0d4017e 100644 --- a/doc/notes.md +++ b/doc/notes.md @@ -142,7 +142,7 @@ Observe progress with ## Start a subset of the configured coins using docker - docker compose run --service-ports swapclient basicswap-run -datadir=/coindata -withcoins=monero + docker compose run --rm --service-ports swapclient basicswap-run -datadir=/coindata -withcoins=monero diff --git a/tests/basicswap/common.py b/tests/basicswap/common.py index 977f36e..7e131c2 100644 --- a/tests/basicswap/common.py +++ b/tests/basicswap/common.py @@ -184,7 +184,7 @@ def wait_for_bid( swap_client.log.debug( f"TEST: wait_for_bid {bid_id.hex()}: Bid not found." ) - raise ValueError("wait_for_bid timed out.") + raise ValueError(f"wait_for_bid timed out {bid_id.hex()}.") def wait_for_bid_tx_state( @@ -331,7 +331,7 @@ def wait_for_balance( delay_event.wait(delay_time) i += 1 if i > iterations: - raise ValueError("Expect {} {}".format(balance_key, expect_amount)) + raise ValueError(f"Expect {balance_key} {expect_amount}") def wait_for_unspent( @@ -347,11 +347,11 @@ def wait_for_unspent( delay_event.wait(delay_time) i += 1 if i > iterations: - raise ValueError("wait_for_unspent {}".format(expect_amount)) + raise ValueError(f"wait_for_unspent {expect_amount}") def delay_for(delay_event, delay_for=60): - logging.info("Delaying for {} seconds.".format(delay_for)) + logging.info(f"Delaying for {delay_for} seconds.") delay_event.wait(delay_for) @@ -375,9 +375,7 @@ def waitForRPC(rpc_func, delay_event, rpc_command="getwalletinfo", max_tries=7): except Exception as ex: if i < max_tries: logging.warning( - "Can't connect to RPC: %s. Retrying in %d second/s.", - str(ex), - (i + 1), + f"Can't connect to RPC: {ex}. Retrying in {i + 1} second/s." ) delay_event.wait(i + 1) raise ValueError("waitForRPC failed") diff --git a/tests/basicswap/common_xmr.py b/tests/basicswap/common_xmr.py index 3c847a1..dcdb687 100644 --- a/tests/basicswap/common_xmr.py +++ b/tests/basicswap/common_xmr.py @@ -89,15 +89,19 @@ DOGECOIN_RPC_PORT_BASE = int(os.getenv("DOGECOIN_RPC_PORT_BASE", DOGE_BASE_RPC_P EXTRA_CONFIG_JSON = json.loads(os.getenv("EXTRA_CONFIG_JSON", "{}")) -def waitForBidState(delay_event, port, bid_id, state_str, wait_for=60): +def waitForBidState(delay_event, port, bid_id, wait_for_state, wait_for=60): for i in range(wait_for): if delay_event.is_set(): raise ValueError("Test stopped.") bid = json.loads( urlopen("http://127.0.0.1:12700/json/bids/{}".format(bid_id)).read() ) - if bid["bid_state"] == state_str: - return + if isinstance(wait_for_state, (list, tuple)): + if bid["bid_state"] in wait_for_state: + return + else: + if bid["bid_state"] == wait_for_state: + return delay_event.wait(1) raise ValueError("waitForBidState failed") diff --git a/tests/basicswap/extended/test_simplex.py b/tests/basicswap/extended/test_simplex.py index 83af4a8..c43bddb 100644 --- a/tests/basicswap/extended/test_simplex.py +++ b/tests/basicswap/extended/test_simplex.py @@ -17,8 +17,7 @@ docker run \ Fingerprint: Q8SNxc2SRcKyXlhJM8KFUgPNW4KXPGRm4eSLtT_oh-I= -export SIMPLEX_SERVER_ADDRESS=smp://Q8SNxc2SRcKyXlhJM8KFUgPNW4KXPGRm4eSLtT_oh-I=:password@127.0.0.1:5223,443 - +export SIMPLEX_SERVER_ADDRESS=smp://Q8SNxc2SRcKyXlhJM8KFUgPNW4KXPGRm4eSLtT_oh-I=:password@127.0.0.1:5223 https://github.com/simplex-chat/simplex-chat/issues/4127 json: {"corrId":"3","cmd":"/_send #1 text test123"} @@ -43,9 +42,12 @@ from basicswap.basicswap import ( from basicswap.chainparams import Coins from basicswap.network.simplex import ( - WebSocketThread, + getJoinedSimplexLink, + getNewSimplexLink, + getResponseData, waitForConnected, waitForResponse, + WebSocketThread, ) from basicswap.network.simplex_chat import startSimplexClient from tests.basicswap.common import ( @@ -53,10 +55,14 @@ from tests.basicswap.common import ( wait_for_bid, wait_for_offer, ) +from tests.basicswap.util import read_json_api from tests.basicswap.test_xmr import BaseTest, test_delay_event, RESET_TEST - -SIMPLEX_SERVER_ADDRESS = os.getenv("SIMPLEX_SERVER_ADDRESS") +SIMPLEX_SERVER_FINGERPRINT = os.getenv("SIMPLEX_SERVER_FINGERPRINT", "") +SIMPLEX_SERVER_ADDRESS = os.getenv( + "SIMPLEX_SERVER_ADDRESS", + f"smp://{SIMPLEX_SERVER_FINGERPRINT}:password@127.0.0.1:5223", +) SIMPLEX_CLIENT_PATH = os.path.expanduser(os.getenv("SIMPLEX_CLIENT_PATH")) TEST_DIR = cfg.TEST_DATADIRS @@ -67,6 +73,35 @@ if not len(logger.handlers): logger.addHandler(logging.StreamHandler(sys.stdout)) +def parse_message(msg_data): + if getResponseData(msg_data, "type") not in ( + "chatItemsStatusesUpdated", + "newChatItems", + ): + return None + + for chat_item in getResponseData(msg_data, "chatItems"): + chat_type: str = chat_item["chatInfo"]["type"] + if chat_type == "group": + chat_name = chat_item["chatInfo"]["groupInfo"]["localDisplayName"] + elif chat_type == "direct": + chat_name = chat_item["chatInfo"]["contact"]["localDisplayName"] + else: + return None + + dir_type = chat_item["chatItem"]["meta"]["itemStatus"]["type"] + msg_dir = "recv" if dir_type == "rcvNew" else "sent" + if dir_type in ("sndRcvd", "rcvNew"): + msg_content = chat_item["chatItem"]["content"]["msgContent"]["text"] + return { + "text": msg_content, + "chat_type": chat_type, + "chat_name": chat_name, + "msg_dir": msg_dir, + } + return None + + class TestSimplex(unittest.TestCase): daemons = [] remove_testdir: bool = False @@ -79,10 +114,10 @@ class TestSimplex(unittest.TestCase): if os.path.isdir(TEST_DIR): if RESET_TEST: - logging.info("Removing " + TEST_DIR) + logger.info("Removing " + TEST_DIR) shutil.rmtree(TEST_DIR) else: - logging.info("Restoring instance from " + TEST_DIR) + logger.info("Restoring instance from " + TEST_DIR) if not os.path.exists(TEST_DIR): os.makedirs(TEST_DIR) @@ -126,7 +161,7 @@ class TestSimplex(unittest.TestCase): waitForConnected(ws_thread, test_delay_event) sent_id = ws_thread.send_command("/group bsx") response = waitForResponse(ws_thread, sent_id, test_delay_event) - assert response["resp"]["type"] == "groupCreated" + assert getResponseData(response, "type") == "groupCreated" ws_thread.send_command("/set voice #bsx off") ws_thread.send_command("/set files #bsx off") @@ -134,82 +169,221 @@ class TestSimplex(unittest.TestCase): ws_thread.send_command("/set reactions #bsx off") ws_thread.send_command("/set reports #bsx off") ws_thread.send_command("/set disappear #bsx on week") - sent_id = ws_thread.send_command("/create link #bsx") - connReqContact = None + sent_id = ws_thread.send_command("/create link #bsx") connReqMsgData = waitForResponse(ws_thread, sent_id, test_delay_event) - connReqContact = connReqMsgData["resp"]["connReqContact"] + connReqContact = getNewSimplexLink(connReqMsgData) group_link = "https://simplex.chat" + connReqContact[8:] logger.info(f"group_link: {group_link}") sent_id = ws_thread2.send_command("/c " + group_link) response = waitForResponse(ws_thread2, sent_id, test_delay_event) - assert "groupLinkId" in response["resp"]["connection"] + assert "groupLinkId" in getResponseData(response, "connection") sent_id = ws_thread2.send_command("/groups") response = waitForResponse(ws_thread2, sent_id, test_delay_event) - assert len(response["resp"]["groups"]) == 1 + assert len(getResponseData(response, "groups")) == 1 - ws_thread.send_command("#bsx test msg 1") + sent_id = ws_thread2.send_command("/connect") + response = waitForResponse(ws_thread2, sent_id, test_delay_event) + with open(os.path.join(client2_dir, "chat_inv.txt"), "w") as fp: + fp.write(json.dumps(response, indent=4)) + + connReqInvitation = getJoinedSimplexLink(response) + logger.info(f"direct_link: {connReqInvitation}") + pccConnId_2_sent = getResponseData(response, "connection")["pccConnId"] + print(f"pccConnId_2_sent: {pccConnId_2_sent}") + + sent_id = ws_thread.send_command(f"/connect {connReqInvitation}") + response = waitForResponse(ws_thread, sent_id, test_delay_event) + with open(os.path.join(client1_dir, "chat_inv_accept.txt"), "w") as fp: + fp.write(json.dumps(response, indent=4)) + pccConnId_1_accepted = getResponseData(response, "connection")["pccConnId"] + print(f"pccConnId_1_accepted: {pccConnId_1_accepted}") + + sent_id = ws_thread.send_command("/chats") + response = waitForResponse(ws_thread, sent_id, test_delay_event) + with open(os.path.join(client1_dir, "chats.txt"), "w") as fp: + fp.write(json.dumps(response, indent=4)) + + direct_local_name_1 = None + for chat in getResponseData(response, "chats"): + print(f"chat: {chat}") + if ( + chat["chatInfo"]["contact"]["activeConn"]["connId"] + == pccConnId_1_accepted + ): + direct_local_name_1 = chat["chatInfo"]["contact"][ + "localDisplayName" + ] + break + print(f"direct_local_name_1: {direct_local_name_1}") + + sent_id = ws_thread2.send_command("/chats") + response = waitForResponse(ws_thread2, sent_id, test_delay_event) + with open(os.path.join(client2_dir, "chats.txt"), "w") as fp: + fp.write(json.dumps(response, indent=4)) + + direct_local_name_2 = None + for chat in getResponseData(response, "chats"): + print(f"chat: {chat}") + if ( + chat["chatInfo"]["contact"]["activeConn"]["connId"] + == pccConnId_2_sent + ): + direct_local_name_2 = chat["chatInfo"]["contact"][ + "localDisplayName" + ] + break + print(f"direct_local_name_2: {direct_local_name_2}") + # localDisplayName in chats doesn't match the contactConnected message. + assert direct_local_name_1 == "user_1" + assert direct_local_name_2 == "user_1" + + sent_id = ws_thread.send_command("#bsx test msg 1") + response = waitForResponse(ws_thread, sent_id, test_delay_event) + assert getResponseData(response, "type") == "newChatItems" + sent_id = ws_thread.send_command("@user_1 test msg 2") + response = waitForResponse(ws_thread, sent_id, test_delay_event) + assert getResponseData(response, "type") == "newChatItems" + + msg_counter1: int = 0 + msg_counter2: int = 0 + found = [dict(), dict()] + found_connected = [dict(), dict()] - found_1 = False - found_2 = False for i in range(100): - message = ws_thread.queue_get() - if message is not None: + if test_delay_event.is_set(): + break + for k in range(100): + message = ws_thread.queue_get() + if message is None or test_delay_event.is_set(): + break + msg_counter1 += 1 data = json.loads(message) - # print(f"message 1: {json.dumps(data, indent=4)}") try: - if data["resp"]["type"] in ( - "chatItemsStatusesUpdated", - "newChatItems", - ): - for chat_item in data["resp"]["chatItems"]: - # print(f"chat_item 1: {json.dumps(chat_item, indent=4)}") - if chat_item["chatItem"]["meta"]["itemStatus"][ - "type" - ] in ("sndRcvd", "rcvNew"): - if ( - chat_item["chatItem"]["content"]["msgContent"][ - "text" - ] - == "test msg 1" - ): - found_1 = True + msg_type = getResponseData(data, "type") + except Exception as e: + print(f"msg_type error: {e}") + msg_type = "None" + with open( + os.path.join( + client1_dir, f"recv_{msg_counter1}_{msg_type}.txt" + ), + "w", + ) as fp: + fp.write(json.dumps(data, indent=4)) + if msg_type == "contactConnected": + found_connected[0][msg_counter1] = data + continue + try: + simplex_msg = parse_message(data) + if simplex_msg: + simplex_msg["msg_id"] = msg_counter1 + found[0][msg_counter1] = simplex_msg except Exception as e: print(f"error 1: {e}") - message = ws_thread2.queue_get() - if message is not None: + for k in range(100): + message = ws_thread2.queue_get() + if message is None or test_delay_event.is_set(): + break + msg_counter2 += 1 data = json.loads(message) - # print(f"message 2: {json.dumps(data, indent=4)}") try: - if data["resp"]["type"] in ( - "chatItemsStatusesUpdated", - "newChatItems", - ): - for chat_item in data["resp"]["chatItems"]: - # print(f"chat_item 1: {json.dumps(chat_item, indent=4)}") - if chat_item["chatItem"]["meta"]["itemStatus"][ - "type" - ] in ("sndRcvd", "rcvNew"): - if ( - chat_item["chatItem"]["content"]["msgContent"][ - "text" - ] - == "test msg 1" - ): - found_2 = True + msg_type = getResponseData(data, "type") + except Exception as e: + print(f"msg_type error: {e}") + msg_type = "None" + with open( + os.path.join( + client2_dir, f"recv_{msg_counter2}_{msg_type}.txt" + ), + "w", + ) as fp: + fp.write(json.dumps(data, indent=4)) + if msg_type == "contactConnected": + found_connected[1][msg_counter2] = data + continue + try: + simplex_msg = parse_message(data) + if simplex_msg: + simplex_msg["msg_id"] = msg_counter2 + found[1][msg_counter2] = simplex_msg except Exception as e: print(f"error 2: {e}") - - if found_1 and found_2: + if ( + len(found[0]) >= 2 + and len(found[1]) >= 2 + and len(found_connected[0]) >= 1 + and len(found_connected[1]) >= 1 + ): break test_delay_event.wait(0.5) - assert found_1 is True - assert found_2 is True + assert len(found_connected[0]) == 1 + node1_connect = list(found_connected[0].values())[0] + assert ( + getResponseData(node1_connect, "contact")["activeConn"]["connId"] + == pccConnId_1_accepted + ) + assert ( + getResponseData(node1_connect, "contact")["localDisplayName"] + == "user_2" + ) + + assert len(found_connected[1]) == 1 + node2_connect = list(found_connected[1].values())[0] + assert ( + getResponseData(node2_connect, "contact")["activeConn"]["connId"] + == pccConnId_2_sent + ) + assert ( + getResponseData(node2_connect, "contact")["localDisplayName"] + == "user_2" + ) + + node1_msg1 = [m for m in found[0].values() if m["text"] == "test msg 1"] + assert len(node1_msg1) == 1 + node1_msg1 = node1_msg1[0] + assert node1_msg1["chat_type"] == "group" + assert node1_msg1["chat_name"] == "bsx" + assert node1_msg1["msg_dir"] == "sent" + node1_msg2 = [m for m in found[0].values() if m["text"] == "test msg 2"] + assert len(node1_msg2) == 1 + node1_msg2 = node1_msg2[0] + assert node1_msg2["chat_type"] == "direct" + assert node1_msg2["chat_name"] == "user_1" + assert node1_msg2["msg_dir"] == "sent" + + node2_msg1 = [m for m in found[1].values() if m["text"] == "test msg 1"] + assert len(node2_msg1) == 1 + node2_msg1 = node2_msg1[0] + assert node2_msg1["chat_type"] == "group" + assert node2_msg1["chat_name"] == "bsx" + assert node2_msg1["msg_dir"] == "recv" + node2_msg2 = [m for m in found[1].values() if m["text"] == "test msg 2"] + assert len(node2_msg2) == 1 + node2_msg2 = node2_msg2[0] + assert node2_msg2["chat_type"] == "direct" + assert node2_msg2["chat_name"] == "user_1" + assert node2_msg2["msg_dir"] == "recv" + + sent_id = ws_thread.send_command("/delete @user_1") + response = waitForResponse(ws_thread, sent_id, test_delay_event) + assert getResponseData(response, "type") == "contactDeleted" + + sent_id = ws_thread2.send_command("/delete @user_1") + response = waitForResponse(ws_thread2, sent_id, test_delay_event) + assert getResponseData(response, "type") == "contactDeleted" + + sent_id = ws_thread2.send_command("/chats") + response = waitForResponse(ws_thread2, sent_id, test_delay_event) + with open(os.path.join(client2_dir, "chats_after_delete.txt"), "w") as fp: + fp.write(json.dumps(response, indent=4)) + + assert len(getResponseData(response, "chats")) == 4 finally: for t in threads: @@ -254,7 +428,7 @@ class Test(BaseTest): waitForConnected(ws_thread, test_delay_event) sent_id = ws_thread.send_command("/group bsx") response = waitForResponse(ws_thread, sent_id, test_delay_event) - assert response["resp"]["type"] == "groupCreated" + assert getResponseData(response, "type") == "groupCreated" ws_thread.send_command("/set voice #bsx off") ws_thread.send_command("/set files #bsx off") @@ -266,7 +440,7 @@ class Test(BaseTest): connReqContact = None connReqMsgData = waitForResponse(ws_thread, sent_id, test_delay_event) - connReqContact = connReqMsgData["resp"]["connReqContact"] + connReqContact = getNewSimplexLink(connReqMsgData) cls.group_link = "https://simplex.chat" + connReqContact[8:] logger.info(f"BSX group_link: {cls.group_link}") @@ -277,13 +451,12 @@ class Test(BaseTest): @classmethod def tearDownClass(cls): - logging.info("Finalising Test") + logger.info("Finalising Test") super(Test, cls).tearDownClass() stopDaemons(cls.daemons) @classmethod def addCoinSettings(cls, settings, datadir, node_id): - settings["networks"] = [ { "type": "simplex", @@ -295,17 +468,22 @@ class Test(BaseTest): ] def test_01_swap(self): - logging.info("---------- Test xmr swap") + logger.info("---------- Test adaptor sig swap") swap_clients = self.swap_clients for sc in swap_clients: - sc.dleag_split_size_init = 9000 - sc.dleag_split_size = 11000 + sc._use_direct_message_routes = False assert len(swap_clients[0].active_networks) == 1 assert swap_clients[0].active_networks[0]["type"] == "simplex" + num_direct_messages_received_before = [0] * 3 + for i in range(3): + num_direct_messages_received_before[i] = swap_clients[ + i + ].num_direct_simplex_messages_received + coin_from = Coins.BTC coin_to = self.coin_to @@ -340,3 +518,506 @@ class Test(BaseTest): sent=True, wait_for=320, ) + + for i in range(3): + assert ( + num_direct_messages_received_before[i] + == swap_clients[i].num_direct_simplex_messages_received + ) + + def test_01_swap_reverse(self): + logger.info("---------- Test adaptor sig swap reverse") + + swap_clients = self.swap_clients + + for sc in swap_clients: + sc._use_direct_message_routes = False + + assert len(swap_clients[0].active_networks) == 1 + assert swap_clients[0].active_networks[0]["type"] == "simplex" + + num_direct_messages_received_before = [0] * 3 + for i in range(3): + num_direct_messages_received_before[i] = swap_clients[ + i + ].num_direct_simplex_messages_received + + coin_from = self.coin_to + coin_to = Coins.BTC + + ci_from = swap_clients[1].ci(coin_from) + ci_to = swap_clients[0].ci(coin_to) + + swap_value = ci_from.make_int(random.uniform(0.2, 20.0), r=1) + rate_swap = ci_to.make_int(random.uniform(0.2, 20.0), r=1) + offer_id = swap_clients[1].postOffer( + coin_from, coin_to, swap_value, rate_swap, swap_value, SwapTypes.XMR_SWAP + ) + + wait_for_offer(test_delay_event, swap_clients[0], offer_id) + offer = swap_clients[0].getOffer(offer_id) + bid_id = swap_clients[0].postBid(offer_id, offer.amount_from) + + wait_for_bid(test_delay_event, swap_clients[1], bid_id, BidStates.BID_RECEIVED) + swap_clients[1].acceptBid(bid_id) + + wait_for_bid( + test_delay_event, + swap_clients[1], + bid_id, + BidStates.SWAP_COMPLETED, + wait_for=320, + ) + wait_for_bid( + test_delay_event, + swap_clients[0], + bid_id, + BidStates.SWAP_COMPLETED, + sent=True, + wait_for=320, + ) + + for i in range(3): + assert ( + num_direct_messages_received_before[i] + == swap_clients[i].num_direct_simplex_messages_received + ) + + def test_02_direct(self): + logger.info("---------- Test adaptor sig swap with direct messages") + + swap_clients = self.swap_clients + + for sc in swap_clients: + sc._use_direct_message_routes = True + + assert len(swap_clients[0].active_networks) == 1 + assert swap_clients[0].active_networks[0]["type"] == "simplex" + + num_direct_messages_received_before = [0] * 3 + num_group_messages_received_before = [0] * 3 + for i in range(3): + num_direct_messages_received_before[i] = swap_clients[ + i + ].num_direct_simplex_messages_received + num_group_messages_received_before[i] = swap_clients[ + i + ].num_group_simplex_messages_received + + coin_from = Coins.BTC + coin_to = self.coin_to + + ci_from = swap_clients[0].ci(coin_from) + ci_to = swap_clients[1].ci(coin_to) + + swap_value = ci_from.make_int(random.uniform(0.2, 20.0), r=1) + rate_swap = ci_to.make_int(random.uniform(0.2, 20.0), r=1) + offer_id = swap_clients[0].postOffer( + coin_from, coin_to, swap_value, rate_swap, swap_value, SwapTypes.XMR_SWAP + ) + + wait_for_offer(test_delay_event, swap_clients[1], offer_id) + offer = swap_clients[1].getOffer(offer_id) + bid_id = swap_clients[1].postBid(offer_id, offer.amount_from) + + wait_for_bid( + test_delay_event, + swap_clients[0], + bid_id, + BidStates.BID_RECEIVED, + wait_for=60, + ) + swap_clients[0].acceptBid(bid_id) + + wait_for_bid( + test_delay_event, + swap_clients[0], + bid_id, + BidStates.SWAP_COMPLETED, + wait_for=320, + ) + wait_for_bid( + test_delay_event, + swap_clients[1], + bid_id, + BidStates.SWAP_COMPLETED, + sent=True, + wait_for=320, + ) + + for i in range(3): + swap_clients[ + i + ].num_group_simplex_messages_received == num_group_messages_received_before[ + i + ] + 2 + swap_clients[ + 2 + ].num_direct_simplex_messages_received == num_direct_messages_received_before[2] + + def test_02_direct_reverse(self): + logger.info( + "---------- Test test_02_direct_reverse adaptor sig swap with direct messages" + ) + + swap_clients = self.swap_clients + + for sc in swap_clients: + sc._use_direct_message_routes = True + + assert len(swap_clients[0].active_networks) == 1 + assert swap_clients[0].active_networks[0]["type"] == "simplex" + + num_direct_messages_received_before = [0] * 3 + num_group_messages_received_before = [0] * 3 + for i in range(3): + num_direct_messages_received_before[i] = swap_clients[ + i + ].num_direct_simplex_messages_received + num_group_messages_received_before[i] = swap_clients[ + i + ].num_group_simplex_messages_received + + coin_from = self.coin_to + coin_to = Coins.BTC + + ci_from = swap_clients[1].ci(coin_from) + ci_to = swap_clients[0].ci(coin_to) + + swap_value = ci_from.make_int(random.uniform(0.2, 20.0), r=1) + rate_swap = ci_to.make_int(random.uniform(0.2, 20.0), r=1) + offer_id = swap_clients[1].postOffer( + coin_from, coin_to, swap_value, rate_swap, swap_value, SwapTypes.XMR_SWAP + ) + + wait_for_offer(test_delay_event, swap_clients[0], offer_id) + offer = swap_clients[0].getOffer(offer_id) + bid_id = swap_clients[0].postBid(offer_id, offer.amount_from) + + wait_for_bid( + test_delay_event, + swap_clients[1], + bid_id, + BidStates.BID_RECEIVED, + wait_for=60, + ) + swap_clients[1].acceptBid(bid_id) + + wait_for_bid( + test_delay_event, + swap_clients[1], + bid_id, + BidStates.SWAP_COMPLETED, + wait_for=320, + ) + wait_for_bid( + test_delay_event, + swap_clients[0], + bid_id, + BidStates.SWAP_COMPLETED, + sent=True, + wait_for=320, + ) + + for i in range(3): + swap_clients[ + i + ].num_group_simplex_messages_received == num_group_messages_received_before[ + i + ] + 2 + swap_clients[ + 2 + ].num_direct_simplex_messages_received == num_direct_messages_received_before[2] + + def test_03_hltc(self): + logger.info("---------- Test secret hash swap") + + swap_clients = self.swap_clients + + for sc in swap_clients: + sc._use_direct_message_routes = False + + assert len(swap_clients[0].active_networks) == 1 + assert swap_clients[0].active_networks[0]["type"] == "simplex" + + num_direct_messages_received_before = [0] * 3 + num_group_messages_received_before = [0] * 3 + for i in range(3): + num_direct_messages_received_before[i] = swap_clients[ + i + ].num_direct_simplex_messages_received + num_group_messages_received_before[i] = swap_clients[ + i + ].num_group_simplex_messages_received + + coin_from = Coins.PART + coin_to = Coins.BTC + + self.prepare_balance(coin_to, 200.0, 1801, 1800) + + ci_from = swap_clients[0].ci(coin_from) + ci_to = swap_clients[1].ci(coin_to) + + swap_value = ci_from.make_int(random.uniform(0.2, 20.0), r=1) + rate_swap = ci_to.make_int(random.uniform(0.2, 20.0), r=1) + offer_id = swap_clients[0].postOffer( + coin_from, + coin_to, + swap_value, + rate_swap, + swap_value, + SwapTypes.SELLER_FIRST, + ) + + wait_for_offer(test_delay_event, swap_clients[1], offer_id) + offer = swap_clients[1].getOffer(offer_id) + bid_id = swap_clients[1].postBid(offer_id, offer.amount_from) + + wait_for_bid( + test_delay_event, + swap_clients[0], + bid_id, + BidStates.BID_RECEIVED, + wait_for=90, + ) + swap_clients[0].acceptBid(bid_id) + + wait_for_bid( + test_delay_event, + swap_clients[0], + bid_id, + BidStates.SWAP_COMPLETED, + wait_for=320, + ) + wait_for_bid( + test_delay_event, + swap_clients[1], + bid_id, + BidStates.SWAP_COMPLETED, + sent=True, + wait_for=320, + ) + + for i in range(3): + assert ( + num_direct_messages_received_before[i] + == swap_clients[i].num_direct_simplex_messages_received + ) + + def test_03_direct_hltc(self): + logger.info("---------- Test secret hash swap with direct messages") + + for i in range(3): + message_routes = read_json_api( + 1800 + i, "messageroutes", {"action": "clear"} + ) + assert len(message_routes) == 0 + + swap_clients = self.swap_clients + + for sc in swap_clients: + sc._use_direct_message_routes = True + + assert len(swap_clients[0].active_networks) == 1 + assert swap_clients[0].active_networks[0]["type"] == "simplex" + + num_direct_messages_received_before = [0] * 3 + num_group_messages_received_before = [0] * 3 + for i in range(3): + num_direct_messages_received_before[i] = swap_clients[ + i + ].num_direct_simplex_messages_received + num_group_messages_received_before[i] = swap_clients[ + i + ].num_group_simplex_messages_received + + coin_from = Coins.PART + coin_to = Coins.BTC + + self.prepare_balance(coin_to, 200.0, 1801, 1800) + + ci_from = swap_clients[0].ci(coin_from) + ci_to = swap_clients[1].ci(coin_to) + + swap_value = ci_from.make_int(random.uniform(0.2, 20.0), r=1) + rate_swap = ci_to.make_int(random.uniform(0.2, 20.0), r=1) + offer_id = swap_clients[0].postOffer( + coin_from, + coin_to, + swap_value, + rate_swap, + swap_value, + SwapTypes.SELLER_FIRST, + ) + + wait_for_offer(test_delay_event, swap_clients[1], offer_id) + offer = swap_clients[1].getOffer(offer_id) + bid_id = swap_clients[1].postBid(offer_id, offer.amount_from) + + wait_for_bid( + test_delay_event, + swap_clients[0], + bid_id, + BidStates.BID_RECEIVED, + wait_for=90, + ) + swap_clients[0].acceptBid(bid_id) + + wait_for_bid( + test_delay_event, + swap_clients[0], + bid_id, + BidStates.SWAP_COMPLETED, + wait_for=320, + ) + wait_for_bid( + test_delay_event, + swap_clients[1], + bid_id, + BidStates.SWAP_COMPLETED, + sent=True, + wait_for=320, + ) + + message_routes = read_json_api(1800, "messageroutes") + assert len(message_routes) == 1 + + for i in range(3): + swap_clients[ + i + ].num_group_simplex_messages_received == num_group_messages_received_before[ + i + ] + 2 + swap_clients[ + 2 + ].num_direct_simplex_messages_received == num_direct_messages_received_before[2] + + def test_04_multiple(self): + logger.info("---------- Test multiple swaps with direct messages") + + for i in range(3): + message_routes = read_json_api( + 1800 + i, "messageroutes", {"action": "clear"} + ) + assert len(message_routes) == 0 + + swap_clients = self.swap_clients + + for sc in swap_clients: + sc._use_direct_message_routes = True + + assert len(swap_clients[0].active_networks) == 1 + assert swap_clients[0].active_networks[0]["type"] == "simplex" + + num_direct_messages_received_before = [0] * 3 + num_group_messages_received_before = [0] * 3 + for i in range(3): + num_direct_messages_received_before[i] = swap_clients[ + i + ].num_direct_simplex_messages_received + num_group_messages_received_before[i] = swap_clients[ + i + ].num_group_simplex_messages_received + + coin_from = Coins.BTC + coin_to = self.coin_to + + ci_from = swap_clients[0].ci(coin_from) + ci_to = swap_clients[1].ci(coin_to) + + swap_value = ci_from.make_int(random.uniform(0.2, 20.0), r=1) + rate_swap = ci_to.make_int(random.uniform(0.2, 20.0), r=1) + offer_id = swap_clients[0].postOffer( + coin_from, coin_to, swap_value, rate_swap, swap_value, SwapTypes.XMR_SWAP + ) + + swap_clients[1].active_networks[0]["ws_thread"].ignore_events = True + + wait_for_offer(test_delay_event, swap_clients[1], offer_id) + offer = swap_clients[1].getOffer(offer_id) + + addr1_bids = swap_clients[1].getReceiveAddressForCoin(Coins.PART) + + bid_ids = [] + for i in range(2): + bid_ids.append( + swap_clients[1].postBid( + offer_id, offer.amount_from, addr_send_from=addr1_bids + ) + ) + + swap_clients[1].active_networks[0]["ws_thread"].disable_debug_mode() + + bid_ids.append(swap_clients[1].postBid(offer_id, offer.amount_from)) + + for i in range(len(bid_ids)): + wait_for_bid( + test_delay_event, + swap_clients[0], + bid_ids[i], + BidStates.BID_RECEIVED, + wait_for=60, + ) + swap_clients[0].acceptBid(bid_ids[i]) + + logger.info("Message routes with active bids shouldn't expire") + swap_clients[0].mock_time_offset = ( + swap_clients[0]._expire_message_routes_after + 1 + ) + swap_clients[0].expireMessageRoutes() + swap_clients[0].mock_time_offset = 0 + message_routes_0 = read_json_api(1800, "messageroutes") + assert len(message_routes_0) == 2 + + for i in range(len(bid_ids)): + wait_for_bid( + test_delay_event, + swap_clients[0], + bid_ids[i], + BidStates.SWAP_COMPLETED, + wait_for=320, + ) + wait_for_bid( + test_delay_event, + swap_clients[1], + bid_ids[i], + BidStates.SWAP_COMPLETED, + sent=True, + wait_for=320, + ) + + for i in range(3): + swap_clients[ + i + ].num_group_simplex_messages_received == num_group_messages_received_before[ + i + ] + 2 + swap_clients[ + 2 + ].num_direct_simplex_messages_received == num_direct_messages_received_before[2] + + message_routes_0 = read_json_api(1800, "messageroutes") + assert len(message_routes_0) == 2 + message_routes_1 = read_json_api(1801, "messageroutes") + assert len(message_routes_1) == 2 + + logger.info("Test closing routes") + read_json_api(1800, "messageroutes", {"action": "clear"}) + + def waitForNumMessageRoutes( + port: int = 1800, num_routes: int = 0, num_tries: int = 40 + ): + logger.info( + f"Waiting for {num_routes} message route{'s' if num_routes != 1 else ''}, port: {port}." + ) + for i in range(num_tries): + test_delay_event.wait(1) + if test_delay_event.is_set(): + raise ValueError("Test stopped.") + message_routes = read_json_api(port, "messageroutes") + if len(message_routes) == num_routes: + return True + raise ValueError("waitForNumMessageRoutes timed out.") + + waitForNumMessageRoutes(1800, 0) + waitForNumMessageRoutes(1801, 0) diff --git a/tests/basicswap/extended/test_xmr_persistent.py b/tests/basicswap/extended/test_xmr_persistent.py index d96ba9d..4a81fb6 100644 --- a/tests/basicswap/extended/test_xmr_persistent.py +++ b/tests/basicswap/extended/test_xmr_persistent.py @@ -35,6 +35,9 @@ import sys import threading import time import unittest + +import basicswap.config as cfg + from unittest.mock import patch from basicswap.rpc_xmr import ( @@ -87,6 +90,19 @@ TEST_COINS_LIST = os.getenv("TEST_COINS_LIST", "bitcoin,monero") NUM_NODES = int(os.getenv("NUM_NODES", 3)) EXTRA_CONFIG_JSON = json.loads(os.getenv("EXTRA_CONFIG_JSON", "{}")) +SIMPLEX_SERVER_FINGERPRINT = os.getenv("SIMPLEX_SERVER_FINGERPRINT", "") +SIMPLEX_SERVER_PASSWORD = os.getenv("SIMPLEX_SERVER_PASSWORD", "password") +SIMPLEX_SERVER_HOST = os.getenv("SIMPLEX_SERVER_HOST", "127.0.0.1") +SIMPLEX_SERVER_PORT = os.getenv("SIMPLEX_SERVER_PORT", "5223") +SIMPLEX_SERVER_ADDRESS = os.getenv( + "SIMPLEX_SERVER_ADDRESS", + f"smp://{SIMPLEX_SERVER_FINGERPRINT}:{SIMPLEX_SERVER_PASSWORD}@{SIMPLEX_SERVER_HOST}:{SIMPLEX_SERVER_PORT}", +) +SIMPLEX_WS_PORT = int(os.getenv("SIMPLEX_WS_PORT", "5225")) +SIMPLEX_GROUP_LINK = os.getenv("SIMPLEX_GROUP_LINK", "") +SIMPLEX_CLIENT_PATH = os.path.expanduser(os.getenv("SIMPLEX_CLIENT_PATH", "")) +SIMPLEX_SERVER_SOCKS_PROXY = os.getenv("SIMPLEX_SERVER_SOCKS_PROXY", "") + logger = logging.getLogger() logger.level = logging.DEBUG if not len(logger.handlers): @@ -318,7 +334,7 @@ def start_processes(self): self.btc_addr = callbtcrpc(0, "getnewaddress", ["mining_addr", "bech32"]) num_blocks: int = 500 # Mine enough to activate segwit if callbtcrpc(0, "getblockcount") < num_blocks: - logging.info("Mining %d Bitcoin blocks to %s", num_blocks, self.btc_addr) + logging.info(f"Mining {num_blocks} Bitcoin blocks to {self.btc_addr}") callbtcrpc(0, "generatetoaddress", [num_blocks, self.btc_addr]) logging.info("BTC blocks: %d", callbtcrpc(0, "getblockcount")) @@ -476,6 +492,28 @@ def start_processes(self): assert particl_blocks >= num_blocks +def modifyConfig(test_path, i): + config_path = os.path.join(test_path, f"client{i}", cfg.CONFIG_FILENAME) + with open(config_path) as fp: + settings = json.load(fp) + + if SIMPLEX_CLIENT_PATH != "": + simplex_options = { + "type": "simplex", + "server_address": SIMPLEX_SERVER_ADDRESS, + "client_path": SIMPLEX_CLIENT_PATH, + "ws_port": SIMPLEX_WS_PORT + i, + "group_link": SIMPLEX_GROUP_LINK, + } + if SIMPLEX_SERVER_SOCKS_PROXY != "": + simplex_options["socks_proxy_override"] = SIMPLEX_SERVER_SOCKS_PROXY + + settings["networks"] = [simplex_options] + + with open(config_path, "w") as fp: + json.dump(settings, fp, indent=4) + + class BaseTestWithPrepare(unittest.TestCase): __test__ = False @@ -522,6 +560,9 @@ class BaseTestWithPrepare(unittest.TestCase): PORT_OFS, ) + for i in range(NUM_NODES): + modifyConfig(test_path, i) + signal.signal( signal.SIGINT, lambda signal, frame: signal_handler(cls, signal, frame) ) diff --git a/tests/basicswap/test_xmr_bids_offline.py b/tests/basicswap/test_xmr_bids_offline.py index a1323a5..2f48732 100644 --- a/tests/basicswap/test_xmr_bids_offline.py +++ b/tests/basicswap/test_xmr_bids_offline.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- # Copyright (c) 2021-2022 tecnovert -# Copyright (c) 2024 The Basicswap developers +# Copyright (c) 2024-2025 The Basicswap developers # Distributed under the MIT software license, see the accompanying # file LICENSE or http://www.opensource.org/licenses/mit-license.php. @@ -17,12 +17,9 @@ python tests/basicswap/test_xmr_bids_offline.py """ import sys -import json import logging import unittest import multiprocessing -from urllib import parse -from urllib.request import urlopen from tests.basicswap.util import ( read_json_api, @@ -64,21 +61,11 @@ class Test(XmrTestBase): "lockhrs": 24, "automation_strat_id": 1, } - rv = json.loads( - urlopen( - "http://127.0.0.1:12700/json/offers/new", - data=parse.urlencode(offer_data).encode(), - ).read() - ) + rv = read_json_api(12700, "offers/new", offer_data) offer0_id = rv["offer_id"] offer_data["amt_from"] = "2" - rv = json.loads( - urlopen( - "http://127.0.0.1:12700/json/offers/new", - data=parse.urlencode(offer_data).encode(), - ).read() - ) + rv = read_json_api(12700, "offers/new", offer_data) offer1_id = rv["offer_id"] summary = read_json_api(12700) @@ -92,52 +79,26 @@ class Test(XmrTestBase): c0.terminate() c0.join() - offers = json.loads( - urlopen("http://127.0.0.1:12701/json/offers/{}".format(offer0_id)).read() - ) + offers = read_json_api(12701, f"offers/{offer0_id}") assert len(offers) == 1 offer0 = offers[0] post_data = {"coin_from": "PART"} - test_post_offers = json.loads( - urlopen( - "http://127.0.0.1:12701/json/offers", - data=parse.urlencode(post_data).encode(), - ).read() - ) + test_post_offers = read_json_api(12701, "offers", post_data) assert len(test_post_offers) == 2 post_data["coin_from"] = "2" - test_post_offers = json.loads( - urlopen( - "http://127.0.0.1:12701/json/offers", - data=parse.urlencode(post_data).encode(), - ).read() - ) + test_post_offers = read_json_api(12701, "offers", post_data) assert len(test_post_offers) == 0 bid_data = {"offer_id": offer0_id, "amount_from": offer0["amount_from"]} + bid0_id = read_json_api(12701, "bids/new", bid_data)["bid_id"] - bid0_id = json.loads( - urlopen( - "http://127.0.0.1:12701/json/bids/new", - data=parse.urlencode(bid_data).encode(), - ).read() - )["bid_id"] - - offers = json.loads( - urlopen("http://127.0.0.1:12701/json/offers/{}".format(offer1_id)).read() - ) + offers = read_json_api(12701, f"offers/{offer1_id}") assert len(offers) == 1 offer1 = offers[0] bid_data = {"offer_id": offer1_id, "amount_from": offer1["amount_from"]} - - bid1_id = json.loads( - urlopen( - "http://127.0.0.1:12701/json/bids/new", - data=parse.urlencode(bid_data).encode(), - ).read() - )["bid_id"] + bid1_id = read_json_api(12701, "bids/new", bid_data)["bid_id"] logger.info("Delaying for 5 seconds.") self.delay_event.wait(5) @@ -149,26 +110,17 @@ class Test(XmrTestBase): waitForServer(self.delay_event, 12700) waitForNumBids(self.delay_event, 12700, 2) - waitForBidState(self.delay_event, 12700, bid0_id, "Received") - waitForBidState(self.delay_event, 12700, bid1_id, "Received") + waitForBidState(self.delay_event, 12700, bid0_id, ("Received", "Delaying")) + waitForBidState(self.delay_event, 12700, bid1_id, ("Received", "Delaying")) # Manually accept on top of auto-accept for extra chaos - data = parse.urlencode({"accept": True}).encode() try: - rv = json.loads( - urlopen( - "http://127.0.0.1:12700/json/bids/{}".format(bid0_id), data=data - ).read() - ) + rv = read_json_api(12700, f"bids/{bid0_id}", {"accept": True}) assert rv["bid_state"] == "Accepted" except Exception as e: print("Accept bid failed", str(e), rv) try: - rv = json.loads( - urlopen( - "http://127.0.0.1:12700/json/bids/{}".format(bid1_id), data=data - ).read() - ) + rv = read_json_api(12700, f"bids/{bid1_id}", {"accept": True}) assert rv["bid_state"] == "Accepted" except Exception as e: print("Accept bid failed", str(e), rv) @@ -179,8 +131,8 @@ class Test(XmrTestBase): raise ValueError("Test stopped.") self.delay_event.wait(4) - rv0 = read_json_api(12700, "bids/{}".format(bid0_id)) - rv1 = read_json_api(12700, "bids/{}".format(bid1_id)) + rv0 = read_json_api(12700, f"bids/{bid0_id}") + rv1 = read_json_api(12700, f"bids/{bid1_id}") if rv0["bid_state"] == "Completed" and rv1["bid_state"] == "Completed": break assert rv0["bid_state"] == "Completed"