feat: reject secret-hash offers where coin pair can use adaptor sig

This commit is contained in:
tecnovert
2026-06-09 10:49:20 +02:00
parent d23665d585
commit 3aacc57f09
5 changed files with 91 additions and 33 deletions
+78 -1
View File
@@ -8,6 +8,7 @@
import hashlib
import logging
import os
import random
import secrets
import threading
@@ -22,10 +23,15 @@ from coincurve.ecdsaotves import (
)
from coincurve.keys import PrivateKey
from basicswap.basicswap import (
Coins,
BasicSwap,
SwapTypes,
)
from basicswap.contrib.mnemonic import Mnemonic
from basicswap.db import create_db_, DBMethods, KnownIdentity
from basicswap.util import h2b
from basicswap.util.address import decodeAddress
from basicswap.util.address import decodeAddress, toWIF
from basicswap.util.crypto import ripemd160, hash160, blake256
from basicswap.util.extkey import ExtKeyPair
from basicswap.util.integer import encode_varint, decode_varint
@@ -60,6 +66,9 @@ from basicswap.contrib.test_framework.messages import (
CTxOut,
uint256_from_str,
)
from tests.basicswap.common import (
PREFIX_SECRET_KEY_REGTEST,
)
logger = logging.getLogger()
@@ -790,6 +799,74 @@ class Test(unittest.TestCase):
== "252cd6e85b99e0fd554c44d5fe638923f7ef563048362406a665cf3400feb1bd"
)
def test_validateSwapType(self):
logging.info("---------- Test validateSwapType")
basicswap_dir = "/tmp/bsx_test_other"
if not os.path.exists(basicswap_dir):
os.makedirs(basicswap_dir)
k = PrivateKey()
settings = {
"network_key": toWIF(PREFIX_SECRET_KEY_REGTEST, k.secret),
"network_pubkey": k.public_key.format().hex(),
}
sc = BasicSwap(
basicswap_dir,
settings,
"regtest",
log_name="bsx_test_other",
)
should_pass = [
(Coins.BTC, Coins.XMR, SwapTypes.XMR_SWAP),
(Coins.XMR, Coins.BTC, SwapTypes.XMR_SWAP),
(Coins.BTC, Coins.FIRO, SwapTypes.XMR_SWAP),
(Coins.FIRO, Coins.BTC, SwapTypes.XMR_SWAP),
(Coins.PIVX, Coins.BTC, SwapTypes.SELLER_FIRST),
(Coins.BTC, Coins.PIVX, SwapTypes.SELLER_FIRST),
(Coins.DASH, Coins.PIVX, SwapTypes.SELLER_FIRST),
(Coins.PIVX, Coins.DASH, SwapTypes.SELLER_FIRST),
]
should_fail = [
(Coins.BTC, Coins.XMR, SwapTypes.SELLER_FIRST),
(Coins.XMR, Coins.PART_ANON, SwapTypes.XMR_SWAP),
(Coins.FIRO, Coins.PART_ANON, SwapTypes.XMR_SWAP),
(Coins.PART_ANON, Coins.FIRO, SwapTypes.XMR_SWAP),
(Coins.FIRO, Coins.BTC, SwapTypes.SELLER_FIRST),
(Coins.BTC, Coins.FIRO, SwapTypes.SELLER_FIRST),
]
for case in should_pass:
sc.validateSwapType(case[0], case[1], case[2])
for case in should_fail:
self.assertRaises(
ValueError, sc.validateSwapType, case[0], case[1], case[2]
)
sc.chain = "mainnet"
for case in should_pass:
try:
sc.validateSwapType(case[0], case[1], case[2])
except Exception as e:
assert "Coin pair should use adaptor sig swap type" in str(e)
else:
if case[2] != SwapTypes.XMR_SWAP:
if (
case[0] not in sc.coins_without_segwit
or case[1] not in sc.coins_without_segwit
):
raise ValueError(f"Invalid swap pair in strict mode {case}")
for case in should_fail:
self.assertRaises(
ValueError, sc.validateSwapType, case[0], case[1], case[2]
)
sc.settings["strict_swap_type"] = False
for case in should_pass:
sc.validateSwapType(case[0], case[1], case[2])
del sc
if __name__ == "__main__":
unittest.main()
-29
View File
@@ -159,35 +159,6 @@ class Test(BaseTest):
rv = read_json_api(1800, "rateslist?from=PART&to=BTC")
assert len(rv) == 1
def test_004_validateSwapType(self):
logging.info("---------- Test validateSwapType")
sc = self.swap_clients[0]
should_pass = [
(Coins.BTC, Coins.XMR, SwapTypes.XMR_SWAP),
(Coins.XMR, Coins.BTC, SwapTypes.XMR_SWAP),
(Coins.BTC, Coins.FIRO, SwapTypes.XMR_SWAP),
(Coins.FIRO, Coins.BTC, SwapTypes.XMR_SWAP),
(Coins.PIVX, Coins.BTC, SwapTypes.SELLER_FIRST),
(Coins.BTC, Coins.PIVX, SwapTypes.SELLER_FIRST),
]
should_fail = [
(Coins.BTC, Coins.XMR, SwapTypes.SELLER_FIRST),
(Coins.XMR, Coins.PART_ANON, SwapTypes.XMR_SWAP),
(Coins.FIRO, Coins.PART_ANON, SwapTypes.XMR_SWAP),
(Coins.PART_ANON, Coins.FIRO, SwapTypes.XMR_SWAP),
(Coins.FIRO, Coins.BTC, SwapTypes.SELLER_FIRST),
(Coins.BTC, Coins.FIRO, SwapTypes.SELLER_FIRST),
]
for case in should_pass:
sc.validateSwapType(case[0], case[1], case[2])
for case in should_fail:
self.assertRaises(
ValueError, sc.validateSwapType, case[0], case[1], case[2]
)
def test_003_cltv(self):
test_coin_from = Coins.PART
logging.info("---------- Test {} cltv".format(test_coin_from.name))
+1 -3
View File
@@ -38,9 +38,7 @@ from basicswap.basicswap_util import (
EventLogTypes,
)
from basicswap.util import COIN, format_amount, make_int, TemporaryError
from basicswap.util.address import (
toWIF,
)
from basicswap.util.address import toWIF
from basicswap.rpc import (
callrpc,
)