Merge pull request #418 from tecnovert/refactor

Refactor
This commit is contained in:
tecnovert
2026-01-12 18:05:50 +00:00
committed by GitHub
3 changed files with 95 additions and 24 deletions

View File

@@ -3271,17 +3271,20 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp):
self.log.debug(f"logBidEvent {self.log.id(bid_id)} {event_type}") self.log.debug(f"logBidEvent {self.log.id(bid_id)} {event_type}")
self.logEvent(Concepts.BID, bid_id, event_type, event_msg, cursor) self.logEvent(Concepts.BID, bid_id, event_type, event_msg, cursor)
def countBidEvents(self, bid, event_type, cursor): def countEvents(self, linked_type: int, linked_id: bytes, event_type: int, cursor):
q = cursor.execute( q = cursor.execute(
"SELECT COUNT(*) FROM eventlog WHERE linked_type = :linked_type AND linked_id = :linked_id AND event_type = :event_type", "SELECT COUNT(*) FROM eventlog WHERE linked_type = :linked_type AND linked_id = :linked_id AND event_type = :event_type",
{ {
"linked_type": int(Concepts.BID), "linked_type": int(Concepts.BID),
"linked_id": bid.bid_id, "linked_id": linked_id,
"event_type": int(event_type), "event_type": int(event_type),
}, },
).fetchone() ).fetchone()
return q[0] return q[0]
def countBidEvents(self, bid, event_type: int, cursor):
return self.countEvents(int(Concepts.BID), bid.bid_id, int(event_type))
def getEvents(self, linked_type: int, linked_id: bytes): def getEvents(self, linked_type: int, linked_id: bytes):
events = [] events = []
cursor = self.openDB() cursor = self.openDB()
@@ -5011,18 +5014,17 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp):
def setBidError( def setBidError(
self, self,
bid_id: bytes,
bid, bid,
error_str: str, error_str: str,
save_bid: bool = True, save_bid: bool = True,
xmr_swap=None, xmr_swap=None,
cursor=None, cursor=None,
) -> None: ) -> None:
self.log.error(f"Bid {self.log.id(bid_id)} - Error: {error_str}") self.log.error(f"Bid {self.log.id(bid.bid_id)} - Error: {error_str}")
self.logEvent(Concepts.BID, bid_id, EventLogTypes.ERROR, error_str, cursor) self.logEvent(Concepts.BID, bid.bid_id, EventLogTypes.ERROR, error_str, cursor)
bid.setState(BidStates.BID_ERROR) bid.setState(BidStates.BID_ERROR)
if save_bid: if save_bid:
self.saveBid(bid_id, bid, xmr_swap=xmr_swap, cursor=cursor) self.saveBid(bid.bid_id, bid, xmr_swap=xmr_swap, cursor=cursor)
def createInitiateTxn( def createInitiateTxn(
self, coin_type, bid_id: bytes, bid, initiate_script, prefunded_tx=None self, coin_type, bid_id: bytes, bid, initiate_script, prefunded_tx=None
@@ -7030,7 +7032,6 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp):
) )
else: else:
self.setBidError( self.setBidError(
bid.bid_id,
bid, bid,
"Unexpected txn spent coin a lock tx: {}".format(spend_txid_hex), "Unexpected txn spent coin a lock tx: {}".format(spend_txid_hex),
save_bid=False, save_bid=False,
@@ -9154,7 +9155,7 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp):
except Exception as ex: except Exception as ex:
if self.debug: if self.debug:
self.log.error(traceback.format_exc()) self.log.error(traceback.format_exc())
self.setBidError(bid.bid_id, bid, str(ex), xmr_swap=xmr_swap) self.setBidError(bid, str(ex), xmr_swap=xmr_swap)
def watchXmrSwap(self, bid, offer, xmr_swap, cursor=None) -> None: def watchXmrSwap(self, bid, offer, xmr_swap, cursor=None) -> None:
self.log.debug(f"Adaptor-sig swap in progress, bid {self.log.id(bid.bid_id)}.") self.log.debug(f"Adaptor-sig swap in progress, bid {self.log.id(bid.bid_id)}.")
@@ -9498,7 +9499,7 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp):
self.saveBidInSession(bid_id, bid, cursor, xmr_swap, save_in_progress=offer) self.saveBidInSession(bid_id, bid, cursor, xmr_swap, save_in_progress=offer)
return return
unlock_time = 0 unlock_time: int = 0
if bid.debug_ind in ( if bid.debug_ind in (
DebugTypes.CREATE_INVALID_COIN_B_LOCK, DebugTypes.CREATE_INVALID_COIN_B_LOCK,
DebugTypes.B_LOCK_TX_MISSED_SEND, DebugTypes.B_LOCK_TX_MISSED_SEND,
@@ -9569,7 +9570,6 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp):
) )
else: else:
self.setBidError( self.setBidError(
bid_id,
bid, bid,
"publishBLockTx failed: " + str(ex), "publishBLockTx failed: " + str(ex),
save_bid=False, save_bid=False,
@@ -9597,7 +9597,7 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp):
self.logBidEvent(bid.bid_id, EventLogTypes.LOCK_TX_B_PUBLISHED, "", cursor) self.logBidEvent(bid.bid_id, EventLogTypes.LOCK_TX_B_PUBLISHED, "", cursor)
if bid.debug_ind == DebugTypes.BID_STOP_AFTER_COIN_B_LOCK: if bid.debug_ind == DebugTypes.BID_STOP_AFTER_COIN_B_LOCK:
self.log.debug( self.log.debug(
"Adaptor-sig bid {self.log.id(bid_id)}: Stalling bid for testing: {bid.debug_ind}." f"Adaptor-sig bid {self.log.id(bid_id)}: Stalling bid for testing: {bid.debug_ind}."
) )
bid.setState(BidStates.BID_STALLED_FOR_TEST) bid.setState(BidStates.BID_STALLED_FOR_TEST)
self.logBidEvent( self.logBidEvent(
@@ -9896,7 +9896,6 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp):
) )
else: else:
self.setBidError( self.setBidError(
bid_id,
bid, bid,
"spendBLockTx failed: " + str(ex), "spendBLockTx failed: " + str(ex),
save_bid=False, save_bid=False,
@@ -10015,7 +10014,6 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp):
) )
else: else:
self.setBidError( self.setBidError(
bid_id,
bid, bid,
"spendBLockTx for refund failed: " + str(ex), "spendBLockTx for refund failed: " + str(ex),
save_bid=False, save_bid=False,
@@ -10220,7 +10218,7 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp):
except Exception as ex: except Exception as ex:
if self.debug: if self.debug:
self.log.error(traceback.format_exc()) self.log.error(traceback.format_exc())
self.setBidError(bid_id, bid, str(ex)) self.setBidError(bid, str(ex))
def processXmrBidLockSpendTx(self, msg) -> None: def processXmrBidLockSpendTx(self, msg) -> None:
# Follower receiving MSG4F # Follower receiving MSG4F
@@ -10285,7 +10283,7 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp):
except Exception as ex: except Exception as ex:
if self.debug: if self.debug:
self.log.error(traceback.format_exc()) self.log.error(traceback.format_exc())
self.setBidError(bid_id, bid, str(ex)) self.setBidError(bid, str(ex))
# Update copy of bid in swaps_in_progress # Update copy of bid in swaps_in_progress
self.swaps_in_progress[bid_id] = (bid, offer) self.swaps_in_progress[bid_id] = (bid, offer)
@@ -10387,7 +10385,7 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp):
except Exception as ex: except Exception as ex:
if self.debug: if self.debug:
self.log.error(traceback.format_exc()) self.log.error(traceback.format_exc())
self.setBidError(bid_id, bid, str(ex)) self.setBidError(bid, str(ex))
self.swaps_in_progress[bid_id] = (bid, offer) self.swaps_in_progress[bid_id] = (bid, offer)
return return
@@ -11000,9 +10998,10 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp):
to_remove = [] to_remove = []
if now - self._last_checked_progress >= self.check_progress_seconds: if now - self._last_checked_progress >= self.check_progress_seconds:
for bid_id, v in self.swaps_in_progress.items(): for bid_id, v in self.swaps_in_progress.items():
bid, offer = v
try: try:
if self.checkBidState(bid_id, v[0], v[1]) is True: if self.checkBidState(bid_id, bid, offer) is True:
to_remove.append((bid_id, v[0], v[1])) to_remove.append((bid_id, bid, offer))
except Exception as ex: except Exception as ex:
if self.debug: if self.debug:
self.log.error("checkBidState %s", traceback.format_exc()) self.log.error("checkBidState %s", traceback.format_exc())
@@ -11018,7 +11017,7 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp):
) )
else: else:
self.log.error(f"checkBidState {self.log.id(bid_id)} {ex}.") self.log.error(f"checkBidState {self.log.id(bid_id)} {ex}.")
self.setBidError(bid_id, v[0], str(ex)) self.setBidError(bid, str(ex))
for bid_id, bid, offer in to_remove: for bid_id, bid, offer in to_remove:
self.deactivateBid(None, offer, bid) self.deactivateBid(None, offer, bid)

