diff --git a/basicswap/interface/xmr.py b/basicswap/interface/xmr.py index f6bd3c4..cf6ce73 100644 --- a/basicswap/interface/xmr.py +++ b/basicswap/interface/xmr.py @@ -7,6 +7,7 @@ # file LICENSE or http://www.opensource.org/licenses/mit-license.php. import logging +import os import basicswap.contrib.ed25519_fast as edf import basicswap.ed25519_fast_util as edu @@ -204,16 +205,51 @@ class XMRInterface(CoinInterface): except Exception as e: if "no connection to daemon" in str(e): self._log.debug(f"{self.coin_name()} {e}") - return # bypass refresh error to allow startup with a busy daemon - if "invalid signature" in str(e): - self._log.debug(f"{self.coin_name()} wallet is corrupt") - raise - - try: - self.rpc_wallet("close_wallet") - self._log.debug(f"Closing {self.coin_name()} wallet") - except Exception as e: # noqa: F841 - pass + return # Bypass refresh error to allow startup with a busy daemon + if any( + x in str(e) + for x in ( + "invalid signature", + "std::bad_alloc", + "basic_string::_M_replace_aux", + ) + ): + self._log.error(f"{self.coin_name()} wallet is corrupt.") + chain_client_settings = self._sc.getChainClientSettings( + self.coin_type() + ) # basicswap.json + if chain_client_settings.get("manage_wallet_daemon", False): + self._log.info(f"Renaming {self.coin_name()} wallet cache file.") + walletpath = os.path.join( + chain_client_settings.get("datadir", "none"), + "wallets", + filename, + ) + if not os.path.isfile(walletpath): + self._log.warning( + f"Could not find {self.coin_name()} wallet cache file." + ) + raise + bkp_path = walletpath + ".corrupt" + for i in range(100): + if not os.path.exists(bkp_path): + break + bkp_path = walletpath + f".corrupt{i}" + if os.path.exists(bkp_path): + self._log.error( + f"Could not find backup path for {self.coin_name()} wallet." + ) + raise + os.rename(walletpath, bkp_path) + # Drop through to open_wallet + else: + raise + else: + try: + self.rpc_wallet("close_wallet") + self._log.debug(f"Closing {self.coin_name()} wallet") + except Exception as e: # noqa: F841 + pass self.rpc_wallet("open_wallet", params) self._log.debug(f"Attempting to open {self.coin_name()} wallet") diff --git a/tests/basicswap/test_xmr.py b/tests/basicswap/test_xmr.py index a85c7db..26aa476 100644 --- a/tests/basicswap/test_xmr.py +++ b/tests/basicswap/test_xmr.py @@ -1089,6 +1089,62 @@ class Test(BaseTest): def notest_00_delay(self): test_delay_event.wait(100000) + def test_007_corrupt_wallet(self): + logging.info(f"---------- Test {Coins.XMR.name} corrupt wallet") + swap_clients = self.swap_clients + ci = swap_clients[0].ci(Coins.XMR) + + chain_client_settings = swap_clients[0].getChainClientSettings(Coins.XMR) + wallet_name = chain_client_settings["wallet_name"] + try: + ci.rpc_wallet("close_wallet") + logging.info(f"Closing {ci.coin_name()} wallet") + except Exception as e: + logging.info(f"Closing {ci.coin_name()} wallet failed with: {e}") + + walletpath = os.path.join( + chain_client_settings.get("datadir", "none"), "wallets", wallet_name + ) + wallet_cache_bytes = os.path.getsize(walletpath) + logging.info(f"[rm] wallet_cache_bytes {wallet_cache_bytes}") + logging.info(f"[rm] walletpath {walletpath}") + shutil.copy(walletpath, walletpath + ".orig") + + # Failed to open wallet : basic_string::_M_replace_aux + # with open(walletpath, "wb") as fp: + # fp.write(os.urandom(wallet_cache_bytes)) + + # Failed to open wallet : std::bad_alloc + with open(walletpath, "ab") as fp: + fp.write(os.urandom(1000)) + + # TODO: Get "invalid signature" + + try: + ci.openWallet(wallet_name) + except Exception as e: + logging.info(f"Opening {ci.coin_name()} wallet failed with: {e}") + assert any( + x in str(e) + for x in ( + "invalid signature", + "std::bad_alloc", + "basic_string::_M_replace_aux", + ) + ) + else: + raise ValueError("Should fail!") + + try: + chain_client_settings["manage_wallet_daemon"] = True + try: + ci.openWallet(wallet_name) + except Exception as e: + logging.info(f"Opening {ci.coin_name()} wallet failed with: {e}") + raise + finally: + chain_client_settings["manage_wallet_daemon"] = False + def test_010_txn_size(self): logging.info("---------- Test {} txn_size".format(Coins.PART))