diff --git a/basicswap/db.py b/basicswap/db.py index 2b2537f..b1e53f6 100644 --- a/basicswap/db.py +++ b/basicswap/db.py @@ -76,10 +76,16 @@ class Table: __sqlite3_table__ = True def __init__(self, **kwargs): + init_all_columns: bool = True for name, value in kwargs.items(): + if name == "_init_all_columns": + init_all_columns = value + continue if not hasattr(self, name): raise ValueError(f"Unknown attribute {name}") setattr(self, name, value) + if init_all_columns is False: + return # Init any unset columns to None for mc in inspect.getmembers(self): mc_name, mc_obj = mc @@ -1033,7 +1039,7 @@ class DBMethods: if cursor is None: self.closeDB(use_cursor, commit=False) - def add(self, obj, cursor, upsert: bool = False): + def add(self, obj, cursor, upsert: bool = False, columns_list=None): if cursor is None: raise ValueError("Cursor is null") if not hasattr(obj, "__tablename__"): @@ -1046,7 +1052,8 @@ class DBMethods: # See if the instance overwrote any class methods for mc in inspect.getmembers(obj.__class__): mc_name, mc_obj = mc - + if columns_list is not None and mc_name not in columns_list: + continue if not hasattr(mc_obj, "__sqlite3_column__"): continue @@ -1087,6 +1094,7 @@ class DBMethods: order_by={}, query_suffix=None, extra_query_data={}, + columns_list=None, ): if cursor is None: raise ValueError("Cursor is null") @@ -1099,6 +1107,8 @@ class DBMethods: for mc in inspect.getmembers(table_class): mc_name, mc_obj = mc + if columns_list is not None and mc_name not in columns_list: + continue if not hasattr(mc_obj, "__sqlite3_column__"): continue if len(columns) > 0: @@ -1167,6 +1177,7 @@ class DBMethods: order_by={}, query_suffix=None, extra_query_data={}, + columns_list=None, ): return firstOrNone( self.query( @@ -1176,10 +1187,11 @@ class DBMethods: order_by, query_suffix, extra_query_data, + columns_list, ) ) - def updateDB(self, obj, cursor, constraints=[]): + def updateDB(self, obj, cursor, constraints=[], columns_list=None): if cursor is None: raise ValueError("Cursor is null") if not hasattr(obj, "__tablename__"): @@ -1191,7 +1203,6 @@ class DBMethods: values = {} for mc in inspect.getmembers(obj.__class__): mc_name, mc_obj = mc - if not hasattr(mc_obj, "__sqlite3_column__"): continue @@ -1203,7 +1214,8 @@ class DBMethods: if mc_name in constraints: values[mc_name] = m_obj continue - + if columns_list is not None and mc_name not in columns_list: + continue if len(values) > 0: query += ", " query += f"{mc_name} = :{mc_name}" diff --git a/tests/basicswap/test_other.py b/tests/basicswap/test_other.py index 7d61ac3..8ecae8f 100644 --- a/tests/basicswap/test_other.py +++ b/tests/basicswap/test_other.py @@ -663,6 +663,7 @@ class Test(unittest.TestCase): ki.record_id = 1 ki.address = "test1" ki.label = "test1" + ki.note = "note1" try: db_test.add(ki, cursor, upsert=False) except Exception as e: @@ -670,6 +671,65 @@ class Test(unittest.TestCase): else: raise ValueError("Should have errored.") db_test.add(ki, cursor, upsert=True) + + # Test columns list + ki_test = db_test.queryOne( + KnownIdentity, + cursor, + {"address": "test1"}, + columns_list=[ + "label", + ], + ) + assert ki_test.label == "test1" + assert ki_test.address is None + + # Test updating partial row + ki_test.label = "test2" + ki_test.record_id = 1 + db_test.add( + ki_test, + cursor, + upsert=True, + columns_list=[ + "record_id", + "label", + ], + ) + ki_test = db_test.queryOne(KnownIdentity, cursor, {"address": "test1"}) + assert ki_test.record_id == 1 + assert ki_test.address == "test1" + assert ki_test.label == "test2" + assert ki_test.note == "note1" + + ki_test.note = "test2" + ki_test.label = "test3" + + db_test.updateDB( + ki_test, + cursor, + ["record_id"], + columns_list=[ + "label", + ], + ) + ki_test = db_test.queryOne(KnownIdentity, cursor, {"address": "test1"}) + assert ki_test.record_id == 1 + assert ki_test.address == "test1" + assert ki_test.label == "test3" + assert ki_test.note == "note1" + + # Test partially initialised object + ki_test_p = KnownIdentity( + _init_all_columns=False, record_id=1, label="test4" + ) + db_test.add(ki_test_p, cursor, upsert=True) + ki_test = db_test.queryOne(KnownIdentity, cursor, {"address": "test1"}) + assert ki_test.record_id == 1 + assert ki_test.address == "test1" + assert ki_test.label == "test4" + assert ki_test.note == "note1" + finally: db_test.closeDB(cursor)