View File

@@ -76,10 +76,16 @@ class Table:
__sqlite3_table__ = True __sqlite3_table__ = True
def __init__(self, **kwargs): def __init__(self, **kwargs):
init_all_columns: bool = True
for name, value in kwargs.items(): for name, value in kwargs.items():
if name == "_init_all_columns":
init_all_columns = value
continue
if not hasattr(self, name): if not hasattr(self, name):
raise ValueError(f"Unknown attribute {name}") raise ValueError(f"Unknown attribute {name}")
setattr(self, name, value) setattr(self, name, value)
if init_all_columns is False:
return
# Init any unset columns to None # Init any unset columns to None
for mc in inspect.getmembers(self): for mc in inspect.getmembers(self):
mc_name, mc_obj = mc mc_name, mc_obj = mc
@@ -1033,7 +1039,7 @@ class DBMethods:
if cursor is None: if cursor is None:
self.closeDB(use_cursor, commit=False) self.closeDB(use_cursor, commit=False)
def add(self, obj, cursor, upsert: bool = False): def add(self, obj, cursor, upsert: bool = False, columns_list=None):
if cursor is None: if cursor is None:
raise ValueError("Cursor is null") raise ValueError("Cursor is null")
if not hasattr(obj, "__tablename__"): if not hasattr(obj, "__tablename__"):
@@ -1046,7 +1052,8 @@ class DBMethods:
# See if the instance overwrote any class methods # See if the instance overwrote any class methods
for mc in inspect.getmembers(obj.__class__): for mc in inspect.getmembers(obj.__class__):
mc_name, mc_obj = mc mc_name, mc_obj = mc
if columns_list is not None and mc_name not in columns_list:
continue
if not hasattr(mc_obj, "__sqlite3_column__"): if not hasattr(mc_obj, "__sqlite3_column__"):
continue continue
@@ -1087,6 +1094,7 @@ class DBMethods:
order_by={}, order_by={},
query_suffix=None, query_suffix=None,
extra_query_data={}, extra_query_data={},
columns_list=None,
): ):
if cursor is None: if cursor is None:
raise ValueError("Cursor is null") raise ValueError("Cursor is null")
@@ -1099,6 +1107,8 @@ class DBMethods:
for mc in inspect.getmembers(table_class): for mc in inspect.getmembers(table_class):
mc_name, mc_obj = mc mc_name, mc_obj = mc
if columns_list is not None and mc_name not in columns_list:
continue
if not hasattr(mc_obj, "__sqlite3_column__"): if not hasattr(mc_obj, "__sqlite3_column__"):
continue continue
if len(columns) > 0: if len(columns) > 0:
@@ -1167,6 +1177,7 @@ class DBMethods:
order_by={}, order_by={},
query_suffix=None, query_suffix=None,
extra_query_data={}, extra_query_data={},
columns_list=None,
): ):
return firstOrNone( return firstOrNone(
self.query( self.query(
@@ -1176,10 +1187,11 @@ class DBMethods:
order_by, order_by,
query_suffix, query_suffix,
extra_query_data, extra_query_data,
columns_list,
) )
) )
def updateDB(self, obj, cursor, constraints=[]): def updateDB(self, obj, cursor, constraints=[], columns_list=None):
if cursor is None: if cursor is None:
raise ValueError("Cursor is null") raise ValueError("Cursor is null")
if not hasattr(obj, "__tablename__"): if not hasattr(obj, "__tablename__"):
@@ -1191,7 +1203,6 @@ class DBMethods:
values = {} values = {}
for mc in inspect.getmembers(obj.__class__): for mc in inspect.getmembers(obj.__class__):
mc_name, mc_obj = mc mc_name, mc_obj = mc
if not hasattr(mc_obj, "__sqlite3_column__"): if not hasattr(mc_obj, "__sqlite3_column__"):
continue continue
@@ -1203,7 +1214,8 @@ class DBMethods:
if mc_name in constraints: if mc_name in constraints:
values[mc_name] = m_obj values[mc_name] = m_obj
continue continue
if columns_list is not None and mc_name not in columns_list:
continue
if len(values) > 0: if len(values) > 0:
query += ", " query += ", "
query += f"{mc_name} = :{mc_name}" query += f"{mc_name} = :{mc_name}"

