diff --git a/basicswap/basicswap.py b/basicswap/basicswap.py index 53bb162..edc79db 100644 --- a/basicswap/basicswap.py +++ b/basicswap/basicswap.py @@ -148,14 +148,7 @@ from .db import ( XmrSwap, ) from .wallet_manager import WalletManager -from .db_wallet import ( - WalletAddress, - WalletLockedUTXO, - WalletPendingTx, - WalletState, - WalletTxCache, - WalletWatchOnly, -) + from .explorers import ( ExplorerInsight, ExplorerBitAps, @@ -614,15 +607,7 @@ class BasicSwap(BaseApp, BSXNetwork, UIApp): if not db_exists: self.log.info("First run") - wallet_tables = [ - WalletAddress, - WalletLockedUTXO, - WalletPendingTx, - WalletState, - WalletTxCache, - WalletWatchOnly, - ] - create_db(self.sqlite_file, self.log, extra_tables=wallet_tables) + create_db(self.sqlite_file, self.log) cursor = self.openDB() try: diff --git a/basicswap/db.py b/basicswap/db.py index 9b49fae..330667f 100644 --- a/basicswap/db.py +++ b/basicswap/db.py @@ -772,8 +772,8 @@ class NetworkPortal(Table): created_at = Column("integer") -def extract_schema(extra_tables: list = None) -> dict: - g = globals().copy() +def extract_schema(extra_tables: list = None, input_globals: dict = None) -> dict: + g = (input_globals if input_globals else globals()).copy() if extra_tables: for table_class in extra_tables: @@ -898,18 +898,21 @@ def create_table(c, table_name, table) -> None: c.execute(query) -def create_db_(con, log, extra_tables: list = None) -> None: - db_schema = extract_schema(extra_tables=extra_tables) +def create_db_(con, log) -> None: + from .db_wallet import extract_wallet_schema + + db_schema = extract_schema() + db_schema.update(extract_wallet_schema()) c = con.cursor() for table_name, table in db_schema.items(): create_table(c, table_name, table) -def create_db(db_path: str, log, extra_tables: list = None) -> None: +def create_db(db_path: str, log) -> None: con = None try: con = sqlite3.connect(db_path) - create_db_(con, log, extra_tables=extra_tables) + create_db_(con, log) con.commit() finally: if con: diff --git a/basicswap/db_upgrades.py b/basicswap/db_upgrades.py index cae8782..c7c29ce 100644 --- a/basicswap/db_upgrades.py +++ b/basicswap/db_upgrades.py @@ -18,14 +18,7 @@ from .db import ( extract_schema, ) -from .db_wallet import ( - WalletAddress, - WalletLockedUTXO, - WalletPendingTx, - WalletState, - WalletTxCache, - WalletWatchOnly, -) +from .db_wallet import extract_wallet_schema from .basicswap_util import ( BidStates, @@ -277,15 +270,8 @@ def upgradeDatabase(self, db_version: int): ), ] - wallet_tables = [ - WalletAddress, - WalletLockedUTXO, - WalletPendingTx, - WalletState, - WalletTxCache, - WalletWatchOnly, - ] - expect_schema = extract_schema(extra_tables=wallet_tables) + expect_schema = extract_schema() + expect_schema.update(extract_wallet_schema()) have_tables = {} try: cursor = self.openDB() diff --git a/basicswap/db_wallet.py b/basicswap/db_wallet.py index 1642b41..dd10577 100644 --- a/basicswap/db_wallet.py +++ b/basicswap/db_wallet.py @@ -5,7 +5,7 @@ # file LICENSE or http://www.opensource.org/licenses/mit-license.php. -from .db import Column, Index, Table, UniqueConstraint +from .db import Column, Index, Table, UniqueConstraint, extract_schema class WalletAddress(Table): @@ -120,3 +120,7 @@ class WalletPendingTx(Table): __unique_1__ = UniqueConstraint("coin_type", "txid") __index_pending_coin__ = Index("idx_pending_coin", "coin_type", "confirmed_at") + + +def extract_wallet_schema() -> dict: + return extract_schema(input_globals=globals())