diff --git a/basicswap/db.py b/basicswap/db.py index bf374c5..aade15b 100644 --- a/basicswap/db.py +++ b/basicswap/db.py @@ -721,6 +721,52 @@ class DirectMessageRouteLink(Table): created_at = Column("integer") +class NetworkPortal(Table): + __tablename__ = "network_portals" + + def __init__( + self, time_start, time_valid, network_from, network_to, address_from, address_to + ): + super().__init__() + self.active_ind = 1 + self.time_start = time_start + self.time_valid = time_valid + self.network_from = network_from + self.network_to = network_to + self.address_from = address_from + self.address_to = address_to + + self.smsg_difficulty = 0x1EFFFFFF + + self.num_refreshes = 0 + self.messages_sent = 0 + self.responses_seen = 0 + self.time_last_used = 0 + self.num_issues = 0 + + record_id = Column("integer", primary_key=True, autoincrement=True) + active_ind = Column("integer") + own_portal = Column("integer") + + address_from = Column("string", unique=True) + address_to = Column("string") + + network_from = Column("integer") + network_to = Column("integer") + + time_start = Column("integer") + time_valid = Column("integer") + smsg_difficulty = Column("integer") + num_refreshes = Column("integer") + + messages_sent = Column("integer") + responses_seen = Column("integer") + time_last_used = Column("integer") + num_issues = Column("integer") + + created_at = Column("integer") + + def extract_schema() -> dict: g = globals().copy() tables = {} diff --git a/basicswap/network/bsx_network.py b/basicswap/network/bsx_network.py index e280d56..8cbed3d 100644 --- a/basicswap/network/bsx_network.py +++ b/basicswap/network/bsx_network.py @@ -17,7 +17,11 @@ from basicswap.basicswap_util import ( from basicswap.chainparams import ( Coins, ) -from basicswap.db import DirectMessageRoute +from basicswap.db import ( + DirectMessageRoute, + NetworkPortal, + SmsgAddress, +) from basicswap.messages_npb import ( MessagePortalOffer, MessagePortalSend, @@ -39,41 +43,6 @@ from basicswap.util.logging import LogCategories as LC from basicswap.util.smsg import smsgGetID -class NetworkPortal: - __slots__ = ( - "time_start", - "time_valid", - "network_from", - "network_to", - "address_from", - "address_to", - "smsg_difficulty", - "num_refreshes", - "messages_sent", - "responses_seen", - "time_last_used", - "num_issues", - ) - - def __init__( - self, time_start, time_valid, network_from, network_to, address_from, address_to - ): - self.time_start = time_start - self.time_valid = time_valid - self.network_from = network_from - self.network_to = network_to - self.address_from = address_from - self.address_to = address_to - - self.smsg_difficulty = 0x1EFFFFFF - - self.num_refreshes = 0 - self.messages_sent = 0 - self.responses_seen = 0 - self.time_last_used = 0 - self.num_issues = 0 - - def networkTypeToID(type: str) -> int: # TODO: remove if type == "smsg": @@ -114,6 +83,7 @@ class BSXNetwork: "check_bridges_seconds", 10, 1, 10 * 60 ) self._last_checked_bridges = 0 + self._forget_portals_after = 86400 * 7 self._zmq_queue_enabled = self.settings.get("zmq_queue_enabled", True) self._poll_smsg = self.settings.get("poll_smsg", False) @@ -219,8 +189,29 @@ class BSXNetwork: # TODO: Ensure smsg is enabled for the active wallet. if self._smsg_plaintext_version >= 2: - ro = self.callrpc("smsgoptions", ["set", "addReceivedPubkeys", False]) - self.log.debug("smsgoptions {ro}") + self.callrpc("smsgoptions", ["set", "addReceivedPubkeys", False]) + + now: int = self.getTime() + + # Load portal data + try: + cursor = self.openDB() + portals = self.query( + NetworkPortal, + cursor, + ) + for portal_data in portals: + if portal_data.time_start + portal_data.time_valid < now: + # Database records are kept longer + continue + + if portal_data.own_portal == 1: + self.own_portals.add(portal_data) + else: + self.known_portals.add(portal_data) + + finally: + self.closeDB(cursor) def add_connection(self, host, port, peer_pubkey): self.log.info(f"add_connection {host} {port} {peer_pubkey.hex()}.") @@ -739,13 +730,15 @@ class BSXNetwork: addr_portal: str = self.prepareSMSGAddress( None, AddressTypes.PORTAL_LOCAL, cursor ) + portal = NetworkPortal( + now, 30 * 60, network_from_id, network_to_id, addr_portal, addr_to + ) + portal.created_at = now + portal.own_portal = True + portal.record_id = self.add(portal, cursor) finally: self.closeDB(cursor) - portal = NetworkPortal( - now, 30 * 60, network_from_id, network_to_id, addr_portal, addr_to - ) - smsg_difficulty: int = 0x1EFFFFFF if self._have_smsg_rpc: smsg_difficulty = self.callrpc("smsggetdifficulty", [-1, True]) @@ -918,15 +911,51 @@ class BSXNetwork: self.log.warning("Offered portal is expired.") return - received_portal = NetworkPortal( - time_start, - portal_data.time_valid, - portal_data.network_type_from, - portal_data.network_type_to, - addr_portal, - portal_data.portal_address_to, - ) - received_portal.smsg_difficulty = portal_data.smsg_difficulty + cursor = self.openDB() + try: + received_portal = self.queryOne( + NetworkPortal, + cursor, + { + "address_from": addr_portal, + }, + ) + if received_portal is None: + received_portal = NetworkPortal( + time_start, + portal_data.time_valid, + portal_data.network_type_from, + portal_data.network_type_to, + addr_portal, + portal_data.portal_address_to, + ) + received_portal.created_at = now + else: + received_portal.num_refreshes += 1 + received_portal.smsg_difficulty = portal_data.smsg_difficulty + received_portal.time_start = time_start + received_portal.time_valid = portal_data.time_valid + + self.add(received_portal, cursor, upsert=True) + + address_record = self.queryOne( + SmsgAddress, + cursor, + { + "addr": addr_portal, + }, + ) + if address_record is None or len(address_record.pubkey) < 33: + if address_record is None: + address_record = SmsgAddress() + address_record.active_ind = 1 + address_record.created_at = now + address_record.addr = received_portal.address_from + address_record.use_type = AddressTypes.PORTAL + address_record.pubkey = msg["pubkey_from"] + self.add(address_record, cursor, upsert=True) + finally: + self.closeDB(cursor) if received_portal.network_from not in self.known_portals: self.known_portals[received_portal.network_from] = {} @@ -939,7 +968,7 @@ class BSXNetwork: for portal in portals_from_to: if portal.address_from == received_portal.address_from: - portal.num_refreshes += 1 + portal.num_refreshes = received_portal.num_refreshes portal.time_start = received_portal.time_start portal.time_valid = received_portal.time_valid portal.smsg_difficulty = received_portal.smsg_difficulty @@ -947,31 +976,6 @@ class BSXNetwork: portals_from_to.append(received_portal) - try: - cursor = self.openDB() - query: str = "SELECT addr_id FROM smsgaddresses WHERE addr = :addr" - addresses = cursor.execute( - query, {"addr": received_portal.address_from} - ).fetchall() - if len(addresses) < 1: - pk_address_from: str = msg["pubkey_from"] - query: str = ( - "INSERT INTO smsgaddresses (active_ind, created_at, addr, pubkey, use_type) VALUES (:active_ind, :created_at, :addr, :pubkey, :use_type)" - ) - cursor.execute( - query, - { - "active_ind": 1, - "created_at": now, - "addr": received_portal.address_from, - "pubkey": pk_address_from, - "use_type": AddressTypes.PORTAL, - }, - ) - - finally: - self.closeDB(cursor) - def processPortalMessage(self, msg): msg_id = msg["msgid"] self.log.debug(f"Processing network portal message {msg_id}.") @@ -1001,6 +1005,21 @@ class BSXNetwork: else: raise ValueError(f"Unknown network ID {network_to_id}") + portal.messages_sent += 1 + cursor = self.openDB() + try: + portal_record = self.queryOne( + NetworkPortal, + cursor, + { + "address_from": portal.address_from, + }, + ) + portal_record.messages_sent = portal.messages_sent + self.add(portal_record, cursor, upsert=True) + finally: + self.closeDB(cursor) + def updateNetworkBridges(self, now: int) -> None: for network in self.active_networks: network_from_id: int = networkTypeToID(network["type"]) @@ -1016,6 +1035,13 @@ class BSXNetwork: else: if portal.time_start + portal.time_valid <= now - (5 * 60): self.refreshPortal(portal) + + cursor = self.openDB() + try: + query: str = "DELETE FROM network_portals WHERE time_start < :time_delete" + cursor.execute(query, {"time_delete": now - self._forget_portals_after}) + finally: + self.closeDB(cursor) self._last_checked_bridges = now def updateNetwork(self) -> None: