diff --git a/basicswap/basicswap.py b/basicswap/basicswap.py index 737dca0..ac35f8d 100644 --- a/basicswap/basicswap.py +++ b/basicswap/basicswap.py @@ -3611,6 +3611,16 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp): raise ValueError( f"Invalid swap type for: {coin_from.name} -> {coin_to.name}" ) + strict_swap_type: bool = self.settings.get( + "strict_swap_type", False if self.chain == "regtest" else True + ) + if strict_swap_type and ( + coin_from not in self.coins_without_segwit + or coin_to not in self.coins_without_segwit + ): + raise ValueError( + f"Coin pair should use adaptor sig swap type: {coin_from.name} -> {coin_to.name}" + ) def _process_notification_safe(self, event_type, event_data) -> None: try: diff --git a/doc/release-notes.md b/doc/release-notes.md index 8e82176..e342f02 100644 --- a/doc/release-notes.md +++ b/doc/release-notes.md @@ -3,6 +3,8 @@ ============== - Updated docker base images to Debian Trixie. +- By default reject secret hash type offers where the coin pair could use adaptor sig swap. + - override with "strict_swap_type" setting. 0.16.3 diff --git a/tests/basicswap/test_other.py b/tests/basicswap/test_other.py index 2782b1c..c992642 100644 --- a/tests/basicswap/test_other.py +++ b/tests/basicswap/test_other.py @@ -8,6 +8,7 @@ import hashlib import logging +import os import random import secrets import threading @@ -22,10 +23,15 @@ from coincurve.ecdsaotves import ( ) from coincurve.keys import PrivateKey +from basicswap.basicswap import ( + Coins, + BasicSwap, + SwapTypes, +) from basicswap.contrib.mnemonic import Mnemonic from basicswap.db import create_db_, DBMethods, KnownIdentity from basicswap.util import h2b -from basicswap.util.address import decodeAddress +from basicswap.util.address import decodeAddress, toWIF from basicswap.util.crypto import ripemd160, hash160, blake256 from basicswap.util.extkey import ExtKeyPair from basicswap.util.integer import encode_varint, decode_varint @@ -60,6 +66,9 @@ from basicswap.contrib.test_framework.messages import ( CTxOut, uint256_from_str, ) +from tests.basicswap.common import ( + PREFIX_SECRET_KEY_REGTEST, +) logger = logging.getLogger() @@ -790,6 +799,74 @@ class Test(unittest.TestCase): == "252cd6e85b99e0fd554c44d5fe638923f7ef563048362406a665cf3400feb1bd" ) + def test_validateSwapType(self): + logging.info("---------- Test validateSwapType") + basicswap_dir = "/tmp/bsx_test_other" + if not os.path.exists(basicswap_dir): + os.makedirs(basicswap_dir) + + k = PrivateKey() + settings = { + "network_key": toWIF(PREFIX_SECRET_KEY_REGTEST, k.secret), + "network_pubkey": k.public_key.format().hex(), + } + + sc = BasicSwap( + basicswap_dir, + settings, + "regtest", + log_name="bsx_test_other", + ) + + should_pass = [ + (Coins.BTC, Coins.XMR, SwapTypes.XMR_SWAP), + (Coins.XMR, Coins.BTC, SwapTypes.XMR_SWAP), + (Coins.BTC, Coins.FIRO, SwapTypes.XMR_SWAP), + (Coins.FIRO, Coins.BTC, SwapTypes.XMR_SWAP), + (Coins.PIVX, Coins.BTC, SwapTypes.SELLER_FIRST), + (Coins.BTC, Coins.PIVX, SwapTypes.SELLER_FIRST), + (Coins.DASH, Coins.PIVX, SwapTypes.SELLER_FIRST), + (Coins.PIVX, Coins.DASH, SwapTypes.SELLER_FIRST), + ] + should_fail = [ + (Coins.BTC, Coins.XMR, SwapTypes.SELLER_FIRST), + (Coins.XMR, Coins.PART_ANON, SwapTypes.XMR_SWAP), + (Coins.FIRO, Coins.PART_ANON, SwapTypes.XMR_SWAP), + (Coins.PART_ANON, Coins.FIRO, SwapTypes.XMR_SWAP), + (Coins.FIRO, Coins.BTC, SwapTypes.SELLER_FIRST), + (Coins.BTC, Coins.FIRO, SwapTypes.SELLER_FIRST), + ] + + for case in should_pass: + sc.validateSwapType(case[0], case[1], case[2]) + for case in should_fail: + self.assertRaises( + ValueError, sc.validateSwapType, case[0], case[1], case[2] + ) + sc.chain = "mainnet" + for case in should_pass: + try: + sc.validateSwapType(case[0], case[1], case[2]) + except Exception as e: + assert "Coin pair should use adaptor sig swap type" in str(e) + else: + if case[2] != SwapTypes.XMR_SWAP: + if ( + case[0] not in sc.coins_without_segwit + or case[1] not in sc.coins_without_segwit + ): + raise ValueError(f"Invalid swap pair in strict mode {case}") + for case in should_fail: + self.assertRaises( + ValueError, sc.validateSwapType, case[0], case[1], case[2] + ) + + sc.settings["strict_swap_type"] = False + for case in should_pass: + sc.validateSwapType(case[0], case[1], case[2]) + + del sc + if __name__ == "__main__": unittest.main() diff --git a/tests/basicswap/test_run.py b/tests/basicswap/test_run.py index a6266f4..f67ae06 100644 --- a/tests/basicswap/test_run.py +++ b/tests/basicswap/test_run.py @@ -159,35 +159,6 @@ class Test(BaseTest): rv = read_json_api(1800, "rateslist?from=PART&to=BTC") assert len(rv) == 1 - def test_004_validateSwapType(self): - logging.info("---------- Test validateSwapType") - - sc = self.swap_clients[0] - - should_pass = [ - (Coins.BTC, Coins.XMR, SwapTypes.XMR_SWAP), - (Coins.XMR, Coins.BTC, SwapTypes.XMR_SWAP), - (Coins.BTC, Coins.FIRO, SwapTypes.XMR_SWAP), - (Coins.FIRO, Coins.BTC, SwapTypes.XMR_SWAP), - (Coins.PIVX, Coins.BTC, SwapTypes.SELLER_FIRST), - (Coins.BTC, Coins.PIVX, SwapTypes.SELLER_FIRST), - ] - should_fail = [ - (Coins.BTC, Coins.XMR, SwapTypes.SELLER_FIRST), - (Coins.XMR, Coins.PART_ANON, SwapTypes.XMR_SWAP), - (Coins.FIRO, Coins.PART_ANON, SwapTypes.XMR_SWAP), - (Coins.PART_ANON, Coins.FIRO, SwapTypes.XMR_SWAP), - (Coins.FIRO, Coins.BTC, SwapTypes.SELLER_FIRST), - (Coins.BTC, Coins.FIRO, SwapTypes.SELLER_FIRST), - ] - - for case in should_pass: - sc.validateSwapType(case[0], case[1], case[2]) - for case in should_fail: - self.assertRaises( - ValueError, sc.validateSwapType, case[0], case[1], case[2] - ) - def test_003_cltv(self): test_coin_from = Coins.PART logging.info("---------- Test {} cltv".format(test_coin_from.name)) diff --git a/tests/basicswap/test_xmr.py b/tests/basicswap/test_xmr.py index 271ceb8..15f8c41 100644 --- a/tests/basicswap/test_xmr.py +++ b/tests/basicswap/test_xmr.py @@ -38,9 +38,7 @@ from basicswap.basicswap_util import ( EventLogTypes, ) from basicswap.util import COIN, format_amount, make_int, TemporaryError -from basicswap.util.address import ( - toWIF, -) +from basicswap.util.address import toWIF from basicswap.rpc import ( callrpc, )