mirror of
https://github.com/basicswap/basicswap.git
synced 2026-01-27 19:15:09 +01:00
db: enable partial retrievals and updates
This commit is contained in:
@@ -76,10 +76,16 @@ class Table:
|
||||
__sqlite3_table__ = True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
init_all_columns: bool = True
|
||||
for name, value in kwargs.items():
|
||||
if name == "_init_all_columns":
|
||||
init_all_columns = value
|
||||
continue
|
||||
if not hasattr(self, name):
|
||||
raise ValueError(f"Unknown attribute {name}")
|
||||
setattr(self, name, value)
|
||||
if init_all_columns is False:
|
||||
return
|
||||
# Init any unset columns to None
|
||||
for mc in inspect.getmembers(self):
|
||||
mc_name, mc_obj = mc
|
||||
@@ -1033,7 +1039,7 @@ class DBMethods:
|
||||
if cursor is None:
|
||||
self.closeDB(use_cursor, commit=False)
|
||||
|
||||
def add(self, obj, cursor, upsert: bool = False):
|
||||
def add(self, obj, cursor, upsert: bool = False, columns_list=None):
|
||||
if cursor is None:
|
||||
raise ValueError("Cursor is null")
|
||||
if not hasattr(obj, "__tablename__"):
|
||||
@@ -1046,7 +1052,8 @@ class DBMethods:
|
||||
# See if the instance overwrote any class methods
|
||||
for mc in inspect.getmembers(obj.__class__):
|
||||
mc_name, mc_obj = mc
|
||||
|
||||
if columns_list is not None and mc_name not in columns_list:
|
||||
continue
|
||||
if not hasattr(mc_obj, "__sqlite3_column__"):
|
||||
continue
|
||||
|
||||
@@ -1087,6 +1094,7 @@ class DBMethods:
|
||||
order_by={},
|
||||
query_suffix=None,
|
||||
extra_query_data={},
|
||||
columns_list=None,
|
||||
):
|
||||
if cursor is None:
|
||||
raise ValueError("Cursor is null")
|
||||
@@ -1099,6 +1107,8 @@ class DBMethods:
|
||||
|
||||
for mc in inspect.getmembers(table_class):
|
||||
mc_name, mc_obj = mc
|
||||
if columns_list is not None and mc_name not in columns_list:
|
||||
continue
|
||||
if not hasattr(mc_obj, "__sqlite3_column__"):
|
||||
continue
|
||||
if len(columns) > 0:
|
||||
@@ -1167,6 +1177,7 @@ class DBMethods:
|
||||
order_by={},
|
||||
query_suffix=None,
|
||||
extra_query_data={},
|
||||
columns_list=None,
|
||||
):
|
||||
return firstOrNone(
|
||||
self.query(
|
||||
@@ -1176,10 +1187,11 @@ class DBMethods:
|
||||
order_by,
|
||||
query_suffix,
|
||||
extra_query_data,
|
||||
columns_list,
|
||||
)
|
||||
)
|
||||
|
||||
def updateDB(self, obj, cursor, constraints=[]):
|
||||
def updateDB(self, obj, cursor, constraints=[], columns_list=None):
|
||||
if cursor is None:
|
||||
raise ValueError("Cursor is null")
|
||||
if not hasattr(obj, "__tablename__"):
|
||||
@@ -1191,7 +1203,6 @@ class DBMethods:
|
||||
values = {}
|
||||
for mc in inspect.getmembers(obj.__class__):
|
||||
mc_name, mc_obj = mc
|
||||
|
||||
if not hasattr(mc_obj, "__sqlite3_column__"):
|
||||
continue
|
||||
|
||||
@@ -1203,7 +1214,8 @@ class DBMethods:
|
||||
if mc_name in constraints:
|
||||
values[mc_name] = m_obj
|
||||
continue
|
||||
|
||||
if columns_list is not None and mc_name not in columns_list:
|
||||
continue
|
||||
if len(values) > 0:
|
||||
query += ", "
|
||||
query += f"{mc_name} = :{mc_name}"
|
||||
|
||||
@@ -663,6 +663,7 @@ class Test(unittest.TestCase):
|
||||
ki.record_id = 1
|
||||
ki.address = "test1"
|
||||
ki.label = "test1"
|
||||
ki.note = "note1"
|
||||
try:
|
||||
db_test.add(ki, cursor, upsert=False)
|
||||
except Exception as e:
|
||||
@@ -670,6 +671,65 @@ class Test(unittest.TestCase):
|
||||
else:
|
||||
raise ValueError("Should have errored.")
|
||||
db_test.add(ki, cursor, upsert=True)
|
||||
|
||||
# Test columns list
|
||||
ki_test = db_test.queryOne(
|
||||
KnownIdentity,
|
||||
cursor,
|
||||
{"address": "test1"},
|
||||
columns_list=[
|
||||
"label",
|
||||
],
|
||||
)
|
||||
assert ki_test.label == "test1"
|
||||
assert ki_test.address is None
|
||||
|
||||
# Test updating partial row
|
||||
ki_test.label = "test2"
|
||||
ki_test.record_id = 1
|
||||
db_test.add(
|
||||
ki_test,
|
||||
cursor,
|
||||
upsert=True,
|
||||
columns_list=[
|
||||
"record_id",
|
||||
"label",
|
||||
],
|
||||
)
|
||||
ki_test = db_test.queryOne(KnownIdentity, cursor, {"address": "test1"})
|
||||
assert ki_test.record_id == 1
|
||||
assert ki_test.address == "test1"
|
||||
assert ki_test.label == "test2"
|
||||
assert ki_test.note == "note1"
|
||||
|
||||
ki_test.note = "test2"
|
||||
ki_test.label = "test3"
|
||||
|
||||
db_test.updateDB(
|
||||
ki_test,
|
||||
cursor,
|
||||
["record_id"],
|
||||
columns_list=[
|
||||
"label",
|
||||
],
|
||||
)
|
||||
ki_test = db_test.queryOne(KnownIdentity, cursor, {"address": "test1"})
|
||||
assert ki_test.record_id == 1
|
||||
assert ki_test.address == "test1"
|
||||
assert ki_test.label == "test3"
|
||||
assert ki_test.note == "note1"
|
||||
|
||||
# Test partially initialised object
|
||||
ki_test_p = KnownIdentity(
|
||||
_init_all_columns=False, record_id=1, label="test4"
|
||||
)
|
||||
db_test.add(ki_test_p, cursor, upsert=True)
|
||||
ki_test = db_test.queryOne(KnownIdentity, cursor, {"address": "test1"})
|
||||
assert ki_test.record_id == 1
|
||||
assert ki_test.address == "test1"
|
||||
assert ki_test.label == "test4"
|
||||
assert ki_test.note == "note1"
|
||||
|
||||
finally:
|
||||
db_test.closeDB(cursor)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user