Fix getLinkedMessageId and validateSwapType

This commit is contained in:
tecnovert
2023-08-02 13:57:29 +02:00
parent a13a5d4bf6
commit 8f4b962285
5 changed files with 56 additions and 16 deletions

View File

@@ -266,7 +266,7 @@ class BasicSwap(BaseApp):
# TODO: Set dynamically
self.scriptless_coins = (Coins.XMR, Coins.PART_ANON)
self.adaptor_swap_only_coins = self.scriptless_coins + (Coins.PART_BLIND, )
self.secret_hash_swap_only_coins = (Coins.PIVX, Coins.DASH, Coins.FIRO, Coins.NMC)
self.secret_hash_swap_only_coins = (Coins.PIVX, Coins.DASH, Coins.FIRO, Coins.NMC) # Coins without segwit
# TODO: Adjust ranges
self.min_delay_event = self.settings.get('min_delay_event', 10)
@@ -1145,14 +1145,16 @@ class BasicSwap(BaseApp):
return bytes.fromhex(ro['msgid'])
def validateSwapType(self, coin_from, coin_to, swap_type):
if (coin_from in self.adaptor_swap_only_coins or coin_to in self.adaptor_swap_only_coins) and swap_type != SwapTypes.XMR_SWAP:
raise ValueError('Invalid swap type for: {} -> {}'.format(coin_from.name, coin_to.name))
if swap_type == SwapTypes.XMR_SWAP:
if (coin_from in self.secret_hash_swap_only_coins or coin_to in self.secret_hash_swap_only_coins):
reverse_bid: bool = coin_from in self.scriptless_coins
itx_coin = coin_to if reverse_bid else coin_from
if (itx_coin in self.secret_hash_swap_only_coins):
raise ValueError('Invalid swap type for: {} -> {}'.format(coin_from.name, coin_to.name))
if (coin_from in self.scriptless_coins and coin_to in self.scriptless_coins):
raise ValueError('Invalid swap type for: {} -> {}'.format(coin_from.name, coin_to.name))
else:
if coin_from in self.adaptor_swap_only_coins or coin_to in self.adaptor_swap_only_coins:
raise ValueError('Invalid swap type for: {} -> {}'.format(coin_from.name, coin_to.name))
def notify(self, event_type, event_data, session=None) -> None:
show_event = event_type not in self._disabled_notification_types
@@ -2066,8 +2068,8 @@ class BasicSwap(BaseApp):
def getLinkedMessageId(self, linked_type: int, linked_id: int, msg_type: int, msg_sequence: int = 0, session=None) -> bytes:
try:
use_session = self.openSession(session)
q = session.execute('SELECT msg_id FROM message_links WHERE linked_type = :linked_type AND linked_id = :linked_id AND msg_type = :msg_type AND msg_sequence = :msg_sequence',
{'linked_type': linked_type, 'linked_id': linked_id, 'msg_type': msg_type, 'msg_sequence': msg_sequence}).first()
q = use_session.execute('SELECT msg_id FROM message_links WHERE linked_type = :linked_type AND linked_id = :linked_id AND msg_type = :msg_type AND msg_sequence = :msg_sequence',
{'linked_type': linked_type, 'linked_id': linked_id, 'msg_type': msg_type, 'msg_sequence': msg_sequence}).first()
return q[0]
finally:
if session is None: