Merge pull request #423 from tecnovert/witness_stack_est

Estimate witness stack size for multiple inputs.
This commit is contained in:
tecnovert
2026-01-31 12:07:47 +00:00
committed by GitHub
4 changed files with 50 additions and 15 deletions

View File

@@ -9,6 +9,7 @@
import threading import threading
from enum import IntEnum from enum import IntEnum
from typing import List
from basicswap.chainparams import ( from basicswap.chainparams import (
chainparams, chainparams,
@@ -180,13 +181,16 @@ class CoinInterface:
class AdaptorSigInterface: class AdaptorSigInterface:
def getScriptLockTxDummyWitness(self, script: bytes): def getP2WPKHDummyWitness(self) -> List[bytes]:
return [bytes(72), bytes(33)]
def getScriptLockTxDummyWitness(self, script: bytes) -> List[bytes]:
return [b"", bytes(72), bytes(72), bytes(len(script))] return [b"", bytes(72), bytes(72), bytes(len(script))]
def getScriptLockRefundSpendTxDummyWitness(self, script: bytes): def getScriptLockRefundSpendTxDummyWitness(self, script: bytes) -> List[bytes]:
return [b"", bytes(72), bytes(72), bytes((1,)), bytes(len(script))] return [b"", bytes(72), bytes(72), bytes((1,)), bytes(len(script))]
def getScriptLockRefundSwipeTxDummyWitness(self, script: bytes): def getScriptLockRefundSwipeTxDummyWitness(self, script: bytes) -> List[bytes]:
return [bytes(72), b"", bytes(len(script))] return [bytes(72), b"", bytes(len(script))]

View File

@@ -2,7 +2,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright (c) 2020-2024 tecnovert # Copyright (c) 2020-2024 tecnovert
# Copyright (c) 2024-2025 The Basicswap developers # Copyright (c) 2024-2026 The Basicswap developers
# Distributed under the MIT software license, see the accompanying # Distributed under the MIT software license, see the accompanying
# file LICENSE or http://www.opensource.org/licenses/mit-license.php. # file LICENSE or http://www.opensource.org/licenses/mit-license.php.
@@ -828,7 +828,7 @@ class BTCInterface(Secp256k1Interface):
def getScriptDummyWitness(self, script: bytes) -> List[bytes]: def getScriptDummyWitness(self, script: bytes) -> List[bytes]:
if self.isScriptP2WPKH(script): if self.isScriptP2WPKH(script):
return [bytes(72), bytes(33)] return self.getP2WPKHDummyWitness()
raise ValueError("Unknown script type") raise ValueError("Unknown script type")
def createSCLockRefundTx( def createSCLockRefundTx(
@@ -1943,9 +1943,16 @@ class BTCInterface(Secp256k1Interface):
raise ValueError("Unimplemented") raise ValueError("Unimplemented")
def getWitnessStackSerialisedLength(self, witness_stack): def getWitnessStackSerialisedLength(self, witness_stack):
length = getCompactSizeLen(len(witness_stack)) length: int = 0
for e in witness_stack: if len(witness_stack) > 0 and isinstance(witness_stack[0], list):
length += getWitnessElementLen(len(e)) for input_stack in witness_stack:
length += getCompactSizeLen(len(input_stack))
for e in input_stack:
length += getWitnessElementLen(len(e))
else:
length += getCompactSizeLen(len(witness_stack))
for e in witness_stack:
length += getWitnessElementLen(len(e))
# See core SerializeTransaction # See core SerializeTransaction
length += 1 # vinDummy length += 1 # vinDummy

View File

@@ -2,7 +2,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright (c) 2020-2024 tecnovert # Copyright (c) 2020-2024 tecnovert
# Copyright (c) 2024-2025 The Basicswap developers # Copyright (c) 2024-2026 The Basicswap developers
# Distributed under the MIT software license, see the accompanying # Distributed under the MIT software license, see the accompanying
# file LICENSE or http://www.opensource.org/licenses/mit-license.php. # file LICENSE or http://www.opensource.org/licenses/mit-license.php.
@@ -137,7 +137,7 @@ class PARTInterface(BTCInterface):
def getScriptDummyWitness(self, script: bytes) -> List[bytes]: def getScriptDummyWitness(self, script: bytes) -> List[bytes]:
if self.isScriptP2WPKH(script) or self.isScriptP2PKH(script): if self.isScriptP2WPKH(script) or self.isScriptP2PKH(script):
return [bytes(72), bytes(33)] return self.getP2WPKHDummyWitness()
raise ValueError("Unknown script type") raise ValueError("Unknown script type")
def formatStealthAddress(self, scan_pubkey, spend_pubkey) -> str: def formatStealthAddress(self, scan_pubkey, spend_pubkey) -> str:
@@ -146,9 +146,16 @@ class PARTInterface(BTCInterface):
return encodeStealthAddress(prefix_byte, scan_pubkey, spend_pubkey) return encodeStealthAddress(prefix_byte, scan_pubkey, spend_pubkey)
def getWitnessStackSerialisedLength(self, witness_stack) -> int: def getWitnessStackSerialisedLength(self, witness_stack) -> int:
length: int = getCompactSizeLen(len(witness_stack)) length: int = 0
for e in witness_stack: if len(witness_stack) > 0 and isinstance(witness_stack[0], list):
length += getWitnessElementLen(len(e)) for input_stack in witness_stack:
length += getCompactSizeLen(len(input_stack))
for e in input_stack:
length += getWitnessElementLen(len(e))
else:
length += getCompactSizeLen(len(witness_stack))
for e in witness_stack:
length += getWitnessElementLen(len(e))
return length return length
def getWalletRestoreHeight(self) -> int: def getWalletRestoreHeight(self) -> int:

View File

@@ -1509,6 +1509,23 @@ class BasicSwapTest(TestFunctions):
vsize = tx_decoded["vsize"] vsize = tx_decoded["vsize"]
expect_fee_int = round(self.test_fee_rate * vsize / 1000) expect_fee_int = round(self.test_fee_rate * vsize / 1000)
tx_obj = ci.loadTx(lock_tx)
vsize_from_ci = ci.getTxVSize(tx_obj)
assert vsize == vsize_from_ci
tx_no_witness = tx_obj.serialize_without_witness()
dummy_witness_stack = []
for txi in tx_obj.vin:
dummy_witness_stack.append(ci.getP2WPKHDummyWitness())
witness_bytes_len_est: int = ci.getWitnessStackSerialisedLength(
dummy_witness_stack
)
tx_obj_no_witness = ci.loadTx(tx_no_witness)
vsize_estimated = ci.getTxVSize(
tx_obj_no_witness, add_witness_bytes=witness_bytes_len_est
)
assert vsize <= vsize_estimated and vsize_estimated - vsize < 4
out_value: int = 0 out_value: int = 0
for txo in tx_decoded["vout"]: for txo in tx_decoded["vout"]:
if "value" in txo: if "value" in txo:
@@ -1556,7 +1573,7 @@ class BasicSwapTest(TestFunctions):
expect_vsize: int = ci.xmr_swap_a_lock_spend_tx_vsize() expect_vsize: int = ci.xmr_swap_a_lock_spend_tx_vsize()
assert expect_vsize >= vsize_actual assert expect_vsize >= vsize_actual
assert expect_vsize - vsize_actual < 10 assert expect_vsize - vsize_actual <= 10
# Test chain b (no-script) lock tx size # Test chain b (no-script) lock tx size
v = ci.getNewRandomKey() v = ci.getNewRandomKey()
@@ -1577,7 +1594,7 @@ class BasicSwapTest(TestFunctions):
expect_vsize: int = ci.xmr_swap_b_lock_spend_tx_vsize() expect_vsize: int = ci.xmr_swap_b_lock_spend_tx_vsize()
assert expect_vsize >= lock_tx_b_spend_decoded["vsize"] assert expect_vsize >= lock_tx_b_spend_decoded["vsize"]
assert expect_vsize - lock_tx_b_spend_decoded["vsize"] < 10 assert expect_vsize - lock_tx_b_spend_decoded["vsize"] <= 10
def test_011_p2sh(self): def test_011_p2sh(self):
# Not used in bsx for native-segwit coins # Not used in bsx for native-segwit coins