From fa0760b17204646c847a6d8efe54044909db330d Mon Sep 17 00:00:00 2001 From: tecnovert Date: Fri, 11 Apr 2025 01:00:19 +0200 Subject: [PATCH] Add simplex chat test. --- basicswap/basicswap.py | 174 ++- basicswap/contrib/test_framework/messages.py | 784 ++++++++++---- basicswap/contrib/test_framework/p2p.py | 1006 ++++++++++++++++++ basicswap/contrib/test_framework/util.py | 392 +++---- basicswap/db.py | 4 +- basicswap/db_upgrades.py | 5 + basicswap/network/__init__.py | 0 basicswap/{ => network}/network.py | 0 basicswap/network/simplex.py | 350 ++++++ basicswap/network/simplex_chat.py | 107 ++ basicswap/network/util.py | 20 + basicswap/util/smsg.py | 229 ++++ requirements.in | 1 + requirements.txt | 6 +- tests/basicswap/extended/test_doge.py | 7 +- tests/basicswap/extended/test_simplex.py | 342 ++++++ tests/basicswap/extended/test_smsg.py | 147 +++ tests/basicswap/test_bch_xmr.py | 19 +- tests/basicswap/test_btc_xmr.py | 23 +- tests/basicswap/test_run.py | 7 +- tests/basicswap/test_xmr.py | 25 +- 21 files changed, 3115 insertions(+), 533 deletions(-) create mode 100755 basicswap/contrib/test_framework/p2p.py create mode 100644 basicswap/network/__init__.py rename basicswap/{ => network}/network.py (100%) create mode 100644 basicswap/network/simplex.py create mode 100644 basicswap/network/simplex_chat.py create mode 100644 basicswap/network/util.py create mode 100644 basicswap/util/smsg.py create mode 100644 tests/basicswap/extended/test_simplex.py create mode 100644 tests/basicswap/extended/test_smsg.py diff --git a/basicswap/basicswap.py b/basicswap/basicswap.py index 7935f50..2fe97b7 100644 --- a/basicswap/basicswap.py +++ b/basicswap/basicswap.py @@ -122,8 +122,16 @@ from .explorers import ( ExplorerBitAps, ExplorerChainz, ) +from .network.simplex import ( + initialiseSimplexNetwork, + sendSimplexMsg, + readSimplexMsgs, +) +from .network.util import ( + getMsgPubkey, +) import basicswap.config as cfg -import basicswap.network as bsn +import basicswap.network.network as bsn import basicswap.protocols.atomic_swap_1 as atomic_swap_1 import basicswap.protocols.xmr_swap_1 as xmr_swap_1 from .basicswap_util import ( @@ -428,6 +436,9 @@ class BasicSwap(BaseApp): self.swaps_in_progress = dict() + self.dleag_split_size_init = 16000 + self.dleag_split_size = 17000 + self.SMSG_SECONDS_IN_HOUR = ( 60 * 60 ) # Note: Set smsgsregtestadjust=0 for regtest @@ -526,6 +537,8 @@ class BasicSwap(BaseApp): self._network = None for t in self.threads: + if hasattr(t, "stop") and callable(t.stop): + t.stop() t.join() if sys.version_info[1] >= 9: @@ -1078,6 +1091,17 @@ class BasicSwap(BaseApp): f"network_key {self.network_key}\nnetwork_pubkey {self.network_pubkey}\nnetwork_addr {self.network_addr}" ) + self.active_networks = [] + network_config_list = self.settings.get("networks", []) + if len(network_config_list) < 1: + network_config_list = [{"type": "smsg", "enabled": True}] + + for network in network_config_list: + if network["type"] == "smsg": + self.active_networks.append({"type": "smsg"}) + elif network["type"] == "simplex": + initialiseSimplexNetwork(self, network) + ro = self.callrpc("smsglocalkeys") found = False for k in ro["smsg_keys"]: @@ -1655,6 +1679,33 @@ class BasicSwap(BaseApp): bid_valid = (bid.expire_at - now) + 10 * 60 # Add 10 minute buffer return max(smsg_min_valid, min(smsg_max_valid, bid_valid)) + def sendMessage( + self, addr_from: str, addr_to: str, payload_hex: bytes, msg_valid: int, cursor + ) -> bytes: + message_id: bytes = None + # First network in list will set message_id + for network in self.active_networks: + net_message_id = None + if network["type"] == "smsg": + net_message_id = self.sendSmsg( + addr_from, addr_to, payload_hex, msg_valid + ) + elif network["type"] == "simplex": + net_message_id = sendSimplexMsg( + self, + network, + addr_from, + addr_to, + bytes.fromhex(payload_hex), + msg_valid, + cursor, + ) + else: + raise ValueError("Unknown network: {}".format(network["type"])) + if not message_id: + message_id = net_message_id + return message_id + def sendSmsg( self, addr_from: str, addr_to: str, payload_hex: bytes, msg_valid: int ) -> bytes: @@ -2200,7 +2251,9 @@ class BasicSwap(BaseApp): offer_bytes = msg_buf.to_bytes() payload_hex = str.format("{:02x}", MessageTypes.OFFER) + offer_bytes.hex() msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, valid_for_seconds) - offer_id = self.sendSmsg(offer_addr, offer_addr_to, payload_hex, msg_valid) + offer_id = self.sendMessage( + offer_addr, offer_addr_to, payload_hex, msg_valid, cursor + ) security_token = extra_options.get("security_token", None) if security_token is not None and len(security_token) != 20: @@ -2305,8 +2358,8 @@ class BasicSwap(BaseApp): ) msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, offer.time_valid) - msg_id = self.sendSmsg( - offer.addr_from, self.network_addr, payload_hex, msg_valid + msg_id = self.sendMessage( + offer.addr_from, self.network_addr, payload_hex, msg_valid, cursor ) self.log.debug( f"Revoked offer {self.log.id(offer_id)} in msg {self.log.id(msg_id)}" @@ -3152,7 +3205,9 @@ class BasicSwap(BaseApp): bid_addr = self.prepareSMSGAddress(addr_send_from, AddressTypes.BID, cursor) msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, valid_for_seconds) - bid_id = self.sendSmsg(bid_addr, offer.addr_from, payload_hex, msg_valid) + bid_id = self.sendMessage( + bid_addr, offer.addr_from, payload_hex, msg_valid, cursor + ) bid = Bid( protocol_version=msg_buf.protocol_version, @@ -3488,8 +3543,8 @@ class BasicSwap(BaseApp): ) msg_valid: int = self.getAcceptBidMsgValidTime(bid) - accept_msg_id = self.sendSmsg( - offer.addr_from, bid.bid_addr, payload_hex, msg_valid + accept_msg_id = self.sendMessage( + offer.addr_from, bid.bid_addr, payload_hex, msg_valid, cursor ) self.addMessageLink( @@ -3519,20 +3574,29 @@ class BasicSwap(BaseApp): dleag: bytes, msg_valid: int, bid_msg_ids, + cursor, ) -> None: - msg_buf2 = XmrSplitMessage( - msg_id=bid_id, msg_type=msg_type, sequence=1, dleag=dleag[16000:32000] - ) - msg_bytes = msg_buf2.to_bytes() - payload_hex = str.format("{:02x}", MessageTypes.XMR_BID_SPLIT) + msg_bytes.hex() - bid_msg_ids[1] = self.sendSmsg(addr_from, addr_to, payload_hex, msg_valid) - msg_buf3 = XmrSplitMessage( - msg_id=bid_id, msg_type=msg_type, sequence=2, dleag=dleag[32000:] - ) - msg_bytes = msg_buf3.to_bytes() - payload_hex = str.format("{:02x}", MessageTypes.XMR_BID_SPLIT) + msg_bytes.hex() - bid_msg_ids[2] = self.sendSmsg(addr_from, addr_to, payload_hex, msg_valid) + sent_bytes = self.dleag_split_size_init + + num_sent = 1 + while sent_bytes < len(dleag): + size_to_send: int = min(self.dleag_split_size, len(dleag) - sent_bytes) + msg_buf = XmrSplitMessage( + msg_id=bid_id, + msg_type=msg_type, + sequence=num_sent, + dleag=dleag[sent_bytes : sent_bytes + size_to_send], + ) + msg_bytes = msg_buf.to_bytes() + payload_hex = ( + str.format("{:02x}", MessageTypes.XMR_BID_SPLIT) + msg_bytes.hex() + ) + bid_msg_ids[num_sent] = self.sendMessage( + addr_from, addr_to, payload_hex, msg_valid, cursor + ) + num_sent += 1 + sent_bytes += size_to_send def postXmrBid( self, offer_id: bytes, amount: int, addr_send_from: str = None, extra_options={} @@ -3608,8 +3672,8 @@ class BasicSwap(BaseApp): ) msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, valid_for_seconds) - xmr_swap.bid_id = self.sendSmsg( - bid_addr, offer.addr_from, payload_hex, msg_valid + xmr_swap.bid_id = self.sendMessage( + bid_addr, offer.addr_from, payload_hex, msg_valid, cursor ) bid = Bid( @@ -3691,7 +3755,7 @@ class BasicSwap(BaseApp): if ci_to.curve_type() == Curves.ed25519: xmr_swap.kbsf_dleag = ci_to.proveDLEAG(kbsf) xmr_swap.pkasf = xmr_swap.kbsf_dleag[0:33] - msg_buf.kbsf_dleag = xmr_swap.kbsf_dleag[:16000] + msg_buf.kbsf_dleag = xmr_swap.kbsf_dleag[: self.dleag_split_size_init] elif ci_to.curve_type() == Curves.secp256k1: for i in range(10): xmr_swap.kbsf_dleag = ci_to.signRecoverable( @@ -3721,8 +3785,8 @@ class BasicSwap(BaseApp): bid_addr = self.prepareSMSGAddress(addr_send_from, AddressTypes.BID, cursor) msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, valid_for_seconds) - xmr_swap.bid_id = self.sendSmsg( - bid_addr, offer.addr_from, payload_hex, msg_valid + xmr_swap.bid_id = self.sendMessage( + bid_addr, offer.addr_from, payload_hex, msg_valid, cursor ) bid_msg_ids = {} @@ -3735,6 +3799,7 @@ class BasicSwap(BaseApp): xmr_swap.kbsf_dleag, msg_valid, bid_msg_ids, + cursor, ) bid = Bid( @@ -4013,7 +4078,7 @@ class BasicSwap(BaseApp): if ci_to.curve_type() == Curves.ed25519: xmr_swap.kbsl_dleag = ci_to.proveDLEAG(kbsl) - msg_buf.kbsl_dleag = xmr_swap.kbsl_dleag[:16000] + msg_buf.kbsl_dleag = xmr_swap.kbsl_dleag[: self.dleag_split_size_init] elif ci_to.curve_type() == Curves.secp256k1: for i in range(10): xmr_swap.kbsl_dleag = ci_to.signRecoverable( @@ -4048,7 +4113,9 @@ class BasicSwap(BaseApp): msg_valid: int = self.getAcceptBidMsgValidTime(bid) bid_msg_ids = {} - bid_msg_ids[0] = self.sendSmsg(addr_from, addr_to, payload_hex, msg_valid) + bid_msg_ids[0] = self.sendMessage( + addr_from, addr_to, payload_hex, msg_valid, use_cursor + ) if ci_to.curve_type() == Curves.ed25519: self.sendXmrSplitMessages( @@ -4059,6 +4126,7 @@ class BasicSwap(BaseApp): xmr_swap.kbsl_dleag, msg_valid, bid_msg_ids, + use_cursor, ) bid.setState(BidStates.BID_ACCEPTED) # ADS @@ -4180,8 +4248,8 @@ class BasicSwap(BaseApp): msg_buf.kbvf = kbvf msg_buf.kbsf_dleag = ( xmr_swap.kbsf_dleag - if len(xmr_swap.kbsf_dleag) < 16000 - else xmr_swap.kbsf_dleag[:16000] + if len(xmr_swap.kbsf_dleag) < self.dleag_split_size_init + else xmr_swap.kbsf_dleag[: self.dleag_split_size_init] ) bid_bytes = msg_buf.to_bytes() @@ -4193,7 +4261,9 @@ class BasicSwap(BaseApp): addr_to: str = bid.bid_addr msg_valid: int = self.getAcceptBidMsgValidTime(bid) bid_msg_ids = {} - bid_msg_ids[0] = self.sendSmsg(addr_from, addr_to, payload_hex, msg_valid) + bid_msg_ids[0] = self.sendMessage( + addr_from, addr_to, payload_hex, msg_valid, use_cursor + ) if ci_to.curve_type() == Curves.ed25519: self.sendXmrSplitMessages( @@ -4204,6 +4274,7 @@ class BasicSwap(BaseApp): xmr_swap.kbsf_dleag, msg_valid, bid_msg_ids, + use_cursor, ) bid.setState(BidStates.BID_REQUEST_ACCEPTED) @@ -6808,6 +6879,11 @@ class BasicSwap(BaseApp): now: int = self.getTime() ttl_xmr_split_messages = 60 * 60 bid_cursor = None + + dleag_proof_len: int = 48893 # coincurve.dleag.dleag_proof_len() + expect_segments: int = -( + (dleag_proof_len - self.dleag_split_size_init) // -self.dleag_split_size + ) # ceiling division try: cursor = self.openDB() bid_cursor = self.getNewDBCursor() @@ -6820,7 +6896,7 @@ class BasicSwap(BaseApp): {"bid_id": bid.bid_id, "msg_type": int(XmrSplitMsgTypes.BID)}, ).fetchone() num_segments = q[0] - if num_segments > 1: + if num_segments >= expect_segments: try: self.receiveXmrBid(bid, cursor) except Exception as ex: @@ -6866,7 +6942,7 @@ class BasicSwap(BaseApp): }, ).fetchone() num_segments = q[0] - if num_segments > 1: + if num_segments >= expect_segments: try: self.receiveXmrBidAccept(bid, cursor) except Exception as ex: @@ -7029,6 +7105,7 @@ class BasicSwap(BaseApp): if self.isOfferRevoked(offer_id, msg["from"]): raise ValueError("Offer has been revoked {}.".format(offer_id.hex())) + pk_from: bytes = getMsgPubkey(self, msg) try: cursor = self.openDB() # Offers must be received on the public network_addr or manually created addresses @@ -7069,6 +7146,7 @@ class BasicSwap(BaseApp): rate_negotiable=offer_data.rate_negotiable, addr_to=msg["to"], addr_from=msg["from"], + pk_from=pk_from, created_at=msg["sent"], expire_at=msg["sent"] + offer_data.time_valid, was_sent=False, @@ -7417,6 +7495,7 @@ class BasicSwap(BaseApp): bid = self.getBid(bid_id) if bid is None: + pk_from: bytes = getMsgPubkey(self, msg) bid = Bid( active_ind=1, bid_id=bid_id, @@ -7431,6 +7510,7 @@ class BasicSwap(BaseApp): created_at=msg["sent"], expire_at=msg["sent"] + bid_data.time_valid, bid_addr=msg["from"], + pk_bid_addr=pk_from, was_received=True, chain_a_height_start=ci_from.getChainHeight(), chain_b_height_start=ci_to.getChainHeight(), @@ -7829,12 +7909,13 @@ class BasicSwap(BaseApp): ) if ci_to.curve_type() == Curves.ed25519: - ensure(len(bid_data.kbsf_dleag) == 16000, "Invalid kbsf_dleag size") + ensure(len(bid_data.kbsf_dleag) <= 16000, "Invalid kbsf_dleag size") bid_id = bytes.fromhex(msg["msgid"]) bid, xmr_swap = self.getXmrBid(bid_id) if bid is None: + pk_from: bytes = getMsgPubkey(self, msg) bid = Bid( active_ind=1, bid_id=bid_id, @@ -7846,6 +7927,7 @@ class BasicSwap(BaseApp): created_at=msg["sent"], expire_at=msg["sent"] + bid_data.time_valid, bid_addr=msg["from"], + pk_bid_addr=pk_from, was_received=True, chain_a_height_start=ci_from.getChainHeight(), chain_b_height_start=ci_to.getChainHeight(), @@ -8175,8 +8257,8 @@ class BasicSwap(BaseApp): msg_valid: int = self.getActiveBidMsgValidTime() addr_send_from: str = offer.addr_from if reverse_bid else bid.bid_addr addr_send_to: str = bid.bid_addr if reverse_bid else offer.addr_from - coin_a_lock_tx_sigs_l_msg_id = self.sendSmsg( - addr_send_from, addr_send_to, payload_hex, msg_valid + coin_a_lock_tx_sigs_l_msg_id = self.sendMessage( + addr_send_from, addr_send_to, payload_hex, msg_valid, cursor ) self.addMessageLink( Concepts.BID, @@ -8544,8 +8626,8 @@ class BasicSwap(BaseApp): addr_send_from: str = bid.bid_addr if reverse_bid else offer.addr_from addr_send_to: str = offer.addr_from if reverse_bid else bid.bid_addr msg_valid: int = self.getActiveBidMsgValidTime() - coin_a_lock_release_msg_id = self.sendSmsg( - addr_send_from, addr_send_to, payload_hex, msg_valid + coin_a_lock_release_msg_id = self.sendMessage( + addr_send_from, addr_send_to, payload_hex, msg_valid, cursor ) self.addMessageLink( Concepts.BID, @@ -8964,8 +9046,8 @@ class BasicSwap(BaseApp): ) msg_valid: int = self.getActiveBidMsgValidTime() - xmr_swap.coin_a_lock_refund_spend_tx_msg_id = self.sendSmsg( - addr_send_from, addr_send_to, payload_hex, msg_valid + xmr_swap.coin_a_lock_refund_spend_tx_msg_id = self.sendMessage( + addr_send_from, addr_send_to, payload_hex, msg_valid, cursor ) bid.setState(BidStates.XMR_SWAP_MSG_SCRIPT_LOCK_SPEND_TX) @@ -9347,6 +9429,7 @@ class BasicSwap(BaseApp): bid, xmr_swap = self.getXmrBid(bid_id) if bid is None: + pk_from: bytes = getMsgPubkey(self, msg) bid = Bid( active_ind=1, bid_id=bid_id, @@ -9358,6 +9441,7 @@ class BasicSwap(BaseApp): created_at=msg["sent"], expire_at=msg["sent"] + bid_data.time_valid, bid_addr=msg["from"], + pk_bid_addr=pk_from, was_sent=False, was_received=True, chain_a_height_start=ci_from.getChainHeight(), @@ -9460,7 +9544,7 @@ class BasicSwap(BaseApp): "Invalid destination address", ) if ci_to.curve_type() == Curves.ed25519: - ensure(len(msg_data.kbsf_dleag) == 16000, "Invalid kbsf_dleag size") + ensure(len(msg_data.kbsf_dleag) <= 16000, "Invalid kbsf_dleag size") xmr_swap.dest_af = msg_data.dest_af xmr_swap.pkaf = msg_data.pkaf @@ -9495,6 +9579,14 @@ class BasicSwap(BaseApp): def processMsg(self, msg) -> None: try: + if "hex" not in msg: + if self.debug: + if "error" in msg: + self.log.debug( + "Message error {}: {}.".format(msg["msgid"], msg["error"]) + ) + raise ValueError("Invalid msg received {}.".format(msg["msgid"])) + return msg_type = int(msg["hex"][:2], 16) if msg_type == MessageTypes.OFFER: @@ -9708,6 +9800,10 @@ class BasicSwap(BaseApp): self.processMsg(msg) try: + for network in self.active_networks: + if network["type"] == "simplex": + readSimplexMsgs(self, network) + # TODO: Wait for blocks / txns, would need to check multiple coins now: int = self.getTime() self.expireBidsAndOffers(now) diff --git a/basicswap/contrib/test_framework/messages.py b/basicswap/contrib/test_framework/messages.py index 7ae6b11..a8b7a95 100755 --- a/basicswap/contrib/test_framework/messages.py +++ b/basicswap/contrib/test_framework/messages.py @@ -28,39 +28,62 @@ import struct import time from .siphash import siphash256 -from .util import hex_str_to_bytes, assert_equal +from .util import assert_equal -MIN_VERSION_SUPPORTED = 60001 -#MY_VERSION = 70014 # past bip-31 for ping/pong MY_VERSION = 90009 -MY_SUBVERSION = b"/python-mininode-tester:0.0.3/" -MY_RELAY = 1 # from version 70001 onwards, fRelay should be appended to version messages (BIP37) - MAX_LOCATOR_SZ = 101 -MAX_BLOCK_BASE_SIZE = 1000000 +MAX_BLOCK_WEIGHT = 4000000 MAX_BLOOM_FILTER_SIZE = 36000 MAX_BLOOM_HASH_FUNCS = 50 COIN = 100000000 # 1 btc in satoshis MAX_MONEY = 21000000 * COIN -BIP125_SEQUENCE_NUMBER = 0xfffffffd # Sequence number that is BIP 125 opt-in and BIP 68-opt-out +MAX_BIP125_RBF_SEQUENCE = 0xfffffffd # Sequence number that is rbf-opt-in (BIP 125) and csv-opt-out (BIP 68) +SEQUENCE_FINAL = 0xffffffff # Sequence number that disables nLockTime if set for every input of a tx +MAX_PROTOCOL_MESSAGE_LENGTH = 4000000 # Maximum length of incoming protocol messages +MAX_HEADERS_RESULTS = 2000 # Number of headers sent in one getheaders result +MAX_INV_SIZE = 50000 # Maximum number of entries in an 'inv' protocol message + +NODE_NONE = 0 NODE_NETWORK = (1 << 0) -NODE_GETUTXO = (1 << 1) NODE_BLOOM = (1 << 2) NODE_WITNESS = (1 << 3) +NODE_SMSG = (1 << 5) +NODE_COMPACT_FILTERS = (1 << 6) NODE_NETWORK_LIMITED = (1 << 10) +NODE_P2P_V2 = (1 << 11) MSG_TX = 1 MSG_BLOCK = 2 MSG_FILTERED_BLOCK = 3 MSG_CMPCT_BLOCK = 4 +MSG_WTX = 5 MSG_WITNESS_FLAG = 1 << 30 MSG_TYPE_MASK = 0xffffffff >> 2 +MSG_WITNESS_TX = MSG_TX | MSG_WITNESS_FLAG FILTER_TYPE_BASIC = 0 +WITNESS_SCALE_FACTOR = 4 + +DEFAULT_ANCESTOR_LIMIT = 25 # default max number of in-mempool ancestors +DEFAULT_DESCENDANT_LIMIT = 25 # default max number of in-mempool descendants + +# Default setting for -datacarriersize. 80 bytes of data, +1 for OP_RETURN, +2 for the pushdata opcodes. +MAX_OP_RETURN_RELAY = 83 + +DEFAULT_MEMPOOL_EXPIRY_HOURS = 336 # hours + +MAGIC_BYTES = { + "mainnet": b"\xfb\xf2\xef\xb4", # mainnet + "testnet3": b"\x08\x11\x05\x0b", # testnet3 + "regtest": b"\x09\x12\x06\x0c", # regtest + "signet": b"\x0a\x03\xcf\x40", # signet +} + +PARTICL_BLOCK_VERSION = 0xa0 PARTICL_TX_VERSION = 0xa0 PARTICL_TX_ANON_MARKER = 0xffffffa0 OUTPUT_TYPE_STANDARD = 1 @@ -69,64 +92,61 @@ OUTPUT_TYPE_RINGCT = 3 OUTPUT_TYPE_DATA = 4 -# Serialization/deserialization tools def sha256(s): - return hashlib.new('sha256', s).digest() + return hashlib.sha256(s).digest() + + +def sha3(s): + return hashlib.sha3_256(s).digest() + def hash256(s): return sha256(sha256(s)) + def ser_compact_size(l): r = b"" if l < 253: - r = struct.pack("B", l) + r = l.to_bytes(1, "little") elif l < 0x10000: - r = struct.pack(">= 32 - return rs + return u.to_bytes(32, 'little') def uint256_from_str(s): - r = 0 - t = struct.unpack("H", f.read(2))[0] + def __eq__(self, other): + return self.net == other.net and self.ip == other.ip and self.nServices == other.nServices and self.port == other.port and self.time == other.time - def serialize(self, with_time=True): + def deserialize(self, f, *, with_time=True): + """Deserialize from addrv1 format (pre-BIP155)""" + if with_time: + # VERSION messages serialize CAddress objects without time + self.time = int.from_bytes(f.read(4), "little") + self.nServices = int.from_bytes(f.read(8), "little") + # We only support IPv4 which means skip 12 bytes and read the next 4 as IPv4 address. + f.read(12) + self.net = self.NET_IPV4 + self.ip = socket.inet_ntoa(f.read(4)) + self.port = int.from_bytes(f.read(2), "big") + + def serialize(self, *, with_time=True): + """Serialize in addrv1 format (pre-BIP155)""" + assert self.net == self.NET_IPV4 r = b"" if with_time: - r += struct.pack("H", self.port) + r += self.port.to_bytes(2, "big") + return r + + def deserialize_v2(self, f): + """Deserialize from addrv2 format (BIP155)""" + self.time = int.from_bytes(f.read(4), "little") + + self.nServices = deser_compact_size(f) + + self.net = int.from_bytes(f.read(1), "little") + assert self.net in self.ADDRV2_NET_NAME + + address_length = deser_compact_size(f) + assert address_length == self.ADDRV2_ADDRESS_LENGTH[self.net] + + addr_bytes = f.read(address_length) + if self.net == self.NET_IPV4: + self.ip = socket.inet_ntoa(addr_bytes) + elif self.net == self.NET_IPV6: + self.ip = socket.inet_ntop(socket.AF_INET6, addr_bytes) + elif self.net == self.NET_TORV3: + prefix = b".onion checksum" + version = bytes([3]) + checksum = sha3(prefix + addr_bytes + version)[:2] + self.ip = b32encode(addr_bytes + checksum + version).decode("ascii").lower() + ".onion" + elif self.net == self.NET_I2P: + self.ip = b32encode(addr_bytes)[0:-len(self.I2P_PAD)].decode("ascii").lower() + ".b32.i2p" + elif self.net == self.NET_CJDNS: + self.ip = socket.inet_ntop(socket.AF_INET6, addr_bytes) + else: + raise Exception(f"Address type not supported") + + self.port = int.from_bytes(f.read(2), "big") + + def serialize_v2(self): + """Serialize in addrv2 format (BIP155)""" + assert self.net in self.ADDRV2_NET_NAME + r = b"" + r += self.time.to_bytes(4, "little") + r += ser_compact_size(self.nServices) + r += self.net.to_bytes(1, "little") + r += ser_compact_size(self.ADDRV2_ADDRESS_LENGTH[self.net]) + if self.net == self.NET_IPV4: + r += socket.inet_aton(self.ip) + elif self.net == self.NET_IPV6: + r += socket.inet_pton(socket.AF_INET6, self.ip) + elif self.net == self.NET_TORV3: + sfx = ".onion" + assert self.ip.endswith(sfx) + r += b32decode(self.ip[0:-len(sfx)], True)[0:32] + elif self.net == self.NET_I2P: + sfx = ".b32.i2p" + assert self.ip.endswith(sfx) + r += b32decode(self.ip[0:-len(sfx)] + self.I2P_PAD, True) + elif self.net == self.NET_CJDNS: + r += socket.inet_pton(socket.AF_INET6, self.ip) + else: + raise Exception(f"Address type not supported") + r += self.port.to_bytes(2, "big") return r def __repr__(self): - return "CAddress(nServices=%i ip=%s port=%i)" % (self.nServices, - self.ip, self.port) + return ("CAddress(nServices=%i net=%s addr=%s port=%i)" + % (self.nServices, self.ADDRV2_NET_NAME[self.net], self.ip, self.port)) class CInv: @@ -245,7 +383,8 @@ class CInv: MSG_TX | MSG_WITNESS_FLAG: "WitnessTx", MSG_BLOCK | MSG_WITNESS_FLAG: "WitnessBlock", MSG_FILTERED_BLOCK: "filtered Block", - 4: "CompactBlock" + MSG_CMPCT_BLOCK: "CompactBlock", + MSG_WTX: "WTX", } def __init__(self, t=0, h=0): @@ -253,12 +392,12 @@ class CInv: self.hash = h def deserialize(self, f): - self.type = struct.unpack(" 0 and with_pos_sig and self.nVersion == PARTICL_BLOCK_VERSION: + r += ser_string(self.blocksig) return r # Calculate the merkle root given a vector of transaction hashes @@ -779,6 +943,13 @@ class CBlock(CBlockHeader): self.nNonce += 1 self.rehash() + # Calculate the block weight using witness and non-witness + # serialization size (does NOT use sigops). + def get_weight(self): + with_witness_size = len(self.serialize(with_witness=True)) + without_witness_size = len(self.serialize(with_witness=False)) + return (WITNESS_SCALE_FACTOR - 1) * without_witness_size + with_witness_size + def __repr__(self): return "CBlock(nVersion=%i hashPrevBlock=%064x hashMerkleRoot=%064x nTime=%s nBits=%08x nNonce=%08x vtx=%s)" \ % (self.nVersion, self.hashPrevBlock, self.hashMerkleRoot, @@ -831,12 +1002,12 @@ class P2PHeaderAndShortIDs: def deserialize(self, f): self.header.deserialize(f) - self.nonce = struct.unpack("= 70001: - # Relay field is optional for version 70001 onwards - try: - self.nRelay = struct.unpack(" 0: + self._on_data() + + # Socket read methods + + def data_received(self, t): + """asyncio callback when data is read from the socket.""" + if len(t) > 0: + self.recvbuf += t + if self.supports_v2_p2p and not self.v2_state.tried_v2_handshake: + self._on_data_v2_handshake() + else: + self._on_data() + + def _on_data(self): + """Try to read P2P messages from the recv buffer. + + This method reads data from the buffer in a loop. It deserializes, + parses and verifies the P2P header, then passes the P2P payload to + the on_message callback for processing.""" + try: + while True: + if self.supports_v2_p2p: + # v2 P2P messages are read + msglen, msg = self.v2_state.v2_receive_packet(self.recvbuf) + if msglen == -1: + raise ValueError("invalid v2 mac tag " + repr(self.recvbuf)) + elif msglen == 0: # need to receive more bytes in recvbuf + return + self.recvbuf = self.recvbuf[msglen:] + + if msg is None: # ignore decoy messages + return + assert msg # application layer messages (which aren't decoy messages) are non-empty + shortid = msg[0] # 1-byte short message type ID + if shortid == 0: + # next 12 bytes are interpreted as ASCII message type if shortid is b'\x00' + if len(msg) < 13: + raise IndexError("msg needs minimum required length of 13 bytes") + msgtype = msg[1:13].rstrip(b'\x00') + msg = msg[13:] # msg is set to be payload + else: + # a 1-byte short message type ID + msgtype = SHORTID.get(shortid, f"unknown-{shortid}") + msg = msg[1:] + else: + # v1 P2P messages are read + if len(self.recvbuf) < 4: + return + if self.recvbuf[:4] != self.magic_bytes: + raise ValueError("magic bytes mismatch: {} != {}".format(repr(self.magic_bytes), repr(self.recvbuf))) + if len(self.recvbuf) < 4 + 12 + 4 + 4: + return + msgtype = self.recvbuf[4:4+12].split(b"\x00", 1)[0] + msglen = struct.unpack(" 500: + log_message += "... (msg truncated)" + logger.debug(log_message) + + +class P2PInterface(P2PConnection): + """A high-level P2P interface class for communicating with a Bitcoin node. + + This class provides high-level callbacks for processing P2P message + payloads, as well as convenience methods for interacting with the + node over P2P. + + Individual testcases should subclass this and override the on_* methods + if they want to alter message handling behaviour.""" + def __init__(self, support_addrv2=False, wtxidrelay=True): + super().__init__() + + # Track number of messages of each type received. + # Should be read-only in a test. + self.message_count = defaultdict(int) + + # Track the most recent message of each type. + # To wait for a message to be received, pop that message from + # this and use self.wait_until. + self.last_message = {} + + # A count of the number of ping messages we've sent to the node + self.ping_counter = 1 + + # The network services received from the peer + self.nServices = 0 + + self.support_addrv2 = support_addrv2 + + # If the peer supports wtxid-relay + self.wtxidrelay = wtxidrelay + + def peer_connect_send_version(self, services): + # Send a version msg + vt = msg_version() + vt.nVersion = P2P_VERSION + vt.strSubVer = P2P_SUBVERSION + vt.relay = P2P_VERSION_RELAY + vt.nServices = services + vt.addrTo.ip = self.dstaddr + vt.addrTo.port = self.dstport + vt.addrFrom.ip = "0.0.0.0" + vt.addrFrom.port = 0 + self.on_connection_send_msg = vt # Will be sent in connection_made callback + + def peer_connect(self, *, services=P2P_SERVICES, send_version, **kwargs): + create_conn = super().peer_connect(**kwargs) + + if send_version: + self.peer_connect_send_version(services) + + return create_conn + + def peer_accept_connection(self, *args, services=P2P_SERVICES, **kwargs): + create_conn = super().peer_accept_connection(*args, **kwargs) + self.peer_connect_send_version(services) + + return create_conn + + # Message receiving methods + + def on_message(self, message): + """Receive message and dispatch message to appropriate callback. + + We keep a count of how many of each message type has been received + and the most recent message of each type.""" + with p2p_lock: + try: + msgtype = message.msgtype.decode('ascii') + self.message_count[msgtype] += 1 + self.last_message[msgtype] = message + getattr(self, 'on_' + msgtype)(message) + except Exception: + print("ERROR delivering %s (%s)" % (repr(message), sys.exc_info()[0])) + raise + + # Callback methods. Can be overridden by subclasses in individual test + # cases to provide custom message handling behaviour. + + def on_open(self): + pass + + def on_close(self): + pass + + def on_addr(self, message): pass + def on_addrv2(self, message): pass + def on_block(self, message): pass + def on_blocktxn(self, message): pass + def on_cfcheckpt(self, message): pass + def on_cfheaders(self, message): pass + def on_cfilter(self, message): pass + def on_cmpctblock(self, message): pass + def on_feefilter(self, message): pass + def on_filteradd(self, message): pass + def on_filterclear(self, message): pass + def on_filterload(self, message): pass + def on_getaddr(self, message): pass + def on_getblocks(self, message): pass + def on_getblocktxn(self, message): pass + def on_getdata(self, message): pass + def on_getheaders(self, message): pass + def on_headers(self, message): pass + def on_mempool(self, message): pass + def on_merkleblock(self, message): pass + def on_notfound(self, message): pass + def on_pong(self, message): pass + def on_sendaddrv2(self, message): pass + def on_sendcmpct(self, message): pass + def on_sendheaders(self, message): pass + def on_sendtxrcncl(self, message): pass + def on_tx(self, message): pass + def on_wtxidrelay(self, message): pass + + def on_inv(self, message): + want = msg_getdata() + for i in message.inv: + if i.type != 0: + want.inv.append(i) + if len(want.inv): + self.send_message(want) + + def on_ping(self, message): + self.send_message(msg_pong(message.nonce)) + + def on_verack(self, message): + pass + + def on_version(self, message): + assert message.nVersion >= MIN_P2P_VERSION_SUPPORTED, "Version {} received. Test framework only supports versions greater than {}".format(message.nVersion, MIN_P2P_VERSION_SUPPORTED) + # for inbound connections, reply to version with own version message + # (could be due to v1 reconnect after a failed v2 handshake) + if not self.p2p_connected_to_node: + self.send_version() + self.reconnect = False + if message.nVersion >= 70016 and self.wtxidrelay: + self.send_message(msg_wtxidrelay()) + if self.support_addrv2: + self.send_message(msg_sendaddrv2()) + self.send_message(msg_verack()) + self.nServices = message.nServices + self.relay = message.relay + if self.p2p_connected_to_node: + self.send_message(msg_getaddr()) + + # Connection helper methods + + def wait_until(self, test_function_in, *, timeout=60, check_connected=True): + def test_function(): + if check_connected: + assert self.is_connected + return test_function_in() + + wait_until_helper_internal(test_function, timeout=timeout, lock=p2p_lock, timeout_factor=self.timeout_factor) + + def wait_for_connect(self, *, timeout=60): + test_function = lambda: self.is_connected + self.wait_until(test_function, timeout=timeout, check_connected=False) + + def wait_for_disconnect(self, *, timeout=60): + test_function = lambda: not self.is_connected + self.wait_until(test_function, timeout=timeout, check_connected=False) + + def wait_for_reconnect(self, *, timeout=60): + def test_function(): + return self.is_connected and self.last_message.get('version') and not self.supports_v2_p2p + self.wait_until(test_function, timeout=timeout, check_connected=False) + + # Message receiving helper methods + + def wait_for_tx(self, txid, *, timeout=60): + def test_function(): + if not self.last_message.get('tx'): + return False + return self.last_message['tx'].tx.rehash() == txid + + self.wait_until(test_function, timeout=timeout) + + def wait_for_block(self, blockhash, *, timeout=60): + def test_function(): + return self.last_message.get("block") and self.last_message["block"].block.rehash() == blockhash + + self.wait_until(test_function, timeout=timeout) + + def wait_for_header(self, blockhash, *, timeout=60): + def test_function(): + last_headers = self.last_message.get('headers') + if not last_headers: + return False + return last_headers.headers[0].rehash() == int(blockhash, 16) + + self.wait_until(test_function, timeout=timeout) + + def wait_for_merkleblock(self, blockhash, *, timeout=60): + def test_function(): + last_filtered_block = self.last_message.get('merkleblock') + if not last_filtered_block: + return False + return last_filtered_block.merkleblock.header.rehash() == int(blockhash, 16) + + self.wait_until(test_function, timeout=timeout) + + def wait_for_getdata(self, hash_list, *, timeout=60): + """Waits for a getdata message. + + The object hashes in the inventory vector must match the provided hash_list.""" + def test_function(): + last_data = self.last_message.get("getdata") + if not last_data: + return False + return [x.hash for x in last_data.inv] == hash_list + + self.wait_until(test_function, timeout=timeout) + + def wait_for_getheaders(self, block_hash=None, *, timeout=60): + """Waits for a getheaders message containing a specific block hash. + + If no block hash is provided, checks whether any getheaders message has been received by the node.""" + def test_function(): + last_getheaders = self.last_message.pop("getheaders", None) + if block_hash is None: + return last_getheaders + if last_getheaders is None: + return False + return block_hash == last_getheaders.locator.vHave[0] + + self.wait_until(test_function, timeout=timeout) + + def wait_for_inv(self, expected_inv, *, timeout=60): + """Waits for an INV message and checks that the first inv object in the message was as expected.""" + if len(expected_inv) > 1: + raise NotImplementedError("wait_for_inv() will only verify the first inv object") + + def test_function(): + return self.last_message.get("inv") and \ + self.last_message["inv"].inv[0].type == expected_inv[0].type and \ + self.last_message["inv"].inv[0].hash == expected_inv[0].hash + + self.wait_until(test_function, timeout=timeout) + + def wait_for_verack(self, *, timeout=60): + def test_function(): + return "verack" in self.last_message + + self.wait_until(test_function, timeout=timeout) + + # Message sending helper functions + + def send_version(self): + if self.on_connection_send_msg: + self.send_message(self.on_connection_send_msg) + self.on_connection_send_msg = None # Never used again + + def send_and_ping(self, message, *, timeout=60): + self.send_message(message) + self.sync_with_ping(timeout=timeout) + + def sync_with_ping(self, *, timeout=60): + """Ensure ProcessMessages and SendMessages is called on this connection""" + # Sending two pings back-to-back, requires that the node calls + # `ProcessMessage` twice, and thus ensures `SendMessages` must have + # been called at least once + self.send_message(msg_ping(nonce=0)) + self.send_message(msg_ping(nonce=self.ping_counter)) + + def test_function(): + return self.last_message.get("pong") and self.last_message["pong"].nonce == self.ping_counter + + self.wait_until(test_function, timeout=timeout) + self.ping_counter += 1 + + +# One lock for synchronizing all data access between the network event loop (see +# NetworkThread below) and the thread running the test logic. For simplicity, +# P2PConnection acquires this lock whenever delivering a message to a P2PInterface. +# This lock should be acquired in the thread running the test logic to synchronize +# access to any data shared with the P2PInterface or P2PConnection. +p2p_lock = threading.Lock() + + +class NetworkThread(threading.Thread): + network_event_loop = None + + def __init__(self): + super().__init__(name="NetworkThread") + # There is only one event loop and no more than one thread must be created + assert not self.network_event_loop + + NetworkThread.listeners = {} + NetworkThread.protos = {} + if platform.system() == 'Windows': + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + NetworkThread.network_event_loop = asyncio.new_event_loop() + + def run(self): + """Start the network thread.""" + self.network_event_loop.run_forever() + + def close(self, *, timeout=10): + """Close the connections and network event loop.""" + self.network_event_loop.call_soon_threadsafe(self.network_event_loop.stop) + wait_until_helper_internal(lambda: not self.network_event_loop.is_running(), timeout=timeout) + self.network_event_loop.close() + self.join(timeout) + # Safe to remove event loop. + NetworkThread.network_event_loop = None + + @classmethod + def listen(cls, p2p, callback, port=None, addr=None, idx=1): + """ Ensure a listening server is running on the given port, and run the + protocol specified by `p2p` on the next connection to it. Once ready + for connections, call `callback`.""" + + if port is None: + assert 0 < idx <= MAX_NODES + port = p2p_port(MAX_NODES - idx) + if addr is None: + addr = '127.0.0.1' + + def exception_handler(loop, context): + if not p2p.reconnect: + loop.default_exception_handler(context) + + cls.network_event_loop.set_exception_handler(exception_handler) + coroutine = cls.create_listen_server(addr, port, callback, p2p) + cls.network_event_loop.call_soon_threadsafe(cls.network_event_loop.create_task, coroutine) + + @classmethod + async def create_listen_server(cls, addr, port, callback, proto): + def peer_protocol(): + """Returns a function that does the protocol handling for a new + connection. To allow different connections to have different + behaviors, the protocol function is first put in the cls.protos + dict. When the connection is made, the function removes the + protocol function from that dict, and returns it so the event loop + can start executing it.""" + response = cls.protos.get((addr, port)) + # remove protocol function from dict only when reconnection doesn't need to happen/already happened + if not proto.reconnect: + cls.protos[(addr, port)] = None + return response + + if (addr, port) not in cls.listeners: + # When creating a listener on a given (addr, port) we only need to + # do it once. If we want different behaviors for different + # connections, we can accomplish this by providing different + # `proto` functions + + listener = await cls.network_event_loop.create_server(peer_protocol, addr, port) + logger.info("Listening server on %s:%d should be started" % (addr, port)) + cls.listeners[(addr, port)] = listener + + cls.protos[(addr, port)] = proto + callback(addr, port) + + +class P2PDataStore(P2PInterface): + """A P2P data store class. + + Keeps a block and transaction store and responds correctly to getdata and getheaders requests.""" + + def __init__(self): + super().__init__() + # store of blocks. key is block hash, value is a CBlock object + self.block_store = {} + self.last_block_hash = '' + # store of txs. key is txid, value is a CTransaction object + self.tx_store = {} + self.getdata_requests = [] + + def on_getdata(self, message): + """Check for the tx/block in our stores and if found, reply with an inv message.""" + for inv in message.inv: + self.getdata_requests.append(inv.hash) + if (inv.type & MSG_TYPE_MASK) == MSG_TX and inv.hash in self.tx_store.keys(): + self.send_message(msg_tx(self.tx_store[inv.hash])) + elif (inv.type & MSG_TYPE_MASK) == MSG_BLOCK and inv.hash in self.block_store.keys(): + self.send_message(msg_block(self.block_store[inv.hash])) + else: + logger.debug('getdata message type {} received.'.format(hex(inv.type))) + + def on_getheaders(self, message): + """Search back through our block store for the locator, and reply with a headers message if found.""" + + locator, hash_stop = message.locator, message.hashstop + + # Assume that the most recent block added is the tip + if not self.block_store: + return + + headers_list = [self.block_store[self.last_block_hash]] + while headers_list[-1].sha256 not in locator.vHave: + # Walk back through the block store, adding headers to headers_list + # as we go. + prev_block_hash = headers_list[-1].hashPrevBlock + if prev_block_hash in self.block_store: + prev_block_header = CBlockHeader(self.block_store[prev_block_hash]) + headers_list.append(prev_block_header) + if prev_block_header.sha256 == hash_stop: + # if this is the hashstop header, stop here + break + else: + logger.debug('block hash {} not found in block store'.format(hex(prev_block_hash))) + break + + # Truncate the list if there are too many headers + headers_list = headers_list[:-MAX_HEADERS_RESULTS - 1:-1] + response = msg_headers(headers_list) + + if response is not None: + self.send_message(response) + + def send_blocks_and_test(self, blocks, node, *, success=True, force_send=False, reject_reason=None, expect_disconnect=False, timeout=60, is_decoy=False): + """Send blocks to test node and test whether the tip advances. + + - add all blocks to our block_store + - send a headers message for the final block + - the on_getheaders handler will ensure that any getheaders are responded to + - if force_send is False: wait for getdata for each of the blocks. The on_getdata handler will + ensure that any getdata messages are responded to. Otherwise send the full block unsolicited. + - if success is True: assert that the node's tip advances to the most recent block + - if success is False: assert that the node's tip doesn't advance + - if reject_reason is set: assert that the correct reject message is logged""" + + with p2p_lock: + for block in blocks: + self.block_store[block.sha256] = block + self.last_block_hash = block.sha256 + + reject_reason = [reject_reason] if reject_reason else [] + with node.assert_debug_log(expected_msgs=reject_reason): + if is_decoy: # since decoy messages are ignored by the recipient - no need to wait for response + force_send = True + if force_send: + for b in blocks: + self.send_message(msg_block(block=b), is_decoy) + else: + self.send_message(msg_headers([CBlockHeader(block) for block in blocks])) + self.wait_until( + lambda: blocks[-1].sha256 in self.getdata_requests, + timeout=timeout, + check_connected=success, + ) + + if expect_disconnect: + self.wait_for_disconnect(timeout=timeout) + else: + self.sync_with_ping(timeout=timeout) + + if success: + self.wait_until(lambda: node.getbestblockhash() == blocks[-1].hash, timeout=timeout) + else: + assert node.getbestblockhash() != blocks[-1].hash + + def send_txs_and_test(self, txs, node, *, success=True, expect_disconnect=False, reject_reason=None): + """Send txs to test node and test whether they're accepted to the mempool. + + - add all txs to our tx_store + - send tx messages for all txs + - if success is True/False: assert that the txs are/are not accepted to the mempool + - if expect_disconnect is True: Skip the sync with ping + - if reject_reason is set: assert that the correct reject message is logged.""" + + with p2p_lock: + for tx in txs: + self.tx_store[tx.sha256] = tx + + reject_reason = [reject_reason] if reject_reason else [] + with node.assert_debug_log(expected_msgs=reject_reason): + for tx in txs: + self.send_message(msg_tx(tx)) + + if expect_disconnect: + self.wait_for_disconnect() + else: + self.sync_with_ping() + + raw_mempool = node.getrawmempool() + if success: + # Check that all txs are now in the mempool + for tx in txs: + assert tx.hash in raw_mempool, "{} not found in mempool".format(tx.hash) + else: + # Check that none of the txs are now in the mempool + for tx in txs: + assert tx.hash not in raw_mempool, "{} tx found in mempool".format(tx.hash) + +class P2PTxInvStore(P2PInterface): + """A P2PInterface which stores a count of how many times each txid has been announced.""" + def __init__(self): + super().__init__() + self.tx_invs_received = defaultdict(int) + + def on_inv(self, message): + super().on_inv(message) # Send getdata in response. + # Store how many times invs have been received for each tx. + for i in message.inv: + if (i.type == MSG_TX) or (i.type == MSG_WTX): + # save txid + self.tx_invs_received[i.hash] += 1 + + def get_invs(self): + with p2p_lock: + return list(self.tx_invs_received.keys()) + + def wait_for_broadcast(self, txns, *, timeout=60): + """Waits for the txns (list of txids) to complete initial broadcast. + The mempool should mark unbroadcast=False for these transactions. + """ + # Wait until invs have been received (and getdatas sent) for each txid. + self.wait_until(lambda: set(self.tx_invs_received.keys()) == set([int(tx, 16) for tx in txns]), timeout=timeout) + # Flush messages and wait for the getdatas to be processed + self.sync_with_ping() diff --git a/basicswap/contrib/test_framework/util.py b/basicswap/contrib/test_framework/util.py index 3c1d035..dc78029 100644 --- a/basicswap/contrib/test_framework/util.py +++ b/basicswap/contrib/test_framework/util.py @@ -5,20 +5,22 @@ """Helpful routines for regression testing.""" from base64 import b64encode -from binascii import unhexlify from decimal import Decimal, ROUND_DOWN from subprocess import CalledProcessError +import hashlib import inspect import json import logging import os -import random +import pathlib +import platform import re import time from . import coverage from .authproxy import AuthServiceProxy, JSONRPCException -from io import BytesIO +from collections.abc import Callable +from typing import Optional logger = logging.getLogger("TestFramework.utils") @@ -28,23 +30,46 @@ logger = logging.getLogger("TestFramework.utils") def assert_approx(v, vexp, vspan=0.00001): """Assert that `v` is within `vspan` of `vexp`""" + if isinstance(v, Decimal) or isinstance(vexp, Decimal): + v=Decimal(v) + vexp=Decimal(vexp) + vspan=Decimal(vspan) if v < vexp - vspan: raise AssertionError("%s < [%s..%s]" % (str(v), str(vexp - vspan), str(vexp + vspan))) if v > vexp + vspan: raise AssertionError("%s > [%s..%s]" % (str(v), str(vexp - vspan), str(vexp + vspan))) -def assert_fee_amount(fee, tx_size, fee_per_kB): - """Assert the fee was in range""" - target_fee = round(tx_size * fee_per_kB / 1000, 8) +def assert_fee_amount(fee, tx_size, feerate_BTC_kvB): + """Assert the fee is in range.""" + assert isinstance(tx_size, int) + target_fee = get_fee(tx_size, feerate_BTC_kvB) if fee < target_fee: raise AssertionError("Fee of %s BTC too low! (Should be %s BTC)" % (str(fee), str(target_fee))) # allow the wallet's estimation to be at most 2 bytes off - if fee > (tx_size + 2) * fee_per_kB / 1000: + high_fee = get_fee(tx_size + 2, feerate_BTC_kvB) + if fee > high_fee: raise AssertionError("Fee of %s BTC too high! (Should be %s BTC)" % (str(fee), str(target_fee))) +def summarise_dict_differences(thing1, thing2): + if not isinstance(thing1, dict) or not isinstance(thing2, dict): + return thing1, thing2 + d1, d2 = {}, {} + for k in sorted(thing1.keys()): + if k not in thing2: + d1[k] = thing1[k] + elif thing1[k] != thing2[k]: + d1[k], d2[k] = summarise_dict_differences(thing1[k], thing2[k]) + for k in sorted(thing2.keys()): + if k not in thing1: + d2[k] = thing2[k] + return d1, d2 + def assert_equal(thing1, thing2, *args): + if thing1 != thing2 and not args and isinstance(thing1, dict) and isinstance(thing2, dict): + d1,d2 = summarise_dict_differences(thing1, thing2) + raise AssertionError("not(%s == %s)\n in particular not(%s == %s)" % (thing1, thing2, d1, d2)) if thing1 != thing2 or any(thing1 != arg for arg in args): raise AssertionError("not(%s)" % " == ".join(str(arg) for arg in (thing1, thing2) + args)) @@ -79,7 +104,7 @@ def assert_raises_message(exc, message, fun, *args, **kwds): raise AssertionError("No exception raised") -def assert_raises_process_error(returncode, output, fun, *args, **kwds): +def assert_raises_process_error(returncode: int, output: str, fun: Callable, *args, **kwds): """Execute a process and asserts the process return code and output. Calls function `fun` with arguments `args` and `kwds`. Catches a CalledProcessError @@ -87,9 +112,9 @@ def assert_raises_process_error(returncode, output, fun, *args, **kwds): no CalledProcessError was raised or if the return code and output are not as expected. Args: - returncode (int): the process return code. - output (string): [a substring of] the process output. - fun (function): the function to call. This should execute a process. + returncode: the process return code. + output: [a substring of] the process output. + fun: the function to call. This should execute a process. args*: positional arguments for the function. kwds**: named arguments for the function. """ @@ -104,7 +129,7 @@ def assert_raises_process_error(returncode, output, fun, *args, **kwds): raise AssertionError("No exception raised") -def assert_raises_rpc_error(code, message, fun, *args, **kwds): +def assert_raises_rpc_error(code: Optional[int], message: Optional[str], fun: Callable, *args, **kwds): """Run an RPC and verify that a specific JSONRPC exception code and message is raised. Calls function `fun` with arguments `args` and `kwds`. Catches a JSONRPCException @@ -112,11 +137,11 @@ def assert_raises_rpc_error(code, message, fun, *args, **kwds): no JSONRPCException was raised or if the error code/message are not as expected. Args: - code (int), optional: the error code returned by the RPC call (defined - in src/rpc/protocol.h). Set to None if checking the error code is not required. - message (string), optional: [a substring of] the error string returned by the - RPC call. Set to None if checking the error string is not required. - fun (function): the function to call. This should be the name of an RPC. + code: the error code returned by the RPC call (defined in src/rpc/protocol.h). + Set to None if checking the error code is not required. + message: [a substring of] the error string returned by the RPC call. + Set to None if checking the error string is not required. + fun: the function to call. This should be the name of an RPC. args*: positional arguments for the function. kwds**: named arguments for the function. """ @@ -203,29 +228,45 @@ def check_json_precision(): raise RuntimeError("JSON encode/decode loses precision") -def EncodeDecimal(o): - if isinstance(o, Decimal): - return str(o) - raise TypeError(repr(o) + " is not JSON serializable") - - def count_bytes(hex_string): return len(bytearray.fromhex(hex_string)) -def hex_str_to_bytes(hex_str): - return unhexlify(hex_str.encode('ascii')) - - def str_to_b64str(string): return b64encode(string.encode('utf-8')).decode('ascii') +def ceildiv(a, b): + """ + Divide 2 ints and round up to next int rather than round down + Implementation requires python integers, which have a // operator that does floor division. + Other types like decimal.Decimal whose // operator truncates towards 0 will not work. + """ + assert isinstance(a, int) + assert isinstance(b, int) + return -(-a // b) + + +def get_fee(tx_size, feerate_btc_kvb): + """Calculate the fee in BTC given a feerate is BTC/kvB. Reflects CFeeRate::GetFee""" + feerate_sat_kvb = int(feerate_btc_kvb * Decimal(1e8)) # Fee in sat/kvb as an int to avoid float precision errors + target_fee_sat = ceildiv(feerate_sat_kvb * tx_size, 1000) # Round calculated fee up to nearest sat + return target_fee_sat / Decimal(1e8) # Return result in BTC + + def satoshi_round(amount): return Decimal(amount).quantize(Decimal('0.00000001'), rounding=ROUND_DOWN) -def wait_until(predicate, *, attempts=float('inf'), timeout=float('inf'), lock=None, timeout_factor=1.0): +def wait_until_helper_internal(predicate, *, attempts=float('inf'), timeout=float('inf'), lock=None, timeout_factor=1.0): + """Sleep until the predicate resolves to be True. + + Warning: Note that this method is not recommended to be used in tests as it is + not aware of the context of the test framework. Using the `wait_until()` members + from `BitcoinTestFramework` or `P2PInterface` class ensures the timeout is + properly scaled. Furthermore, `wait_until()` from `P2PInterface` class in + `p2p.py` has a preset lock. + """ if attempts == float('inf') and timeout == float('inf'): timeout = 60 timeout = timeout * timeout_factor @@ -253,6 +294,16 @@ def wait_until(predicate, *, attempts=float('inf'), timeout=float('inf'), lock=N raise RuntimeError('Unreachable') +def sha256sum_file(filename): + h = hashlib.sha256() + with open(filename, 'rb') as f: + d = f.read(4096) + while len(d) > 0: + h.update(d) + d = f.read(4096) + return h.digest() + + # RPC/P2P connection constants and functions ############################################ @@ -269,15 +320,15 @@ class PortSeed: n = None -def get_rpc_proxy(url, node_number, *, timeout=None, coveragedir=None): +def get_rpc_proxy(url: str, node_number: int, *, timeout: Optional[int]=None, coveragedir: Optional[str]=None) -> coverage.AuthServiceProxyWrapper: """ Args: - url (str): URL of the RPC server to call - node_number (int): the node number (or id) that this calls to + url: URL of the RPC server to call + node_number: the node number (or id) that this calls to Kwargs: - timeout (int): HTTP timeout in seconds - coveragedir (str): Directory + timeout: HTTP timeout in seconds + coveragedir: Directory Returns: AuthServiceProxy. convenience object for making RPC calls. @@ -288,11 +339,10 @@ def get_rpc_proxy(url, node_number, *, timeout=None, coveragedir=None): proxy_kwargs['timeout'] = int(timeout) proxy = AuthServiceProxy(url, **proxy_kwargs) - proxy.url = url # store URL on proxy for info coverage_logfile = coverage.get_filename(coveragedir, node_number) if coveragedir else None - return coverage.AuthServiceProxyWrapper(proxy, coverage_logfile) + return coverage.AuthServiceProxyWrapper(proxy, url, coverage_logfile) def p2p_port(n): @@ -321,38 +371,76 @@ def rpc_url(datadir, i, chain, rpchost): ################ -def initialize_datadir(dirname, n, chain): +def initialize_datadir(dirname, n, chain, disable_autoconnect=True): datadir = get_datadir_path(dirname, n) if not os.path.isdir(datadir): os.makedirs(datadir) - # Translate chain name to config name - if chain == 'testnet3': + write_config(os.path.join(datadir, "particl.conf"), n=n, chain=chain, disable_autoconnect=disable_autoconnect) + os.makedirs(os.path.join(datadir, 'stderr'), exist_ok=True) + os.makedirs(os.path.join(datadir, 'stdout'), exist_ok=True) + return datadir + + +def write_config(config_path, *, n, chain, extra_config="", disable_autoconnect=True): + # Translate chain subdirectory name to config name + if chain == 'testnet': chain_name_conf_arg = 'testnet' chain_name_conf_section = 'test' else: chain_name_conf_arg = chain chain_name_conf_section = chain - with open(os.path.join(datadir, "particl.conf"), 'w', encoding='utf8') as f: - f.write("{}=1\n".format(chain_name_conf_arg)) - f.write("[{}]\n".format(chain_name_conf_section)) + with open(config_path, 'w', encoding='utf8') as f: + if chain_name_conf_arg: + f.write("{}=1\n".format(chain_name_conf_arg)) + if chain_name_conf_section: + f.write("[{}]\n".format(chain_name_conf_section)) f.write("port=" + str(p2p_port(n)) + "\n") f.write("rpcport=" + str(rpc_port(n)) + "\n") + # Disable server-side timeouts to avoid intermittent issues + f.write("rpcservertimeout=99000\n") + f.write("rpcdoccheck=1\n") f.write("fallbackfee=0.0002\n") f.write("server=1\n") f.write("keypool=1\n") f.write("discover=0\n") f.write("dnsseed=0\n") + f.write("fixedseeds=0\n") f.write("listenonion=0\n") + # Increase peertimeout to avoid disconnects while using mocktime. + # peertimeout is measured in mock time, so setting it large enough to + # cover any duration in mock time is sufficient. It can be overridden + # in tests. + f.write("peertimeout=999999999\n") f.write("printtoconsole=0\n") f.write("upnp=0\n") + f.write("natpmp=0\n") f.write("shrinkdebugfile=0\n") - os.makedirs(os.path.join(datadir, 'stderr'), exist_ok=True) - os.makedirs(os.path.join(datadir, 'stdout'), exist_ok=True) - return datadir + f.write("deprecatedrpc=create_bdb\n") # Required to run the tests + # To improve SQLite wallet performance so that the tests don't timeout, use -unsafesqlitesync + f.write("unsafesqlitesync=1\n") + if disable_autoconnect: + f.write("connect=0\n") + f.write(extra_config) def get_datadir_path(dirname, n): - return os.path.join(dirname, "node" + str(n)) + return pathlib.Path(dirname) / f"node{n}" + + +def get_temp_default_datadir(temp_dir: pathlib.Path) -> tuple[dict, pathlib.Path]: + """Return os-specific environment variables that can be set to make the + GetDefaultDataDir() function return a datadir path under the provided + temp_dir, as well as the complete path it would return.""" + if platform.system() == "Windows": + env = dict(APPDATA=str(temp_dir)) + datadir = temp_dir / "Particl" + else: + env = dict(HOME=str(temp_dir)) + if platform.system() == "Darwin": + datadir = temp_dir / "Library/Application Support/Particl" + else: + datadir = temp_dir / ".particl" + return env, datadir def append_config(datadir, options): @@ -395,7 +483,7 @@ def delete_cookie_file(datadir, chain): def softfork_active(node, key): """Return whether a softfork is active.""" - return node.getblockchaininfo()['softforks'][key]['active'] + return node.getdeploymentinfo()['deployments'][key]['active'] def set_node_times(nodes, t): @@ -403,208 +491,51 @@ def set_node_times(nodes, t): node.setmocktime(t) -def disconnect_nodes(from_connection, node_num): - def get_peer_ids(): - result = [] - for peer in from_connection.getpeerinfo(): - if "testnode{}".format(node_num) in peer['subver']: - result.append(peer['id']) - return result - - peer_ids = get_peer_ids() - if not peer_ids: - logger.warning("disconnect_nodes: {} and {} were not connected".format( - from_connection.index, - node_num, - )) - return - for peer_id in peer_ids: - try: - from_connection.disconnectnode(nodeid=peer_id) - except JSONRPCException as e: - # If this node is disconnected between calculating the peer id - # and issuing the disconnect, don't worry about it. - # This avoids a race condition if we're mass-disconnecting peers. - if e.error['code'] != -29: # RPC_CLIENT_NODE_NOT_CONNECTED - raise - - # wait to disconnect - wait_until(lambda: not get_peer_ids(), timeout=5) - - -def connect_nodes(from_connection, node_num): - ip_port = "127.0.0.1:" + str(p2p_port(node_num)) - from_connection.addnode(ip_port, "onetry") - # poll until version handshake complete to avoid race conditions - # with transaction relaying - # See comments in net_processing: - # * Must have a version message before anything else - # * Must have a verack message before anything else - wait_until(lambda: all(peer['version'] != 0 for peer in from_connection.getpeerinfo())) - wait_until(lambda: all(peer['bytesrecv_per_msg'].pop('verack', 0) == 24 for peer in from_connection.getpeerinfo())) +def check_node_connections(*, node, num_in, num_out): + info = node.getnetworkinfo() + assert_equal(info["connections_in"], num_in) + assert_equal(info["connections_out"], num_out) # Transaction/Block functions ############################# -def find_output(node, txid, amount, *, blockhash=None): - """ - Return index to output of txid with value amount - Raises exception if there is none. - """ - txdata = node.getrawtransaction(txid, 1, blockhash) - for i in range(len(txdata["vout"])): - if txdata["vout"][i]["value"] == amount: - return i - raise RuntimeError("find_output txid %s : %s not found" % (txid, str(amount))) - - -def gather_inputs(from_node, amount_needed, confirmations_required=1): - """ - Return a random set of unspent txouts that are enough to pay amount_needed - """ - assert confirmations_required >= 0 - utxo = from_node.listunspent(confirmations_required) - random.shuffle(utxo) - inputs = [] - total_in = Decimal("0.00000000") - while total_in < amount_needed and len(utxo) > 0: - t = utxo.pop() - total_in += t["amount"] - inputs.append({"txid": t["txid"], "vout": t["vout"], "address": t["address"]}) - if total_in < amount_needed: - raise RuntimeError("Insufficient funds: need %d, have %d" % (amount_needed, total_in)) - return (total_in, inputs) - - -def make_change(from_node, amount_in, amount_out, fee): - """ - Create change output(s), return them - """ - outputs = {} - amount = amount_out + fee - change = amount_in - amount - if change > amount * 2: - # Create an extra change output to break up big inputs - change_address = from_node.getnewaddress() - # Split change in two, being careful of rounding: - outputs[change_address] = Decimal(change / 2).quantize(Decimal('0.00000001'), rounding=ROUND_DOWN) - change = amount_in - amount - outputs[change_address] - if change > 0: - outputs[from_node.getnewaddress()] = change - return outputs - - -def random_transaction(nodes, amount, min_fee, fee_increment, fee_variants): - """ - Create a random transaction. - Returns (txid, hex-encoded-transaction-data, fee) - """ - from_node = random.choice(nodes) - to_node = random.choice(nodes) - fee = min_fee + fee_increment * random.randint(0, fee_variants) - - (total_in, inputs) = gather_inputs(from_node, amount + fee) - outputs = make_change(from_node, total_in, amount, fee) - outputs[to_node.getnewaddress()] = float(amount) - - rawtx = from_node.createrawtransaction(inputs, outputs) - signresult = from_node.signrawtransactionwithwallet(rawtx) - txid = from_node.sendrawtransaction(signresult["hex"], 0) - - return (txid, signresult["hex"], fee) - - -# Helper to create at least "count" utxos -# Pass in a fee that is sufficient for relay and mining new transactions. -def create_confirmed_utxos(fee, node, count): - to_generate = int(0.5 * count) + 101 - while to_generate > 0: - node.generate(min(25, to_generate)) - to_generate -= 25 - utxos = node.listunspent() - iterations = count - len(utxos) - addr1 = node.getnewaddress() - addr2 = node.getnewaddress() - if iterations <= 0: - return utxos - for i in range(iterations): - t = utxos.pop() - inputs = [] - inputs.append({"txid": t["txid"], "vout": t["vout"]}) - outputs = {} - send_value = t['amount'] - fee - outputs[addr1] = satoshi_round(send_value / 2) - outputs[addr2] = satoshi_round(send_value / 2) - raw_tx = node.createrawtransaction(inputs, outputs) - signed_tx = node.signrawtransactionwithwallet(raw_tx)["hex"] - node.sendrawtransaction(signed_tx) - - while (node.getmempoolinfo()['size'] > 0): - node.generate(1) - - utxos = node.listunspent() - assert len(utxos) >= count - return utxos - - # Create large OP_RETURN txouts that can be appended to a transaction -# to make it large (helper for constructing large transactions). +# to make it large (helper for constructing large transactions). The +# total serialized size of the txouts is about 66k vbytes. def gen_return_txouts(): - # Some pre-processing to create a bunch of OP_RETURN txouts to insert into transactions we create - # So we have big transactions (and therefore can't fit very many into each block) - # create one script_pubkey - script_pubkey = "6a4d0200" # OP_RETURN OP_PUSH2 512 bytes - for i in range(512): - script_pubkey = script_pubkey + "01" - # concatenate 128 txouts of above script_pubkey which we'll insert before the txout for change - txouts = [] from .messages import CTxOut - txout = CTxOut() - txout.nValue = 0 - txout.scriptPubKey = hex_str_to_bytes(script_pubkey) - for k in range(128): - txouts.append(txout) + from .script import CScript, OP_RETURN + txouts = [CTxOut(nValue=0, scriptPubKey=CScript([OP_RETURN, b'\x01'*67437]))] + assert_equal(sum([len(txout.serialize()) for txout in txouts]), 67456) return txouts # Create a spend of each passed-in utxo, splicing in "txouts" to each raw # transaction to make it large. See gen_return_txouts() above. -def create_lots_of_big_transactions(node, txouts, utxos, num, fee): - addr = node.getnewaddress() +def create_lots_of_big_transactions(mini_wallet, node, fee, tx_batch_size, txouts, utxos=None): txids = [] - from .messages import CTransaction - for _ in range(num): - t = utxos.pop() - inputs = [{"txid": t["txid"], "vout": t["vout"]}] - outputs = {} - change = t['amount'] - fee - outputs[addr] = satoshi_round(change) - rawtx = node.createrawtransaction(inputs, outputs) - tx = CTransaction() - tx.deserialize(BytesIO(hex_str_to_bytes(rawtx))) - for txout in txouts: - tx.vout.append(txout) - newtx = tx.serialize().hex() - signresult = node.signrawtransactionwithwallet(newtx, None, "NONE") - txid = node.sendrawtransaction(signresult["hex"], 0) - txids.append(txid) + use_internal_utxos = utxos is None + for _ in range(tx_batch_size): + tx = mini_wallet.create_self_transfer( + utxo_to_spend=None if use_internal_utxos else utxos.pop(), + fee=fee, + )["tx"] + tx.vout.extend(txouts) + res = node.testmempoolaccept([tx.serialize().hex()])[0] + assert_equal(res['fees']['base'], fee) + txids.append(node.sendrawtransaction(tx.serialize().hex())) return txids -def mine_large_block(node, utxos=None): +def mine_large_block(test_framework, mini_wallet, node): # generate a 66k transaction, # and 14 of them is close to the 1MB block limit - num = 14 txouts = gen_return_txouts() - utxos = utxos if utxos is not None else [] - if len(utxos) < num: - utxos.clear() - utxos.extend(node.listunspent()) fee = 100 * node.getnetworkinfo()["relayfee"] - create_lots_of_big_transactions(node, txouts, utxos, num, fee=fee) - node.generate(1) + create_lots_of_big_transactions(mini_wallet, node, fee, 14, txouts) + test_framework.generate(node, 1) def find_vout_for_address(node, txid, addr): @@ -614,11 +545,6 @@ def find_vout_for_address(node, txid, addr): """ tx = node.getrawtransaction(txid, True) for i in range(len(tx["vout"])): - scriptPubKey = tx["vout"][i]["scriptPubKey"] - if "addresses" in scriptPubKey: - if any([addr == a for a in scriptPubKey["addresses"]]): - return i - elif "address" in scriptPubKey: - if addr == scriptPubKey["address"]: - return i + if addr == tx["vout"][i]["scriptPubKey"]["address"]: + return i raise RuntimeError("Vout not found for address: txid=%s, addr=%s" % (txid, addr)) diff --git a/basicswap/db.py b/basicswap/db.py index 1cea42d..7b83b78 100644 --- a/basicswap/db.py +++ b/basicswap/db.py @@ -13,7 +13,7 @@ from enum import IntEnum, auto from typing import Optional -CURRENT_DB_VERSION = 27 +CURRENT_DB_VERSION = 28 CURRENT_DB_DATA_VERSION = 6 @@ -174,6 +174,7 @@ class Offer(Table): secret_hash = Column("blob") addr_from = Column("string") + pk_from = Column("blob") addr_to = Column("string") created_at = Column("integer") expire_at = Column("integer") @@ -216,6 +217,7 @@ class Bid(Table): created_at = Column("integer") expire_at = Column("integer") bid_addr = Column("string") + pk_bid_addr = Column("blob") proof_address = Column("string") proof_utxos = Column("blob") # Address to spend lock tx to - address from wallet if empty TODO diff --git a/basicswap/db_upgrades.py b/basicswap/db_upgrades.py index 16da49a..2ed8357 100644 --- a/basicswap/db_upgrades.py +++ b/basicswap/db_upgrades.py @@ -428,6 +428,11 @@ def upgradeDatabase(self, db_version): elif current_version == 26: db_version += 1 cursor.execute("ALTER TABLE offers ADD COLUMN auto_accept_type INTEGER") + elif current_version == 27: + db_version += 1 + cursor.execute("ALTER TABLE offers ADD COLUMN pk_from BLOB") + cursor.execute("ALTER TABLE bids ADD COLUMN pk_bid_addr BLOB") + if current_version != db_version: self.db_version = db_version self.setIntKV("db_version", db_version, cursor) diff --git a/basicswap/network/__init__.py b/basicswap/network/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/basicswap/network.py b/basicswap/network/network.py similarity index 100% rename from basicswap/network.py rename to basicswap/network/network.py diff --git a/basicswap/network/simplex.py b/basicswap/network/simplex.py new file mode 100644 index 0000000..ad404ff --- /dev/null +++ b/basicswap/network/simplex.py @@ -0,0 +1,350 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# Copyright (c) 2025 The Basicswap developers +# Distributed under the MIT software license, see the accompanying +# file LICENSE or http://www.opensource.org/licenses/mit-license.php. + +import base64 +import json +import threading +import websocket + + +from queue import Queue, Empty + +from basicswap.util.smsg import ( + smsgEncrypt, + smsgDecrypt, + smsgGetID, +) +from basicswap.chainparams import ( + Coins, +) +from basicswap.util.address import ( + b58decode, + decodeWif, +) +from basicswap.basicswap_util import ( + BidStates, +) + + +def encode_base64(data: bytes) -> str: + return base64.b64encode(data).decode("utf-8") + + +def decode_base64(encoded_data: str) -> bytes: + return base64.b64decode(encoded_data) + + +class WebSocketThread(threading.Thread): + def __init__(self, url: str, tag: str = None, logger=None): + super().__init__() + self.url: str = url + self.tag = tag + self.logger = logger + self.ws = None + self.mutex = threading.Lock() + self.corrId: int = 0 + self.connected: bool = False + self.delay_event = threading.Event() + + self.recv_queue = Queue() + self.cmd_recv_queue = Queue() + + def on_message(self, ws, message): + if self.logger: + self.logger.debug("Simplex received msg") + else: + print(f"{self.tag} - Received msg") + + if message.startswith('{"corrId"'): + self.cmd_recv_queue.put(message) + else: + self.recv_queue.put(message) + + def queue_get(self): + try: + return self.recv_queue.get(block=False) + except Empty: + return None + + def cmd_queue_get(self): + try: + return self.cmd_recv_queue.get(block=False) + except Empty: + return None + + def on_error(self, ws, error): + if self.logger: + self.logger.error(f"Simplex ws - {error}") + else: + print(f"{self.tag} - Error: {error}") + + def on_close(self, ws, close_status_code, close_msg): + if self.logger: + self.logger.info(f"Simplex ws - Closed: {close_status_code}, {close_msg}") + else: + print(f"{self.tag} - Closed: {close_status_code}, {close_msg}") + + def on_open(self, ws): + if self.logger: + self.logger.info("Simplex ws - Connection opened") + else: + print(f"{self.tag}: WebSocket connection opened") + self.connected = True + + def send_command(self, cmd_str: str): + with self.mutex: + self.corrId += 1 + if self.logger: + self.logger.debug(f"Simplex sent command {self.corrId}") + else: + print(f"{self.tag}: sent command {self.corrId}") + cmd = json.dumps({"corrId": str(self.corrId), "cmd": cmd_str}) + self.ws.send(cmd) + return self.corrId + + def run(self): + self.ws = websocket.WebSocketApp( + self.url, + on_message=self.on_message, + on_error=self.on_error, + on_open=self.on_open, + on_close=self.on_close, + ) + while not self.delay_event.is_set(): + self.ws.run_forever() + self.delay_event.wait(0.5) + + def stop(self): + self.delay_event.set() + if self.ws: + self.ws.close() + + +def waitForResponse(ws_thread, sent_id, delay_event): + sent_id = str(sent_id) + for i in range(100): + message = ws_thread.cmd_queue_get() + if message is not None: + data = json.loads(message) + # print(f"json: {json.dumps(data, indent=4)}") + if "corrId" in data: + if data["corrId"] == sent_id: + return data + delay_event.wait(0.5) + raise ValueError(f"waitForResponse timed-out waiting for id: {sent_id}") + + +def waitForConnected(ws_thread, delay_event): + for i in range(100): + if ws_thread.connected: + return True + delay_event.wait(0.5) + raise ValueError("waitForConnected timed-out.") + + +def getPrivkeyForAddress(self, addr) -> bytes: + + ci_part = self.ci(Coins.PART) + try: + return ci_part.decodeKey( + self.callrpc( + "smsgdumpprivkey", + [ + addr, + ], + ) + ) + except Exception as e: # noqa: F841 + pass + try: + return ci_part.decodeKey( + ci_part.rpc_wallet( + "dumpprivkey", + [ + addr, + ], + ) + ) + except Exception as e: # noqa: F841 + pass + raise ValueError("key not found") + + +def sendSimplexMsg( + self, network, addr_from: str, addr_to: str, payload: bytes, msg_valid: int, cursor +) -> bytes: + self.log.debug("sendSimplexMsg") + + try: + rv = self.callrpc( + "smsggetpubkey", + [ + addr_to, + ], + ) + pubkey_to: bytes = b58decode(rv["publickey"]) + except Exception as e: # noqa: F841 + use_cursor = self.openDB(cursor) + try: + query: str = "SELECT pk_from FROM offers WHERE addr_from = :addr_to LIMIT 1" + rows = use_cursor.execute(query, {"addr_to": addr_to}).fetchall() + if len(rows) > 0: + pubkey_to = rows[0][0] + else: + query: str = ( + "SELECT pk_bid_addr FROM bids WHERE bid_addr = :addr_to LIMIT 1" + ) + rows = use_cursor.execute(query, {"addr_to": addr_to}).fetchall() + if len(rows) > 0: + pubkey_to = rows[0][0] + else: + raise ValueError(f"Could not get public key for address {addr_to}") + finally: + if cursor is None: + self.closeDB(use_cursor, commit=False) + + privkey_from = getPrivkeyForAddress(self, addr_from) + + payload += bytes((0,)) # Include null byte to match smsg + smsg_msg: bytes = smsgEncrypt(privkey_from, pubkey_to, payload) + + smsg_id = smsgGetID(smsg_msg) + + ws_thread = network["ws_thread"] + sent_id = ws_thread.send_command("#bsx " + encode_base64(smsg_msg)) + response = waitForResponse(ws_thread, sent_id, self.delay_event) + if response["resp"]["type"] != "newChatItems": + json_str = json.dumps(response, indent=4) + self.log.debug(f"Response {json_str}") + raise ValueError("Send failed") + + return smsg_id + + +def decryptSimplexMsg(self, msg_data): + ci_part = self.ci(Coins.PART) + + # Try with the network key first + network_key: bytes = decodeWif(self.network_key) + try: + decrypted = smsgDecrypt(network_key, msg_data, output_dict=True) + decrypted["from"] = ci_part.pubkey_to_address( + bytes.fromhex(decrypted["pk_from"]) + ) + decrypted["to"] = self.network_addr + decrypted["msg_net"] = "simplex" + return decrypted + except Exception as e: # noqa: F841 + pass + + # Try with all active bid/offer addresses + query: str = """SELECT DISTINCT address FROM ( + SELECT bid_addr AS address FROM bids WHERE active_ind = 1 + AND (in_progress = 1 OR (state > :bid_received AND state < :bid_completed) OR (state IN (:bid_received, :bid_sent) AND expire_at > :now)) + UNION + SELECT addr_from AS address FROM offers WHERE active_ind = 1 AND expire_at > :now + )""" + + now: int = self.getTime() + + try: + cursor = self.openDB() + addr_rows = cursor.execute( + query, + { + "bid_received": int(BidStates.BID_RECEIVED), + "bid_completed": int(BidStates.SWAP_COMPLETED), + "bid_sent": int(BidStates.BID_SENT), + "now": now, + }, + ).fetchall() + finally: + self.closeDB(cursor, commit=False) + + decrypted = None + for row in addr_rows: + addr = row[0] + try: + vk_addr = getPrivkeyForAddress(self, addr) + decrypted = smsgDecrypt(vk_addr, msg_data, output_dict=True) + decrypted["from"] = ci_part.pubkey_to_address( + bytes.fromhex(decrypted["pk_from"]) + ) + decrypted["to"] = addr + decrypted["msg_net"] = "simplex" + return decrypted + except Exception as e: # noqa: F841 + pass + + return decrypted + + +def readSimplexMsgs(self, network): + ws_thread = network["ws_thread"] + + for i in range(100): + message = ws_thread.queue_get() + if message is None: + break + + data = json.loads(message) + # self.log.debug(f"message 1: {json.dumps(data, indent=4)}") + try: + if data["resp"]["type"] in ("chatItemsStatusesUpdated", "newChatItems"): + for chat_item in data["resp"]["chatItems"]: + item_status = chat_item["chatItem"]["meta"]["itemStatus"] + if item_status["type"] in ("sndRcvd", "rcvNew"): + snd_progress = item_status.get("sndProgress", None) + if snd_progress: + if snd_progress != "complete": + item_id = chat_item["chatItem"]["meta"]["itemId"] + self.log.debug( + f"simplex chat item {item_id} {snd_progress}" + ) + continue + try: + msg_data: bytes = decode_base64( + chat_item["chatItem"]["content"]["msgContent"]["text"] + ) + decrypted_msg = decryptSimplexMsg(self, msg_data) + if decrypted_msg is None: + continue + self.processMsg(decrypted_msg) + except Exception as e: # noqa: F841 + # self.log.debug(f"decryptSimplexMsg error: {e}") + pass + except Exception as e: + self.log.debug(f"readSimplexMsgs error: {e}") + + self.delay_event.wait(0.05) + + +def initialiseSimplexNetwork(self, network_config) -> None: + self.log.debug("initialiseSimplexNetwork") + + client_host: str = network_config.get("client_host", "127.0.0.1") + ws_port: str = network_config.get("ws_port") + + ws_thread = WebSocketThread(f"ws://{client_host}:{ws_port}", logger=self.log) + self.threads.append(ws_thread) + ws_thread.start() + waitForConnected(ws_thread, self.delay_event) + + sent_id = ws_thread.send_command("/groups") + response = waitForResponse(ws_thread, sent_id, self.delay_event) + + if len(response["resp"]["groups"]) < 1: + sent_id = ws_thread.send_command("/c " + network_config["group_link"]) + response = waitForResponse(ws_thread, sent_id, self.delay_event) + assert "groupLinkId" in response["resp"]["connection"] + + network = { + "type": "simplex", + "ws_thread": ws_thread, + } + + self.active_networks.append(network) diff --git a/basicswap/network/simplex_chat.py b/basicswap/network/simplex_chat.py new file mode 100644 index 0000000..df1951c --- /dev/null +++ b/basicswap/network/simplex_chat.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# Copyright (c) 2025 The Basicswap developers +# Distributed under the MIT software license, see the accompanying +# file LICENSE or http://www.opensource.org/licenses/mit-license.php. + +import os +import select +import subprocess +import time + +from basicswap.bin.run import Daemon + + +def initSimplexClient(args, logger, delay_event): + logger.info("Initialising Simplex client") + + (pipe_r, pipe_w) = os.pipe() # subprocess.PIPE is buffered, blocks when read + + if os.name == "nt": + str_args = " ".join(args) + p = subprocess.Popen( + str_args, shell=True, stdin=subprocess.PIPE, stdout=pipe_w, stderr=pipe_w + ) + else: + p = subprocess.Popen(args, stdin=subprocess.PIPE, stdout=pipe_w, stderr=pipe_w) + + def readOutput(): + buf = os.read(pipe_r, 1024).decode("utf-8") + response = None + # logging.debug(f"simplex-chat output: {buf}") + if "display name:" in buf: + logger.debug("Setting display name") + response = b"user\n" + else: + logger.debug(f"Unexpected output: {buf}") + return + if response is not None: + p.stdin.write(response) + p.stdin.flush() + + try: + start_time: int = time.time() + max_wait_seconds: int = 60 + while p.poll() is None: + if time.time() > start_time + max_wait_seconds: + raise ValueError("Timed out") + if os.name == "nt": + readOutput() + delay_event.wait(0.1) + continue + while len(select.select([pipe_r], [], [], 0)[0]) == 1: + readOutput() + delay_event.wait(0.1) + except Exception as e: + logger.error(f"initSimplexClient: {e}") + finally: + if p.poll() is None: + p.terminate() + os.close(pipe_r) + os.close(pipe_w) + p.stdin.close() + + +def startSimplexClient( + bin_path: str, + data_path: str, + server_address: str, + websocket_port: int, + logger, + delay_event, +) -> Daemon: + logger.info("Starting Simplex client") + if not os.path.exists(data_path): + os.makedirs(data_path) + + db_path = os.path.join(data_path, "simplex_client_data") + + args = [bin_path, "-d", db_path, "-s", server_address, "-p", str(websocket_port)] + + if not os.path.exists(db_path): + # Need to set initial profile through CLI + # TODO: Must be a better way? + init_args = args + ["-e", "/help"] # Run command ro exit client + initSimplexClient(init_args, logger, delay_event) + + args += ["-l", "debug"] + + opened_files = [] + stdout_dest = open( + os.path.join(data_path, "simplex_stdout.log"), + "w", + ) + opened_files.append(stdout_dest) + stderr_dest = stdout_dest + return Daemon( + subprocess.Popen( + args, + shell=False, + stdin=subprocess.PIPE, + stdout=stdout_dest, + stderr=stderr_dest, + cwd=data_path, + ), + opened_files, + ) diff --git a/basicswap/network/util.py b/basicswap/network/util.py new file mode 100644 index 0000000..0ec019f --- /dev/null +++ b/basicswap/network/util.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# Copyright (c) 2025 The Basicswap developers +# Distributed under the MIT software license, see the accompanying +# file LICENSE or http://www.opensource.org/licenses/mit-license.php. + +from basicswap.util.address import b58decode + + +def getMsgPubkey(self, msg) -> bytes: + if "pk_from" in msg: + return bytes.fromhex(msg["pk_from"]) + rv = self.callrpc( + "smsggetpubkey", + [ + msg["from"], + ], + ) + return b58decode(rv["publickey"]) diff --git a/basicswap/util/smsg.py b/basicswap/util/smsg.py new file mode 100644 index 0000000..6ab1fba --- /dev/null +++ b/basicswap/util/smsg.py @@ -0,0 +1,229 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright (c) 2025 The Basicswap developers +# Distributed under the MIT software license, see the accompanying +# file LICENSE or http://www.opensource.org/licenses/mit-license.php. + +import hashlib +import hmac +import secrets +import time + + +from typing import Union, Dict +from coincurve.keys import ( + PublicKey, + PrivateKey, +) +from Crypto.Cipher import AES + +from basicswap.util.crypto import hash160, sha256, ripemd160 +from basicswap.util.ecc import getSecretInt +from basicswap.contrib.test_framework.messages import ( + uint256_from_compact, + uint256_from_str, +) + + +AES_BLOCK_SIZE = 16 + + +def aes_pad(s: bytes): + c = AES_BLOCK_SIZE - len(s) % AES_BLOCK_SIZE + return s + (bytes((c,)) * c) + + +def aes_unpad(s: bytes): + return s[: -(s[len(s) - 1])] + + +def aes_encrypt(raw: bytes, pass_data: bytes, iv: bytes): + assert len(pass_data) == 32 + assert len(iv) == 16 + raw = aes_pad(raw) + cipher = AES.new(pass_data, AES.MODE_CBC, iv) + return cipher.encrypt(raw) + + +def aes_decrypt(enc, pass_data: bytes, iv: bytes): + assert len(pass_data) == 32 + assert len(iv) == 16 + cipher = AES.new(pass_data, AES.MODE_CBC, iv) + return aes_unpad(cipher.decrypt(enc)) + + +SMSG_MIN_TTL = 60 * 60 +SMSG_BUCKET_LEN = 60 * 60 +SMSG_HDR_LEN = ( + 108 # Length of unencrypted header, 4 + 4 + 2 + 1 + 8 + 4 + 16 + 33 + 32 + 4 +) +SMSG_PL_HDR_LEN = 1 + 20 + 65 + 4 # Length of encrypted header in payload + + +def smsgGetTimestamp(smsg_message: bytes) -> int: + assert len(smsg_message) > SMSG_HDR_LEN + return int.from_bytes(smsg_message[11 : 11 + 8], byteorder="little") + + +def smsgGetPOWHash(smsg_message: bytes) -> bytes: + assert len(smsg_message) > SMSG_HDR_LEN + ofs: int = 4 + nonce: bytes = smsg_message[ofs : ofs + 4] + iv: bytes = nonce * 8 + + m = hmac.new(iv, digestmod="SHA256") + m.update(smsg_message[4:]) + return m.digest() + + +def smsgGetID(smsg_message: bytes) -> bytes: + assert len(smsg_message) > SMSG_HDR_LEN + smsg_timestamp = int.from_bytes(smsg_message[11 : 11 + 8], byteorder="little") + return smsg_timestamp.to_bytes(8, byteorder="big") + ripemd160(smsg_message[8:]) + + +def smsgEncrypt(privkey_from: bytes, pubkey_to: bytes, payload: bytes) -> bytes: + # assert len(payload) < 128 # Requires lz4 if payload > 128 bytes + # TODO: Add lz4 to match core smsg + smsg_timestamp = int(time.time()) + r = getSecretInt().to_bytes(32, byteorder="big") + R = PublicKey.from_secret(r).format() + p = PrivateKey(r).ecdh(pubkey_to) + H = hashlib.sha512(p).digest() + key_e: bytes = H[:32] + key_m: bytes = H[32:] + + smsg_iv: bytes = secrets.token_bytes(16) + + payload_hash: bytes = sha256(sha256(payload)) + signature: bytes = PrivateKey(privkey_from).sign_recoverable( + payload_hash, hasher=None + ) + + # Convert format to BTC, add 4 to mark as compressed key + recid = signature[64] + signature = bytes((27 + recid + 4,)) + signature[:64] + + pubkey_from: bytes = PublicKey.from_secret(privkey_from).format() + pkh_from: bytes = hash160(pubkey_from) + + len_payload = len(payload) + address_version = 0 + plaintext_data: bytes = ( + bytes((address_version,)) + + pkh_from + + signature + + len_payload.to_bytes(4, byteorder="little") + + payload + ) + + ciphertext: bytes = aes_encrypt(plaintext_data, key_e, smsg_iv) + + m = hmac.new(key_m, digestmod="SHA256") + m.update(smsg_timestamp.to_bytes(8, byteorder="little")) + m.update(smsg_iv) + m.update(ciphertext) + mac: bytes = m.digest() + + smsg_hash = bytes((0,)) * 4 + smsg_nonce = bytes((0,)) * 4 + smsg_version = bytes((2, 1)) + smsg_flags = bytes((0,)) + + smsg_ttl = SMSG_MIN_TTL + + assert len(R) == 33 + assert len(mac) == 32 + + smsg_message: bytes = ( + smsg_hash + + smsg_nonce + + smsg_version + + smsg_flags + + smsg_timestamp.to_bytes(8, byteorder="little") + + smsg_ttl.to_bytes(4, byteorder="little") + + smsg_iv + + R + + mac + + len(ciphertext).to_bytes(4, byteorder="little") + + ciphertext + ) + + target: int = uint256_from_compact(0x1EFFFFFF) + + for i in range(1000000): + pow_hash = smsgGetPOWHash(smsg_message) + if uint256_from_str(pow_hash) > target: + smsg_nonce = (int.from_bytes(smsg_nonce, byteorder="little") + 1).to_bytes( + 4, byteorder="little" + ) + smsg_message = pow_hash[:4] + smsg_nonce + smsg_message[8:] + continue + smsg_message = pow_hash[:4] + smsg_message[4:] + return smsg_message + raise ValueError("Failed to set POW hash.") + + +def smsgDecrypt( + privkey_to: bytes, encrypted_message: bytes, output_dict: bool = False +) -> Union[bytes, Dict]: + # Without lz4 + + assert len(encrypted_message) > SMSG_HDR_LEN + smsg_timestamp = int.from_bytes(encrypted_message[11 : 11 + 8], byteorder="little") + ofs: int = 23 + smsg_iv = encrypted_message[ofs : ofs + 16] + + ofs += 16 + R = encrypted_message[ofs : ofs + 33] + ofs += 33 + mac = encrypted_message[ofs : ofs + 32] + ofs += 32 + ciphertextlen = int.from_bytes(encrypted_message[ofs : ofs + 4], byteorder="little") + ofs += 4 + ciphertext = encrypted_message[ofs:] + assert len(ciphertext) == ciphertextlen + + p = PrivateKey(privkey_to).ecdh(R) + H = hashlib.sha512(p).digest() + key_e: bytes = H[:32] + key_m: bytes = H[32:] + + m = hmac.new(key_m, digestmod="SHA256") + m.update(smsg_timestamp.to_bytes(8, byteorder="little")) + m.update(smsg_iv) + m.update(ciphertext) + mac_calculated: bytes = m.digest() + + assert mac == mac_calculated + + plaintext = aes_decrypt(ciphertext, key_e, smsg_iv) + + ofs = 1 + pkh_from = plaintext[ofs : ofs + 20] + ofs += 20 + signature = plaintext[ofs : ofs + 65] + ofs += 65 + ofs += 4 + payload = plaintext[ofs:] + payload_hash: bytes = sha256(sha256(payload)) + + # Convert format from BTC + recid = (signature[0] - 27) & 3 + signature = signature[1:] + bytes((recid,)) + + pubkey_signer = PublicKey.from_signature_and_message( + signature, payload_hash, hasher=None + ).format() + pkh_from_recovered: bytes = hash160(pubkey_signer) + assert pkh_from == pkh_from_recovered + + if output_dict: + return { + "msgid": smsgGetID(encrypted_message).hex(), + "sent": smsg_timestamp, + "hex": payload.hex(), + "pk_from": pubkey_signer.hex(), + } + return payload diff --git a/requirements.in b/requirements.in index 3f37b10..199ffc8 100644 --- a/requirements.in +++ b/requirements.in @@ -3,4 +3,5 @@ python-gnupg==0.5.4 Jinja2==3.1.6 pycryptodome==3.21.0 PySocks==1.7.1 +websocket-client==1.8.0 coincurve@https://github.com/basicswap/coincurve/archive/refs/tags/basicswap_v0.2.zip diff --git a/requirements.txt b/requirements.txt index cc1cd13..fff166f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with Python 3.12 +# This file is autogenerated by pip-compile with Python 3.13 # by the following command: # # pip-compile --generate-hashes --output-file=requirements.txt requirements.in @@ -305,3 +305,7 @@ pyzmq==26.2.1 \ --hash=sha256:f9ba5def063243793dec6603ad1392f735255cbc7202a3a484c14f99ec290705 \ --hash=sha256:fc409c18884eaf9ddde516d53af4f2db64a8bc7d81b1a0c274b8aa4e929958e8 # via -r requirements.in +websocket-client==1.8.0 \ + --hash=sha256:17b44cc997f5c498e809b22cdf2d9c7a9e71c02c8cc2b6c56e7c2d1239bfa526 \ + --hash=sha256:3239df9f44da632f96012472805d40a23281a991027ce11d2f45a6f24ac4c3da + # via -r requirements.in diff --git a/tests/basicswap/extended/test_doge.py b/tests/basicswap/extended/test_doge.py index 0bcc63f..10a45d5 100644 --- a/tests/basicswap/extended/test_doge.py +++ b/tests/basicswap/extended/test_doge.py @@ -30,7 +30,6 @@ from basicswap.contrib.test_framework.messages import ( CTransaction, CTxIn, COutPoint, - ToHex, ) from basicswap.contrib.test_framework.script import ( CScript, @@ -318,7 +317,7 @@ class Test(TestFunctions): tx = CTransaction() tx.nVersion = ci.txVersion() tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest)) - tx_hex = ToHex(tx) + tx_hex = tx.serialize().hex() tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex]) utxo_pos = 0 if tx_funded["changepos"] == 1 else 1 tx_signed = ci.rpc_wallet( @@ -357,10 +356,10 @@ class Test(TestFunctions): ) ) tx_spend.vout.append(ci.txoType()(ci.make_int(1.099), script_out)) - tx_spend_hex = ToHex(tx_spend) + tx_spend_hex = tx_spend.serialize().hex() tx_spend.nLockTime = chain_height + 2 - tx_spend_invalid_hex = ToHex(tx_spend) + tx_spend_invalid_hex = tx_spend.serialize().hex() for tx_hex in [tx_spend_invalid_hex, tx_spend_hex]: try: diff --git a/tests/basicswap/extended/test_simplex.py b/tests/basicswap/extended/test_simplex.py new file mode 100644 index 0000000..83af4a8 --- /dev/null +++ b/tests/basicswap/extended/test_simplex.py @@ -0,0 +1,342 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright (c) 2025 The Basicswap developers +# Distributed under the MIT software license, see the accompanying +# file LICENSE or http://www.opensource.org/licenses/mit-license.php. + +""" +docker run \ + -e "ADDR=127.0.0.1" \ + -e "PASS=password" \ + -p 5223:5223 \ + -v /tmp/simplex/smp/config:/etc/opt/simplex:z \ + -v /tmp/simplex/smp/logs:/var/opt/simplex:z \ + -v /tmp/simplex/certs:/certificates \ + simplexchat/smp-server:latest + +Fingerprint: Q8SNxc2SRcKyXlhJM8KFUgPNW4KXPGRm4eSLtT_oh-I= + +export SIMPLEX_SERVER_ADDRESS=smp://Q8SNxc2SRcKyXlhJM8KFUgPNW4KXPGRm4eSLtT_oh-I=:password@127.0.0.1:5223,443 + + +https://github.com/simplex-chat/simplex-chat/issues/4127 + json: {"corrId":"3","cmd":"/_send #1 text test123"} + direct message: {"corrId":"1","cmd":"/_send @2 text the message"} + +""" + +import json +import logging +import os +import random +import shutil +import sys +import unittest + +import basicswap.config as cfg + +from basicswap.basicswap import ( + BidStates, + SwapTypes, +) +from basicswap.chainparams import Coins + +from basicswap.network.simplex import ( + WebSocketThread, + waitForConnected, + waitForResponse, +) +from basicswap.network.simplex_chat import startSimplexClient +from tests.basicswap.common import ( + stopDaemons, + wait_for_bid, + wait_for_offer, +) +from tests.basicswap.test_xmr import BaseTest, test_delay_event, RESET_TEST + + +SIMPLEX_SERVER_ADDRESS = os.getenv("SIMPLEX_SERVER_ADDRESS") +SIMPLEX_CLIENT_PATH = os.path.expanduser(os.getenv("SIMPLEX_CLIENT_PATH")) +TEST_DIR = cfg.TEST_DATADIRS + + +logger = logging.getLogger() +logger.level = logging.DEBUG +if not len(logger.handlers): + logger.addHandler(logging.StreamHandler(sys.stdout)) + + +class TestSimplex(unittest.TestCase): + daemons = [] + remove_testdir: bool = False + + @classmethod + def tearDownClass(cls): + stopDaemons(cls.daemons) + + def test_basic(self): + + if os.path.isdir(TEST_DIR): + if RESET_TEST: + logging.info("Removing " + TEST_DIR) + shutil.rmtree(TEST_DIR) + else: + logging.info("Restoring instance from " + TEST_DIR) + if not os.path.exists(TEST_DIR): + os.makedirs(TEST_DIR) + + client1_dir = os.path.join(TEST_DIR, "client1") + if os.path.exists(client1_dir): + shutil.rmtree(client1_dir) + + client1_daemon = startSimplexClient( + SIMPLEX_CLIENT_PATH, + client1_dir, + SIMPLEX_SERVER_ADDRESS, + 5225, + logger, + test_delay_event, + ) + self.daemons.append(client1_daemon) + + client2_dir = os.path.join(TEST_DIR, "client2") + if os.path.exists(client2_dir): + shutil.rmtree(client2_dir) + client2_daemon = startSimplexClient( + SIMPLEX_CLIENT_PATH, + client2_dir, + SIMPLEX_SERVER_ADDRESS, + 5226, + logger, + test_delay_event, + ) + self.daemons.append(client2_daemon) + + threads = [] + try: + ws_thread = WebSocketThread("ws://127.0.0.1:5225", tag="C1") + ws_thread.start() + threads.append(ws_thread) + + ws_thread2 = WebSocketThread("ws://127.0.0.1:5226", tag="C2") + ws_thread2.start() + threads.append(ws_thread2) + + waitForConnected(ws_thread, test_delay_event) + sent_id = ws_thread.send_command("/group bsx") + response = waitForResponse(ws_thread, sent_id, test_delay_event) + assert response["resp"]["type"] == "groupCreated" + + ws_thread.send_command("/set voice #bsx off") + ws_thread.send_command("/set files #bsx off") + ws_thread.send_command("/set direct #bsx off") + ws_thread.send_command("/set reactions #bsx off") + ws_thread.send_command("/set reports #bsx off") + ws_thread.send_command("/set disappear #bsx on week") + sent_id = ws_thread.send_command("/create link #bsx") + + connReqContact = None + connReqMsgData = waitForResponse(ws_thread, sent_id, test_delay_event) + connReqContact = connReqMsgData["resp"]["connReqContact"] + + group_link = "https://simplex.chat" + connReqContact[8:] + logger.info(f"group_link: {group_link}") + + sent_id = ws_thread2.send_command("/c " + group_link) + response = waitForResponse(ws_thread2, sent_id, test_delay_event) + assert "groupLinkId" in response["resp"]["connection"] + + sent_id = ws_thread2.send_command("/groups") + response = waitForResponse(ws_thread2, sent_id, test_delay_event) + assert len(response["resp"]["groups"]) == 1 + + ws_thread.send_command("#bsx test msg 1") + + found_1 = False + found_2 = False + for i in range(100): + message = ws_thread.queue_get() + if message is not None: + data = json.loads(message) + # print(f"message 1: {json.dumps(data, indent=4)}") + try: + if data["resp"]["type"] in ( + "chatItemsStatusesUpdated", + "newChatItems", + ): + for chat_item in data["resp"]["chatItems"]: + # print(f"chat_item 1: {json.dumps(chat_item, indent=4)}") + if chat_item["chatItem"]["meta"]["itemStatus"][ + "type" + ] in ("sndRcvd", "rcvNew"): + if ( + chat_item["chatItem"]["content"]["msgContent"][ + "text" + ] + == "test msg 1" + ): + found_1 = True + except Exception as e: + print(f"error 1: {e}") + + message = ws_thread2.queue_get() + if message is not None: + data = json.loads(message) + # print(f"message 2: {json.dumps(data, indent=4)}") + try: + if data["resp"]["type"] in ( + "chatItemsStatusesUpdated", + "newChatItems", + ): + for chat_item in data["resp"]["chatItems"]: + # print(f"chat_item 1: {json.dumps(chat_item, indent=4)}") + if chat_item["chatItem"]["meta"]["itemStatus"][ + "type" + ] in ("sndRcvd", "rcvNew"): + if ( + chat_item["chatItem"]["content"]["msgContent"][ + "text" + ] + == "test msg 1" + ): + found_2 = True + except Exception as e: + print(f"error 2: {e}") + + if found_1 and found_2: + break + test_delay_event.wait(0.5) + + assert found_1 is True + assert found_2 is True + + finally: + for t in threads: + t.stop() + t.join() + + +class Test(BaseTest): + __test__ = True + start_ltc_nodes = False + start_xmr_nodes = True + group_link = None + daemons = [] + coin_to = Coins.XMR + # coin_to = Coins.PART + + @classmethod + def prepareTestDir(cls): + base_ws_port: int = 5225 + for i in range(cls.num_nodes): + + client_dir = os.path.join(TEST_DIR, f"simplex_client{i}") + if os.path.exists(client_dir): + shutil.rmtree(client_dir) + + client_daemon = startSimplexClient( + SIMPLEX_CLIENT_PATH, + client_dir, + SIMPLEX_SERVER_ADDRESS, + base_ws_port + i, + logger, + test_delay_event, + ) + cls.daemons.append(client_daemon) + + # Create the group for bsx + logger.info("Creating BSX group") + ws_thread = None + try: + ws_thread = WebSocketThread(f"ws://127.0.0.1:{base_ws_port}", tag="C0") + ws_thread.start() + waitForConnected(ws_thread, test_delay_event) + sent_id = ws_thread.send_command("/group bsx") + response = waitForResponse(ws_thread, sent_id, test_delay_event) + assert response["resp"]["type"] == "groupCreated" + + ws_thread.send_command("/set voice #bsx off") + ws_thread.send_command("/set files #bsx off") + ws_thread.send_command("/set direct #bsx off") + ws_thread.send_command("/set reactions #bsx off") + ws_thread.send_command("/set reports #bsx off") + ws_thread.send_command("/set disappear #bsx on week") + sent_id = ws_thread.send_command("/create link #bsx") + + connReqContact = None + connReqMsgData = waitForResponse(ws_thread, sent_id, test_delay_event) + connReqContact = connReqMsgData["resp"]["connReqContact"] + cls.group_link = "https://simplex.chat" + connReqContact[8:] + logger.info(f"BSX group_link: {cls.group_link}") + + finally: + if ws_thread: + ws_thread.stop() + ws_thread.join() + + @classmethod + def tearDownClass(cls): + logging.info("Finalising Test") + super(Test, cls).tearDownClass() + stopDaemons(cls.daemons) + + @classmethod + def addCoinSettings(cls, settings, datadir, node_id): + + settings["networks"] = [ + { + "type": "simplex", + "server_address": SIMPLEX_SERVER_ADDRESS, + "client_path": SIMPLEX_CLIENT_PATH, + "ws_port": 5225 + node_id, + "group_link": cls.group_link, + }, + ] + + def test_01_swap(self): + logging.info("---------- Test xmr swap") + + swap_clients = self.swap_clients + + for sc in swap_clients: + sc.dleag_split_size_init = 9000 + sc.dleag_split_size = 11000 + + assert len(swap_clients[0].active_networks) == 1 + assert swap_clients[0].active_networks[0]["type"] == "simplex" + + coin_from = Coins.BTC + coin_to = self.coin_to + + ci_from = swap_clients[0].ci(coin_from) + ci_to = swap_clients[1].ci(coin_to) + + swap_value = ci_from.make_int(random.uniform(0.2, 20.0), r=1) + rate_swap = ci_to.make_int(random.uniform(0.2, 20.0), r=1) + offer_id = swap_clients[0].postOffer( + coin_from, coin_to, swap_value, rate_swap, swap_value, SwapTypes.XMR_SWAP + ) + + wait_for_offer(test_delay_event, swap_clients[1], offer_id) + offer = swap_clients[1].getOffer(offer_id) + bid_id = swap_clients[1].postBid(offer_id, offer.amount_from) + + wait_for_bid(test_delay_event, swap_clients[0], bid_id, BidStates.BID_RECEIVED) + swap_clients[0].acceptBid(bid_id) + + wait_for_bid( + test_delay_event, + swap_clients[0], + bid_id, + BidStates.SWAP_COMPLETED, + wait_for=320, + ) + wait_for_bid( + test_delay_event, + swap_clients[1], + bid_id, + BidStates.SWAP_COMPLETED, + sent=True, + wait_for=320, + ) diff --git a/tests/basicswap/extended/test_smsg.py b/tests/basicswap/extended/test_smsg.py new file mode 100644 index 0000000..d6170c3 --- /dev/null +++ b/tests/basicswap/extended/test_smsg.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright (c) 2025 The Basicswap developers +# Distributed under the MIT software license, see the accompanying +# file LICENSE or http://www.opensource.org/licenses/mit-license.php. + +import logging + +from basicswap.chainparams import Coins +from basicswap.util.smsg import ( + smsgEncrypt, + smsgDecrypt, + smsgGetID, + smsgGetTimestamp, + SMSG_BUCKET_LEN, +) +from basicswap.contrib.test_framework.messages import ( + NODE_SMSG, + msg_smsgPong, + msg_smsgMsg, +) +from basicswap.contrib.test_framework.p2p import ( + P2PInterface, + P2P_SERVICES, + NetworkThread, +) +from basicswap.contrib.test_framework.util import ( + PortSeed, +) + +from tests.basicswap.common import BASE_PORT +from tests.basicswap.test_xmr import BaseTest, test_delay_event + + +class P2PInterfaceSMSG(P2PInterface): + def __init__(self): + super().__init__() + self.is_part = True + + def on_smsgPing(self, msg): + logging.info("on_smsgPing") + self.send_message(msg_smsgPong(1)) + + def on_smsgPong(self, msg): + logging.info("on_smsgPong", msg) + + def on_smsgInv(self, msg): + logging.info("on_smsgInv") + + +def wait_for_smsg(ci, msg_id: str, wait_for=20) -> None: + for i in range(wait_for): + if test_delay_event.is_set(): + raise ValueError("Test stopped.") + try: + ci.rpc_wallet("smsg", [msg_id]) + return + except Exception as e: + logging.info(e) + test_delay_event.wait(1) + + raise ValueError("wait_for_smsg timed out.") + + +class Test(BaseTest): + __test__ = True + start_ltc_nodes = False + start_xmr_nodes = False + + @classmethod + def setUpClass(cls): + super(Test, cls).setUpClass() + PortSeed.n = 1 + + logging.info("Setting up network thread") + cls.network_thread = NetworkThread() + cls.network_thread.network_event_loop.set_debug(True) + cls.network_thread.start() + cls.network_thread.network_event_loop.set_debug(True) + + @classmethod + def run_loop_ended(cls): + logging.info("run_loop_ended") + logging.info("Closing down network thread") + cls.network_thread.close() + + @classmethod + def tearDownClass(cls): + logging.info("Finalising Test") + + # logging.info('Closing down network thread') + # cls.network_thread.close() + + super(Test, cls).tearDownClass() + + @classmethod + def coins_loop(cls): + super(Test, cls).coins_loop() + + def test_01_p2p(self): + swap_clients = self.swap_clients + + kwargs = {} + kwargs["dstport"] = BASE_PORT + kwargs["dstaddr"] = "127.0.0.1" + services = P2P_SERVICES | NODE_SMSG + p2p_conn = P2PInterfaceSMSG() + p2p_conn.p2p_connected_to_node = True + p2p_conn.peer_connect( + **kwargs, + services=services, + send_version=True, + net="regtest", + timeout_factor=99999, + supports_v2_p2p=False, + )() + + p2p_conn.wait_for_connect() + p2p_conn.wait_for_verack() + p2p_conn.sync_with_ping() + + ci0_part = swap_clients[0].ci(Coins.PART) + test_key_recv: bytes = ci0_part.getNewRandomKey() + test_key_recv_wif: str = ci0_part.encodeKey(test_key_recv) + test_key_recv_pk: bytes = ci0_part.getPubkey(test_key_recv) + ci0_part.rpc("smsgimportprivkey", [test_key_recv_wif, "test key"]) + + message_test: str = "Test message" + test_key_send: bytes = ci0_part.getNewRandomKey() + encrypted_message: bytes = smsgEncrypt( + test_key_send, test_key_recv_pk, message_test.encode("utf-8") + ) + + decrypted_message: bytes = smsgDecrypt(test_key_recv, encrypted_message) + assert decrypted_message.decode("utf-8") == message_test + + msg_id: bytes = smsgGetID(encrypted_message) + smsg_timestamp: int = smsgGetTimestamp(encrypted_message) + smsg_bucket: int = smsg_timestamp - (smsg_timestamp % SMSG_BUCKET_LEN) + + smsgMsg = msg_smsgMsg(1, smsg_bucket, encrypted_message) + p2p_conn.send_message(smsgMsg) + + wait_for_smsg(ci0_part, msg_id.hex()) + rv = ci0_part.rpc_wallet("smsg", [msg_id.hex()]) + assert rv["text"] == message_test diff --git a/tests/basicswap/test_bch_xmr.py b/tests/basicswap/test_bch_xmr.py index d45c438..be76fe5 100644 --- a/tests/basicswap/test_bch_xmr.py +++ b/tests/basicswap/test_bch_xmr.py @@ -26,7 +26,6 @@ from tests.basicswap.common import ( waitForRPC, ) from basicswap.contrib.test_framework.messages import ( - ToHex, CTxIn, COutPoint, CTransaction, @@ -251,7 +250,7 @@ class TestBCH(BasicSwapTest): tx = CTransaction() tx.nVersion = ci.txVersion() tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest)) - tx_hex = ToHex(tx) + tx_hex = tx.serialize().hex() tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex]) utxo_pos = 0 if tx_funded["changepos"] == 1 else 1 tx_signed = ci.rpc_wallet( @@ -285,10 +284,10 @@ class TestBCH(BasicSwapTest): ) ) tx_spend.vout.append(ci.txoType()(ci.make_int(1.0999), script_out)) - tx_spend_hex = ToHex(tx_spend) + tx_spend_hex = tx_spend.serialize().hex() tx_spend.nLockTime = chain_height + 2 - tx_spend_invalid_hex = ToHex(tx_spend) + tx_spend_invalid_hex = tx_spend.serialize().hex() for tx_hex in [tx_spend_invalid_hex, tx_spend_hex]: try: @@ -362,7 +361,7 @@ class TestBCH(BasicSwapTest): tx = CTransaction() tx.nVersion = ci.txVersion() tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest)) - tx_hex = ToHex(tx) + tx_hex = tx.serialize().hex() tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex]) utxo_pos = 0 if tx_funded["changepos"] == 1 else 1 tx_signed = ci.rpc_wallet( @@ -405,7 +404,7 @@ class TestBCH(BasicSwapTest): ) ) tx_spend.vout.append(ci.txoType()(ci.make_int(1.0999), script_out)) - tx_spend_hex = ToHex(tx_spend) + tx_spend_hex = tx_spend.serialize().hex() try: txid = ci.rpc( "sendrawtransaction", @@ -640,7 +639,7 @@ class TestBCH(BasicSwapTest): tx = CTransaction() tx.nVersion = ci.txVersion() tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest)) - tx_hex = ToHex(tx) + tx_hex = tx.serialize().hex() tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex]) utxo_pos = 0 if tx_funded["changepos"] == 1 else 1 tx_signed = ci.rpc_wallet( @@ -682,7 +681,7 @@ class TestBCH(BasicSwapTest): ) ) tx_spend.vout.append(ci.txoType()(ci.make_int(1.0999), script_out)) - tx_spend_hex = ToHex(tx_spend) + tx_spend_hex = tx_spend.serialize().hex() txid = ci.rpc( "sendrawtransaction", @@ -730,7 +729,7 @@ class TestBCH(BasicSwapTest): tx = CTransaction() tx.nVersion = ci.txVersion() tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest)) - tx_hex = ToHex(tx) + tx_hex = tx.serialize().hex() tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex]) utxo_pos = 0 if tx_funded["changepos"] == 1 else 1 tx_signed = ci.rpc_wallet( @@ -772,7 +771,7 @@ class TestBCH(BasicSwapTest): ) ) tx_spend.vout.append(ci.txoType()(ci.make_int(1.0999), script_out)) - tx_spend_hex = ToHex(tx_spend) + tx_spend_hex = tx_spend.serialize().hex() txid = ci.rpc( "sendrawtransaction", diff --git a/tests/basicswap/test_btc_xmr.py b/tests/basicswap/test_btc_xmr.py index c314f56..f8b8e41 100644 --- a/tests/basicswap/test_btc_xmr.py +++ b/tests/basicswap/test_btc_xmr.py @@ -46,8 +46,7 @@ from tests.basicswap.common import ( ) from basicswap.contrib.test_framework.descriptors import descsum_create from basicswap.contrib.test_framework.messages import ( - ToHex, - FromHex, + from_hex, CTxIn, COutPoint, CTransaction, @@ -860,7 +859,7 @@ class BasicSwapTest(TestFunctions): addr_p2sh_segwit, ], ) - decoded_tx = FromHex(CTransaction(), tx_funded) + decoded_tx = from_hex(CTransaction(), tx_funded) decoded_tx.vin[0].scriptSig = bytes.fromhex("16" + addr_p2sh_segwit_info["hex"]) txid_with_scriptsig = decoded_tx.rehash() assert txid_with_scriptsig == tx_signed_decoded["txid"] @@ -950,7 +949,7 @@ class BasicSwapTest(TestFunctions): tx = CTransaction() tx.nVersion = ci.txVersion() tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest)) - tx_hex = ToHex(tx) + tx_hex = tx.serialize().hex() tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex]) utxo_pos = 0 if tx_funded["changepos"] == 1 else 1 tx_signed = ci.rpc_wallet( @@ -979,10 +978,10 @@ class BasicSwapTest(TestFunctions): tx_spend.wit.vtxinwit[0].scriptWitness.stack = [ script, ] - tx_spend_hex = ToHex(tx_spend) + tx_spend_hex = tx_spend.serialize().hex() tx_spend.nLockTime = chain_height + 2 - tx_spend_invalid_hex = ToHex(tx_spend) + tx_spend_invalid_hex = tx_spend.serialize().hex() for tx_hex in [tx_spend_invalid_hex, tx_spend_hex]: try: @@ -1055,7 +1054,7 @@ class BasicSwapTest(TestFunctions): tx = CTransaction() tx.nVersion = ci.txVersion() tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest)) - tx_hex = ToHex(tx) + tx_hex = tx.serialize().hex() tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex]) utxo_pos = 0 if tx_funded["changepos"] == 1 else 1 tx_signed = ci.rpc_wallet( @@ -1094,7 +1093,7 @@ class BasicSwapTest(TestFunctions): tx_spend.wit.vtxinwit[0].scriptWitness.stack = [ script, ] - tx_spend_hex = ToHex(tx_spend) + tx_spend_hex = tx_spend.serialize().hex() try: txid = ci.rpc( "sendrawtransaction", @@ -1435,7 +1434,7 @@ class BasicSwapTest(TestFunctions): tx = CTransaction() tx.nVersion = ci.txVersion() tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest)) - tx_hex = ToHex(tx) + tx_hex = tx.serialize().hex() tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex]) utxo_pos = 0 if tx_funded["changepos"] == 1 else 1 tx_signed = ci.rpc_wallet( @@ -1477,7 +1476,7 @@ class BasicSwapTest(TestFunctions): ) ) tx_spend.vout.append(ci.txoType()(ci.make_int(1.0999), script_out)) - tx_spend_hex = ToHex(tx_spend) + tx_spend_hex = tx_spend.serialize().hex() txid = ci.rpc( "sendrawtransaction", @@ -1525,7 +1524,7 @@ class BasicSwapTest(TestFunctions): tx = CTransaction() tx.nVersion = ci.txVersion() tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest)) - tx_hex = ToHex(tx) + tx_hex = tx.serialize().hex() tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex]) utxo_pos = 0 if tx_funded["changepos"] == 1 else 1 tx_signed = ci.rpc_wallet( @@ -1567,7 +1566,7 @@ class BasicSwapTest(TestFunctions): tx_spend.wit.vtxinwit[0].scriptWitness.stack = [ script, ] - tx_spend_hex = ToHex(tx_spend) + tx_spend_hex = tx_spend.serialize().hex() txid = ci.rpc( "sendrawtransaction", diff --git a/tests/basicswap/test_run.py b/tests/basicswap/test_run.py index c1569ca..9cfd2f0 100644 --- a/tests/basicswap/test_run.py +++ b/tests/basicswap/test_run.py @@ -56,7 +56,6 @@ from basicswap.contrib.test_framework.messages import ( CTransaction, CTxIn, CTxInWitness, - ToHex, ) from basicswap.contrib.test_framework.script import ( CScript, @@ -211,7 +210,7 @@ class Test(BaseTest): tx = CTransaction() tx.nVersion = ci.txVersion() tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest)) - tx_hex = ToHex(tx) + tx_hex = tx.serialize().hex() tx_funded = callnoderpc(0, "fundrawtransaction", [tx_hex]) utxo_pos = 0 if tx_funded["changepos"] == 1 else 1 tx_signed = callnoderpc( @@ -248,10 +247,10 @@ class Test(BaseTest): tx_spend.wit.vtxinwit[0].scriptWitness.stack = [ script, ] - tx_spend_hex = ToHex(tx_spend) + tx_spend_hex = tx_spend.serialize().hex() tx_spend.nLockTime = chain_height + 2 - tx_spend_invalid_hex = ToHex(tx_spend) + tx_spend_invalid_hex = tx_spend.serialize().hex() for tx_hex in [tx_spend_invalid_hex, tx_spend_hex]: try: diff --git a/tests/basicswap/test_xmr.py b/tests/basicswap/test_xmr.py index a85c7db..d8220b0 100644 --- a/tests/basicswap/test_xmr.py +++ b/tests/basicswap/test_xmr.py @@ -247,7 +247,7 @@ def ltcCli(cmd, node_id=0): def signal_handler(sig, frame): - logging.info("signal {} detected.".format(sig)) + logging.info(f"signal {sig} detected.") signal_event.set() test_delay_event.set() @@ -309,6 +309,7 @@ def run_loop(cls): for c in cls.swap_clients: c.update() test_delay_event.wait(1.0) + cls.run_loop_ended() class BaseTest(unittest.TestCase): @@ -322,12 +323,13 @@ class BaseTest(unittest.TestCase): ltc_daemons = [] xmr_daemons = [] xmr_wallet_auth = [] - restore_instance = False - extra_wait_time = 0 + restore_instance: bool = False + extra_wait_time: int = 0 + num_nodes: int = NUM_NODES - start_ltc_nodes = False - start_xmr_nodes = True - has_segwit = True + start_ltc_nodes: bool = False + start_xmr_nodes: bool = True + has_segwit: bool = True xmr_addr = None btc_addr = None @@ -392,6 +394,8 @@ class BaseTest(unittest.TestCase): cls.stream_fp.setFormatter(formatter) logger.addHandler(cls.stream_fp) + cls.prepareTestDir() + try: logging.info("Preparing coin nodes.") for i in range(NUM_NODES): @@ -645,6 +649,7 @@ class BaseTest(unittest.TestCase): start_nodes, cls, ) + basicswap_dir = os.path.join( os.path.join(TEST_DIR, "basicswap_" + str(i)) ) @@ -966,6 +971,10 @@ class BaseTest(unittest.TestCase): super(BaseTest, cls).tearDownClass() + @classmethod + def prepareTestDir(cls): + pass + @classmethod def addCoinSettings(cls, settings, datadir, node_id): pass @@ -995,6 +1004,10 @@ class BaseTest(unittest.TestCase): {"wallet_address": cls.xmr_addr, "amount_of_blocks": 1}, ) + @classmethod + def run_loop_ended(cls): + pass + @classmethod def waitForParticlHeight(cls, num_blocks, node_id=0): logging.info(