Merge pull request #295 from tecnovert/multinet

Add simplex chat test.
This commit is contained in:
tecnovert
2025-04-10 23:02:39 +00:00
committed by GitHub
21 changed files with 3168 additions and 591 deletions

View File

@@ -122,8 +122,16 @@ from .explorers import (
ExplorerBitAps, ExplorerBitAps,
ExplorerChainz, ExplorerChainz,
) )
from .network.simplex import (
initialiseSimplexNetwork,
sendSimplexMsg,
readSimplexMsgs,
)
from .network.util import (
getMsgPubkey,
)
import basicswap.config as cfg import basicswap.config as cfg
import basicswap.network as bsn import basicswap.network.network as bsn
import basicswap.protocols.atomic_swap_1 as atomic_swap_1 import basicswap.protocols.atomic_swap_1 as atomic_swap_1
import basicswap.protocols.xmr_swap_1 as xmr_swap_1 import basicswap.protocols.xmr_swap_1 as xmr_swap_1
from .basicswap_util import ( from .basicswap_util import (
@@ -428,6 +436,9 @@ class BasicSwap(BaseApp):
self.swaps_in_progress = dict() self.swaps_in_progress = dict()
self.dleag_split_size_init = 16000
self.dleag_split_size = 17000
self.SMSG_SECONDS_IN_HOUR = ( self.SMSG_SECONDS_IN_HOUR = (
60 * 60 60 * 60
) # Note: Set smsgsregtestadjust=0 for regtest ) # Note: Set smsgsregtestadjust=0 for regtest
@@ -526,6 +537,8 @@ class BasicSwap(BaseApp):
self._network = None self._network = None
for t in self.threads: for t in self.threads:
if hasattr(t, "stop") and callable(t.stop):
t.stop()
t.join() t.join()
if sys.version_info[1] >= 9: if sys.version_info[1] >= 9:
@@ -1078,6 +1091,17 @@ class BasicSwap(BaseApp):
f"network_key {self.network_key}\nnetwork_pubkey {self.network_pubkey}\nnetwork_addr {self.network_addr}" f"network_key {self.network_key}\nnetwork_pubkey {self.network_pubkey}\nnetwork_addr {self.network_addr}"
) )
self.active_networks = []
network_config_list = self.settings.get("networks", [])
if len(network_config_list) < 1:
network_config_list = [{"type": "smsg", "enabled": True}]
for network in network_config_list:
if network["type"] == "smsg":
self.active_networks.append({"type": "smsg"})
elif network["type"] == "simplex":
initialiseSimplexNetwork(self, network)
ro = self.callrpc("smsglocalkeys") ro = self.callrpc("smsglocalkeys")
found = False found = False
for k in ro["smsg_keys"]: for k in ro["smsg_keys"]:
@@ -1655,6 +1679,33 @@ class BasicSwap(BaseApp):
bid_valid = (bid.expire_at - now) + 10 * 60 # Add 10 minute buffer bid_valid = (bid.expire_at - now) + 10 * 60 # Add 10 minute buffer
return max(smsg_min_valid, min(smsg_max_valid, bid_valid)) return max(smsg_min_valid, min(smsg_max_valid, bid_valid))
def sendMessage(
self, addr_from: str, addr_to: str, payload_hex: bytes, msg_valid: int, cursor
) -> bytes:
message_id: bytes = None
# First network in list will set message_id
for network in self.active_networks:
net_message_id = None
if network["type"] == "smsg":
net_message_id = self.sendSmsg(
addr_from, addr_to, payload_hex, msg_valid
)
elif network["type"] == "simplex":
net_message_id = sendSimplexMsg(
self,
network,
addr_from,
addr_to,
bytes.fromhex(payload_hex),
msg_valid,
cursor,
)
else:
raise ValueError("Unknown network: {}".format(network["type"]))
if not message_id:
message_id = net_message_id
return message_id
def sendSmsg( def sendSmsg(
self, addr_from: str, addr_to: str, payload_hex: bytes, msg_valid: int self, addr_from: str, addr_to: str, payload_hex: bytes, msg_valid: int
) -> bytes: ) -> bytes:
@@ -2200,7 +2251,9 @@ class BasicSwap(BaseApp):
offer_bytes = msg_buf.to_bytes() offer_bytes = msg_buf.to_bytes()
payload_hex = str.format("{:02x}", MessageTypes.OFFER) + offer_bytes.hex() payload_hex = str.format("{:02x}", MessageTypes.OFFER) + offer_bytes.hex()
msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, valid_for_seconds) msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, valid_for_seconds)
offer_id = self.sendSmsg(offer_addr, offer_addr_to, payload_hex, msg_valid) offer_id = self.sendMessage(
offer_addr, offer_addr_to, payload_hex, msg_valid, cursor
)
security_token = extra_options.get("security_token", None) security_token = extra_options.get("security_token", None)
if security_token is not None and len(security_token) != 20: if security_token is not None and len(security_token) != 20:
@@ -2305,8 +2358,8 @@ class BasicSwap(BaseApp):
) )
msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, offer.time_valid) msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, offer.time_valid)
msg_id = self.sendSmsg( msg_id = self.sendMessage(
offer.addr_from, self.network_addr, payload_hex, msg_valid offer.addr_from, self.network_addr, payload_hex, msg_valid, cursor
) )
self.log.debug( self.log.debug(
f"Revoked offer {self.log.id(offer_id)} in msg {self.log.id(msg_id)}" f"Revoked offer {self.log.id(offer_id)} in msg {self.log.id(msg_id)}"
@@ -3152,7 +3205,9 @@ class BasicSwap(BaseApp):
bid_addr = self.prepareSMSGAddress(addr_send_from, AddressTypes.BID, cursor) bid_addr = self.prepareSMSGAddress(addr_send_from, AddressTypes.BID, cursor)
msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, valid_for_seconds) msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, valid_for_seconds)
bid_id = self.sendSmsg(bid_addr, offer.addr_from, payload_hex, msg_valid) bid_id = self.sendMessage(
bid_addr, offer.addr_from, payload_hex, msg_valid, cursor
)
bid = Bid( bid = Bid(
protocol_version=msg_buf.protocol_version, protocol_version=msg_buf.protocol_version,
@@ -3488,8 +3543,8 @@ class BasicSwap(BaseApp):
) )
msg_valid: int = self.getAcceptBidMsgValidTime(bid) msg_valid: int = self.getAcceptBidMsgValidTime(bid)
accept_msg_id = self.sendSmsg( accept_msg_id = self.sendMessage(
offer.addr_from, bid.bid_addr, payload_hex, msg_valid offer.addr_from, bid.bid_addr, payload_hex, msg_valid, cursor
) )
self.addMessageLink( self.addMessageLink(
@@ -3519,20 +3574,29 @@ class BasicSwap(BaseApp):
dleag: bytes, dleag: bytes,
msg_valid: int, msg_valid: int,
bid_msg_ids, bid_msg_ids,
cursor,
) -> None: ) -> None:
msg_buf2 = XmrSplitMessage(
msg_id=bid_id, msg_type=msg_type, sequence=1, dleag=dleag[16000:32000]
)
msg_bytes = msg_buf2.to_bytes()
payload_hex = str.format("{:02x}", MessageTypes.XMR_BID_SPLIT) + msg_bytes.hex()
bid_msg_ids[1] = self.sendSmsg(addr_from, addr_to, payload_hex, msg_valid)
msg_buf3 = XmrSplitMessage( sent_bytes = self.dleag_split_size_init
msg_id=bid_id, msg_type=msg_type, sequence=2, dleag=dleag[32000:]
num_sent = 1
while sent_bytes < len(dleag):
size_to_send: int = min(self.dleag_split_size, len(dleag) - sent_bytes)
msg_buf = XmrSplitMessage(
msg_id=bid_id,
msg_type=msg_type,
sequence=num_sent,
dleag=dleag[sent_bytes : sent_bytes + size_to_send],
) )
msg_bytes = msg_buf3.to_bytes() msg_bytes = msg_buf.to_bytes()
payload_hex = str.format("{:02x}", MessageTypes.XMR_BID_SPLIT) + msg_bytes.hex() payload_hex = (
bid_msg_ids[2] = self.sendSmsg(addr_from, addr_to, payload_hex, msg_valid) str.format("{:02x}", MessageTypes.XMR_BID_SPLIT) + msg_bytes.hex()
)
bid_msg_ids[num_sent] = self.sendMessage(
addr_from, addr_to, payload_hex, msg_valid, cursor
)
num_sent += 1
sent_bytes += size_to_send
def postXmrBid( def postXmrBid(
self, offer_id: bytes, amount: int, addr_send_from: str = None, extra_options={} self, offer_id: bytes, amount: int, addr_send_from: str = None, extra_options={}
@@ -3608,8 +3672,8 @@ class BasicSwap(BaseApp):
) )
msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, valid_for_seconds) msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, valid_for_seconds)
xmr_swap.bid_id = self.sendSmsg( xmr_swap.bid_id = self.sendMessage(
bid_addr, offer.addr_from, payload_hex, msg_valid bid_addr, offer.addr_from, payload_hex, msg_valid, cursor
) )
bid = Bid( bid = Bid(
@@ -3691,7 +3755,7 @@ class BasicSwap(BaseApp):
if ci_to.curve_type() == Curves.ed25519: if ci_to.curve_type() == Curves.ed25519:
xmr_swap.kbsf_dleag = ci_to.proveDLEAG(kbsf) xmr_swap.kbsf_dleag = ci_to.proveDLEAG(kbsf)
xmr_swap.pkasf = xmr_swap.kbsf_dleag[0:33] xmr_swap.pkasf = xmr_swap.kbsf_dleag[0:33]
msg_buf.kbsf_dleag = xmr_swap.kbsf_dleag[:16000] msg_buf.kbsf_dleag = xmr_swap.kbsf_dleag[: self.dleag_split_size_init]
elif ci_to.curve_type() == Curves.secp256k1: elif ci_to.curve_type() == Curves.secp256k1:
for i in range(10): for i in range(10):
xmr_swap.kbsf_dleag = ci_to.signRecoverable( xmr_swap.kbsf_dleag = ci_to.signRecoverable(
@@ -3721,8 +3785,8 @@ class BasicSwap(BaseApp):
bid_addr = self.prepareSMSGAddress(addr_send_from, AddressTypes.BID, cursor) bid_addr = self.prepareSMSGAddress(addr_send_from, AddressTypes.BID, cursor)
msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, valid_for_seconds) msg_valid: int = max(self.SMSG_SECONDS_IN_HOUR, valid_for_seconds)
xmr_swap.bid_id = self.sendSmsg( xmr_swap.bid_id = self.sendMessage(
bid_addr, offer.addr_from, payload_hex, msg_valid bid_addr, offer.addr_from, payload_hex, msg_valid, cursor
) )
bid_msg_ids = {} bid_msg_ids = {}
@@ -3735,6 +3799,7 @@ class BasicSwap(BaseApp):
xmr_swap.kbsf_dleag, xmr_swap.kbsf_dleag,
msg_valid, msg_valid,
bid_msg_ids, bid_msg_ids,
cursor,
) )
bid = Bid( bid = Bid(
@@ -4013,7 +4078,7 @@ class BasicSwap(BaseApp):
if ci_to.curve_type() == Curves.ed25519: if ci_to.curve_type() == Curves.ed25519:
xmr_swap.kbsl_dleag = ci_to.proveDLEAG(kbsl) xmr_swap.kbsl_dleag = ci_to.proveDLEAG(kbsl)
msg_buf.kbsl_dleag = xmr_swap.kbsl_dleag[:16000] msg_buf.kbsl_dleag = xmr_swap.kbsl_dleag[: self.dleag_split_size_init]
elif ci_to.curve_type() == Curves.secp256k1: elif ci_to.curve_type() == Curves.secp256k1:
for i in range(10): for i in range(10):
xmr_swap.kbsl_dleag = ci_to.signRecoverable( xmr_swap.kbsl_dleag = ci_to.signRecoverable(
@@ -4048,7 +4113,9 @@ class BasicSwap(BaseApp):
msg_valid: int = self.getAcceptBidMsgValidTime(bid) msg_valid: int = self.getAcceptBidMsgValidTime(bid)
bid_msg_ids = {} bid_msg_ids = {}
bid_msg_ids[0] = self.sendSmsg(addr_from, addr_to, payload_hex, msg_valid) bid_msg_ids[0] = self.sendMessage(
addr_from, addr_to, payload_hex, msg_valid, use_cursor
)
if ci_to.curve_type() == Curves.ed25519: if ci_to.curve_type() == Curves.ed25519:
self.sendXmrSplitMessages( self.sendXmrSplitMessages(
@@ -4059,6 +4126,7 @@ class BasicSwap(BaseApp):
xmr_swap.kbsl_dleag, xmr_swap.kbsl_dleag,
msg_valid, msg_valid,
bid_msg_ids, bid_msg_ids,
use_cursor,
) )
bid.setState(BidStates.BID_ACCEPTED) # ADS bid.setState(BidStates.BID_ACCEPTED) # ADS
@@ -4180,8 +4248,8 @@ class BasicSwap(BaseApp):
msg_buf.kbvf = kbvf msg_buf.kbvf = kbvf
msg_buf.kbsf_dleag = ( msg_buf.kbsf_dleag = (
xmr_swap.kbsf_dleag xmr_swap.kbsf_dleag
if len(xmr_swap.kbsf_dleag) < 16000 if len(xmr_swap.kbsf_dleag) < self.dleag_split_size_init
else xmr_swap.kbsf_dleag[:16000] else xmr_swap.kbsf_dleag[: self.dleag_split_size_init]
) )
bid_bytes = msg_buf.to_bytes() bid_bytes = msg_buf.to_bytes()
@@ -4193,7 +4261,9 @@ class BasicSwap(BaseApp):
addr_to: str = bid.bid_addr addr_to: str = bid.bid_addr
msg_valid: int = self.getAcceptBidMsgValidTime(bid) msg_valid: int = self.getAcceptBidMsgValidTime(bid)
bid_msg_ids = {} bid_msg_ids = {}
bid_msg_ids[0] = self.sendSmsg(addr_from, addr_to, payload_hex, msg_valid) bid_msg_ids[0] = self.sendMessage(
addr_from, addr_to, payload_hex, msg_valid, use_cursor
)
if ci_to.curve_type() == Curves.ed25519: if ci_to.curve_type() == Curves.ed25519:
self.sendXmrSplitMessages( self.sendXmrSplitMessages(
@@ -4204,6 +4274,7 @@ class BasicSwap(BaseApp):
xmr_swap.kbsf_dleag, xmr_swap.kbsf_dleag,
msg_valid, msg_valid,
bid_msg_ids, bid_msg_ids,
use_cursor,
) )
bid.setState(BidStates.BID_REQUEST_ACCEPTED) bid.setState(BidStates.BID_REQUEST_ACCEPTED)
@@ -6808,75 +6879,61 @@ class BasicSwap(BaseApp):
now: int = self.getTime() now: int = self.getTime()
ttl_xmr_split_messages = 60 * 60 ttl_xmr_split_messages = 60 * 60
bid_cursor = None bid_cursor = None
dleag_proof_len: int = 48893 # coincurve.dleag.dleag_proof_len()
try: try:
cursor = self.openDB() cursor = self.openDB()
bid_cursor = self.getNewDBCursor() bid_cursor = self.getNewDBCursor()
q_bids = self.query( q_bids = self.query(
Bid, bid_cursor, {"state": int(BidStates.BID_RECEIVING)} Bid,
bid_cursor,
{
"state": (
int(BidStates.BID_RECEIVING),
int(BidStates.BID_RECEIVING_ACC),
)
},
) )
for bid in q_bids: for bid in q_bids:
q = cursor.execute( q = cursor.execute(
"SELECT COUNT(*) FROM xmr_split_data WHERE bid_id = :bid_id AND msg_type = :msg_type", "SELECT LENGTH(kbsl_dleag), LENGTH(kbsf_dleag) FROM xmr_swaps WHERE bid_id = :bid_id",
{"bid_id": bid.bid_id, "msg_type": int(XmrSplitMsgTypes.BID)},
).fetchone()
num_segments = q[0]
if num_segments > 1:
try:
self.receiveXmrBid(bid, cursor)
except Exception as ex:
self.log.info(
f"Verify adaptor-sig bid {self.log.id(bid.bid_id)} failed: {ex}"
)
if self.debug:
self.log.error(traceback.format_exc())
bid.setState(
BidStates.BID_ERROR, "Failed validation: " + str(ex)
)
self.updateDB(
bid,
cursor,
[
"bid_id",
],
)
self.updateBidInProgress(bid)
continue
if bid.created_at + ttl_xmr_split_messages < now:
self.log.debug(
f"Expiring partially received bid: {self.log.id(bid.bid_id)}."
)
bid.setState(BidStates.BID_ERROR, "Timed out")
self.updateDB(
bid,
cursor,
[
"bid_id",
],
)
q_bids = self.query(
Bid, bid_cursor, {"state": int(BidStates.BID_RECEIVING_ACC)}
)
for bid in q_bids:
q = cursor.execute(
"SELECT COUNT(*) FROM xmr_split_data WHERE bid_id = :bid_id AND msg_type = :msg_type",
{ {
"bid_id": bid.bid_id, "bid_id": bid.bid_id,
"msg_type": int(XmrSplitMsgTypes.BID_ACCEPT),
}, },
).fetchone() ).fetchone()
num_segments = q[0] kbsl_dleag_len: int = q[0]
if num_segments > 1: kbsf_dleag_len: int = q[1]
if bid.state == int(BidStates.BID_RECEIVING_ACC):
bid_type: str = "bid accept"
msg_type: int = int(XmrSplitMsgTypes.BID_ACCEPT)
total_dleag_size: int = kbsl_dleag_len
else:
bid_type: str = "bid"
msg_type: int = int(XmrSplitMsgTypes.BID)
total_dleag_size: int = kbsf_dleag_len
q = cursor.execute(
"SELECT COUNT(*), SUM(LENGTH(dleag)) AS total_dleag_size FROM xmr_split_data WHERE bid_id = :bid_id AND msg_type = :msg_type",
{"bid_id": bid.bid_id, "msg_type": msg_type},
).fetchone()
total_dleag_size += 0 if q[1] is None else q[1]
if total_dleag_size >= dleag_proof_len:
try: try:
if bid.state == int(BidStates.BID_RECEIVING):
self.receiveXmrBid(bid, cursor)
elif bid.state == int(BidStates.BID_RECEIVING_ACC):
self.receiveXmrBidAccept(bid, cursor) self.receiveXmrBidAccept(bid, cursor)
else:
raise ValueError("Unexpected bid state")
except Exception as ex: except Exception as ex:
self.log.info(
f"Verify adaptor-sig {bid_type} {self.log.id(bid.bid_id)} failed: {ex}"
)
if self.debug: if self.debug:
self.log.error(traceback.format_exc()) self.log.error(traceback.format_exc())
self.log.info(
f"Verify adaptor-sig bid accept {self.log.id(bid.bid_id)} failed: {ex}."
)
bid.setState( bid.setState(
BidStates.BID_ERROR, "Failed accept validation: " + str(ex) BidStates.BID_ERROR, f"Failed {bid_type} validation: {ex}"
) )
self.updateDB( self.updateDB(
bid, bid,
@@ -6889,7 +6946,7 @@ class BasicSwap(BaseApp):
continue continue
if bid.created_at + ttl_xmr_split_messages < now: if bid.created_at + ttl_xmr_split_messages < now:
self.log.debug( self.log.debug(
f"Expiring partially received bid accept: {self.log.id(bid.bid_id)}." f"Expiring partially received {bid_type}: {self.log.id(bid.bid_id)}."
) )
bid.setState(BidStates.BID_ERROR, "Timed out") bid.setState(BidStates.BID_ERROR, "Timed out")
self.updateDB( self.updateDB(
@@ -6899,7 +6956,6 @@ class BasicSwap(BaseApp):
"bid_id", "bid_id",
], ],
) )
# Expire old records # Expire old records
cursor.execute( cursor.execute(
"DELETE FROM xmr_split_data WHERE created_at + :ttl < :now", "DELETE FROM xmr_split_data WHERE created_at + :ttl < :now",
@@ -7029,6 +7085,7 @@ class BasicSwap(BaseApp):
if self.isOfferRevoked(offer_id, msg["from"]): if self.isOfferRevoked(offer_id, msg["from"]):
raise ValueError("Offer has been revoked {}.".format(offer_id.hex())) raise ValueError("Offer has been revoked {}.".format(offer_id.hex()))
pk_from: bytes = getMsgPubkey(self, msg)
try: try:
cursor = self.openDB() cursor = self.openDB()
# Offers must be received on the public network_addr or manually created addresses # Offers must be received on the public network_addr or manually created addresses
@@ -7069,6 +7126,7 @@ class BasicSwap(BaseApp):
rate_negotiable=offer_data.rate_negotiable, rate_negotiable=offer_data.rate_negotiable,
addr_to=msg["to"], addr_to=msg["to"],
addr_from=msg["from"], addr_from=msg["from"],
pk_from=pk_from,
created_at=msg["sent"], created_at=msg["sent"],
expire_at=msg["sent"] + offer_data.time_valid, expire_at=msg["sent"] + offer_data.time_valid,
was_sent=False, was_sent=False,
@@ -7417,6 +7475,7 @@ class BasicSwap(BaseApp):
bid = self.getBid(bid_id) bid = self.getBid(bid_id)
if bid is None: if bid is None:
pk_from: bytes = getMsgPubkey(self, msg)
bid = Bid( bid = Bid(
active_ind=1, active_ind=1,
bid_id=bid_id, bid_id=bid_id,
@@ -7431,6 +7490,7 @@ class BasicSwap(BaseApp):
created_at=msg["sent"], created_at=msg["sent"],
expire_at=msg["sent"] + bid_data.time_valid, expire_at=msg["sent"] + bid_data.time_valid,
bid_addr=msg["from"], bid_addr=msg["from"],
pk_bid_addr=pk_from,
was_received=True, was_received=True,
chain_a_height_start=ci_from.getChainHeight(), chain_a_height_start=ci_from.getChainHeight(),
chain_b_height_start=ci_to.getChainHeight(), chain_b_height_start=ci_to.getChainHeight(),
@@ -7829,12 +7889,13 @@ class BasicSwap(BaseApp):
) )
if ci_to.curve_type() == Curves.ed25519: if ci_to.curve_type() == Curves.ed25519:
ensure(len(bid_data.kbsf_dleag) == 16000, "Invalid kbsf_dleag size") ensure(len(bid_data.kbsf_dleag) <= 16000, "Invalid kbsf_dleag size")
bid_id = bytes.fromhex(msg["msgid"]) bid_id = bytes.fromhex(msg["msgid"])
bid, xmr_swap = self.getXmrBid(bid_id) bid, xmr_swap = self.getXmrBid(bid_id)
if bid is None: if bid is None:
pk_from: bytes = getMsgPubkey(self, msg)
bid = Bid( bid = Bid(
active_ind=1, active_ind=1,
bid_id=bid_id, bid_id=bid_id,
@@ -7846,6 +7907,7 @@ class BasicSwap(BaseApp):
created_at=msg["sent"], created_at=msg["sent"],
expire_at=msg["sent"] + bid_data.time_valid, expire_at=msg["sent"] + bid_data.time_valid,
bid_addr=msg["from"], bid_addr=msg["from"],
pk_bid_addr=pk_from,
was_received=True, was_received=True,
chain_a_height_start=ci_from.getChainHeight(), chain_a_height_start=ci_from.getChainHeight(),
chain_b_height_start=ci_to.getChainHeight(), chain_b_height_start=ci_to.getChainHeight(),
@@ -8175,8 +8237,8 @@ class BasicSwap(BaseApp):
msg_valid: int = self.getActiveBidMsgValidTime() msg_valid: int = self.getActiveBidMsgValidTime()
addr_send_from: str = offer.addr_from if reverse_bid else bid.bid_addr addr_send_from: str = offer.addr_from if reverse_bid else bid.bid_addr
addr_send_to: str = bid.bid_addr if reverse_bid else offer.addr_from addr_send_to: str = bid.bid_addr if reverse_bid else offer.addr_from
coin_a_lock_tx_sigs_l_msg_id = self.sendSmsg( coin_a_lock_tx_sigs_l_msg_id = self.sendMessage(
addr_send_from, addr_send_to, payload_hex, msg_valid addr_send_from, addr_send_to, payload_hex, msg_valid, cursor
) )
self.addMessageLink( self.addMessageLink(
Concepts.BID, Concepts.BID,
@@ -8544,8 +8606,8 @@ class BasicSwap(BaseApp):
addr_send_from: str = bid.bid_addr if reverse_bid else offer.addr_from addr_send_from: str = bid.bid_addr if reverse_bid else offer.addr_from
addr_send_to: str = offer.addr_from if reverse_bid else bid.bid_addr addr_send_to: str = offer.addr_from if reverse_bid else bid.bid_addr
msg_valid: int = self.getActiveBidMsgValidTime() msg_valid: int = self.getActiveBidMsgValidTime()
coin_a_lock_release_msg_id = self.sendSmsg( coin_a_lock_release_msg_id = self.sendMessage(
addr_send_from, addr_send_to, payload_hex, msg_valid addr_send_from, addr_send_to, payload_hex, msg_valid, cursor
) )
self.addMessageLink( self.addMessageLink(
Concepts.BID, Concepts.BID,
@@ -8964,8 +9026,8 @@ class BasicSwap(BaseApp):
) )
msg_valid: int = self.getActiveBidMsgValidTime() msg_valid: int = self.getActiveBidMsgValidTime()
xmr_swap.coin_a_lock_refund_spend_tx_msg_id = self.sendSmsg( xmr_swap.coin_a_lock_refund_spend_tx_msg_id = self.sendMessage(
addr_send_from, addr_send_to, payload_hex, msg_valid addr_send_from, addr_send_to, payload_hex, msg_valid, cursor
) )
bid.setState(BidStates.XMR_SWAP_MSG_SCRIPT_LOCK_SPEND_TX) bid.setState(BidStates.XMR_SWAP_MSG_SCRIPT_LOCK_SPEND_TX)
@@ -9347,6 +9409,7 @@ class BasicSwap(BaseApp):
bid, xmr_swap = self.getXmrBid(bid_id) bid, xmr_swap = self.getXmrBid(bid_id)
if bid is None: if bid is None:
pk_from: bytes = getMsgPubkey(self, msg)
bid = Bid( bid = Bid(
active_ind=1, active_ind=1,
bid_id=bid_id, bid_id=bid_id,
@@ -9358,6 +9421,7 @@ class BasicSwap(BaseApp):
created_at=msg["sent"], created_at=msg["sent"],
expire_at=msg["sent"] + bid_data.time_valid, expire_at=msg["sent"] + bid_data.time_valid,
bid_addr=msg["from"], bid_addr=msg["from"],
pk_bid_addr=pk_from,
was_sent=False, was_sent=False,
was_received=True, was_received=True,
chain_a_height_start=ci_from.getChainHeight(), chain_a_height_start=ci_from.getChainHeight(),
@@ -9460,7 +9524,7 @@ class BasicSwap(BaseApp):
"Invalid destination address", "Invalid destination address",
) )
if ci_to.curve_type() == Curves.ed25519: if ci_to.curve_type() == Curves.ed25519:
ensure(len(msg_data.kbsf_dleag) == 16000, "Invalid kbsf_dleag size") ensure(len(msg_data.kbsf_dleag) <= 16000, "Invalid kbsf_dleag size")
xmr_swap.dest_af = msg_data.dest_af xmr_swap.dest_af = msg_data.dest_af
xmr_swap.pkaf = msg_data.pkaf xmr_swap.pkaf = msg_data.pkaf
@@ -9495,6 +9559,14 @@ class BasicSwap(BaseApp):
def processMsg(self, msg) -> None: def processMsg(self, msg) -> None:
try: try:
if "hex" not in msg:
if self.debug:
if "error" in msg:
self.log.debug(
"Message error {}: {}.".format(msg["msgid"], msg["error"])
)
raise ValueError("Invalid msg received {}.".format(msg["msgid"]))
return
msg_type = int(msg["hex"][:2], 16) msg_type = int(msg["hex"][:2], 16)
if msg_type == MessageTypes.OFFER: if msg_type == MessageTypes.OFFER:
@@ -9708,6 +9780,10 @@ class BasicSwap(BaseApp):
self.processMsg(msg) self.processMsg(msg)
try: try:
for network in self.active_networks:
if network["type"] == "simplex":
readSimplexMsgs(self, network)
# TODO: Wait for blocks / txns, would need to check multiple coins # TODO: Wait for blocks / txns, would need to check multiple coins
now: int = self.getTime() now: int = self.getTime()
self.expireBidsAndOffers(now) self.expireBidsAndOffers(now)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -5,20 +5,22 @@
"""Helpful routines for regression testing.""" """Helpful routines for regression testing."""
from base64 import b64encode from base64 import b64encode
from binascii import unhexlify
from decimal import Decimal, ROUND_DOWN from decimal import Decimal, ROUND_DOWN
from subprocess import CalledProcessError from subprocess import CalledProcessError
import hashlib
import inspect import inspect
import json import json
import logging import logging
import os import os
import random import pathlib
import platform
import re import re
import time import time
from . import coverage from . import coverage
from .authproxy import AuthServiceProxy, JSONRPCException from .authproxy import AuthServiceProxy, JSONRPCException
from io import BytesIO from collections.abc import Callable
from typing import Optional
logger = logging.getLogger("TestFramework.utils") logger = logging.getLogger("TestFramework.utils")
@@ -28,23 +30,46 @@ logger = logging.getLogger("TestFramework.utils")
def assert_approx(v, vexp, vspan=0.00001): def assert_approx(v, vexp, vspan=0.00001):
"""Assert that `v` is within `vspan` of `vexp`""" """Assert that `v` is within `vspan` of `vexp`"""
if isinstance(v, Decimal) or isinstance(vexp, Decimal):
v=Decimal(v)
vexp=Decimal(vexp)
vspan=Decimal(vspan)
if v < vexp - vspan: if v < vexp - vspan:
raise AssertionError("%s < [%s..%s]" % (str(v), str(vexp - vspan), str(vexp + vspan))) raise AssertionError("%s < [%s..%s]" % (str(v), str(vexp - vspan), str(vexp + vspan)))
if v > vexp + vspan: if v > vexp + vspan:
raise AssertionError("%s > [%s..%s]" % (str(v), str(vexp - vspan), str(vexp + vspan))) raise AssertionError("%s > [%s..%s]" % (str(v), str(vexp - vspan), str(vexp + vspan)))
def assert_fee_amount(fee, tx_size, fee_per_kB): def assert_fee_amount(fee, tx_size, feerate_BTC_kvB):
"""Assert the fee was in range""" """Assert the fee is in range."""
target_fee = round(tx_size * fee_per_kB / 1000, 8) assert isinstance(tx_size, int)
target_fee = get_fee(tx_size, feerate_BTC_kvB)
if fee < target_fee: if fee < target_fee:
raise AssertionError("Fee of %s BTC too low! (Should be %s BTC)" % (str(fee), str(target_fee))) raise AssertionError("Fee of %s BTC too low! (Should be %s BTC)" % (str(fee), str(target_fee)))
# allow the wallet's estimation to be at most 2 bytes off # allow the wallet's estimation to be at most 2 bytes off
if fee > (tx_size + 2) * fee_per_kB / 1000: high_fee = get_fee(tx_size + 2, feerate_BTC_kvB)
if fee > high_fee:
raise AssertionError("Fee of %s BTC too high! (Should be %s BTC)" % (str(fee), str(target_fee))) raise AssertionError("Fee of %s BTC too high! (Should be %s BTC)" % (str(fee), str(target_fee)))
def summarise_dict_differences(thing1, thing2):
if not isinstance(thing1, dict) or not isinstance(thing2, dict):
return thing1, thing2
d1, d2 = {}, {}
for k in sorted(thing1.keys()):
if k not in thing2:
d1[k] = thing1[k]
elif thing1[k] != thing2[k]:
d1[k], d2[k] = summarise_dict_differences(thing1[k], thing2[k])
for k in sorted(thing2.keys()):
if k not in thing1:
d2[k] = thing2[k]
return d1, d2
def assert_equal(thing1, thing2, *args): def assert_equal(thing1, thing2, *args):
if thing1 != thing2 and not args and isinstance(thing1, dict) and isinstance(thing2, dict):
d1,d2 = summarise_dict_differences(thing1, thing2)
raise AssertionError("not(%s == %s)\n in particular not(%s == %s)" % (thing1, thing2, d1, d2))
if thing1 != thing2 or any(thing1 != arg for arg in args): if thing1 != thing2 or any(thing1 != arg for arg in args):
raise AssertionError("not(%s)" % " == ".join(str(arg) for arg in (thing1, thing2) + args)) raise AssertionError("not(%s)" % " == ".join(str(arg) for arg in (thing1, thing2) + args))
@@ -79,7 +104,7 @@ def assert_raises_message(exc, message, fun, *args, **kwds):
raise AssertionError("No exception raised") raise AssertionError("No exception raised")
def assert_raises_process_error(returncode, output, fun, *args, **kwds): def assert_raises_process_error(returncode: int, output: str, fun: Callable, *args, **kwds):
"""Execute a process and asserts the process return code and output. """Execute a process and asserts the process return code and output.
Calls function `fun` with arguments `args` and `kwds`. Catches a CalledProcessError Calls function `fun` with arguments `args` and `kwds`. Catches a CalledProcessError
@@ -87,9 +112,9 @@ def assert_raises_process_error(returncode, output, fun, *args, **kwds):
no CalledProcessError was raised or if the return code and output are not as expected. no CalledProcessError was raised or if the return code and output are not as expected.
Args: Args:
returncode (int): the process return code. returncode: the process return code.
output (string): [a substring of] the process output. output: [a substring of] the process output.
fun (function): the function to call. This should execute a process. fun: the function to call. This should execute a process.
args*: positional arguments for the function. args*: positional arguments for the function.
kwds**: named arguments for the function. kwds**: named arguments for the function.
""" """
@@ -104,7 +129,7 @@ def assert_raises_process_error(returncode, output, fun, *args, **kwds):
raise AssertionError("No exception raised") raise AssertionError("No exception raised")
def assert_raises_rpc_error(code, message, fun, *args, **kwds): def assert_raises_rpc_error(code: Optional[int], message: Optional[str], fun: Callable, *args, **kwds):
"""Run an RPC and verify that a specific JSONRPC exception code and message is raised. """Run an RPC and verify that a specific JSONRPC exception code and message is raised.
Calls function `fun` with arguments `args` and `kwds`. Catches a JSONRPCException Calls function `fun` with arguments `args` and `kwds`. Catches a JSONRPCException
@@ -112,11 +137,11 @@ def assert_raises_rpc_error(code, message, fun, *args, **kwds):
no JSONRPCException was raised or if the error code/message are not as expected. no JSONRPCException was raised or if the error code/message are not as expected.
Args: Args:
code (int), optional: the error code returned by the RPC call (defined code: the error code returned by the RPC call (defined in src/rpc/protocol.h).
in src/rpc/protocol.h). Set to None if checking the error code is not required. Set to None if checking the error code is not required.
message (string), optional: [a substring of] the error string returned by the message: [a substring of] the error string returned by the RPC call.
RPC call. Set to None if checking the error string is not required. Set to None if checking the error string is not required.
fun (function): the function to call. This should be the name of an RPC. fun: the function to call. This should be the name of an RPC.
args*: positional arguments for the function. args*: positional arguments for the function.
kwds**: named arguments for the function. kwds**: named arguments for the function.
""" """
@@ -203,29 +228,45 @@ def check_json_precision():
raise RuntimeError("JSON encode/decode loses precision") raise RuntimeError("JSON encode/decode loses precision")
def EncodeDecimal(o):
if isinstance(o, Decimal):
return str(o)
raise TypeError(repr(o) + " is not JSON serializable")
def count_bytes(hex_string): def count_bytes(hex_string):
return len(bytearray.fromhex(hex_string)) return len(bytearray.fromhex(hex_string))
def hex_str_to_bytes(hex_str):
return unhexlify(hex_str.encode('ascii'))
def str_to_b64str(string): def str_to_b64str(string):
return b64encode(string.encode('utf-8')).decode('ascii') return b64encode(string.encode('utf-8')).decode('ascii')
def ceildiv(a, b):
"""
Divide 2 ints and round up to next int rather than round down
Implementation requires python integers, which have a // operator that does floor division.
Other types like decimal.Decimal whose // operator truncates towards 0 will not work.
"""
assert isinstance(a, int)
assert isinstance(b, int)
return -(-a // b)
def get_fee(tx_size, feerate_btc_kvb):
"""Calculate the fee in BTC given a feerate is BTC/kvB. Reflects CFeeRate::GetFee"""
feerate_sat_kvb = int(feerate_btc_kvb * Decimal(1e8)) # Fee in sat/kvb as an int to avoid float precision errors
target_fee_sat = ceildiv(feerate_sat_kvb * tx_size, 1000) # Round calculated fee up to nearest sat
return target_fee_sat / Decimal(1e8) # Return result in BTC
def satoshi_round(amount): def satoshi_round(amount):
return Decimal(amount).quantize(Decimal('0.00000001'), rounding=ROUND_DOWN) return Decimal(amount).quantize(Decimal('0.00000001'), rounding=ROUND_DOWN)
def wait_until(predicate, *, attempts=float('inf'), timeout=float('inf'), lock=None, timeout_factor=1.0): def wait_until_helper_internal(predicate, *, attempts=float('inf'), timeout=float('inf'), lock=None, timeout_factor=1.0):
"""Sleep until the predicate resolves to be True.
Warning: Note that this method is not recommended to be used in tests as it is
not aware of the context of the test framework. Using the `wait_until()` members
from `BitcoinTestFramework` or `P2PInterface` class ensures the timeout is
properly scaled. Furthermore, `wait_until()` from `P2PInterface` class in
`p2p.py` has a preset lock.
"""
if attempts == float('inf') and timeout == float('inf'): if attempts == float('inf') and timeout == float('inf'):
timeout = 60 timeout = 60
timeout = timeout * timeout_factor timeout = timeout * timeout_factor
@@ -253,6 +294,16 @@ def wait_until(predicate, *, attempts=float('inf'), timeout=float('inf'), lock=N
raise RuntimeError('Unreachable') raise RuntimeError('Unreachable')
def sha256sum_file(filename):
h = hashlib.sha256()
with open(filename, 'rb') as f:
d = f.read(4096)
while len(d) > 0:
h.update(d)
d = f.read(4096)
return h.digest()
# RPC/P2P connection constants and functions # RPC/P2P connection constants and functions
############################################ ############################################
@@ -269,15 +320,15 @@ class PortSeed:
n = None n = None
def get_rpc_proxy(url, node_number, *, timeout=None, coveragedir=None): def get_rpc_proxy(url: str, node_number: int, *, timeout: Optional[int]=None, coveragedir: Optional[str]=None) -> coverage.AuthServiceProxyWrapper:
""" """
Args: Args:
url (str): URL of the RPC server to call url: URL of the RPC server to call
node_number (int): the node number (or id) that this calls to node_number: the node number (or id) that this calls to
Kwargs: Kwargs:
timeout (int): HTTP timeout in seconds timeout: HTTP timeout in seconds
coveragedir (str): Directory coveragedir: Directory
Returns: Returns:
AuthServiceProxy. convenience object for making RPC calls. AuthServiceProxy. convenience object for making RPC calls.
@@ -288,11 +339,10 @@ def get_rpc_proxy(url, node_number, *, timeout=None, coveragedir=None):
proxy_kwargs['timeout'] = int(timeout) proxy_kwargs['timeout'] = int(timeout)
proxy = AuthServiceProxy(url, **proxy_kwargs) proxy = AuthServiceProxy(url, **proxy_kwargs)
proxy.url = url # store URL on proxy for info
coverage_logfile = coverage.get_filename(coveragedir, node_number) if coveragedir else None coverage_logfile = coverage.get_filename(coveragedir, node_number) if coveragedir else None
return coverage.AuthServiceProxyWrapper(proxy, coverage_logfile) return coverage.AuthServiceProxyWrapper(proxy, url, coverage_logfile)
def p2p_port(n): def p2p_port(n):
@@ -321,38 +371,76 @@ def rpc_url(datadir, i, chain, rpchost):
################ ################
def initialize_datadir(dirname, n, chain): def initialize_datadir(dirname, n, chain, disable_autoconnect=True):
datadir = get_datadir_path(dirname, n) datadir = get_datadir_path(dirname, n)
if not os.path.isdir(datadir): if not os.path.isdir(datadir):
os.makedirs(datadir) os.makedirs(datadir)
# Translate chain name to config name write_config(os.path.join(datadir, "particl.conf"), n=n, chain=chain, disable_autoconnect=disable_autoconnect)
if chain == 'testnet3':
chain_name_conf_arg = 'testnet'
chain_name_conf_section = 'test'
else:
chain_name_conf_arg = chain
chain_name_conf_section = chain
with open(os.path.join(datadir, "particl.conf"), 'w', encoding='utf8') as f:
f.write("{}=1\n".format(chain_name_conf_arg))
f.write("[{}]\n".format(chain_name_conf_section))
f.write("port=" + str(p2p_port(n)) + "\n")
f.write("rpcport=" + str(rpc_port(n)) + "\n")
f.write("fallbackfee=0.0002\n")
f.write("server=1\n")
f.write("keypool=1\n")
f.write("discover=0\n")
f.write("dnsseed=0\n")
f.write("listenonion=0\n")
f.write("printtoconsole=0\n")
f.write("upnp=0\n")
f.write("shrinkdebugfile=0\n")
os.makedirs(os.path.join(datadir, 'stderr'), exist_ok=True) os.makedirs(os.path.join(datadir, 'stderr'), exist_ok=True)
os.makedirs(os.path.join(datadir, 'stdout'), exist_ok=True) os.makedirs(os.path.join(datadir, 'stdout'), exist_ok=True)
return datadir return datadir
def write_config(config_path, *, n, chain, extra_config="", disable_autoconnect=True):
# Translate chain subdirectory name to config name
if chain == 'testnet':
chain_name_conf_arg = 'testnet'
chain_name_conf_section = 'test'
else:
chain_name_conf_arg = chain
chain_name_conf_section = chain
with open(config_path, 'w', encoding='utf8') as f:
if chain_name_conf_arg:
f.write("{}=1\n".format(chain_name_conf_arg))
if chain_name_conf_section:
f.write("[{}]\n".format(chain_name_conf_section))
f.write("port=" + str(p2p_port(n)) + "\n")
f.write("rpcport=" + str(rpc_port(n)) + "\n")
# Disable server-side timeouts to avoid intermittent issues
f.write("rpcservertimeout=99000\n")
f.write("rpcdoccheck=1\n")
f.write("fallbackfee=0.0002\n")
f.write("server=1\n")
f.write("keypool=1\n")
f.write("discover=0\n")
f.write("dnsseed=0\n")
f.write("fixedseeds=0\n")
f.write("listenonion=0\n")
# Increase peertimeout to avoid disconnects while using mocktime.
# peertimeout is measured in mock time, so setting it large enough to
# cover any duration in mock time is sufficient. It can be overridden
# in tests.
f.write("peertimeout=999999999\n")
f.write("printtoconsole=0\n")
f.write("upnp=0\n")
f.write("natpmp=0\n")
f.write("shrinkdebugfile=0\n")
f.write("deprecatedrpc=create_bdb\n") # Required to run the tests
# To improve SQLite wallet performance so that the tests don't timeout, use -unsafesqlitesync
f.write("unsafesqlitesync=1\n")
if disable_autoconnect:
f.write("connect=0\n")
f.write(extra_config)
def get_datadir_path(dirname, n): def get_datadir_path(dirname, n):
return os.path.join(dirname, "node" + str(n)) return pathlib.Path(dirname) / f"node{n}"
def get_temp_default_datadir(temp_dir: pathlib.Path) -> tuple[dict, pathlib.Path]:
"""Return os-specific environment variables that can be set to make the
GetDefaultDataDir() function return a datadir path under the provided
temp_dir, as well as the complete path it would return."""
if platform.system() == "Windows":
env = dict(APPDATA=str(temp_dir))
datadir = temp_dir / "Particl"
else:
env = dict(HOME=str(temp_dir))
if platform.system() == "Darwin":
datadir = temp_dir / "Library/Application Support/Particl"
else:
datadir = temp_dir / ".particl"
return env, datadir
def append_config(datadir, options): def append_config(datadir, options):
@@ -395,7 +483,7 @@ def delete_cookie_file(datadir, chain):
def softfork_active(node, key): def softfork_active(node, key):
"""Return whether a softfork is active.""" """Return whether a softfork is active."""
return node.getblockchaininfo()['softforks'][key]['active'] return node.getdeploymentinfo()['deployments'][key]['active']
def set_node_times(nodes, t): def set_node_times(nodes, t):
@@ -403,208 +491,51 @@ def set_node_times(nodes, t):
node.setmocktime(t) node.setmocktime(t)
def disconnect_nodes(from_connection, node_num): def check_node_connections(*, node, num_in, num_out):
def get_peer_ids(): info = node.getnetworkinfo()
result = [] assert_equal(info["connections_in"], num_in)
for peer in from_connection.getpeerinfo(): assert_equal(info["connections_out"], num_out)
if "testnode{}".format(node_num) in peer['subver']:
result.append(peer['id'])
return result
peer_ids = get_peer_ids()
if not peer_ids:
logger.warning("disconnect_nodes: {} and {} were not connected".format(
from_connection.index,
node_num,
))
return
for peer_id in peer_ids:
try:
from_connection.disconnectnode(nodeid=peer_id)
except JSONRPCException as e:
# If this node is disconnected between calculating the peer id
# and issuing the disconnect, don't worry about it.
# This avoids a race condition if we're mass-disconnecting peers.
if e.error['code'] != -29: # RPC_CLIENT_NODE_NOT_CONNECTED
raise
# wait to disconnect
wait_until(lambda: not get_peer_ids(), timeout=5)
def connect_nodes(from_connection, node_num):
ip_port = "127.0.0.1:" + str(p2p_port(node_num))
from_connection.addnode(ip_port, "onetry")
# poll until version handshake complete to avoid race conditions
# with transaction relaying
# See comments in net_processing:
# * Must have a version message before anything else
# * Must have a verack message before anything else
wait_until(lambda: all(peer['version'] != 0 for peer in from_connection.getpeerinfo()))
wait_until(lambda: all(peer['bytesrecv_per_msg'].pop('verack', 0) == 24 for peer in from_connection.getpeerinfo()))
# Transaction/Block functions # Transaction/Block functions
############################# #############################
def find_output(node, txid, amount, *, blockhash=None):
"""
Return index to output of txid with value amount
Raises exception if there is none.
"""
txdata = node.getrawtransaction(txid, 1, blockhash)
for i in range(len(txdata["vout"])):
if txdata["vout"][i]["value"] == amount:
return i
raise RuntimeError("find_output txid %s : %s not found" % (txid, str(amount)))
def gather_inputs(from_node, amount_needed, confirmations_required=1):
"""
Return a random set of unspent txouts that are enough to pay amount_needed
"""
assert confirmations_required >= 0
utxo = from_node.listunspent(confirmations_required)
random.shuffle(utxo)
inputs = []
total_in = Decimal("0.00000000")
while total_in < amount_needed and len(utxo) > 0:
t = utxo.pop()
total_in += t["amount"]
inputs.append({"txid": t["txid"], "vout": t["vout"], "address": t["address"]})
if total_in < amount_needed:
raise RuntimeError("Insufficient funds: need %d, have %d" % (amount_needed, total_in))
return (total_in, inputs)
def make_change(from_node, amount_in, amount_out, fee):
"""
Create change output(s), return them
"""
outputs = {}
amount = amount_out + fee
change = amount_in - amount
if change > amount * 2:
# Create an extra change output to break up big inputs
change_address = from_node.getnewaddress()
# Split change in two, being careful of rounding:
outputs[change_address] = Decimal(change / 2).quantize(Decimal('0.00000001'), rounding=ROUND_DOWN)
change = amount_in - amount - outputs[change_address]
if change > 0:
outputs[from_node.getnewaddress()] = change
return outputs
def random_transaction(nodes, amount, min_fee, fee_increment, fee_variants):
"""
Create a random transaction.
Returns (txid, hex-encoded-transaction-data, fee)
"""
from_node = random.choice(nodes)
to_node = random.choice(nodes)
fee = min_fee + fee_increment * random.randint(0, fee_variants)
(total_in, inputs) = gather_inputs(from_node, amount + fee)
outputs = make_change(from_node, total_in, amount, fee)
outputs[to_node.getnewaddress()] = float(amount)
rawtx = from_node.createrawtransaction(inputs, outputs)
signresult = from_node.signrawtransactionwithwallet(rawtx)
txid = from_node.sendrawtransaction(signresult["hex"], 0)
return (txid, signresult["hex"], fee)
# Helper to create at least "count" utxos
# Pass in a fee that is sufficient for relay and mining new transactions.
def create_confirmed_utxos(fee, node, count):
to_generate = int(0.5 * count) + 101
while to_generate > 0:
node.generate(min(25, to_generate))
to_generate -= 25
utxos = node.listunspent()
iterations = count - len(utxos)
addr1 = node.getnewaddress()
addr2 = node.getnewaddress()
if iterations <= 0:
return utxos
for i in range(iterations):
t = utxos.pop()
inputs = []
inputs.append({"txid": t["txid"], "vout": t["vout"]})
outputs = {}
send_value = t['amount'] - fee
outputs[addr1] = satoshi_round(send_value / 2)
outputs[addr2] = satoshi_round(send_value / 2)
raw_tx = node.createrawtransaction(inputs, outputs)
signed_tx = node.signrawtransactionwithwallet(raw_tx)["hex"]
node.sendrawtransaction(signed_tx)
while (node.getmempoolinfo()['size'] > 0):
node.generate(1)
utxos = node.listunspent()
assert len(utxos) >= count
return utxos
# Create large OP_RETURN txouts that can be appended to a transaction # Create large OP_RETURN txouts that can be appended to a transaction
# to make it large (helper for constructing large transactions). # to make it large (helper for constructing large transactions). The
# total serialized size of the txouts is about 66k vbytes.
def gen_return_txouts(): def gen_return_txouts():
# Some pre-processing to create a bunch of OP_RETURN txouts to insert into transactions we create
# So we have big transactions (and therefore can't fit very many into each block)
# create one script_pubkey
script_pubkey = "6a4d0200" # OP_RETURN OP_PUSH2 512 bytes
for i in range(512):
script_pubkey = script_pubkey + "01"
# concatenate 128 txouts of above script_pubkey which we'll insert before the txout for change
txouts = []
from .messages import CTxOut from .messages import CTxOut
txout = CTxOut() from .script import CScript, OP_RETURN
txout.nValue = 0 txouts = [CTxOut(nValue=0, scriptPubKey=CScript([OP_RETURN, b'\x01'*67437]))]
txout.scriptPubKey = hex_str_to_bytes(script_pubkey) assert_equal(sum([len(txout.serialize()) for txout in txouts]), 67456)
for k in range(128):
txouts.append(txout)
return txouts return txouts
# Create a spend of each passed-in utxo, splicing in "txouts" to each raw # Create a spend of each passed-in utxo, splicing in "txouts" to each raw
# transaction to make it large. See gen_return_txouts() above. # transaction to make it large. See gen_return_txouts() above.
def create_lots_of_big_transactions(node, txouts, utxos, num, fee): def create_lots_of_big_transactions(mini_wallet, node, fee, tx_batch_size, txouts, utxos=None):
addr = node.getnewaddress()
txids = [] txids = []
from .messages import CTransaction use_internal_utxos = utxos is None
for _ in range(num): for _ in range(tx_batch_size):
t = utxos.pop() tx = mini_wallet.create_self_transfer(
inputs = [{"txid": t["txid"], "vout": t["vout"]}] utxo_to_spend=None if use_internal_utxos else utxos.pop(),
outputs = {} fee=fee,
change = t['amount'] - fee )["tx"]
outputs[addr] = satoshi_round(change) tx.vout.extend(txouts)
rawtx = node.createrawtransaction(inputs, outputs) res = node.testmempoolaccept([tx.serialize().hex()])[0]
tx = CTransaction() assert_equal(res['fees']['base'], fee)
tx.deserialize(BytesIO(hex_str_to_bytes(rawtx))) txids.append(node.sendrawtransaction(tx.serialize().hex()))
for txout in txouts:
tx.vout.append(txout)
newtx = tx.serialize().hex()
signresult = node.signrawtransactionwithwallet(newtx, None, "NONE")
txid = node.sendrawtransaction(signresult["hex"], 0)
txids.append(txid)
return txids return txids
def mine_large_block(node, utxos=None): def mine_large_block(test_framework, mini_wallet, node):
# generate a 66k transaction, # generate a 66k transaction,
# and 14 of them is close to the 1MB block limit # and 14 of them is close to the 1MB block limit
num = 14
txouts = gen_return_txouts() txouts = gen_return_txouts()
utxos = utxos if utxos is not None else []
if len(utxos) < num:
utxos.clear()
utxos.extend(node.listunspent())
fee = 100 * node.getnetworkinfo()["relayfee"] fee = 100 * node.getnetworkinfo()["relayfee"]
create_lots_of_big_transactions(node, txouts, utxos, num, fee=fee) create_lots_of_big_transactions(mini_wallet, node, fee, 14, txouts)
node.generate(1) test_framework.generate(node, 1)
def find_vout_for_address(node, txid, addr): def find_vout_for_address(node, txid, addr):
@@ -614,11 +545,6 @@ def find_vout_for_address(node, txid, addr):
""" """
tx = node.getrawtransaction(txid, True) tx = node.getrawtransaction(txid, True)
for i in range(len(tx["vout"])): for i in range(len(tx["vout"])):
scriptPubKey = tx["vout"][i]["scriptPubKey"] if addr == tx["vout"][i]["scriptPubKey"]["address"]:
if "addresses" in scriptPubKey:
if any([addr == a for a in scriptPubKey["addresses"]]):
return i
elif "address" in scriptPubKey:
if addr == scriptPubKey["address"]:
return i return i
raise RuntimeError("Vout not found for address: txid=%s, addr=%s" % (txid, addr)) raise RuntimeError("Vout not found for address: txid=%s, addr=%s" % (txid, addr))

View File

@@ -13,7 +13,7 @@ from enum import IntEnum, auto
from typing import Optional from typing import Optional
CURRENT_DB_VERSION = 27 CURRENT_DB_VERSION = 28
CURRENT_DB_DATA_VERSION = 6 CURRENT_DB_DATA_VERSION = 6
@@ -174,6 +174,7 @@ class Offer(Table):
secret_hash = Column("blob") secret_hash = Column("blob")
addr_from = Column("string") addr_from = Column("string")
pk_from = Column("blob")
addr_to = Column("string") addr_to = Column("string")
created_at = Column("integer") created_at = Column("integer")
expire_at = Column("integer") expire_at = Column("integer")
@@ -216,6 +217,7 @@ class Bid(Table):
created_at = Column("integer") created_at = Column("integer")
expire_at = Column("integer") expire_at = Column("integer")
bid_addr = Column("string") bid_addr = Column("string")
pk_bid_addr = Column("blob")
proof_address = Column("string") proof_address = Column("string")
proof_utxos = Column("blob") proof_utxos = Column("blob")
# Address to spend lock tx to - address from wallet if empty TODO # Address to spend lock tx to - address from wallet if empty TODO
@@ -927,15 +929,12 @@ class DBMethods:
table_name: str = table_class.__tablename__ table_name: str = table_class.__tablename__
query: str = "SELECT " query: str = "SELECT "
columns = [] columns = []
for mc in inspect.getmembers(table_class): for mc in inspect.getmembers(table_class):
mc_name, mc_obj = mc mc_name, mc_obj = mc
if not hasattr(mc_obj, "__sqlite3_column__"): if not hasattr(mc_obj, "__sqlite3_column__"):
continue continue
if len(columns) > 0: if len(columns) > 0:
query += ", " query += ", "
query += mc_name query += mc_name
@@ -943,10 +942,29 @@ class DBMethods:
query += f" FROM {table_name} WHERE 1=1 " query += f" FROM {table_name} WHERE 1=1 "
query_data = {}
for ck in constraints: for ck in constraints:
if not validColumnName(ck): if not validColumnName(ck):
raise ValueError(f"Invalid constraint column: {ck}") raise ValueError(f"Invalid constraint column: {ck}")
constraint_value = constraints[ck]
if isinstance(constraint_value, tuple) or isinstance(
constraint_value, list
):
if len(constraint_value) < 2:
raise ValueError(f"Too few constraint values for list: {ck}")
query += f" AND {ck} IN ("
for i, cv in enumerate(constraint_value):
cv_name: str = f"{ck}_{i}"
if i > 0:
query += ","
query += ":" + cv_name
query_data[cv_name] = cv
query += ") "
else:
query += f" AND {ck} = :{ck} " query += f" AND {ck} = :{ck} "
query_data[ck] = constraint_value
for order_col, order_dir in order_by.items(): for order_col, order_dir in order_by.items():
if validColumnName(order_col) is False: if validColumnName(order_col) is False:
@@ -959,7 +977,6 @@ class DBMethods:
if query_suffix: if query_suffix:
query += query_suffix query += query_suffix
query_data = constraints.copy()
query_data.update(extra_query_data) query_data.update(extra_query_data)
rows = cursor.execute(query, query_data) rows = cursor.execute(query, query_data)
for row in rows: for row in rows:

View File

@@ -428,6 +428,11 @@ def upgradeDatabase(self, db_version):
elif current_version == 26: elif current_version == 26:
db_version += 1 db_version += 1
cursor.execute("ALTER TABLE offers ADD COLUMN auto_accept_type INTEGER") cursor.execute("ALTER TABLE offers ADD COLUMN auto_accept_type INTEGER")
elif current_version == 27:
db_version += 1
cursor.execute("ALTER TABLE offers ADD COLUMN pk_from BLOB")
cursor.execute("ALTER TABLE bids ADD COLUMN pk_bid_addr BLOB")
if current_version != db_version: if current_version != db_version:
self.db_version = db_version self.db_version = db_version
self.setIntKV("db_version", db_version, cursor) self.setIntKV("db_version", db_version, cursor)

View File

View File

@@ -0,0 +1,350 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2025 The Basicswap developers
# Distributed under the MIT software license, see the accompanying
# file LICENSE or http://www.opensource.org/licenses/mit-license.php.
import base64
import json
import threading
import websocket
from queue import Queue, Empty
from basicswap.util.smsg import (
smsgEncrypt,
smsgDecrypt,
smsgGetID,
)
from basicswap.chainparams import (
Coins,
)
from basicswap.util.address import (
b58decode,
decodeWif,
)
from basicswap.basicswap_util import (
BidStates,
)
def encode_base64(data: bytes) -> str:
return base64.b64encode(data).decode("utf-8")
def decode_base64(encoded_data: str) -> bytes:
return base64.b64decode(encoded_data)
class WebSocketThread(threading.Thread):
def __init__(self, url: str, tag: str = None, logger=None):
super().__init__()
self.url: str = url
self.tag = tag
self.logger = logger
self.ws = None
self.mutex = threading.Lock()
self.corrId: int = 0
self.connected: bool = False
self.delay_event = threading.Event()
self.recv_queue = Queue()
self.cmd_recv_queue = Queue()
def on_message(self, ws, message):
if self.logger:
self.logger.debug("Simplex received msg")
else:
print(f"{self.tag} - Received msg")
if message.startswith('{"corrId"'):
self.cmd_recv_queue.put(message)
else:
self.recv_queue.put(message)
def queue_get(self):
try:
return self.recv_queue.get(block=False)
except Empty:
return None
def cmd_queue_get(self):
try:
return self.cmd_recv_queue.get(block=False)
except Empty:
return None
def on_error(self, ws, error):
if self.logger:
self.logger.error(f"Simplex ws - {error}")
else:
print(f"{self.tag} - Error: {error}")
def on_close(self, ws, close_status_code, close_msg):
if self.logger:
self.logger.info(f"Simplex ws - Closed: {close_status_code}, {close_msg}")
else:
print(f"{self.tag} - Closed: {close_status_code}, {close_msg}")
def on_open(self, ws):
if self.logger:
self.logger.info("Simplex ws - Connection opened")
else:
print(f"{self.tag}: WebSocket connection opened")
self.connected = True
def send_command(self, cmd_str: str):
with self.mutex:
self.corrId += 1
if self.logger:
self.logger.debug(f"Simplex sent command {self.corrId}")
else:
print(f"{self.tag}: sent command {self.corrId}")
cmd = json.dumps({"corrId": str(self.corrId), "cmd": cmd_str})
self.ws.send(cmd)
return self.corrId
def run(self):
self.ws = websocket.WebSocketApp(
self.url,
on_message=self.on_message,
on_error=self.on_error,
on_open=self.on_open,
on_close=self.on_close,
)
while not self.delay_event.is_set():
self.ws.run_forever()
self.delay_event.wait(0.5)
def stop(self):
self.delay_event.set()
if self.ws:
self.ws.close()
def waitForResponse(ws_thread, sent_id, delay_event):
sent_id = str(sent_id)
for i in range(100):
message = ws_thread.cmd_queue_get()
if message is not None:
data = json.loads(message)
# print(f"json: {json.dumps(data, indent=4)}")
if "corrId" in data:
if data["corrId"] == sent_id:
return data
delay_event.wait(0.5)
raise ValueError(f"waitForResponse timed-out waiting for id: {sent_id}")
def waitForConnected(ws_thread, delay_event):
for i in range(100):
if ws_thread.connected:
return True
delay_event.wait(0.5)
raise ValueError("waitForConnected timed-out.")
def getPrivkeyForAddress(self, addr) -> bytes:
ci_part = self.ci(Coins.PART)
try:
return ci_part.decodeKey(
self.callrpc(
"smsgdumpprivkey",
[
addr,
],
)
)
except Exception as e: # noqa: F841
pass
try:
return ci_part.decodeKey(
ci_part.rpc_wallet(
"dumpprivkey",
[
addr,
],
)
)
except Exception as e: # noqa: F841
pass
raise ValueError("key not found")
def sendSimplexMsg(
self, network, addr_from: str, addr_to: str, payload: bytes, msg_valid: int, cursor
) -> bytes:
self.log.debug("sendSimplexMsg")
try:
rv = self.callrpc(
"smsggetpubkey",
[
addr_to,
],
)
pubkey_to: bytes = b58decode(rv["publickey"])
except Exception as e: # noqa: F841
use_cursor = self.openDB(cursor)
try:
query: str = "SELECT pk_from FROM offers WHERE addr_from = :addr_to LIMIT 1"
rows = use_cursor.execute(query, {"addr_to": addr_to}).fetchall()
if len(rows) > 0:
pubkey_to = rows[0][0]
else:
query: str = (
"SELECT pk_bid_addr FROM bids WHERE bid_addr = :addr_to LIMIT 1"
)
rows = use_cursor.execute(query, {"addr_to": addr_to}).fetchall()
if len(rows) > 0:
pubkey_to = rows[0][0]
else:
raise ValueError(f"Could not get public key for address {addr_to}")
finally:
if cursor is None:
self.closeDB(use_cursor, commit=False)
privkey_from = getPrivkeyForAddress(self, addr_from)
payload += bytes((0,)) # Include null byte to match smsg
smsg_msg: bytes = smsgEncrypt(privkey_from, pubkey_to, payload)
smsg_id = smsgGetID(smsg_msg)
ws_thread = network["ws_thread"]
sent_id = ws_thread.send_command("#bsx " + encode_base64(smsg_msg))
response = waitForResponse(ws_thread, sent_id, self.delay_event)
if response["resp"]["type"] != "newChatItems":
json_str = json.dumps(response, indent=4)
self.log.debug(f"Response {json_str}")
raise ValueError("Send failed")
return smsg_id
def decryptSimplexMsg(self, msg_data):
ci_part = self.ci(Coins.PART)
# Try with the network key first
network_key: bytes = decodeWif(self.network_key)
try:
decrypted = smsgDecrypt(network_key, msg_data, output_dict=True)
decrypted["from"] = ci_part.pubkey_to_address(
bytes.fromhex(decrypted["pk_from"])
)
decrypted["to"] = self.network_addr
decrypted["msg_net"] = "simplex"
return decrypted
except Exception as e: # noqa: F841
pass
# Try with all active bid/offer addresses
query: str = """SELECT DISTINCT address FROM (
SELECT bid_addr AS address FROM bids WHERE active_ind = 1
AND (in_progress = 1 OR (state > :bid_received AND state < :bid_completed) OR (state IN (:bid_received, :bid_sent) AND expire_at > :now))
UNION
SELECT addr_from AS address FROM offers WHERE active_ind = 1 AND expire_at > :now
)"""
now: int = self.getTime()
try:
cursor = self.openDB()
addr_rows = cursor.execute(
query,
{
"bid_received": int(BidStates.BID_RECEIVED),
"bid_completed": int(BidStates.SWAP_COMPLETED),
"bid_sent": int(BidStates.BID_SENT),
"now": now,
},
).fetchall()
finally:
self.closeDB(cursor, commit=False)
decrypted = None
for row in addr_rows:
addr = row[0]
try:
vk_addr = getPrivkeyForAddress(self, addr)
decrypted = smsgDecrypt(vk_addr, msg_data, output_dict=True)
decrypted["from"] = ci_part.pubkey_to_address(
bytes.fromhex(decrypted["pk_from"])
)
decrypted["to"] = addr
decrypted["msg_net"] = "simplex"
return decrypted
except Exception as e: # noqa: F841
pass
return decrypted
def readSimplexMsgs(self, network):
ws_thread = network["ws_thread"]
for i in range(100):
message = ws_thread.queue_get()
if message is None:
break
data = json.loads(message)
# self.log.debug(f"message 1: {json.dumps(data, indent=4)}")
try:
if data["resp"]["type"] in ("chatItemsStatusesUpdated", "newChatItems"):
for chat_item in data["resp"]["chatItems"]:
item_status = chat_item["chatItem"]["meta"]["itemStatus"]
if item_status["type"] in ("sndRcvd", "rcvNew"):
snd_progress = item_status.get("sndProgress", None)
if snd_progress:
if snd_progress != "complete":
item_id = chat_item["chatItem"]["meta"]["itemId"]
self.log.debug(
f"simplex chat item {item_id} {snd_progress}"
)
continue
try:
msg_data: bytes = decode_base64(
chat_item["chatItem"]["content"]["msgContent"]["text"]
)
decrypted_msg = decryptSimplexMsg(self, msg_data)
if decrypted_msg is None:
continue
self.processMsg(decrypted_msg)
except Exception as e: # noqa: F841
# self.log.debug(f"decryptSimplexMsg error: {e}")
pass
except Exception as e:
self.log.debug(f"readSimplexMsgs error: {e}")
self.delay_event.wait(0.05)
def initialiseSimplexNetwork(self, network_config) -> None:
self.log.debug("initialiseSimplexNetwork")
client_host: str = network_config.get("client_host", "127.0.0.1")
ws_port: str = network_config.get("ws_port")
ws_thread = WebSocketThread(f"ws://{client_host}:{ws_port}", logger=self.log)
self.threads.append(ws_thread)
ws_thread.start()
waitForConnected(ws_thread, self.delay_event)
sent_id = ws_thread.send_command("/groups")
response = waitForResponse(ws_thread, sent_id, self.delay_event)
if len(response["resp"]["groups"]) < 1:
sent_id = ws_thread.send_command("/c " + network_config["group_link"])
response = waitForResponse(ws_thread, sent_id, self.delay_event)
assert "groupLinkId" in response["resp"]["connection"]
network = {
"type": "simplex",
"ws_thread": ws_thread,
}
self.active_networks.append(network)

View File

@@ -0,0 +1,107 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2025 The Basicswap developers
# Distributed under the MIT software license, see the accompanying
# file LICENSE or http://www.opensource.org/licenses/mit-license.php.
import os
import select
import subprocess
import time
from basicswap.bin.run import Daemon
def initSimplexClient(args, logger, delay_event):
logger.info("Initialising Simplex client")
(pipe_r, pipe_w) = os.pipe() # subprocess.PIPE is buffered, blocks when read
if os.name == "nt":
str_args = " ".join(args)
p = subprocess.Popen(
str_args, shell=True, stdin=subprocess.PIPE, stdout=pipe_w, stderr=pipe_w
)
else:
p = subprocess.Popen(args, stdin=subprocess.PIPE, stdout=pipe_w, stderr=pipe_w)
def readOutput():
buf = os.read(pipe_r, 1024).decode("utf-8")
response = None
# logging.debug(f"simplex-chat output: {buf}")
if "display name:" in buf:
logger.debug("Setting display name")
response = b"user\n"
else:
logger.debug(f"Unexpected output: {buf}")
return
if response is not None:
p.stdin.write(response)
p.stdin.flush()
try:
start_time: int = time.time()
max_wait_seconds: int = 60
while p.poll() is None:
if time.time() > start_time + max_wait_seconds:
raise ValueError("Timed out")
if os.name == "nt":
readOutput()
delay_event.wait(0.1)
continue
while len(select.select([pipe_r], [], [], 0)[0]) == 1:
readOutput()
delay_event.wait(0.1)
except Exception as e:
logger.error(f"initSimplexClient: {e}")
finally:
if p.poll() is None:
p.terminate()
os.close(pipe_r)
os.close(pipe_w)
p.stdin.close()
def startSimplexClient(
bin_path: str,
data_path: str,
server_address: str,
websocket_port: int,
logger,
delay_event,
) -> Daemon:
logger.info("Starting Simplex client")
if not os.path.exists(data_path):
os.makedirs(data_path)
db_path = os.path.join(data_path, "simplex_client_data")
args = [bin_path, "-d", db_path, "-s", server_address, "-p", str(websocket_port)]
if not os.path.exists(db_path):
# Need to set initial profile through CLI
# TODO: Must be a better way?
init_args = args + ["-e", "/help"] # Run command ro exit client
initSimplexClient(init_args, logger, delay_event)
args += ["-l", "debug"]
opened_files = []
stdout_dest = open(
os.path.join(data_path, "simplex_stdout.log"),
"w",
)
opened_files.append(stdout_dest)
stderr_dest = stdout_dest
return Daemon(
subprocess.Popen(
args,
shell=False,
stdin=subprocess.PIPE,
stdout=stdout_dest,
stderr=stderr_dest,
cwd=data_path,
),
opened_files,
)

20
basicswap/network/util.py Normal file
View File

@@ -0,0 +1,20 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2025 The Basicswap developers
# Distributed under the MIT software license, see the accompanying
# file LICENSE or http://www.opensource.org/licenses/mit-license.php.
from basicswap.util.address import b58decode
def getMsgPubkey(self, msg) -> bytes:
if "pk_from" in msg:
return bytes.fromhex(msg["pk_from"])
rv = self.callrpc(
"smsggetpubkey",
[
msg["from"],
],
)
return b58decode(rv["publickey"])

229
basicswap/util/smsg.py Normal file
View File

@@ -0,0 +1,229 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (c) 2025 The Basicswap developers
# Distributed under the MIT software license, see the accompanying
# file LICENSE or http://www.opensource.org/licenses/mit-license.php.
import hashlib
import hmac
import secrets
import time
from typing import Union, Dict
from coincurve.keys import (
PublicKey,
PrivateKey,
)
from Crypto.Cipher import AES
from basicswap.util.crypto import hash160, sha256, ripemd160
from basicswap.util.ecc import getSecretInt
from basicswap.contrib.test_framework.messages import (
uint256_from_compact,
uint256_from_str,
)
AES_BLOCK_SIZE = 16
def aes_pad(s: bytes):
c = AES_BLOCK_SIZE - len(s) % AES_BLOCK_SIZE
return s + (bytes((c,)) * c)
def aes_unpad(s: bytes):
return s[: -(s[len(s) - 1])]
def aes_encrypt(raw: bytes, pass_data: bytes, iv: bytes):
assert len(pass_data) == 32
assert len(iv) == 16
raw = aes_pad(raw)
cipher = AES.new(pass_data, AES.MODE_CBC, iv)
return cipher.encrypt(raw)
def aes_decrypt(enc, pass_data: bytes, iv: bytes):
assert len(pass_data) == 32
assert len(iv) == 16
cipher = AES.new(pass_data, AES.MODE_CBC, iv)
return aes_unpad(cipher.decrypt(enc))
SMSG_MIN_TTL = 60 * 60
SMSG_BUCKET_LEN = 60 * 60
SMSG_HDR_LEN = (
108 # Length of unencrypted header, 4 + 4 + 2 + 1 + 8 + 4 + 16 + 33 + 32 + 4
)
SMSG_PL_HDR_LEN = 1 + 20 + 65 + 4 # Length of encrypted header in payload
def smsgGetTimestamp(smsg_message: bytes) -> int:
assert len(smsg_message) > SMSG_HDR_LEN
return int.from_bytes(smsg_message[11 : 11 + 8], byteorder="little")
def smsgGetPOWHash(smsg_message: bytes) -> bytes:
assert len(smsg_message) > SMSG_HDR_LEN
ofs: int = 4
nonce: bytes = smsg_message[ofs : ofs + 4]
iv: bytes = nonce * 8
m = hmac.new(iv, digestmod="SHA256")
m.update(smsg_message[4:])
return m.digest()
def smsgGetID(smsg_message: bytes) -> bytes:
assert len(smsg_message) > SMSG_HDR_LEN
smsg_timestamp = int.from_bytes(smsg_message[11 : 11 + 8], byteorder="little")
return smsg_timestamp.to_bytes(8, byteorder="big") + ripemd160(smsg_message[8:])
def smsgEncrypt(privkey_from: bytes, pubkey_to: bytes, payload: bytes) -> bytes:
# assert len(payload) < 128 # Requires lz4 if payload > 128 bytes
# TODO: Add lz4 to match core smsg
smsg_timestamp = int(time.time())
r = getSecretInt().to_bytes(32, byteorder="big")
R = PublicKey.from_secret(r).format()
p = PrivateKey(r).ecdh(pubkey_to)
H = hashlib.sha512(p).digest()
key_e: bytes = H[:32]
key_m: bytes = H[32:]
smsg_iv: bytes = secrets.token_bytes(16)
payload_hash: bytes = sha256(sha256(payload))
signature: bytes = PrivateKey(privkey_from).sign_recoverable(
payload_hash, hasher=None
)
# Convert format to BTC, add 4 to mark as compressed key
recid = signature[64]
signature = bytes((27 + recid + 4,)) + signature[:64]
pubkey_from: bytes = PublicKey.from_secret(privkey_from).format()
pkh_from: bytes = hash160(pubkey_from)
len_payload = len(payload)
address_version = 0
plaintext_data: bytes = (
bytes((address_version,))
+ pkh_from
+ signature
+ len_payload.to_bytes(4, byteorder="little")
+ payload
)
ciphertext: bytes = aes_encrypt(plaintext_data, key_e, smsg_iv)
m = hmac.new(key_m, digestmod="SHA256")
m.update(smsg_timestamp.to_bytes(8, byteorder="little"))
m.update(smsg_iv)
m.update(ciphertext)
mac: bytes = m.digest()
smsg_hash = bytes((0,)) * 4
smsg_nonce = bytes((0,)) * 4
smsg_version = bytes((2, 1))
smsg_flags = bytes((0,))
smsg_ttl = SMSG_MIN_TTL
assert len(R) == 33
assert len(mac) == 32
smsg_message: bytes = (
smsg_hash
+ smsg_nonce
+ smsg_version
+ smsg_flags
+ smsg_timestamp.to_bytes(8, byteorder="little")
+ smsg_ttl.to_bytes(4, byteorder="little")
+ smsg_iv
+ R
+ mac
+ len(ciphertext).to_bytes(4, byteorder="little")
+ ciphertext
)
target: int = uint256_from_compact(0x1EFFFFFF)
for i in range(1000000):
pow_hash = smsgGetPOWHash(smsg_message)
if uint256_from_str(pow_hash) > target:
smsg_nonce = (int.from_bytes(smsg_nonce, byteorder="little") + 1).to_bytes(
4, byteorder="little"
)
smsg_message = pow_hash[:4] + smsg_nonce + smsg_message[8:]
continue
smsg_message = pow_hash[:4] + smsg_message[4:]
return smsg_message
raise ValueError("Failed to set POW hash.")
def smsgDecrypt(
privkey_to: bytes, encrypted_message: bytes, output_dict: bool = False
) -> Union[bytes, Dict]:
# Without lz4
assert len(encrypted_message) > SMSG_HDR_LEN
smsg_timestamp = int.from_bytes(encrypted_message[11 : 11 + 8], byteorder="little")
ofs: int = 23
smsg_iv = encrypted_message[ofs : ofs + 16]
ofs += 16
R = encrypted_message[ofs : ofs + 33]
ofs += 33
mac = encrypted_message[ofs : ofs + 32]
ofs += 32
ciphertextlen = int.from_bytes(encrypted_message[ofs : ofs + 4], byteorder="little")
ofs += 4
ciphertext = encrypted_message[ofs:]
assert len(ciphertext) == ciphertextlen
p = PrivateKey(privkey_to).ecdh(R)
H = hashlib.sha512(p).digest()
key_e: bytes = H[:32]
key_m: bytes = H[32:]
m = hmac.new(key_m, digestmod="SHA256")
m.update(smsg_timestamp.to_bytes(8, byteorder="little"))
m.update(smsg_iv)
m.update(ciphertext)
mac_calculated: bytes = m.digest()
assert mac == mac_calculated
plaintext = aes_decrypt(ciphertext, key_e, smsg_iv)
ofs = 1
pkh_from = plaintext[ofs : ofs + 20]
ofs += 20
signature = plaintext[ofs : ofs + 65]
ofs += 65
ofs += 4
payload = plaintext[ofs:]
payload_hash: bytes = sha256(sha256(payload))
# Convert format from BTC
recid = (signature[0] - 27) & 3
signature = signature[1:] + bytes((recid,))
pubkey_signer = PublicKey.from_signature_and_message(
signature, payload_hash, hasher=None
).format()
pkh_from_recovered: bytes = hash160(pubkey_signer)
assert pkh_from == pkh_from_recovered
if output_dict:
return {
"msgid": smsgGetID(encrypted_message).hex(),
"sent": smsg_timestamp,
"hex": payload.hex(),
"pk_from": pubkey_signer.hex(),
}
return payload

View File

@@ -3,4 +3,5 @@ python-gnupg==0.5.4
Jinja2==3.1.6 Jinja2==3.1.6
pycryptodome==3.21.0 pycryptodome==3.21.0
PySocks==1.7.1 PySocks==1.7.1
websocket-client==1.8.0
coincurve@https://github.com/basicswap/coincurve/archive/refs/tags/basicswap_v0.2.zip coincurve@https://github.com/basicswap/coincurve/archive/refs/tags/basicswap_v0.2.zip

View File

@@ -1,5 +1,5 @@
# #
# This file is autogenerated by pip-compile with Python 3.12 # This file is autogenerated by pip-compile with Python 3.13
# by the following command: # by the following command:
# #
# pip-compile --generate-hashes --output-file=requirements.txt requirements.in # pip-compile --generate-hashes --output-file=requirements.txt requirements.in
@@ -305,3 +305,7 @@ pyzmq==26.2.1 \
--hash=sha256:f9ba5def063243793dec6603ad1392f735255cbc7202a3a484c14f99ec290705 \ --hash=sha256:f9ba5def063243793dec6603ad1392f735255cbc7202a3a484c14f99ec290705 \
--hash=sha256:fc409c18884eaf9ddde516d53af4f2db64a8bc7d81b1a0c274b8aa4e929958e8 --hash=sha256:fc409c18884eaf9ddde516d53af4f2db64a8bc7d81b1a0c274b8aa4e929958e8
# via -r requirements.in # via -r requirements.in
websocket-client==1.8.0 \
--hash=sha256:17b44cc997f5c498e809b22cdf2d9c7a9e71c02c8cc2b6c56e7c2d1239bfa526 \
--hash=sha256:3239df9f44da632f96012472805d40a23281a991027ce11d2f45a6f24ac4c3da
# via -r requirements.in

View File

@@ -30,7 +30,6 @@ from basicswap.contrib.test_framework.messages import (
CTransaction, CTransaction,
CTxIn, CTxIn,
COutPoint, COutPoint,
ToHex,
) )
from basicswap.contrib.test_framework.script import ( from basicswap.contrib.test_framework.script import (
CScript, CScript,
@@ -318,7 +317,7 @@ class Test(TestFunctions):
tx = CTransaction() tx = CTransaction()
tx.nVersion = ci.txVersion() tx.nVersion = ci.txVersion()
tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest)) tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest))
tx_hex = ToHex(tx) tx_hex = tx.serialize().hex()
tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex]) tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex])
utxo_pos = 0 if tx_funded["changepos"] == 1 else 1 utxo_pos = 0 if tx_funded["changepos"] == 1 else 1
tx_signed = ci.rpc_wallet( tx_signed = ci.rpc_wallet(
@@ -357,10 +356,10 @@ class Test(TestFunctions):
) )
) )
tx_spend.vout.append(ci.txoType()(ci.make_int(1.099), script_out)) tx_spend.vout.append(ci.txoType()(ci.make_int(1.099), script_out))
tx_spend_hex = ToHex(tx_spend) tx_spend_hex = tx_spend.serialize().hex()
tx_spend.nLockTime = chain_height + 2 tx_spend.nLockTime = chain_height + 2
tx_spend_invalid_hex = ToHex(tx_spend) tx_spend_invalid_hex = tx_spend.serialize().hex()
for tx_hex in [tx_spend_invalid_hex, tx_spend_hex]: for tx_hex in [tx_spend_invalid_hex, tx_spend_hex]:
try: try:

View File

@@ -0,0 +1,342 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (c) 2025 The Basicswap developers
# Distributed under the MIT software license, see the accompanying
# file LICENSE or http://www.opensource.org/licenses/mit-license.php.
"""
docker run \
-e "ADDR=127.0.0.1" \
-e "PASS=password" \
-p 5223:5223 \
-v /tmp/simplex/smp/config:/etc/opt/simplex:z \
-v /tmp/simplex/smp/logs:/var/opt/simplex:z \
-v /tmp/simplex/certs:/certificates \
simplexchat/smp-server:latest
Fingerprint: Q8SNxc2SRcKyXlhJM8KFUgPNW4KXPGRm4eSLtT_oh-I=
export SIMPLEX_SERVER_ADDRESS=smp://Q8SNxc2SRcKyXlhJM8KFUgPNW4KXPGRm4eSLtT_oh-I=:password@127.0.0.1:5223,443
https://github.com/simplex-chat/simplex-chat/issues/4127
json: {"corrId":"3","cmd":"/_send #1 text test123"}
direct message: {"corrId":"1","cmd":"/_send @2 text the message"}
"""
import json
import logging
import os
import random
import shutil
import sys
import unittest
import basicswap.config as cfg
from basicswap.basicswap import (
BidStates,
SwapTypes,
)
from basicswap.chainparams import Coins
from basicswap.network.simplex import (
WebSocketThread,
waitForConnected,
waitForResponse,
)
from basicswap.network.simplex_chat import startSimplexClient
from tests.basicswap.common import (
stopDaemons,
wait_for_bid,
wait_for_offer,
)
from tests.basicswap.test_xmr import BaseTest, test_delay_event, RESET_TEST
SIMPLEX_SERVER_ADDRESS = os.getenv("SIMPLEX_SERVER_ADDRESS")
SIMPLEX_CLIENT_PATH = os.path.expanduser(os.getenv("SIMPLEX_CLIENT_PATH"))
TEST_DIR = cfg.TEST_DATADIRS
logger = logging.getLogger()
logger.level = logging.DEBUG
if not len(logger.handlers):
logger.addHandler(logging.StreamHandler(sys.stdout))
class TestSimplex(unittest.TestCase):
daemons = []
remove_testdir: bool = False
@classmethod
def tearDownClass(cls):
stopDaemons(cls.daemons)
def test_basic(self):
if os.path.isdir(TEST_DIR):
if RESET_TEST:
logging.info("Removing " + TEST_DIR)
shutil.rmtree(TEST_DIR)
else:
logging.info("Restoring instance from " + TEST_DIR)
if not os.path.exists(TEST_DIR):
os.makedirs(TEST_DIR)
client1_dir = os.path.join(TEST_DIR, "client1")
if os.path.exists(client1_dir):
shutil.rmtree(client1_dir)
client1_daemon = startSimplexClient(
SIMPLEX_CLIENT_PATH,
client1_dir,
SIMPLEX_SERVER_ADDRESS,
5225,
logger,
test_delay_event,
)
self.daemons.append(client1_daemon)
client2_dir = os.path.join(TEST_DIR, "client2")
if os.path.exists(client2_dir):
shutil.rmtree(client2_dir)
client2_daemon = startSimplexClient(
SIMPLEX_CLIENT_PATH,
client2_dir,
SIMPLEX_SERVER_ADDRESS,
5226,
logger,
test_delay_event,
)
self.daemons.append(client2_daemon)
threads = []
try:
ws_thread = WebSocketThread("ws://127.0.0.1:5225", tag="C1")
ws_thread.start()
threads.append(ws_thread)
ws_thread2 = WebSocketThread("ws://127.0.0.1:5226", tag="C2")
ws_thread2.start()
threads.append(ws_thread2)
waitForConnected(ws_thread, test_delay_event)
sent_id = ws_thread.send_command("/group bsx")
response = waitForResponse(ws_thread, sent_id, test_delay_event)
assert response["resp"]["type"] == "groupCreated"
ws_thread.send_command("/set voice #bsx off")
ws_thread.send_command("/set files #bsx off")
ws_thread.send_command("/set direct #bsx off")
ws_thread.send_command("/set reactions #bsx off")
ws_thread.send_command("/set reports #bsx off")
ws_thread.send_command("/set disappear #bsx on week")
sent_id = ws_thread.send_command("/create link #bsx")
connReqContact = None
connReqMsgData = waitForResponse(ws_thread, sent_id, test_delay_event)
connReqContact = connReqMsgData["resp"]["connReqContact"]
group_link = "https://simplex.chat" + connReqContact[8:]
logger.info(f"group_link: {group_link}")
sent_id = ws_thread2.send_command("/c " + group_link)
response = waitForResponse(ws_thread2, sent_id, test_delay_event)
assert "groupLinkId" in response["resp"]["connection"]
sent_id = ws_thread2.send_command("/groups")
response = waitForResponse(ws_thread2, sent_id, test_delay_event)
assert len(response["resp"]["groups"]) == 1
ws_thread.send_command("#bsx test msg 1")
found_1 = False
found_2 = False
for i in range(100):
message = ws_thread.queue_get()
if message is not None:
data = json.loads(message)
# print(f"message 1: {json.dumps(data, indent=4)}")
try:
if data["resp"]["type"] in (
"chatItemsStatusesUpdated",
"newChatItems",
):
for chat_item in data["resp"]["chatItems"]:
# print(f"chat_item 1: {json.dumps(chat_item, indent=4)}")
if chat_item["chatItem"]["meta"]["itemStatus"][
"type"
] in ("sndRcvd", "rcvNew"):
if (
chat_item["chatItem"]["content"]["msgContent"][
"text"
]
== "test msg 1"
):
found_1 = True
except Exception as e:
print(f"error 1: {e}")
message = ws_thread2.queue_get()
if message is not None:
data = json.loads(message)
# print(f"message 2: {json.dumps(data, indent=4)}")
try:
if data["resp"]["type"] in (
"chatItemsStatusesUpdated",
"newChatItems",
):
for chat_item in data["resp"]["chatItems"]:
# print(f"chat_item 1: {json.dumps(chat_item, indent=4)}")
if chat_item["chatItem"]["meta"]["itemStatus"][
"type"
] in ("sndRcvd", "rcvNew"):
if (
chat_item["chatItem"]["content"]["msgContent"][
"text"
]
== "test msg 1"
):
found_2 = True
except Exception as e:
print(f"error 2: {e}")
if found_1 and found_2:
break
test_delay_event.wait(0.5)
assert found_1 is True
assert found_2 is True
finally:
for t in threads:
t.stop()
t.join()
class Test(BaseTest):
__test__ = True
start_ltc_nodes = False
start_xmr_nodes = True
group_link = None
daemons = []
coin_to = Coins.XMR
# coin_to = Coins.PART
@classmethod
def prepareTestDir(cls):
base_ws_port: int = 5225
for i in range(cls.num_nodes):
client_dir = os.path.join(TEST_DIR, f"simplex_client{i}")
if os.path.exists(client_dir):
shutil.rmtree(client_dir)
client_daemon = startSimplexClient(
SIMPLEX_CLIENT_PATH,
client_dir,
SIMPLEX_SERVER_ADDRESS,
base_ws_port + i,
logger,
test_delay_event,
)
cls.daemons.append(client_daemon)
# Create the group for bsx
logger.info("Creating BSX group")
ws_thread = None
try:
ws_thread = WebSocketThread(f"ws://127.0.0.1:{base_ws_port}", tag="C0")
ws_thread.start()
waitForConnected(ws_thread, test_delay_event)
sent_id = ws_thread.send_command("/group bsx")
response = waitForResponse(ws_thread, sent_id, test_delay_event)
assert response["resp"]["type"] == "groupCreated"
ws_thread.send_command("/set voice #bsx off")
ws_thread.send_command("/set files #bsx off")
ws_thread.send_command("/set direct #bsx off")
ws_thread.send_command("/set reactions #bsx off")
ws_thread.send_command("/set reports #bsx off")
ws_thread.send_command("/set disappear #bsx on week")
sent_id = ws_thread.send_command("/create link #bsx")
connReqContact = None
connReqMsgData = waitForResponse(ws_thread, sent_id, test_delay_event)
connReqContact = connReqMsgData["resp"]["connReqContact"]
cls.group_link = "https://simplex.chat" + connReqContact[8:]
logger.info(f"BSX group_link: {cls.group_link}")
finally:
if ws_thread:
ws_thread.stop()
ws_thread.join()
@classmethod
def tearDownClass(cls):
logging.info("Finalising Test")
super(Test, cls).tearDownClass()
stopDaemons(cls.daemons)
@classmethod
def addCoinSettings(cls, settings, datadir, node_id):
settings["networks"] = [
{
"type": "simplex",
"server_address": SIMPLEX_SERVER_ADDRESS,
"client_path": SIMPLEX_CLIENT_PATH,
"ws_port": 5225 + node_id,
"group_link": cls.group_link,
},
]
def test_01_swap(self):
logging.info("---------- Test xmr swap")
swap_clients = self.swap_clients
for sc in swap_clients:
sc.dleag_split_size_init = 9000
sc.dleag_split_size = 11000
assert len(swap_clients[0].active_networks) == 1
assert swap_clients[0].active_networks[0]["type"] == "simplex"
coin_from = Coins.BTC
coin_to = self.coin_to
ci_from = swap_clients[0].ci(coin_from)
ci_to = swap_clients[1].ci(coin_to)
swap_value = ci_from.make_int(random.uniform(0.2, 20.0), r=1)
rate_swap = ci_to.make_int(random.uniform(0.2, 20.0), r=1)
offer_id = swap_clients[0].postOffer(
coin_from, coin_to, swap_value, rate_swap, swap_value, SwapTypes.XMR_SWAP
)
wait_for_offer(test_delay_event, swap_clients[1], offer_id)
offer = swap_clients[1].getOffer(offer_id)
bid_id = swap_clients[1].postBid(offer_id, offer.amount_from)
wait_for_bid(test_delay_event, swap_clients[0], bid_id, BidStates.BID_RECEIVED)
swap_clients[0].acceptBid(bid_id)
wait_for_bid(
test_delay_event,
swap_clients[0],
bid_id,
BidStates.SWAP_COMPLETED,
wait_for=320,
)
wait_for_bid(
test_delay_event,
swap_clients[1],
bid_id,
BidStates.SWAP_COMPLETED,
sent=True,
wait_for=320,
)

View File

@@ -0,0 +1,147 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (c) 2025 The Basicswap developers
# Distributed under the MIT software license, see the accompanying
# file LICENSE or http://www.opensource.org/licenses/mit-license.php.
import logging
from basicswap.chainparams import Coins
from basicswap.util.smsg import (
smsgEncrypt,
smsgDecrypt,
smsgGetID,
smsgGetTimestamp,
SMSG_BUCKET_LEN,
)
from basicswap.contrib.test_framework.messages import (
NODE_SMSG,
msg_smsgPong,
msg_smsgMsg,
)
from basicswap.contrib.test_framework.p2p import (
P2PInterface,
P2P_SERVICES,
NetworkThread,
)
from basicswap.contrib.test_framework.util import (
PortSeed,
)
from tests.basicswap.common import BASE_PORT
from tests.basicswap.test_xmr import BaseTest, test_delay_event
class P2PInterfaceSMSG(P2PInterface):
def __init__(self):
super().__init__()
self.is_part = True
def on_smsgPing(self, msg):
logging.info("on_smsgPing")
self.send_message(msg_smsgPong(1))
def on_smsgPong(self, msg):
logging.info("on_smsgPong", msg)
def on_smsgInv(self, msg):
logging.info("on_smsgInv")
def wait_for_smsg(ci, msg_id: str, wait_for=20) -> None:
for i in range(wait_for):
if test_delay_event.is_set():
raise ValueError("Test stopped.")
try:
ci.rpc_wallet("smsg", [msg_id])
return
except Exception as e:
logging.info(e)
test_delay_event.wait(1)
raise ValueError("wait_for_smsg timed out.")
class Test(BaseTest):
__test__ = True
start_ltc_nodes = False
start_xmr_nodes = False
@classmethod
def setUpClass(cls):
super(Test, cls).setUpClass()
PortSeed.n = 1
logging.info("Setting up network thread")
cls.network_thread = NetworkThread()
cls.network_thread.network_event_loop.set_debug(True)
cls.network_thread.start()
cls.network_thread.network_event_loop.set_debug(True)
@classmethod
def run_loop_ended(cls):
logging.info("run_loop_ended")
logging.info("Closing down network thread")
cls.network_thread.close()
@classmethod
def tearDownClass(cls):
logging.info("Finalising Test")
# logging.info('Closing down network thread')
# cls.network_thread.close()
super(Test, cls).tearDownClass()
@classmethod
def coins_loop(cls):
super(Test, cls).coins_loop()
def test_01_p2p(self):
swap_clients = self.swap_clients
kwargs = {}
kwargs["dstport"] = BASE_PORT
kwargs["dstaddr"] = "127.0.0.1"
services = P2P_SERVICES | NODE_SMSG
p2p_conn = P2PInterfaceSMSG()
p2p_conn.p2p_connected_to_node = True
p2p_conn.peer_connect(
**kwargs,
services=services,
send_version=True,
net="regtest",
timeout_factor=99999,
supports_v2_p2p=False,
)()
p2p_conn.wait_for_connect()
p2p_conn.wait_for_verack()
p2p_conn.sync_with_ping()
ci0_part = swap_clients[0].ci(Coins.PART)
test_key_recv: bytes = ci0_part.getNewRandomKey()
test_key_recv_wif: str = ci0_part.encodeKey(test_key_recv)
test_key_recv_pk: bytes = ci0_part.getPubkey(test_key_recv)
ci0_part.rpc("smsgimportprivkey", [test_key_recv_wif, "test key"])
message_test: str = "Test message"
test_key_send: bytes = ci0_part.getNewRandomKey()
encrypted_message: bytes = smsgEncrypt(
test_key_send, test_key_recv_pk, message_test.encode("utf-8")
)
decrypted_message: bytes = smsgDecrypt(test_key_recv, encrypted_message)
assert decrypted_message.decode("utf-8") == message_test
msg_id: bytes = smsgGetID(encrypted_message)
smsg_timestamp: int = smsgGetTimestamp(encrypted_message)
smsg_bucket: int = smsg_timestamp - (smsg_timestamp % SMSG_BUCKET_LEN)
smsgMsg = msg_smsgMsg(1, smsg_bucket, encrypted_message)
p2p_conn.send_message(smsgMsg)
wait_for_smsg(ci0_part, msg_id.hex())
rv = ci0_part.rpc_wallet("smsg", [msg_id.hex()])
assert rv["text"] == message_test

View File

@@ -26,7 +26,6 @@ from tests.basicswap.common import (
waitForRPC, waitForRPC,
) )
from basicswap.contrib.test_framework.messages import ( from basicswap.contrib.test_framework.messages import (
ToHex,
CTxIn, CTxIn,
COutPoint, COutPoint,
CTransaction, CTransaction,
@@ -251,7 +250,7 @@ class TestBCH(BasicSwapTest):
tx = CTransaction() tx = CTransaction()
tx.nVersion = ci.txVersion() tx.nVersion = ci.txVersion()
tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest)) tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest))
tx_hex = ToHex(tx) tx_hex = tx.serialize().hex()
tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex]) tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex])
utxo_pos = 0 if tx_funded["changepos"] == 1 else 1 utxo_pos = 0 if tx_funded["changepos"] == 1 else 1
tx_signed = ci.rpc_wallet( tx_signed = ci.rpc_wallet(
@@ -285,10 +284,10 @@ class TestBCH(BasicSwapTest):
) )
) )
tx_spend.vout.append(ci.txoType()(ci.make_int(1.0999), script_out)) tx_spend.vout.append(ci.txoType()(ci.make_int(1.0999), script_out))
tx_spend_hex = ToHex(tx_spend) tx_spend_hex = tx_spend.serialize().hex()
tx_spend.nLockTime = chain_height + 2 tx_spend.nLockTime = chain_height + 2
tx_spend_invalid_hex = ToHex(tx_spend) tx_spend_invalid_hex = tx_spend.serialize().hex()
for tx_hex in [tx_spend_invalid_hex, tx_spend_hex]: for tx_hex in [tx_spend_invalid_hex, tx_spend_hex]:
try: try:
@@ -362,7 +361,7 @@ class TestBCH(BasicSwapTest):
tx = CTransaction() tx = CTransaction()
tx.nVersion = ci.txVersion() tx.nVersion = ci.txVersion()
tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest)) tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest))
tx_hex = ToHex(tx) tx_hex = tx.serialize().hex()
tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex]) tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex])
utxo_pos = 0 if tx_funded["changepos"] == 1 else 1 utxo_pos = 0 if tx_funded["changepos"] == 1 else 1
tx_signed = ci.rpc_wallet( tx_signed = ci.rpc_wallet(
@@ -405,7 +404,7 @@ class TestBCH(BasicSwapTest):
) )
) )
tx_spend.vout.append(ci.txoType()(ci.make_int(1.0999), script_out)) tx_spend.vout.append(ci.txoType()(ci.make_int(1.0999), script_out))
tx_spend_hex = ToHex(tx_spend) tx_spend_hex = tx_spend.serialize().hex()
try: try:
txid = ci.rpc( txid = ci.rpc(
"sendrawtransaction", "sendrawtransaction",
@@ -640,7 +639,7 @@ class TestBCH(BasicSwapTest):
tx = CTransaction() tx = CTransaction()
tx.nVersion = ci.txVersion() tx.nVersion = ci.txVersion()
tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest)) tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest))
tx_hex = ToHex(tx) tx_hex = tx.serialize().hex()
tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex]) tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex])
utxo_pos = 0 if tx_funded["changepos"] == 1 else 1 utxo_pos = 0 if tx_funded["changepos"] == 1 else 1
tx_signed = ci.rpc_wallet( tx_signed = ci.rpc_wallet(
@@ -682,7 +681,7 @@ class TestBCH(BasicSwapTest):
) )
) )
tx_spend.vout.append(ci.txoType()(ci.make_int(1.0999), script_out)) tx_spend.vout.append(ci.txoType()(ci.make_int(1.0999), script_out))
tx_spend_hex = ToHex(tx_spend) tx_spend_hex = tx_spend.serialize().hex()
txid = ci.rpc( txid = ci.rpc(
"sendrawtransaction", "sendrawtransaction",
@@ -730,7 +729,7 @@ class TestBCH(BasicSwapTest):
tx = CTransaction() tx = CTransaction()
tx.nVersion = ci.txVersion() tx.nVersion = ci.txVersion()
tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest)) tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest))
tx_hex = ToHex(tx) tx_hex = tx.serialize().hex()
tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex]) tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex])
utxo_pos = 0 if tx_funded["changepos"] == 1 else 1 utxo_pos = 0 if tx_funded["changepos"] == 1 else 1
tx_signed = ci.rpc_wallet( tx_signed = ci.rpc_wallet(
@@ -772,7 +771,7 @@ class TestBCH(BasicSwapTest):
) )
) )
tx_spend.vout.append(ci.txoType()(ci.make_int(1.0999), script_out)) tx_spend.vout.append(ci.txoType()(ci.make_int(1.0999), script_out))
tx_spend_hex = ToHex(tx_spend) tx_spend_hex = tx_spend.serialize().hex()
txid = ci.rpc( txid = ci.rpc(
"sendrawtransaction", "sendrawtransaction",

View File

@@ -46,8 +46,7 @@ from tests.basicswap.common import (
) )
from basicswap.contrib.test_framework.descriptors import descsum_create from basicswap.contrib.test_framework.descriptors import descsum_create
from basicswap.contrib.test_framework.messages import ( from basicswap.contrib.test_framework.messages import (
ToHex, from_hex,
FromHex,
CTxIn, CTxIn,
COutPoint, COutPoint,
CTransaction, CTransaction,
@@ -860,7 +859,7 @@ class BasicSwapTest(TestFunctions):
addr_p2sh_segwit, addr_p2sh_segwit,
], ],
) )
decoded_tx = FromHex(CTransaction(), tx_funded) decoded_tx = from_hex(CTransaction(), tx_funded)
decoded_tx.vin[0].scriptSig = bytes.fromhex("16" + addr_p2sh_segwit_info["hex"]) decoded_tx.vin[0].scriptSig = bytes.fromhex("16" + addr_p2sh_segwit_info["hex"])
txid_with_scriptsig = decoded_tx.rehash() txid_with_scriptsig = decoded_tx.rehash()
assert txid_with_scriptsig == tx_signed_decoded["txid"] assert txid_with_scriptsig == tx_signed_decoded["txid"]
@@ -950,7 +949,7 @@ class BasicSwapTest(TestFunctions):
tx = CTransaction() tx = CTransaction()
tx.nVersion = ci.txVersion() tx.nVersion = ci.txVersion()
tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest)) tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest))
tx_hex = ToHex(tx) tx_hex = tx.serialize().hex()
tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex]) tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex])
utxo_pos = 0 if tx_funded["changepos"] == 1 else 1 utxo_pos = 0 if tx_funded["changepos"] == 1 else 1
tx_signed = ci.rpc_wallet( tx_signed = ci.rpc_wallet(
@@ -979,10 +978,10 @@ class BasicSwapTest(TestFunctions):
tx_spend.wit.vtxinwit[0].scriptWitness.stack = [ tx_spend.wit.vtxinwit[0].scriptWitness.stack = [
script, script,
] ]
tx_spend_hex = ToHex(tx_spend) tx_spend_hex = tx_spend.serialize().hex()
tx_spend.nLockTime = chain_height + 2 tx_spend.nLockTime = chain_height + 2
tx_spend_invalid_hex = ToHex(tx_spend) tx_spend_invalid_hex = tx_spend.serialize().hex()
for tx_hex in [tx_spend_invalid_hex, tx_spend_hex]: for tx_hex in [tx_spend_invalid_hex, tx_spend_hex]:
try: try:
@@ -1055,7 +1054,7 @@ class BasicSwapTest(TestFunctions):
tx = CTransaction() tx = CTransaction()
tx.nVersion = ci.txVersion() tx.nVersion = ci.txVersion()
tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest)) tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest))
tx_hex = ToHex(tx) tx_hex = tx.serialize().hex()
tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex]) tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex])
utxo_pos = 0 if tx_funded["changepos"] == 1 else 1 utxo_pos = 0 if tx_funded["changepos"] == 1 else 1
tx_signed = ci.rpc_wallet( tx_signed = ci.rpc_wallet(
@@ -1094,7 +1093,7 @@ class BasicSwapTest(TestFunctions):
tx_spend.wit.vtxinwit[0].scriptWitness.stack = [ tx_spend.wit.vtxinwit[0].scriptWitness.stack = [
script, script,
] ]
tx_spend_hex = ToHex(tx_spend) tx_spend_hex = tx_spend.serialize().hex()
try: try:
txid = ci.rpc( txid = ci.rpc(
"sendrawtransaction", "sendrawtransaction",
@@ -1435,7 +1434,7 @@ class BasicSwapTest(TestFunctions):
tx = CTransaction() tx = CTransaction()
tx.nVersion = ci.txVersion() tx.nVersion = ci.txVersion()
tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest)) tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest))
tx_hex = ToHex(tx) tx_hex = tx.serialize().hex()
tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex]) tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex])
utxo_pos = 0 if tx_funded["changepos"] == 1 else 1 utxo_pos = 0 if tx_funded["changepos"] == 1 else 1
tx_signed = ci.rpc_wallet( tx_signed = ci.rpc_wallet(
@@ -1477,7 +1476,7 @@ class BasicSwapTest(TestFunctions):
) )
) )
tx_spend.vout.append(ci.txoType()(ci.make_int(1.0999), script_out)) tx_spend.vout.append(ci.txoType()(ci.make_int(1.0999), script_out))
tx_spend_hex = ToHex(tx_spend) tx_spend_hex = tx_spend.serialize().hex()
txid = ci.rpc( txid = ci.rpc(
"sendrawtransaction", "sendrawtransaction",
@@ -1525,7 +1524,7 @@ class BasicSwapTest(TestFunctions):
tx = CTransaction() tx = CTransaction()
tx.nVersion = ci.txVersion() tx.nVersion = ci.txVersion()
tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest)) tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest))
tx_hex = ToHex(tx) tx_hex = tx.serialize().hex()
tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex]) tx_funded = ci.rpc_wallet("fundrawtransaction", [tx_hex])
utxo_pos = 0 if tx_funded["changepos"] == 1 else 1 utxo_pos = 0 if tx_funded["changepos"] == 1 else 1
tx_signed = ci.rpc_wallet( tx_signed = ci.rpc_wallet(
@@ -1567,7 +1566,7 @@ class BasicSwapTest(TestFunctions):
tx_spend.wit.vtxinwit[0].scriptWitness.stack = [ tx_spend.wit.vtxinwit[0].scriptWitness.stack = [
script, script,
] ]
tx_spend_hex = ToHex(tx_spend) tx_spend_hex = tx_spend.serialize().hex()
txid = ci.rpc( txid = ci.rpc(
"sendrawtransaction", "sendrawtransaction",

View File

@@ -56,7 +56,6 @@ from basicswap.contrib.test_framework.messages import (
CTransaction, CTransaction,
CTxIn, CTxIn,
CTxInWitness, CTxInWitness,
ToHex,
) )
from basicswap.contrib.test_framework.script import ( from basicswap.contrib.test_framework.script import (
CScript, CScript,
@@ -211,7 +210,7 @@ class Test(BaseTest):
tx = CTransaction() tx = CTransaction()
tx.nVersion = ci.txVersion() tx.nVersion = ci.txVersion()
tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest)) tx.vout.append(ci.txoType()(ci.make_int(1.1), script_dest))
tx_hex = ToHex(tx) tx_hex = tx.serialize().hex()
tx_funded = callnoderpc(0, "fundrawtransaction", [tx_hex]) tx_funded = callnoderpc(0, "fundrawtransaction", [tx_hex])
utxo_pos = 0 if tx_funded["changepos"] == 1 else 1 utxo_pos = 0 if tx_funded["changepos"] == 1 else 1
tx_signed = callnoderpc( tx_signed = callnoderpc(
@@ -248,10 +247,10 @@ class Test(BaseTest):
tx_spend.wit.vtxinwit[0].scriptWitness.stack = [ tx_spend.wit.vtxinwit[0].scriptWitness.stack = [
script, script,
] ]
tx_spend_hex = ToHex(tx_spend) tx_spend_hex = tx_spend.serialize().hex()
tx_spend.nLockTime = chain_height + 2 tx_spend.nLockTime = chain_height + 2
tx_spend_invalid_hex = ToHex(tx_spend) tx_spend_invalid_hex = tx_spend.serialize().hex()
for tx_hex in [tx_spend_invalid_hex, tx_spend_hex]: for tx_hex in [tx_spend_invalid_hex, tx_spend_hex]:
try: try:

View File

@@ -247,7 +247,7 @@ def ltcCli(cmd, node_id=0):
def signal_handler(sig, frame): def signal_handler(sig, frame):
logging.info("signal {} detected.".format(sig)) logging.info(f"signal {sig} detected.")
signal_event.set() signal_event.set()
test_delay_event.set() test_delay_event.set()
@@ -309,6 +309,7 @@ def run_loop(cls):
for c in cls.swap_clients: for c in cls.swap_clients:
c.update() c.update()
test_delay_event.wait(1.0) test_delay_event.wait(1.0)
cls.run_loop_ended()
class BaseTest(unittest.TestCase): class BaseTest(unittest.TestCase):
@@ -322,12 +323,13 @@ class BaseTest(unittest.TestCase):
ltc_daemons = [] ltc_daemons = []
xmr_daemons = [] xmr_daemons = []
xmr_wallet_auth = [] xmr_wallet_auth = []
restore_instance = False restore_instance: bool = False
extra_wait_time = 0 extra_wait_time: int = 0
num_nodes: int = NUM_NODES
start_ltc_nodes = False start_ltc_nodes: bool = False
start_xmr_nodes = True start_xmr_nodes: bool = True
has_segwit = True has_segwit: bool = True
xmr_addr = None xmr_addr = None
btc_addr = None btc_addr = None
@@ -392,6 +394,8 @@ class BaseTest(unittest.TestCase):
cls.stream_fp.setFormatter(formatter) cls.stream_fp.setFormatter(formatter)
logger.addHandler(cls.stream_fp) logger.addHandler(cls.stream_fp)
cls.prepareTestDir()
try: try:
logging.info("Preparing coin nodes.") logging.info("Preparing coin nodes.")
for i in range(NUM_NODES): for i in range(NUM_NODES):
@@ -645,6 +649,7 @@ class BaseTest(unittest.TestCase):
start_nodes, start_nodes,
cls, cls,
) )
basicswap_dir = os.path.join( basicswap_dir = os.path.join(
os.path.join(TEST_DIR, "basicswap_" + str(i)) os.path.join(TEST_DIR, "basicswap_" + str(i))
) )
@@ -966,6 +971,10 @@ class BaseTest(unittest.TestCase):
super(BaseTest, cls).tearDownClass() super(BaseTest, cls).tearDownClass()
@classmethod
def prepareTestDir(cls):
pass
@classmethod @classmethod
def addCoinSettings(cls, settings, datadir, node_id): def addCoinSettings(cls, settings, datadir, node_id):
pass pass
@@ -995,6 +1004,10 @@ class BaseTest(unittest.TestCase):
{"wallet_address": cls.xmr_addr, "amount_of_blocks": 1}, {"wallet_address": cls.xmr_addr, "amount_of_blocks": 1},
) )
@classmethod
def run_loop_ended(cls):
pass
@classmethod @classmethod
def waitForParticlHeight(cls, num_blocks, node_id=0): def waitForParticlHeight(cls, num_blocks, node_id=0):
logging.info( logging.info(