View File

@@ -663,6 +663,7 @@ class Test(unittest.TestCase):
ki.record_id = 1 ki.record_id = 1
ki.address = "test1" ki.address = "test1"
ki.label = "test1" ki.label = "test1"
ki.note = "note1"
try: try:
db_test.add(ki, cursor, upsert=False) db_test.add(ki, cursor, upsert=False)
except Exception as e: except Exception as e:
@@ -670,6 +671,65 @@ class Test(unittest.TestCase):
else: else:
raise ValueError("Should have errored.") raise ValueError("Should have errored.")
db_test.add(ki, cursor, upsert=True) db_test.add(ki, cursor, upsert=True)
# Test columns list
ki_test = db_test.queryOne(
KnownIdentity,
cursor,
{"address": "test1"},
columns_list=[
"label",
],
)
assert ki_test.label == "test1"
assert ki_test.address is None
# Test updating partial row
ki_test.label = "test2"
ki_test.record_id = 1
db_test.add(
ki_test,
cursor,
upsert=True,
columns_list=[
"record_id",
"label",
],
)
ki_test = db_test.queryOne(KnownIdentity, cursor, {"address": "test1"})
assert ki_test.record_id == 1
assert ki_test.address == "test1"
assert ki_test.label == "test2"
assert ki_test.note == "note1"
ki_test.note = "test2"
ki_test.label = "test3"
db_test.updateDB(
ki_test,
cursor,
["record_id"],
columns_list=[
"label",
],
)
ki_test = db_test.queryOne(KnownIdentity, cursor, {"address": "test1"})
assert ki_test.record_id == 1
assert ki_test.address == "test1"
assert ki_test.label == "test3"
assert ki_test.note == "note1"
# Test partially initialised object
ki_test_p = KnownIdentity(
_init_all_columns=False, record_id=1, label="test4"
)
db_test.add(ki_test_p, cursor, upsert=True)
ki_test = db_test.queryOne(KnownIdentity, cursor, {"address": "test1"})
assert ki_test.record_id == 1
assert ki_test.address == "test1"
assert ki_test.label == "test4"
assert ki_test.note == "note1"
finally: finally:
db_test.closeDB(cursor) db_test.closeDB(cursor)