diff --git a/basicswap/basicswap.py b/basicswap/basicswap.py index c1fbfb2..e0079d6 100644 --- a/basicswap/basicswap.py +++ b/basicswap/basicswap.py @@ -500,7 +500,6 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp): self.db_version = self.getIntKV("db_version", cursor, CURRENT_DB_VERSION) self.db_data_version = self.getIntKV("db_data_version", cursor, 0) self._contract_count = self.getIntKV("contract_count", cursor, 0) - self.commitDB() finally: self.closeDB(cursor) @@ -583,6 +582,13 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp): def finalise(self): self.log.info("Finalising") + if self.ws_server: + try: + self.log.info("Stopping websocket server.") + self.ws_server.shutdown_gracefully() + except Exception as e: # noqa: F841 + traceback.print_exc() + self._price_fetch_running = False if self._price_fetch_thread and self._price_fetch_thread.is_alive(): self._price_fetch_thread.join(timeout=5) @@ -715,6 +721,7 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp): "rpcauth": rpcauth, "blocks_confirmed": chain_client_settings.get("blocks_confirmed", 6), "conf_target": chain_client_settings.get("conf_target", 2), + "watched_transactions": [], "watched_outputs": [], "watched_scripts": [], "last_height_checked": last_height_checked, @@ -4423,17 +4430,13 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp): xmr_swap.kbsf_dleag = ci_to.proveDLEAG(kbsf) xmr_swap.pkasf = xmr_swap.kbsf_dleag[0:33] elif ci_to.curve_type() == Curves.secp256k1: - for i in range(10): - xmr_swap.kbsf_dleag = ci_to.signRecoverable( - kbsf, "proof kbsf owned for swap" - ) - pk_recovered = ci_to.verifySigAndRecover( - xmr_swap.kbsf_dleag, "proof kbsf owned for swap" - ) - if pk_recovered == xmr_swap.pkbsf: - break - self.log.debug("kbsl recovered pubkey mismatch, retrying.") - assert pk_recovered == xmr_swap.pkbsf + xmr_swap.kbsf_dleag = ci_to.signRecoverable( + kbsf, "proof kbsf owned for swap" + ) + pk_recovered = ci_to.verifySigAndRecover( + xmr_swap.kbsf_dleag, "proof kbsf owned for swap" + ) + ensure(pk_recovered == xmr_swap.pkbsf, "kbsf recovered pubkey mismatch") xmr_swap.pkasf = xmr_swap.pkbsf else: raise ValueError("Unknown curve") @@ -4736,17 +4739,13 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp): xmr_swap.kbsl_dleag = ci_to.proveDLEAG(kbsl) msg_buf.kbsl_dleag = xmr_swap.kbsl_dleag[:dleag_split_size_init] elif ci_to.curve_type() == Curves.secp256k1: - for i in range(10): - xmr_swap.kbsl_dleag = ci_to.signRecoverable( - kbsl, "proof kbsl owned for swap" - ) - pk_recovered = ci_to.verifySigAndRecover( - xmr_swap.kbsl_dleag, "proof kbsl owned for swap" - ) - if pk_recovered == xmr_swap.pkbsl: - break - self.log.debug("kbsl recovered pubkey mismatch, retrying.") - assert pk_recovered == xmr_swap.pkbsl + xmr_swap.kbsl_dleag = ci_to.signRecoverable( + kbsl, "proof kbsl owned for swap" + ) + pk_recovered = ci_to.verifySigAndRecover( + xmr_swap.kbsl_dleag, "proof kbsl owned for swap" + ) + ensure(pk_recovered == xmr_swap.pkbsl, "kbsl recovered pubkey mismatch") msg_buf.kbsl_dleag = xmr_swap.kbsl_dleag else: raise ValueError("Unknown curve") @@ -5556,7 +5555,11 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp): block_header = ci.getBlockHeaderFromHeight(tx_height) block_time = block_header["time"] cc = self.coin_clients[coin_type] - if len(cc["watched_outputs"]) == 0 and len(cc["watched_scripts"]) == 0: + if ( + len(cc["watched_outputs"]) == 0 + and len(cc["watched_transactions"]) == 0 + and len(cc["watched_scripts"]) == 0 + ): cc["last_height_checked"] = tx_height cc["block_check_min_time"] = block_time self.setIntKV("block_check_min_time_" + coin_name, block_time, cursor) @@ -6702,6 +6705,39 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp): except Exception: return None + def addWatchedTransaction( + self, coin_type, bid_id, txid_hex, tx_type, swap_type=None + ): + self.log.debug( + f"Adding watched transaction {Coins(coin_type).name} bid {self.log.id(bid_id)} tx {self.log.id(txid_hex)} type {tx_type}" + ) + + watched = self.coin_clients[coin_type]["watched_transactions"] + + for wo in watched: + if wo.bid_id == bid_id and wo.txid_hex == txid_hex: + self.log.debug("Transaction already being watched.") + return + + watched.append( + WatchedTransaction(bid_id, coin_type, txid_hex, tx_type, swap_type) + ) + + def removeWatchedTransaction(self, coin_type, bid_id: bytes, txid_hex: str) -> None: + # Remove all for bid if txid is None + self.log.debug( + f"Removing watched transaction {Coins(coin_type).name} {self.log.id(bid_id)} {self.log.id(txid_hex)}" + ) + watched = self.coin_clients[coin_type]["watched_transactions"] + old_len = len(watched) + for i in range(old_len - 1, -1, -1): + wo = watched[i] + if wo.bid_id == bid_id and (txid_hex is None or wo.txid_hex == txid_hex): + del watched[i] + self.log.debug( + f"Removed watched transaction {Coins(coin_type).name} {self.log.id(bid_id)} {self.log.id(wo.txid_hex)}" + ) + def addWatchedOutput( self, coin_type, bid_id, txid_hex, vout, tx_type, swap_type=None ): @@ -6710,7 +6746,6 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp): ) watched = self.coin_clients[coin_type]["watched_outputs"] - for wo in watched: if wo.bid_id == bid_id and wo.txid_hex == txid_hex and wo.vout == vout: self.log.debug("Output already being watched.") @@ -6721,13 +6756,14 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp): def removeWatchedOutput(self, coin_type, bid_id: bytes, txid_hex: str) -> None: # Remove all for bid if txid is None self.log.debug( - f"removeWatchedOutput {Coins(coin_type).name} {self.log.id(bid_id)} {self.log.id(txid_hex)}" + f"Removing watched output {Coins(coin_type).name} {self.log.id(bid_id)} {self.log.id(txid_hex)}" ) - old_len = len(self.coin_clients[coin_type]["watched_outputs"]) + watched = self.coin_clients[coin_type]["watched_outputs"] + old_len = len(watched) for i in range(old_len - 1, -1, -1): - wo = self.coin_clients[coin_type]["watched_outputs"][i] + wo = watched[i] if wo.bid_id == bid_id and (txid_hex is None or wo.txid_hex == txid_hex): - del self.coin_clients[coin_type]["watched_outputs"][i] + del watched[i] self.log.debug( f"Removed watched output {Coins(coin_type).name} {self.log.id(bid_id)} {self.log.id(wo.txid_hex)}" ) @@ -6740,7 +6776,6 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp): ) watched = self.coin_clients[coin_type]["watched_scripts"] - for ws in watched: if ws.bid_id == bid_id and ws.tx_type == tx_type and ws.script == script: self.log.debug("Script already being watched.") @@ -6753,21 +6788,22 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp): ) -> None: # Remove all for bid if script and type_ind is None self.log.debug( - "removeWatchedScript {} {}{}".format( + "Removing watched script {} {}{}".format( Coins(coin_type).name, {self.log.id(bid_id)}, (" type " + str(tx_type)) if tx_type is not None else "", ) ) - old_len = len(self.coin_clients[coin_type]["watched_scripts"]) + watched = self.coin_clients[coin_type]["watched_scripts"] + old_len = len(watched) for i in range(old_len - 1, -1, -1): - ws = self.coin_clients[coin_type]["watched_scripts"][i] + ws = watched[i] if ( ws.bid_id == bid_id and (script is None or ws.script == script) and (tx_type is None or ws.tx_type == tx_type) ): - del self.coin_clients[coin_type]["watched_scripts"][i] + del watched[i] self.log.debug( f"Removed watched script {Coins(coin_type).name} {self.log.id(bid_id)}" ) @@ -7145,6 +7181,17 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp): finally: self.closeDB(cursor) + def processFoundTransaction( + self, + watched_txn: WatchedTransaction, + block_hash_hex: str, + block_height: int, + chain_blocks: int, + ): + self.log.warning( + f"Unknown swap_type for found transaction: {self.logIDT(bytes.fromhex(watched_txn.txid_hex))}." + ) + def processSpentOutput( self, coin_type, watched_output, spend_txid_hex, spend_n, spend_txn ) -> None: @@ -7407,6 +7454,13 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp): continue for tx in block["tx"]: + for t in c["watched_transactions"]: + if t.block_hash is not None: + continue + if tx["txid"] == t.txid_hex: + self.processFoundTransaction( + t, block_hash, block["height"], chain_blocks + ) for s in c["watched_scripts"]: for i, txo in enumerate(tx["vout"]): if "scriptPubKey" in txo and "hex" in txo["scriptPubKey"]: @@ -8716,20 +8770,18 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp): # Extract pubkeys from MSG1F DLEAG xmr_swap.pkasl = xmr_swap.kbsl_dleag[0:33] - if not ci_from.verifyPubkey(xmr_swap.pkasl): - raise ValueError("Invalid coin a pubkey.") xmr_swap.pkbsl = xmr_swap.kbsl_dleag[33 : 33 + 32] - if not ci_to.verifyPubkey(xmr_swap.pkbsl): - raise ValueError("Invalid coin b pubkey.") elif ci_to.curve_type() == Curves.secp256k1: xmr_swap.pkasl = ci_to.verifySigAndRecover( xmr_swap.kbsl_dleag, "proof kbsl owned for swap" ) - if not ci_from.verifyPubkey(xmr_swap.pkasl): - raise ValueError("Invalid coin a pubkey.") xmr_swap.pkbsl = xmr_swap.pkasl else: raise ValueError("Unknown curve") + if not ci_from.verifyPubkey(xmr_swap.pkasl): + raise ValueError("Invalid coin a pubkey.") + if not ci_to.verifyPubkey(xmr_swap.pkbsl): + raise ValueError("Invalid coin b pubkey.") # vkbv and vkbvl are verified in processXmrBidAccept xmr_swap.pkbv = ci_to.sumPubkeys(xmr_swap.pkbvl, xmr_swap.pkbvf) diff --git a/basicswap/bin/run.py b/basicswap/bin/run.py index 43fcf7d..d67859f 100755 --- a/basicswap/bin/run.py +++ b/basicswap/bin/run.py @@ -608,13 +608,6 @@ def runClient( except Exception as e: # noqa: F841 traceback.print_exc() - if swap_client.ws_server: - try: - swap_client.log.info("Stopping websocket server.") - swap_client.ws_server.shutdown_gracefully() - except Exception as e: # noqa: F841 - traceback.print_exc() - swap_client.finalise() closed_pids = [] diff --git a/basicswap/db.py b/basicswap/db.py index a0a15d5..5bea941 100644 --- a/basicswap/db.py +++ b/basicswap/db.py @@ -792,8 +792,8 @@ class NetworkPortal(Table): created_at = Column("integer") -def extract_schema() -> dict: - g = globals().copy() +def extract_schema(input_globals=None) -> dict: + g = globals() if input_globals is None else input_globals tables = {} for name, obj in g.items(): if not inspect.isclass(obj): diff --git a/basicswap/db_upgrades.py b/basicswap/db_upgrades.py index 5157223..89ccaff 100644 --- a/basicswap/db_upgrades.py +++ b/basicswap/db_upgrades.py @@ -152,7 +152,94 @@ def upgradeDatabaseData(self, data_version): self.closeDB(cursor, commit=False) -def upgradeDatabase(self, db_version): +def upgradeDatabaseFromSchema(self, cursor, expect_schema): + have_tables = {} + + query = "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;" + tables = cursor.execute(query).fetchall() + for table in tables: + table_name = table[0] + if table_name in ("sqlite_sequence",): + continue + + have_table = {} + have_columns = {} + query = "SELECT * FROM PRAGMA_TABLE_INFO(:table_name) ORDER BY cid DESC;" + columns = cursor.execute(query, {"table_name": table_name}).fetchall() + for column in columns: + cid, name, data_type, notnull, default_value, primary_key = column + have_columns[name] = {"type": data_type, "primary_key": primary_key} + + have_table["columns"] = have_columns + + cursor.execute(f"PRAGMA INDEX_LIST('{table_name}');") + indices = cursor.fetchall() + for index in indices: + seq, index_name, unique, origin, partial = index + + if origin == "pk": # Created by a PRIMARY KEY constraint + continue + + cursor.execute(f"PRAGMA INDEX_INFO('{index_name}');") + index_info = cursor.fetchall() + + add_index = {"index_name": index_name} + for index_columns in index_info: + seqno, cid, name = index_columns + if origin == "u": # Created by a UNIQUE constraint + have_columns[name]["unique"] = 1 + else: + if "column_1" not in add_index: + add_index["column_1"] = name + elif "column_2" not in add_index: + add_index["column_2"] = name + elif "column_3" not in add_index: + add_index["column_3"] = name + else: + raise RuntimeError("Add more index columns.") + if origin == "c": + if "indices" not in table: + have_table["indices"] = [] + have_table["indices"].append(add_index) + + have_tables[table_name] = have_table + + for table_name, table in expect_schema.items(): + if table_name not in have_tables: + self.log.info(f"Creating table {table_name}.") + create_table(cursor, table_name, table) + continue + + have_table = have_tables[table_name] + have_columns = have_table["columns"] + for colname, column in table["columns"].items(): + if colname not in have_columns: + col_type = column["type"] + self.log.info(f"Adding column {colname} to table {table_name}.") + cursor.execute( + f"ALTER TABLE {table_name} ADD COLUMN {colname} {col_type}" + ) + indices = table.get("indices", []) + have_indices = have_table.get("indices", []) + for index in indices: + index_name = index["index_name"] + if not any( + have_idx.get("index_name") == index_name for have_idx in have_indices + ): + self.log.info(f"Adding index {index_name} to table {table_name}.") + column_1 = index["column_1"] + column_2 = index.get("column_2", None) + column_3 = index.get("column_3", None) + query: str = f"CREATE INDEX {index_name} ON {table_name} ({column_1}" + if column_2: + query += f", {column_2}" + if column_3: + query += f", {column_3}" + query += ")" + cursor.execute(query) + + +def upgradeDatabase(self, db_version: int): if self._force_db_upgrade is False and db_version >= CURRENT_DB_VERSION: return @@ -174,103 +261,15 @@ def upgradeDatabase(self, db_version): ] expect_schema = extract_schema() - have_tables = {} try: cursor = self.openDB() - for rename_column in rename_columns: dbv, table_name, colname_from, colname_to = rename_column if db_version < dbv: cursor.execute( f"ALTER TABLE {table_name} RENAME COLUMN {colname_from} TO {colname_to}" ) - - query = "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;" - tables = cursor.execute(query).fetchall() - for table in tables: - table_name = table[0] - if table_name in ("sqlite_sequence",): - continue - - have_table = {} - have_columns = {} - query = "SELECT * FROM PRAGMA_TABLE_INFO(:table_name) ORDER BY cid DESC;" - columns = cursor.execute(query, {"table_name": table_name}).fetchall() - for column in columns: - cid, name, data_type, notnull, default_value, primary_key = column - have_columns[name] = {"type": data_type, "primary_key": primary_key} - - have_table["columns"] = have_columns - - cursor.execute(f"PRAGMA INDEX_LIST('{table_name}');") - indices = cursor.fetchall() - for index in indices: - seq, index_name, unique, origin, partial = index - - if origin == "pk": # Created by a PRIMARY KEY constraint - continue - - cursor.execute(f"PRAGMA INDEX_INFO('{index_name}');") - index_info = cursor.fetchall() - - add_index = {"index_name": index_name} - for index_columns in index_info: - seqno, cid, name = index_columns - if origin == "u": # Created by a UNIQUE constraint - have_columns[name]["unique"] = 1 - else: - if "column_1" not in add_index: - add_index["column_1"] = name - elif "column_2" not in add_index: - add_index["column_2"] = name - elif "column_3" not in add_index: - add_index["column_3"] = name - else: - raise RuntimeError("Add more index columns.") - if origin == "c": - if "indices" not in table: - have_table["indices"] = [] - have_table["indices"].append(add_index) - - have_tables[table_name] = have_table - - for table_name, table in expect_schema.items(): - if table_name not in have_tables: - self.log.info(f"Creating table {table_name}.") - create_table(cursor, table_name, table) - continue - - have_table = have_tables[table_name] - have_columns = have_table["columns"] - for colname, column in table["columns"].items(): - if colname not in have_columns: - col_type = column["type"] - self.log.info(f"Adding column {colname} to table {table_name}.") - cursor.execute( - f"ALTER TABLE {table_name} ADD COLUMN {colname} {col_type}" - ) - indices = table.get("indices", []) - have_indices = have_table.get("indices", []) - for index in indices: - index_name = index["index_name"] - if not any( - have_idx.get("index_name") == index_name - for have_idx in have_indices - ): - self.log.info(f"Adding index {index_name} to table {table_name}.") - column_1 = index["column_1"] - column_2 = index.get("column_2", None) - column_3 = index.get("column_3", None) - query: str = ( - f"CREATE INDEX {index_name} ON {table_name} ({column_1}" - ) - if column_2: - query += f", {column_2}" - if column_3: - query += f", {column_3}" - query += ")" - cursor.execute(query) - + upgradeDatabaseFromSchema(self, cursor, expect_schema) if CURRENT_DB_VERSION != db_version: self.db_version = CURRENT_DB_VERSION self.setIntKV("db_version", CURRENT_DB_VERSION, cursor) diff --git a/basicswap/interface/btc.py b/basicswap/interface/btc.py index c7b0ab5..029613c 100644 --- a/basicswap/interface/btc.py +++ b/basicswap/interface/btc.py @@ -235,6 +235,14 @@ class BTCInterface(Secp256k1Interface): def txoType(): return CTxOut + @staticmethod + def outpointType(): + return COutPoint + + @staticmethod + def txiType(): + return CTxIn + @staticmethod def getExpectedSequence(lockType: int, lockVal: int) -> int: ensure(lockVal >= 1, "Bad lockVal") @@ -1203,7 +1211,7 @@ class BTCInterface(Secp256k1Interface): ensure(C == Kaf, "Bad script pubkey") fee_paid = swap_value - locked_coin - assert fee_paid > 0 + ensure(fee_paid > 0, "negative fee_paid") dummy_witness_stack = self.getScriptLockTxDummyWitness(prevout_script) witness_bytes = self.getWitnessStackSerialisedLength(dummy_witness_stack) @@ -1267,7 +1275,7 @@ class BTCInterface(Secp256k1Interface): tx_value = tx.vout[0].nValue fee_paid = prevout_value - tx_value - assert fee_paid > 0 + ensure(fee_paid > 0, "negative fee_paid") dummy_witness_stack = self.getScriptLockRefundSpendTxDummyWitness( prevout_script @@ -2575,12 +2583,7 @@ class BTCInterface(Secp256k1Interface): self._log.id(bytes.fromhex(tx["txid"])) ) ) - self.rpc( - "sendrawtransaction", - [ - tx_signed, - ], - ) + self.publishTx(tx_signed) return tx["txid"] diff --git a/basicswap/protocols/xmr_swap_1.py b/basicswap/protocols/xmr_swap_1.py index 733aea9..1a3e012 100644 --- a/basicswap/protocols/xmr_swap_1.py +++ b/basicswap/protocols/xmr_swap_1.py @@ -191,17 +191,11 @@ def setDLEAG(xmr_swap, ci_to, kbsf: bytes) -> None: xmr_swap.kbsf_dleag = ci_to.proveDLEAG(kbsf) xmr_swap.pkasf = xmr_swap.kbsf_dleag[0:33] elif ci_to.curve_type() == Curves.secp256k1: - for i in range(10): - xmr_swap.kbsf_dleag = ci_to.signRecoverable( - kbsf, "proof kbsf owned for swap" - ) - pk_recovered: bytes = ci_to.verifySigAndRecover( - xmr_swap.kbsf_dleag, "proof kbsf owned for swap" - ) - if pk_recovered == xmr_swap.pkbsf: - break - # self.log.debug('kbsl recovered pubkey mismatch, retrying.') - assert pk_recovered == xmr_swap.pkbsf + xmr_swap.kbsf_dleag = ci_to.signRecoverable(kbsf, "proof kbsf owned for swap") + pk_recovered: bytes = ci_to.verifySigAndRecover( + xmr_swap.kbsf_dleag, "proof kbsf owned for swap" + ) + ensure(pk_recovered == xmr_swap.pkbsf, "kbsf recovered pubkey mismatch") xmr_swap.pkasf = xmr_swap.pkbsf else: raise ValueError("Unknown curve") diff --git a/tests/basicswap/test_other.py b/tests/basicswap/test_other.py index 1dfc7b5..7d61ac3 100644 --- a/tests/basicswap/test_other.py +++ b/tests/basicswap/test_other.py @@ -94,6 +94,7 @@ class Test(unittest.TestCase): time_val = 48 * 60 * 60 encoded = ci.getExpectedSequence(TxLockTypes.SEQUENCE_LOCK_TIME, time_val) decoded = ci.decodeSequence(encoded) + assert encoded == 4194642 assert decoded >= time_val assert decoded <= time_val + 512 diff --git a/tests/basicswap/test_run.py b/tests/basicswap/test_run.py index 1428cd0..54af4e9 100644 --- a/tests/basicswap/test_run.py +++ b/tests/basicswap/test_run.py @@ -319,6 +319,7 @@ class Test(BaseTest): test_coin_from = Coins.PART # p2wpkh logging.info("---------- Test {} segwit".format(test_coin_from.name)) + ci = self.swap_clients[0].ci(test_coin_from) addr_native = ci.rpc_wallet("getnewaddress", ["p2pkh segwit test"]) @@ -329,9 +330,11 @@ class Test(BaseTest): ], ) assert addr_info["iswitness"] is False # address is p2pkh, not p2wpkh + addr_segwit = ci.rpc_wallet( "getnewaddress", ["p2wpkh segwit test", True, False, False, "bech32"] ) + addr_info = ci.rpc_wallet( "getaddressinfo", [ @@ -351,6 +354,7 @@ class Test(BaseTest): ], ], ) + assert len(txid) == 64 tx_wallet = ci.rpc_wallet( "gettransaction",