diff --git a/basicswap/interface/base.py b/basicswap/interface/base.py index dc708ef..9a8e0f2 100644 --- a/basicswap/interface/base.py +++ b/basicswap/interface/base.py @@ -9,6 +9,7 @@ import threading from enum import IntEnum +from typing import List from basicswap.chainparams import ( chainparams, @@ -180,13 +181,16 @@ class CoinInterface: class AdaptorSigInterface: - def getScriptLockTxDummyWitness(self, script: bytes): + def getP2WPKHDummyWitness(self) -> List[bytes]: + return [bytes(72), bytes(33)] + + def getScriptLockTxDummyWitness(self, script: bytes) -> List[bytes]: return [b"", bytes(72), bytes(72), bytes(len(script))] - def getScriptLockRefundSpendTxDummyWitness(self, script: bytes): + def getScriptLockRefundSpendTxDummyWitness(self, script: bytes) -> List[bytes]: return [b"", bytes(72), bytes(72), bytes((1,)), bytes(len(script))] - def getScriptLockRefundSwipeTxDummyWitness(self, script: bytes): + def getScriptLockRefundSwipeTxDummyWitness(self, script: bytes) -> List[bytes]: return [bytes(72), b"", bytes(len(script))] diff --git a/basicswap/interface/btc.py b/basicswap/interface/btc.py index 029613c..8fba86c 100644 --- a/basicswap/interface/btc.py +++ b/basicswap/interface/btc.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- # Copyright (c) 2020-2024 tecnovert -# Copyright (c) 2024-2025 The Basicswap developers +# Copyright (c) 2024-2026 The Basicswap developers # Distributed under the MIT software license, see the accompanying # file LICENSE or http://www.opensource.org/licenses/mit-license.php. @@ -828,7 +828,7 @@ class BTCInterface(Secp256k1Interface): def getScriptDummyWitness(self, script: bytes) -> List[bytes]: if self.isScriptP2WPKH(script): - return [bytes(72), bytes(33)] + return self.getP2WPKHDummyWitness() raise ValueError("Unknown script type") def createSCLockRefundTx( @@ -1943,9 +1943,16 @@ class BTCInterface(Secp256k1Interface): raise ValueError("Unimplemented") def getWitnessStackSerialisedLength(self, witness_stack): - length = getCompactSizeLen(len(witness_stack)) - for e in witness_stack: - length += getWitnessElementLen(len(e)) + length: int = 0 + if len(witness_stack) > 0 and isinstance(witness_stack[0], list): + for input_stack in witness_stack: + length += getCompactSizeLen(len(input_stack)) + for e in input_stack: + length += getWitnessElementLen(len(e)) + else: + length += getCompactSizeLen(len(witness_stack)) + for e in witness_stack: + length += getWitnessElementLen(len(e)) # See core SerializeTransaction length += 1 # vinDummy diff --git a/basicswap/interface/part.py b/basicswap/interface/part.py index 646f895..a40a99b 100644 --- a/basicswap/interface/part.py +++ b/basicswap/interface/part.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- # Copyright (c) 2020-2024 tecnovert -# Copyright (c) 2024-2025 The Basicswap developers +# Copyright (c) 2024-2026 The Basicswap developers # Distributed under the MIT software license, see the accompanying # file LICENSE or http://www.opensource.org/licenses/mit-license.php. @@ -137,7 +137,7 @@ class PARTInterface(BTCInterface): def getScriptDummyWitness(self, script: bytes) -> List[bytes]: if self.isScriptP2WPKH(script) or self.isScriptP2PKH(script): - return [bytes(72), bytes(33)] + return self.getP2WPKHDummyWitness() raise ValueError("Unknown script type") def formatStealthAddress(self, scan_pubkey, spend_pubkey) -> str: @@ -146,9 +146,16 @@ class PARTInterface(BTCInterface): return encodeStealthAddress(prefix_byte, scan_pubkey, spend_pubkey) def getWitnessStackSerialisedLength(self, witness_stack) -> int: - length: int = getCompactSizeLen(len(witness_stack)) - for e in witness_stack: - length += getWitnessElementLen(len(e)) + length: int = 0 + if len(witness_stack) > 0 and isinstance(witness_stack[0], list): + for input_stack in witness_stack: + length += getCompactSizeLen(len(input_stack)) + for e in input_stack: + length += getWitnessElementLen(len(e)) + else: + length += getCompactSizeLen(len(witness_stack)) + for e in witness_stack: + length += getWitnessElementLen(len(e)) return length def getWalletRestoreHeight(self) -> int: diff --git a/tests/basicswap/test_btc_xmr.py b/tests/basicswap/test_btc_xmr.py index 53a1e8e..18329da 100644 --- a/tests/basicswap/test_btc_xmr.py +++ b/tests/basicswap/test_btc_xmr.py @@ -1509,6 +1509,23 @@ class BasicSwapTest(TestFunctions): vsize = tx_decoded["vsize"] expect_fee_int = round(self.test_fee_rate * vsize / 1000) + tx_obj = ci.loadTx(lock_tx) + vsize_from_ci = ci.getTxVSize(tx_obj) + assert vsize == vsize_from_ci + tx_no_witness = tx_obj.serialize_without_witness() + + dummy_witness_stack = [] + for txi in tx_obj.vin: + dummy_witness_stack.append(ci.getP2WPKHDummyWitness()) + witness_bytes_len_est: int = ci.getWitnessStackSerialisedLength( + dummy_witness_stack + ) + tx_obj_no_witness = ci.loadTx(tx_no_witness) + vsize_estimated = ci.getTxVSize( + tx_obj_no_witness, add_witness_bytes=witness_bytes_len_est + ) + assert vsize <= vsize_estimated and vsize_estimated - vsize < 4 + out_value: int = 0 for txo in tx_decoded["vout"]: if "value" in txo: @@ -1556,7 +1573,7 @@ class BasicSwapTest(TestFunctions): expect_vsize: int = ci.xmr_swap_a_lock_spend_tx_vsize() assert expect_vsize >= vsize_actual - assert expect_vsize - vsize_actual < 10 + assert expect_vsize - vsize_actual <= 10 # Test chain b (no-script) lock tx size v = ci.getNewRandomKey() @@ -1577,7 +1594,7 @@ class BasicSwapTest(TestFunctions): expect_vsize: int = ci.xmr_swap_b_lock_spend_tx_vsize() assert expect_vsize >= lock_tx_b_spend_decoded["vsize"] - assert expect_vsize - lock_tx_b_spend_decoded["vsize"] < 10 + assert expect_vsize - lock_tx_b_spend_decoded["vsize"] <= 10 def test_011_p2sh(self): # Not used in bsx for native-segwit coins