From db2ba192202f8188c22d8ee45aa740a66f716f1f Mon Sep 17 00:00:00 2001 From: tecnovert Date: Mon, 7 Apr 2025 22:01:32 +0200 Subject: [PATCH] Improve checkSplitMessages. --- basicswap/basicswap.py | 98 +++++++++++++++++------------------------- basicswap/db.py | 25 ++++++++--- 2 files changed, 59 insertions(+), 64 deletions(-) diff --git a/basicswap/basicswap.py b/basicswap/basicswap.py index 2fe97b7..cbbd980 100644 --- a/basicswap/basicswap.py +++ b/basicswap/basicswap.py @@ -6879,80 +6879,61 @@ class BasicSwap(BaseApp): now: int = self.getTime() ttl_xmr_split_messages = 60 * 60 bid_cursor = None - dleag_proof_len: int = 48893 # coincurve.dleag.dleag_proof_len() - expect_segments: int = -( - (dleag_proof_len - self.dleag_split_size_init) // -self.dleag_split_size - ) # ceiling division try: cursor = self.openDB() bid_cursor = self.getNewDBCursor() q_bids = self.query( - Bid, bid_cursor, {"state": int(BidStates.BID_RECEIVING)} + Bid, + bid_cursor, + { + "state": ( + int(BidStates.BID_RECEIVING), + int(BidStates.BID_RECEIVING_ACC), + ) + }, ) for bid in q_bids: q = cursor.execute( - "SELECT COUNT(*) FROM xmr_split_data WHERE bid_id = :bid_id AND msg_type = :msg_type", - {"bid_id": bid.bid_id, "msg_type": int(XmrSplitMsgTypes.BID)}, - ).fetchone() - num_segments = q[0] - if num_segments >= expect_segments: - try: - self.receiveXmrBid(bid, cursor) - except Exception as ex: - self.log.info( - f"Verify adaptor-sig bid {self.log.id(bid.bid_id)} failed: {ex}" - ) - if self.debug: - self.log.error(traceback.format_exc()) - bid.setState( - BidStates.BID_ERROR, "Failed validation: " + str(ex) - ) - self.updateDB( - bid, - cursor, - [ - "bid_id", - ], - ) - self.updateBidInProgress(bid) - continue - if bid.created_at + ttl_xmr_split_messages < now: - self.log.debug( - f"Expiring partially received bid: {self.log.id(bid.bid_id)}." - ) - bid.setState(BidStates.BID_ERROR, "Timed out") - self.updateDB( - bid, - cursor, - [ - "bid_id", - ], - ) - - q_bids = self.query( - Bid, bid_cursor, {"state": int(BidStates.BID_RECEIVING_ACC)} - ) - for bid in q_bids: - q = cursor.execute( - "SELECT COUNT(*) FROM xmr_split_data WHERE bid_id = :bid_id AND msg_type = :msg_type", + "SELECT LENGTH(kbsl_dleag), LENGTH(kbsf_dleag) FROM xmr_swaps WHERE bid_id = :bid_id", { "bid_id": bid.bid_id, - "msg_type": int(XmrSplitMsgTypes.BID_ACCEPT), }, ).fetchone() - num_segments = q[0] - if num_segments >= expect_segments: + kbsl_dleag_len: int = q[0] + kbsf_dleag_len: int = q[1] + + if bid.state == int(BidStates.BID_RECEIVING_ACC): + bid_type: str = "bid accept" + msg_type: int = int(XmrSplitMsgTypes.BID_ACCEPT) + total_dleag_size: int = kbsl_dleag_len + else: + bid_type: str = "bid" + msg_type: int = int(XmrSplitMsgTypes.BID) + total_dleag_size: int = kbsf_dleag_len + + q = cursor.execute( + "SELECT COUNT(*), SUM(LENGTH(dleag)) AS total_dleag_size FROM xmr_split_data WHERE bid_id = :bid_id AND msg_type = :msg_type", + {"bid_id": bid.bid_id, "msg_type": msg_type}, + ).fetchone() + total_dleag_size += 0 if q[1] is None else q[1] + + if total_dleag_size >= dleag_proof_len: try: - self.receiveXmrBidAccept(bid, cursor) + if bid.state == int(BidStates.BID_RECEIVING): + self.receiveXmrBid(bid, cursor) + elif bid.state == int(BidStates.BID_RECEIVING_ACC): + self.receiveXmrBidAccept(bid, cursor) + else: + raise ValueError("Unexpected bid state") except Exception as ex: + self.log.info( + f"Verify adaptor-sig {bid_type} {self.log.id(bid.bid_id)} failed: {ex}" + ) if self.debug: self.log.error(traceback.format_exc()) - self.log.info( - f"Verify adaptor-sig bid accept {self.log.id(bid.bid_id)} failed: {ex}." - ) bid.setState( - BidStates.BID_ERROR, "Failed accept validation: " + str(ex) + BidStates.BID_ERROR, f"Failed {bid_type} validation: {ex}" ) self.updateDB( bid, @@ -6965,7 +6946,7 @@ class BasicSwap(BaseApp): continue if bid.created_at + ttl_xmr_split_messages < now: self.log.debug( - f"Expiring partially received bid accept: {self.log.id(bid.bid_id)}." + f"Expiring partially received {bid_type}: {self.log.id(bid.bid_id)}." ) bid.setState(BidStates.BID_ERROR, "Timed out") self.updateDB( @@ -6975,7 +6956,6 @@ class BasicSwap(BaseApp): "bid_id", ], ) - # Expire old records cursor.execute( "DELETE FROM xmr_split_data WHERE created_at + :ttl < :now", diff --git a/basicswap/db.py b/basicswap/db.py index 7b83b78..b418c84 100644 --- a/basicswap/db.py +++ b/basicswap/db.py @@ -929,15 +929,12 @@ class DBMethods: table_name: str = table_class.__tablename__ query: str = "SELECT " - columns = [] for mc in inspect.getmembers(table_class): mc_name, mc_obj = mc - if not hasattr(mc_obj, "__sqlite3_column__"): continue - if len(columns) > 0: query += ", " query += mc_name @@ -945,10 +942,29 @@ class DBMethods: query += f" FROM {table_name} WHERE 1=1 " + query_data = {} for ck in constraints: if not validColumnName(ck): raise ValueError(f"Invalid constraint column: {ck}") - query += f" AND {ck} = :{ck} " + + constraint_value = constraints[ck] + if isinstance(constraint_value, tuple) or isinstance( + constraint_value, list + ): + if len(constraint_value) < 2: + raise ValueError(f"Too few constraint values for list: {ck}") + query += f" AND {ck} IN (" + + for i, cv in enumerate(constraint_value): + cv_name: str = f"{ck}_{i}" + if i > 0: + query += "," + query += ":" + cv_name + query_data[cv_name] = cv + query += ") " + else: + query += f" AND {ck} = :{ck} " + query_data[ck] = constraint_value for order_col, order_dir in order_by.items(): if validColumnName(order_col) is False: @@ -961,7 +977,6 @@ class DBMethods: if query_suffix: query += query_suffix - query_data = constraints.copy() query_data.update(extra_query_data) rows = cursor.execute(query, query_data) for row in rows: