mirror of
https://github.com/basicswap/basicswap.git
synced 2025-12-29 00:41:39 +01:00
@@ -72,6 +72,7 @@ from .db_util import remove_expired_data
|
|||||||
from .http_server import HttpThread
|
from .http_server import HttpThread
|
||||||
from .rpc import escape_rpcauth
|
from .rpc import escape_rpcauth
|
||||||
from .rpc_xmr import make_xmr_rpc2_func
|
from .rpc_xmr import make_xmr_rpc2_func
|
||||||
|
from .types import WatchedTransaction, WatchedScript, WatchedOutput
|
||||||
from .ui.app import UIApp
|
from .ui.app import UIApp
|
||||||
from .ui.util import getCoinName
|
from .ui.util import getCoinName
|
||||||
from .util import (
|
from .util import (
|
||||||
@@ -287,37 +288,6 @@ def threadPollChainState(swap_client, coin_type):
|
|||||||
swap_client.chainstate_delay_event.wait(random.randrange(*poll_delay_range))
|
swap_client.chainstate_delay_event.wait(random.randrange(*poll_delay_range))
|
||||||
|
|
||||||
|
|
||||||
class WatchedOutput: # Watch for spends
|
|
||||||
__slots__ = ("bid_id", "txid_hex", "vout", "tx_type", "swap_type")
|
|
||||||
|
|
||||||
def __init__(self, bid_id: bytes, txid_hex: str, vout, tx_type, swap_type):
|
|
||||||
self.bid_id = bid_id
|
|
||||||
self.txid_hex = txid_hex
|
|
||||||
self.vout = vout
|
|
||||||
self.tx_type = tx_type
|
|
||||||
self.swap_type = swap_type
|
|
||||||
|
|
||||||
|
|
||||||
class WatchedScript: # Watch for txns containing outputs
|
|
||||||
__slots__ = ("bid_id", "script", "tx_type", "swap_type")
|
|
||||||
|
|
||||||
def __init__(self, bid_id: bytes, script: bytes, tx_type, swap_type):
|
|
||||||
self.bid_id = bid_id
|
|
||||||
self.script = script
|
|
||||||
self.tx_type = tx_type
|
|
||||||
self.swap_type = swap_type
|
|
||||||
|
|
||||||
|
|
||||||
class WatchedTransaction:
|
|
||||||
# TODO
|
|
||||||
# Watch for presence in mempool (getrawtransaction)
|
|
||||||
def __init__(self, bid_id: bytes, txid_hex: str, tx_type, swap_type):
|
|
||||||
self.bid_id = bid_id
|
|
||||||
self.txid_hex = txid_hex
|
|
||||||
self.tx_type = tx_type
|
|
||||||
self.swap_type = swap_type
|
|
||||||
|
|
||||||
|
|
||||||
class BasicSwap(BaseApp, BSXNetwork, UIApp):
|
class BasicSwap(BaseApp, BSXNetwork, UIApp):
|
||||||
ws_server = None
|
ws_server = None
|
||||||
protocolInterfaces = {
|
protocolInterfaces = {
|
||||||
@@ -530,7 +500,6 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp):
|
|||||||
self.db_version = self.getIntKV("db_version", cursor, CURRENT_DB_VERSION)
|
self.db_version = self.getIntKV("db_version", cursor, CURRENT_DB_VERSION)
|
||||||
self.db_data_version = self.getIntKV("db_data_version", cursor, 0)
|
self.db_data_version = self.getIntKV("db_data_version", cursor, 0)
|
||||||
self._contract_count = self.getIntKV("contract_count", cursor, 0)
|
self._contract_count = self.getIntKV("contract_count", cursor, 0)
|
||||||
self.commitDB()
|
|
||||||
finally:
|
finally:
|
||||||
self.closeDB(cursor)
|
self.closeDB(cursor)
|
||||||
|
|
||||||
@@ -613,6 +582,13 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp):
|
|||||||
def finalise(self):
|
def finalise(self):
|
||||||
self.log.info("Finalising")
|
self.log.info("Finalising")
|
||||||
|
|
||||||
|
if self.ws_server:
|
||||||
|
try:
|
||||||
|
self.log.info("Stopping websocket server.")
|
||||||
|
self.ws_server.shutdown_gracefully()
|
||||||
|
except Exception as e: # noqa: F841
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
self._price_fetch_running = False
|
self._price_fetch_running = False
|
||||||
if self._price_fetch_thread and self._price_fetch_thread.is_alive():
|
if self._price_fetch_thread and self._price_fetch_thread.is_alive():
|
||||||
self._price_fetch_thread.join(timeout=5)
|
self._price_fetch_thread.join(timeout=5)
|
||||||
@@ -745,6 +721,7 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp):
|
|||||||
"rpcauth": rpcauth,
|
"rpcauth": rpcauth,
|
||||||
"blocks_confirmed": chain_client_settings.get("blocks_confirmed", 6),
|
"blocks_confirmed": chain_client_settings.get("blocks_confirmed", 6),
|
||||||
"conf_target": chain_client_settings.get("conf_target", 2),
|
"conf_target": chain_client_settings.get("conf_target", 2),
|
||||||
|
"watched_transactions": [],
|
||||||
"watched_outputs": [],
|
"watched_outputs": [],
|
||||||
"watched_scripts": [],
|
"watched_scripts": [],
|
||||||
"last_height_checked": last_height_checked,
|
"last_height_checked": last_height_checked,
|
||||||
@@ -4453,17 +4430,13 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp):
|
|||||||
xmr_swap.kbsf_dleag = ci_to.proveDLEAG(kbsf)
|
xmr_swap.kbsf_dleag = ci_to.proveDLEAG(kbsf)
|
||||||
xmr_swap.pkasf = xmr_swap.kbsf_dleag[0:33]
|
xmr_swap.pkasf = xmr_swap.kbsf_dleag[0:33]
|
||||||
elif ci_to.curve_type() == Curves.secp256k1:
|
elif ci_to.curve_type() == Curves.secp256k1:
|
||||||
for i in range(10):
|
xmr_swap.kbsf_dleag = ci_to.signRecoverable(
|
||||||
xmr_swap.kbsf_dleag = ci_to.signRecoverable(
|
kbsf, "proof kbsf owned for swap"
|
||||||
kbsf, "proof kbsf owned for swap"
|
)
|
||||||
)
|
pk_recovered = ci_to.verifySigAndRecover(
|
||||||
pk_recovered = ci_to.verifySigAndRecover(
|
xmr_swap.kbsf_dleag, "proof kbsf owned for swap"
|
||||||
xmr_swap.kbsf_dleag, "proof kbsf owned for swap"
|
)
|
||||||
)
|
ensure(pk_recovered == xmr_swap.pkbsf, "kbsf recovered pubkey mismatch")
|
||||||
if pk_recovered == xmr_swap.pkbsf:
|
|
||||||
break
|
|
||||||
self.log.debug("kbsl recovered pubkey mismatch, retrying.")
|
|
||||||
assert pk_recovered == xmr_swap.pkbsf
|
|
||||||
xmr_swap.pkasf = xmr_swap.pkbsf
|
xmr_swap.pkasf = xmr_swap.pkbsf
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown curve")
|
raise ValueError("Unknown curve")
|
||||||
@@ -4766,17 +4739,13 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp):
|
|||||||
xmr_swap.kbsl_dleag = ci_to.proveDLEAG(kbsl)
|
xmr_swap.kbsl_dleag = ci_to.proveDLEAG(kbsl)
|
||||||
msg_buf.kbsl_dleag = xmr_swap.kbsl_dleag[:dleag_split_size_init]
|
msg_buf.kbsl_dleag = xmr_swap.kbsl_dleag[:dleag_split_size_init]
|
||||||
elif ci_to.curve_type() == Curves.secp256k1:
|
elif ci_to.curve_type() == Curves.secp256k1:
|
||||||
for i in range(10):
|
xmr_swap.kbsl_dleag = ci_to.signRecoverable(
|
||||||
xmr_swap.kbsl_dleag = ci_to.signRecoverable(
|
kbsl, "proof kbsl owned for swap"
|
||||||
kbsl, "proof kbsl owned for swap"
|
)
|
||||||
)
|
pk_recovered = ci_to.verifySigAndRecover(
|
||||||
pk_recovered = ci_to.verifySigAndRecover(
|
xmr_swap.kbsl_dleag, "proof kbsl owned for swap"
|
||||||
xmr_swap.kbsl_dleag, "proof kbsl owned for swap"
|
)
|
||||||
)
|
ensure(pk_recovered == xmr_swap.pkbsl, "kbsl recovered pubkey mismatch")
|
||||||
if pk_recovered == xmr_swap.pkbsl:
|
|
||||||
break
|
|
||||||
self.log.debug("kbsl recovered pubkey mismatch, retrying.")
|
|
||||||
assert pk_recovered == xmr_swap.pkbsl
|
|
||||||
msg_buf.kbsl_dleag = xmr_swap.kbsl_dleag
|
msg_buf.kbsl_dleag = xmr_swap.kbsl_dleag
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown curve")
|
raise ValueError("Unknown curve")
|
||||||
@@ -5586,7 +5555,11 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp):
|
|||||||
block_header = ci.getBlockHeaderFromHeight(tx_height)
|
block_header = ci.getBlockHeaderFromHeight(tx_height)
|
||||||
block_time = block_header["time"]
|
block_time = block_header["time"]
|
||||||
cc = self.coin_clients[coin_type]
|
cc = self.coin_clients[coin_type]
|
||||||
if len(cc["watched_outputs"]) == 0 and len(cc["watched_scripts"]) == 0:
|
if (
|
||||||
|
len(cc["watched_outputs"]) == 0
|
||||||
|
and len(cc["watched_transactions"]) == 0
|
||||||
|
and len(cc["watched_scripts"]) == 0
|
||||||
|
):
|
||||||
cc["last_height_checked"] = tx_height
|
cc["last_height_checked"] = tx_height
|
||||||
cc["block_check_min_time"] = block_time
|
cc["block_check_min_time"] = block_time
|
||||||
self.setIntKV("block_check_min_time_" + coin_name, block_time, cursor)
|
self.setIntKV("block_check_min_time_" + coin_name, block_time, cursor)
|
||||||
@@ -6732,6 +6705,39 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp):
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def addWatchedTransaction(
|
||||||
|
self, coin_type, bid_id, txid_hex, tx_type, swap_type=None
|
||||||
|
):
|
||||||
|
self.log.debug(
|
||||||
|
f"Adding watched transaction {Coins(coin_type).name} bid {self.log.id(bid_id)} tx {self.log.id(txid_hex)} type {tx_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
watched = self.coin_clients[coin_type]["watched_transactions"]
|
||||||
|
|
||||||
|
for wo in watched:
|
||||||
|
if wo.bid_id == bid_id and wo.txid_hex == txid_hex:
|
||||||
|
self.log.debug("Transaction already being watched.")
|
||||||
|
return
|
||||||
|
|
||||||
|
watched.append(
|
||||||
|
WatchedTransaction(bid_id, coin_type, txid_hex, tx_type, swap_type)
|
||||||
|
)
|
||||||
|
|
||||||
|
def removeWatchedTransaction(self, coin_type, bid_id: bytes, txid_hex: str) -> None:
|
||||||
|
# Remove all for bid if txid is None
|
||||||
|
self.log.debug(
|
||||||
|
f"Removing watched transaction {Coins(coin_type).name} {self.log.id(bid_id)} {self.log.id(txid_hex)}"
|
||||||
|
)
|
||||||
|
watched = self.coin_clients[coin_type]["watched_transactions"]
|
||||||
|
old_len = len(watched)
|
||||||
|
for i in range(old_len - 1, -1, -1):
|
||||||
|
wo = watched[i]
|
||||||
|
if wo.bid_id == bid_id and (txid_hex is None or wo.txid_hex == txid_hex):
|
||||||
|
del watched[i]
|
||||||
|
self.log.debug(
|
||||||
|
f"Removed watched transaction {Coins(coin_type).name} {self.log.id(bid_id)} {self.log.id(wo.txid_hex)}"
|
||||||
|
)
|
||||||
|
|
||||||
def addWatchedOutput(
|
def addWatchedOutput(
|
||||||
self, coin_type, bid_id, txid_hex, vout, tx_type, swap_type=None
|
self, coin_type, bid_id, txid_hex, vout, tx_type, swap_type=None
|
||||||
):
|
):
|
||||||
@@ -6740,7 +6746,6 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp):
|
|||||||
)
|
)
|
||||||
|
|
||||||
watched = self.coin_clients[coin_type]["watched_outputs"]
|
watched = self.coin_clients[coin_type]["watched_outputs"]
|
||||||
|
|
||||||
for wo in watched:
|
for wo in watched:
|
||||||
if wo.bid_id == bid_id and wo.txid_hex == txid_hex and wo.vout == vout:
|
if wo.bid_id == bid_id and wo.txid_hex == txid_hex and wo.vout == vout:
|
||||||
self.log.debug("Output already being watched.")
|
self.log.debug("Output already being watched.")
|
||||||
@@ -6751,13 +6756,14 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp):
|
|||||||
def removeWatchedOutput(self, coin_type, bid_id: bytes, txid_hex: str) -> None:
|
def removeWatchedOutput(self, coin_type, bid_id: bytes, txid_hex: str) -> None:
|
||||||
# Remove all for bid if txid is None
|
# Remove all for bid if txid is None
|
||||||
self.log.debug(
|
self.log.debug(
|
||||||
f"removeWatchedOutput {Coins(coin_type).name} {self.log.id(bid_id)} {self.log.id(txid_hex)}"
|
f"Removing watched output {Coins(coin_type).name} {self.log.id(bid_id)} {self.log.id(txid_hex)}"
|
||||||
)
|
)
|
||||||
old_len = len(self.coin_clients[coin_type]["watched_outputs"])
|
watched = self.coin_clients[coin_type]["watched_outputs"]
|
||||||
|
old_len = len(watched)
|
||||||
for i in range(old_len - 1, -1, -1):
|
for i in range(old_len - 1, -1, -1):
|
||||||
wo = self.coin_clients[coin_type]["watched_outputs"][i]
|
wo = watched[i]
|
||||||
if wo.bid_id == bid_id and (txid_hex is None or wo.txid_hex == txid_hex):
|
if wo.bid_id == bid_id and (txid_hex is None or wo.txid_hex == txid_hex):
|
||||||
del self.coin_clients[coin_type]["watched_outputs"][i]
|
del watched[i]
|
||||||
self.log.debug(
|
self.log.debug(
|
||||||
f"Removed watched output {Coins(coin_type).name} {self.log.id(bid_id)} {self.log.id(wo.txid_hex)}"
|
f"Removed watched output {Coins(coin_type).name} {self.log.id(bid_id)} {self.log.id(wo.txid_hex)}"
|
||||||
)
|
)
|
||||||
@@ -6770,7 +6776,6 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp):
|
|||||||
)
|
)
|
||||||
|
|
||||||
watched = self.coin_clients[coin_type]["watched_scripts"]
|
watched = self.coin_clients[coin_type]["watched_scripts"]
|
||||||
|
|
||||||
for ws in watched:
|
for ws in watched:
|
||||||
if ws.bid_id == bid_id and ws.tx_type == tx_type and ws.script == script:
|
if ws.bid_id == bid_id and ws.tx_type == tx_type and ws.script == script:
|
||||||
self.log.debug("Script already being watched.")
|
self.log.debug("Script already being watched.")
|
||||||
@@ -6783,21 +6788,22 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp):
|
|||||||
) -> None:
|
) -> None:
|
||||||
# Remove all for bid if script and type_ind is None
|
# Remove all for bid if script and type_ind is None
|
||||||
self.log.debug(
|
self.log.debug(
|
||||||
"removeWatchedScript {} {}{}".format(
|
"Removing watched script {} {}{}".format(
|
||||||
Coins(coin_type).name,
|
Coins(coin_type).name,
|
||||||
{self.log.id(bid_id)},
|
{self.log.id(bid_id)},
|
||||||
(" type " + str(tx_type)) if tx_type is not None else "",
|
(" type " + str(tx_type)) if tx_type is not None else "",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
old_len = len(self.coin_clients[coin_type]["watched_scripts"])
|
watched = self.coin_clients[coin_type]["watched_scripts"]
|
||||||
|
old_len = len(watched)
|
||||||
for i in range(old_len - 1, -1, -1):
|
for i in range(old_len - 1, -1, -1):
|
||||||
ws = self.coin_clients[coin_type]["watched_scripts"][i]
|
ws = watched[i]
|
||||||
if (
|
if (
|
||||||
ws.bid_id == bid_id
|
ws.bid_id == bid_id
|
||||||
and (script is None or ws.script == script)
|
and (script is None or ws.script == script)
|
||||||
and (tx_type is None or ws.tx_type == tx_type)
|
and (tx_type is None or ws.tx_type == tx_type)
|
||||||
):
|
):
|
||||||
del self.coin_clients[coin_type]["watched_scripts"][i]
|
del watched[i]
|
||||||
self.log.debug(
|
self.log.debug(
|
||||||
f"Removed watched script {Coins(coin_type).name} {self.log.id(bid_id)}"
|
f"Removed watched script {Coins(coin_type).name} {self.log.id(bid_id)}"
|
||||||
)
|
)
|
||||||
@@ -7175,6 +7181,17 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp):
|
|||||||
finally:
|
finally:
|
||||||
self.closeDB(cursor)
|
self.closeDB(cursor)
|
||||||
|
|
||||||
|
def processFoundTransaction(
|
||||||
|
self,
|
||||||
|
watched_txn: WatchedTransaction,
|
||||||
|
block_hash_hex: str,
|
||||||
|
block_height: int,
|
||||||
|
chain_blocks: int,
|
||||||
|
):
|
||||||
|
self.log.warning(
|
||||||
|
f"Unknown swap_type for found transaction: {self.logIDT(bytes.fromhex(watched_txn.txid_hex))}."
|
||||||
|
)
|
||||||
|
|
||||||
def processSpentOutput(
|
def processSpentOutput(
|
||||||
self, coin_type, watched_output, spend_txid_hex, spend_n, spend_txn
|
self, coin_type, watched_output, spend_txid_hex, spend_n, spend_txn
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -7437,6 +7454,13 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
for tx in block["tx"]:
|
for tx in block["tx"]:
|
||||||
|
for t in c["watched_transactions"]:
|
||||||
|
if t.block_hash is not None:
|
||||||
|
continue
|
||||||
|
if tx["txid"] == t.txid_hex:
|
||||||
|
self.processFoundTransaction(
|
||||||
|
t, block_hash, block["height"], chain_blocks
|
||||||
|
)
|
||||||
for s in c["watched_scripts"]:
|
for s in c["watched_scripts"]:
|
||||||
for i, txo in enumerate(tx["vout"]):
|
for i, txo in enumerate(tx["vout"]):
|
||||||
if "scriptPubKey" in txo and "hex" in txo["scriptPubKey"]:
|
if "scriptPubKey" in txo and "hex" in txo["scriptPubKey"]:
|
||||||
@@ -8746,20 +8770,18 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp):
|
|||||||
|
|
||||||
# Extract pubkeys from MSG1F DLEAG
|
# Extract pubkeys from MSG1F DLEAG
|
||||||
xmr_swap.pkasl = xmr_swap.kbsl_dleag[0:33]
|
xmr_swap.pkasl = xmr_swap.kbsl_dleag[0:33]
|
||||||
if not ci_from.verifyPubkey(xmr_swap.pkasl):
|
|
||||||
raise ValueError("Invalid coin a pubkey.")
|
|
||||||
xmr_swap.pkbsl = xmr_swap.kbsl_dleag[33 : 33 + 32]
|
xmr_swap.pkbsl = xmr_swap.kbsl_dleag[33 : 33 + 32]
|
||||||
if not ci_to.verifyPubkey(xmr_swap.pkbsl):
|
|
||||||
raise ValueError("Invalid coin b pubkey.")
|
|
||||||
elif ci_to.curve_type() == Curves.secp256k1:
|
elif ci_to.curve_type() == Curves.secp256k1:
|
||||||
xmr_swap.pkasl = ci_to.verifySigAndRecover(
|
xmr_swap.pkasl = ci_to.verifySigAndRecover(
|
||||||
xmr_swap.kbsl_dleag, "proof kbsl owned for swap"
|
xmr_swap.kbsl_dleag, "proof kbsl owned for swap"
|
||||||
)
|
)
|
||||||
if not ci_from.verifyPubkey(xmr_swap.pkasl):
|
|
||||||
raise ValueError("Invalid coin a pubkey.")
|
|
||||||
xmr_swap.pkbsl = xmr_swap.pkasl
|
xmr_swap.pkbsl = xmr_swap.pkasl
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown curve")
|
raise ValueError("Unknown curve")
|
||||||
|
if not ci_from.verifyPubkey(xmr_swap.pkasl):
|
||||||
|
raise ValueError("Invalid coin a pubkey.")
|
||||||
|
if not ci_to.verifyPubkey(xmr_swap.pkbsl):
|
||||||
|
raise ValueError("Invalid coin b pubkey.")
|
||||||
|
|
||||||
# vkbv and vkbvl are verified in processXmrBidAccept
|
# vkbv and vkbvl are verified in processXmrBidAccept
|
||||||
xmr_swap.pkbv = ci_to.sumPubkeys(xmr_swap.pkbvl, xmr_swap.pkbvf)
|
xmr_swap.pkbv = ci_to.sumPubkeys(xmr_swap.pkbvl, xmr_swap.pkbvf)
|
||||||
@@ -10827,7 +10849,7 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp):
|
|||||||
>= self.check_expiring_bids_offers_seconds
|
>= self.check_expiring_bids_offers_seconds
|
||||||
):
|
):
|
||||||
check_records = True
|
check_records = True
|
||||||
self._last_checked_expiring_bids = now
|
self._last_checked_expiring_bids_offers = now
|
||||||
|
|
||||||
if (
|
if (
|
||||||
len(bids_to_expire) == 0
|
len(bids_to_expire) == 0
|
||||||
|
|||||||
@@ -608,13 +608,6 @@ def runClient(
|
|||||||
except Exception as e: # noqa: F841
|
except Exception as e: # noqa: F841
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
if swap_client.ws_server:
|
|
||||||
try:
|
|
||||||
swap_client.log.info("Stopping websocket server.")
|
|
||||||
swap_client.ws_server.shutdown_gracefully()
|
|
||||||
except Exception as e: # noqa: F841
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
swap_client.finalise()
|
swap_client.finalise()
|
||||||
|
|
||||||
closed_pids = []
|
closed_pids = []
|
||||||
|
|||||||
@@ -792,8 +792,8 @@ class NetworkPortal(Table):
|
|||||||
created_at = Column("integer")
|
created_at = Column("integer")
|
||||||
|
|
||||||
|
|
||||||
def extract_schema() -> dict:
|
def extract_schema(input_globals=None) -> dict:
|
||||||
g = globals().copy()
|
g = globals() if input_globals is None else input_globals
|
||||||
tables = {}
|
tables = {}
|
||||||
for name, obj in g.items():
|
for name, obj in g.items():
|
||||||
if not inspect.isclass(obj):
|
if not inspect.isclass(obj):
|
||||||
|
|||||||
@@ -152,7 +152,94 @@ def upgradeDatabaseData(self, data_version):
|
|||||||
self.closeDB(cursor, commit=False)
|
self.closeDB(cursor, commit=False)
|
||||||
|
|
||||||
|
|
||||||
def upgradeDatabase(self, db_version):
|
def upgradeDatabaseFromSchema(self, cursor, expect_schema):
|
||||||
|
have_tables = {}
|
||||||
|
|
||||||
|
query = "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;"
|
||||||
|
tables = cursor.execute(query).fetchall()
|
||||||
|
for table in tables:
|
||||||
|
table_name = table[0]
|
||||||
|
if table_name in ("sqlite_sequence",):
|
||||||
|
continue
|
||||||
|
|
||||||
|
have_table = {}
|
||||||
|
have_columns = {}
|
||||||
|
query = "SELECT * FROM PRAGMA_TABLE_INFO(:table_name) ORDER BY cid DESC;"
|
||||||
|
columns = cursor.execute(query, {"table_name": table_name}).fetchall()
|
||||||
|
for column in columns:
|
||||||
|
cid, name, data_type, notnull, default_value, primary_key = column
|
||||||
|
have_columns[name] = {"type": data_type, "primary_key": primary_key}
|
||||||
|
|
||||||
|
have_table["columns"] = have_columns
|
||||||
|
|
||||||
|
cursor.execute(f"PRAGMA INDEX_LIST('{table_name}');")
|
||||||
|
indices = cursor.fetchall()
|
||||||
|
for index in indices:
|
||||||
|
seq, index_name, unique, origin, partial = index
|
||||||
|
|
||||||
|
if origin == "pk": # Created by a PRIMARY KEY constraint
|
||||||
|
continue
|
||||||
|
|
||||||
|
cursor.execute(f"PRAGMA INDEX_INFO('{index_name}');")
|
||||||
|
index_info = cursor.fetchall()
|
||||||
|
|
||||||
|
add_index = {"index_name": index_name}
|
||||||
|
for index_columns in index_info:
|
||||||
|
seqno, cid, name = index_columns
|
||||||
|
if origin == "u": # Created by a UNIQUE constraint
|
||||||
|
have_columns[name]["unique"] = 1
|
||||||
|
else:
|
||||||
|
if "column_1" not in add_index:
|
||||||
|
add_index["column_1"] = name
|
||||||
|
elif "column_2" not in add_index:
|
||||||
|
add_index["column_2"] = name
|
||||||
|
elif "column_3" not in add_index:
|
||||||
|
add_index["column_3"] = name
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Add more index columns.")
|
||||||
|
if origin == "c":
|
||||||
|
if "indices" not in table:
|
||||||
|
have_table["indices"] = []
|
||||||
|
have_table["indices"].append(add_index)
|
||||||
|
|
||||||
|
have_tables[table_name] = have_table
|
||||||
|
|
||||||
|
for table_name, table in expect_schema.items():
|
||||||
|
if table_name not in have_tables:
|
||||||
|
self.log.info(f"Creating table {table_name}.")
|
||||||
|
create_table(cursor, table_name, table)
|
||||||
|
continue
|
||||||
|
|
||||||
|
have_table = have_tables[table_name]
|
||||||
|
have_columns = have_table["columns"]
|
||||||
|
for colname, column in table["columns"].items():
|
||||||
|
if colname not in have_columns:
|
||||||
|
col_type = column["type"]
|
||||||
|
self.log.info(f"Adding column {colname} to table {table_name}.")
|
||||||
|
cursor.execute(
|
||||||
|
f"ALTER TABLE {table_name} ADD COLUMN {colname} {col_type}"
|
||||||
|
)
|
||||||
|
indices = table.get("indices", [])
|
||||||
|
have_indices = have_table.get("indices", [])
|
||||||
|
for index in indices:
|
||||||
|
index_name = index["index_name"]
|
||||||
|
if not any(
|
||||||
|
have_idx.get("index_name") == index_name for have_idx in have_indices
|
||||||
|
):
|
||||||
|
self.log.info(f"Adding index {index_name} to table {table_name}.")
|
||||||
|
column_1 = index["column_1"]
|
||||||
|
column_2 = index.get("column_2", None)
|
||||||
|
column_3 = index.get("column_3", None)
|
||||||
|
query: str = f"CREATE INDEX {index_name} ON {table_name} ({column_1}"
|
||||||
|
if column_2:
|
||||||
|
query += f", {column_2}"
|
||||||
|
if column_3:
|
||||||
|
query += f", {column_3}"
|
||||||
|
query += ")"
|
||||||
|
cursor.execute(query)
|
||||||
|
|
||||||
|
|
||||||
|
def upgradeDatabase(self, db_version: int):
|
||||||
if self._force_db_upgrade is False and db_version >= CURRENT_DB_VERSION:
|
if self._force_db_upgrade is False and db_version >= CURRENT_DB_VERSION:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -174,103 +261,15 @@ def upgradeDatabase(self, db_version):
|
|||||||
]
|
]
|
||||||
|
|
||||||
expect_schema = extract_schema()
|
expect_schema = extract_schema()
|
||||||
have_tables = {}
|
|
||||||
try:
|
try:
|
||||||
cursor = self.openDB()
|
cursor = self.openDB()
|
||||||
|
|
||||||
for rename_column in rename_columns:
|
for rename_column in rename_columns:
|
||||||
dbv, table_name, colname_from, colname_to = rename_column
|
dbv, table_name, colname_from, colname_to = rename_column
|
||||||
if db_version < dbv:
|
if db_version < dbv:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
f"ALTER TABLE {table_name} RENAME COLUMN {colname_from} TO {colname_to}"
|
f"ALTER TABLE {table_name} RENAME COLUMN {colname_from} TO {colname_to}"
|
||||||
)
|
)
|
||||||
|
upgradeDatabaseFromSchema(self, cursor, expect_schema)
|
||||||
query = "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;"
|
|
||||||
tables = cursor.execute(query).fetchall()
|
|
||||||
for table in tables:
|
|
||||||
table_name = table[0]
|
|
||||||
if table_name in ("sqlite_sequence",):
|
|
||||||
continue
|
|
||||||
|
|
||||||
have_table = {}
|
|
||||||
have_columns = {}
|
|
||||||
query = "SELECT * FROM PRAGMA_TABLE_INFO(:table_name) ORDER BY cid DESC;"
|
|
||||||
columns = cursor.execute(query, {"table_name": table_name}).fetchall()
|
|
||||||
for column in columns:
|
|
||||||
cid, name, data_type, notnull, default_value, primary_key = column
|
|
||||||
have_columns[name] = {"type": data_type, "primary_key": primary_key}
|
|
||||||
|
|
||||||
have_table["columns"] = have_columns
|
|
||||||
|
|
||||||
cursor.execute(f"PRAGMA INDEX_LIST('{table_name}');")
|
|
||||||
indices = cursor.fetchall()
|
|
||||||
for index in indices:
|
|
||||||
seq, index_name, unique, origin, partial = index
|
|
||||||
|
|
||||||
if origin == "pk": # Created by a PRIMARY KEY constraint
|
|
||||||
continue
|
|
||||||
|
|
||||||
cursor.execute(f"PRAGMA INDEX_INFO('{index_name}');")
|
|
||||||
index_info = cursor.fetchall()
|
|
||||||
|
|
||||||
add_index = {"index_name": index_name}
|
|
||||||
for index_columns in index_info:
|
|
||||||
seqno, cid, name = index_columns
|
|
||||||
if origin == "u": # Created by a UNIQUE constraint
|
|
||||||
have_columns[name]["unique"] = 1
|
|
||||||
else:
|
|
||||||
if "column_1" not in add_index:
|
|
||||||
add_index["column_1"] = name
|
|
||||||
elif "column_2" not in add_index:
|
|
||||||
add_index["column_2"] = name
|
|
||||||
elif "column_3" not in add_index:
|
|
||||||
add_index["column_3"] = name
|
|
||||||
else:
|
|
||||||
raise RuntimeError("Add more index columns.")
|
|
||||||
if origin == "c":
|
|
||||||
if "indices" not in table:
|
|
||||||
have_table["indices"] = []
|
|
||||||
have_table["indices"].append(add_index)
|
|
||||||
|
|
||||||
have_tables[table_name] = have_table
|
|
||||||
|
|
||||||
for table_name, table in expect_schema.items():
|
|
||||||
if table_name not in have_tables:
|
|
||||||
self.log.info(f"Creating table {table_name}.")
|
|
||||||
create_table(cursor, table_name, table)
|
|
||||||
continue
|
|
||||||
|
|
||||||
have_table = have_tables[table_name]
|
|
||||||
have_columns = have_table["columns"]
|
|
||||||
for colname, column in table["columns"].items():
|
|
||||||
if colname not in have_columns:
|
|
||||||
col_type = column["type"]
|
|
||||||
self.log.info(f"Adding column {colname} to table {table_name}.")
|
|
||||||
cursor.execute(
|
|
||||||
f"ALTER TABLE {table_name} ADD COLUMN {colname} {col_type}"
|
|
||||||
)
|
|
||||||
indices = table.get("indices", [])
|
|
||||||
have_indices = have_table.get("indices", [])
|
|
||||||
for index in indices:
|
|
||||||
index_name = index["index_name"]
|
|
||||||
if not any(
|
|
||||||
have_idx.get("index_name") == index_name
|
|
||||||
for have_idx in have_indices
|
|
||||||
):
|
|
||||||
self.log.info(f"Adding index {index_name} to table {table_name}.")
|
|
||||||
column_1 = index["column_1"]
|
|
||||||
column_2 = index.get("column_2", None)
|
|
||||||
column_3 = index.get("column_3", None)
|
|
||||||
query: str = (
|
|
||||||
f"CREATE INDEX {index_name} ON {table_name} ({column_1}"
|
|
||||||
)
|
|
||||||
if column_2:
|
|
||||||
query += f", {column_2}"
|
|
||||||
if column_3:
|
|
||||||
query += f", {column_3}"
|
|
||||||
query += ")"
|
|
||||||
cursor.execute(query)
|
|
||||||
|
|
||||||
if CURRENT_DB_VERSION != db_version:
|
if CURRENT_DB_VERSION != db_version:
|
||||||
self.db_version = CURRENT_DB_VERSION
|
self.db_version = CURRENT_DB_VERSION
|
||||||
self.setIntKV("db_version", CURRENT_DB_VERSION, cursor)
|
self.setIntKV("db_version", CURRENT_DB_VERSION, cursor)
|
||||||
|
|||||||
@@ -235,6 +235,14 @@ class BTCInterface(Secp256k1Interface):
|
|||||||
def txoType():
|
def txoType():
|
||||||
return CTxOut
|
return CTxOut
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def outpointType():
|
||||||
|
return COutPoint
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def txiType():
|
||||||
|
return CTxIn
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def getExpectedSequence(lockType: int, lockVal: int) -> int:
|
def getExpectedSequence(lockType: int, lockVal: int) -> int:
|
||||||
ensure(lockVal >= 1, "Bad lockVal")
|
ensure(lockVal >= 1, "Bad lockVal")
|
||||||
@@ -1203,7 +1211,7 @@ class BTCInterface(Secp256k1Interface):
|
|||||||
ensure(C == Kaf, "Bad script pubkey")
|
ensure(C == Kaf, "Bad script pubkey")
|
||||||
|
|
||||||
fee_paid = swap_value - locked_coin
|
fee_paid = swap_value - locked_coin
|
||||||
assert fee_paid > 0
|
ensure(fee_paid > 0, "negative fee_paid")
|
||||||
|
|
||||||
dummy_witness_stack = self.getScriptLockTxDummyWitness(prevout_script)
|
dummy_witness_stack = self.getScriptLockTxDummyWitness(prevout_script)
|
||||||
witness_bytes = self.getWitnessStackSerialisedLength(dummy_witness_stack)
|
witness_bytes = self.getWitnessStackSerialisedLength(dummy_witness_stack)
|
||||||
@@ -1267,7 +1275,7 @@ class BTCInterface(Secp256k1Interface):
|
|||||||
tx_value = tx.vout[0].nValue
|
tx_value = tx.vout[0].nValue
|
||||||
|
|
||||||
fee_paid = prevout_value - tx_value
|
fee_paid = prevout_value - tx_value
|
||||||
assert fee_paid > 0
|
ensure(fee_paid > 0, "negative fee_paid")
|
||||||
|
|
||||||
dummy_witness_stack = self.getScriptLockRefundSpendTxDummyWitness(
|
dummy_witness_stack = self.getScriptLockRefundSpendTxDummyWitness(
|
||||||
prevout_script
|
prevout_script
|
||||||
@@ -2575,12 +2583,7 @@ class BTCInterface(Secp256k1Interface):
|
|||||||
self._log.id(bytes.fromhex(tx["txid"]))
|
self._log.id(bytes.fromhex(tx["txid"]))
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.rpc(
|
self.publishTx(tx_signed)
|
||||||
"sendrawtransaction",
|
|
||||||
[
|
|
||||||
tx_signed,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
return tx["txid"]
|
return tx["txid"]
|
||||||
|
|
||||||
|
|||||||
@@ -191,17 +191,11 @@ def setDLEAG(xmr_swap, ci_to, kbsf: bytes) -> None:
|
|||||||
xmr_swap.kbsf_dleag = ci_to.proveDLEAG(kbsf)
|
xmr_swap.kbsf_dleag = ci_to.proveDLEAG(kbsf)
|
||||||
xmr_swap.pkasf = xmr_swap.kbsf_dleag[0:33]
|
xmr_swap.pkasf = xmr_swap.kbsf_dleag[0:33]
|
||||||
elif ci_to.curve_type() == Curves.secp256k1:
|
elif ci_to.curve_type() == Curves.secp256k1:
|
||||||
for i in range(10):
|
xmr_swap.kbsf_dleag = ci_to.signRecoverable(kbsf, "proof kbsf owned for swap")
|
||||||
xmr_swap.kbsf_dleag = ci_to.signRecoverable(
|
pk_recovered: bytes = ci_to.verifySigAndRecover(
|
||||||
kbsf, "proof kbsf owned for swap"
|
xmr_swap.kbsf_dleag, "proof kbsf owned for swap"
|
||||||
)
|
)
|
||||||
pk_recovered: bytes = ci_to.verifySigAndRecover(
|
ensure(pk_recovered == xmr_swap.pkbsf, "kbsf recovered pubkey mismatch")
|
||||||
xmr_swap.kbsf_dleag, "proof kbsf owned for swap"
|
|
||||||
)
|
|
||||||
if pk_recovered == xmr_swap.pkbsf:
|
|
||||||
break
|
|
||||||
# self.log.debug('kbsl recovered pubkey mismatch, retrying.')
|
|
||||||
assert pk_recovered == xmr_swap.pkbsf
|
|
||||||
xmr_swap.pkasf = xmr_swap.pkbsf
|
xmr_swap.pkasf = xmr_swap.pkbsf
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown curve")
|
raise ValueError("Unknown curve")
|
||||||
|
|||||||
51
basicswap/types.py
Normal file
51
basicswap/types.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
# Copyright (c) 2025 The Basicswap developers
|
||||||
|
# Distributed under the MIT software license, see the accompanying
|
||||||
|
# file LICENSE or http://www.opensource.org/licenses/mit-license.php.
|
||||||
|
|
||||||
|
|
||||||
|
class WatchedOutput: # Watch for spends
|
||||||
|
__slots__ = ("bid_id", "txid_hex", "vout", "tx_type", "swap_type")
|
||||||
|
|
||||||
|
def __init__(self, bid_id: bytes, txid_hex: str, vout, tx_type, swap_type):
|
||||||
|
self.bid_id = bid_id
|
||||||
|
self.txid_hex = txid_hex
|
||||||
|
self.vout = vout
|
||||||
|
self.tx_type = tx_type
|
||||||
|
self.swap_type = swap_type
|
||||||
|
|
||||||
|
|
||||||
|
class WatchedScript: # Watch for txns containing outputs
|
||||||
|
__slots__ = ("bid_id", "script", "tx_type", "swap_type")
|
||||||
|
|
||||||
|
def __init__(self, bid_id: bytes, script: bytes, tx_type, swap_type):
|
||||||
|
self.bid_id = bid_id
|
||||||
|
self.script = script
|
||||||
|
self.tx_type = tx_type
|
||||||
|
self.swap_type = swap_type
|
||||||
|
|
||||||
|
|
||||||
|
class WatchedTransaction:
|
||||||
|
__slots__ = (
|
||||||
|
"bid_id",
|
||||||
|
"coin_type",
|
||||||
|
"txid_hex",
|
||||||
|
"tx_type",
|
||||||
|
"swap_type",
|
||||||
|
"block_hash",
|
||||||
|
"depth",
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
# Watch for presence in mempool (getrawtransaction)
|
||||||
|
def __init__(
|
||||||
|
self, bid_id: bytes, coin_type: int, txid_hex: str, tx_type, swap_type
|
||||||
|
):
|
||||||
|
self.bid_id = bid_id
|
||||||
|
self.coin_type = coin_type
|
||||||
|
self.txid_hex = txid_hex
|
||||||
|
self.tx_type = tx_type
|
||||||
|
self.swap_type = swap_type
|
||||||
|
self.block_hash = None
|
||||||
|
self.depth = -1
|
||||||
@@ -94,6 +94,7 @@ class Test(unittest.TestCase):
|
|||||||
time_val = 48 * 60 * 60
|
time_val = 48 * 60 * 60
|
||||||
encoded = ci.getExpectedSequence(TxLockTypes.SEQUENCE_LOCK_TIME, time_val)
|
encoded = ci.getExpectedSequence(TxLockTypes.SEQUENCE_LOCK_TIME, time_val)
|
||||||
decoded = ci.decodeSequence(encoded)
|
decoded = ci.decodeSequence(encoded)
|
||||||
|
assert encoded == 4194642
|
||||||
assert decoded >= time_val
|
assert decoded >= time_val
|
||||||
assert decoded <= time_val + 512
|
assert decoded <= time_val + 512
|
||||||
|
|
||||||
|
|||||||
@@ -319,6 +319,7 @@ class Test(BaseTest):
|
|||||||
test_coin_from = Coins.PART
|
test_coin_from = Coins.PART
|
||||||
# p2wpkh
|
# p2wpkh
|
||||||
logging.info("---------- Test {} segwit".format(test_coin_from.name))
|
logging.info("---------- Test {} segwit".format(test_coin_from.name))
|
||||||
|
|
||||||
ci = self.swap_clients[0].ci(test_coin_from)
|
ci = self.swap_clients[0].ci(test_coin_from)
|
||||||
|
|
||||||
addr_native = ci.rpc_wallet("getnewaddress", ["p2pkh segwit test"])
|
addr_native = ci.rpc_wallet("getnewaddress", ["p2pkh segwit test"])
|
||||||
@@ -329,9 +330,11 @@ class Test(BaseTest):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
assert addr_info["iswitness"] is False # address is p2pkh, not p2wpkh
|
assert addr_info["iswitness"] is False # address is p2pkh, not p2wpkh
|
||||||
|
|
||||||
addr_segwit = ci.rpc_wallet(
|
addr_segwit = ci.rpc_wallet(
|
||||||
"getnewaddress", ["p2wpkh segwit test", True, False, False, "bech32"]
|
"getnewaddress", ["p2wpkh segwit test", True, False, False, "bech32"]
|
||||||
)
|
)
|
||||||
|
|
||||||
addr_info = ci.rpc_wallet(
|
addr_info = ci.rpc_wallet(
|
||||||
"getaddressinfo",
|
"getaddressinfo",
|
||||||
[
|
[
|
||||||
@@ -351,6 +354,7 @@ class Test(BaseTest):
|
|||||||
],
|
],
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(txid) == 64
|
assert len(txid) == 64
|
||||||
tx_wallet = ci.rpc_wallet(
|
tx_wallet = ci.rpc_wallet(
|
||||||
"gettransaction",
|
"gettransaction",
|
||||||
|
|||||||
Reference in New Issue
Block a user