tests: Add script test

This commit is contained in:
tecnovert
2023-02-14 23:34:01 +02:00
parent 9117e2b723
commit dc0bd147b8
11 changed files with 1379 additions and 106 deletions

View File

@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2019-2022 tecnovert
# Copyright (c) 2019-2023 tecnovert
# Distributed under the MIT software license, see the accompanying
# file LICENSE or http://www.opensource.org/licenses/mit-license.php.
@@ -867,7 +867,7 @@ class BasicSwap(BaseApp):
def updateIdentityBidState(self, session, address: str, bid) -> None:
identity_stats = session.query(KnownIdentity).filter_by(address=address).first()
if not identity_stats:
identity_stats = KnownIdentity(address=address, created_at=int(time.time()))
identity_stats = KnownIdentity(active_ind=1, address=address, created_at=int(time.time()))
if bid.state == BidStates.SWAP_COMPLETED:
if bid.was_sent:
@@ -1171,7 +1171,7 @@ class BasicSwap(BaseApp):
def buildNotificationsCache(self, session):
self._notifications_cache.clear()
q = session.execute(f'SELECT created_at, event_type, event_data FROM notifications WHERE active_ind=1 ORDER BY created_at ASC LIMIT {self._show_notifications}')
q = session.execute(f'SELECT created_at, event_type, event_data FROM notifications WHERE active_ind = 1 ORDER BY created_at ASC LIMIT {self._show_notifications}')
for entry in q:
self._notifications_cache[entry[0]] = (entry[1], json.loads(entry[2].decode('UTF-8')))
@@ -1181,6 +1181,71 @@ class BasicSwap(BaseApp):
rv.append((time.strftime('%d-%m-%y %H:%M:%S', time.localtime(k)), int(v[0]), v[1]))
return rv
def setIdentityData(self, filters, data):
address = filters['address']
ci = self.ci(Coins.PART)
ensure(ci.isValidAddress(address), 'Invalid identity address')
try:
now = int(time.time())
session = self.openSession()
q = session.execute(f'SELECT COUNT(*) FROM knownidentities WHERE address = "{address}"').first()
if q[0] < 1:
q = session.execute(f'INSERT INTO knownidentities (active_ind, address, created_at) VALUES (1, "{address}", {now})')
values = []
pattern = ''
if 'label' in data:
pattern += (', ' if pattern != '' else '')
pattern += 'label = "{}"'.format(data['label'])
values.append(address)
q = session.execute(f'UPDATE knownidentities SET {pattern} WHERE address = "{address}"')
finally:
self.closeSession(session)
def listIdentities(self, filters):
try:
session = self.openSession()
query_str = 'SELECT address, label, num_sent_bids_successful, num_recv_bids_successful, ' + \
' num_sent_bids_rejected, num_recv_bids_rejected, num_sent_bids_failed, num_recv_bids_failed ' + \
' FROM knownidentities ' + \
' WHERE active_ind = 1 '
address = filters.get('address', None)
if address is not None:
query_str += f' AND address = "{address}" '
sort_dir = filters.get('sort_dir', 'DESC').upper()
sort_by = filters.get('sort_by', 'created_at')
query_str += f' ORDER BY {sort_by} {sort_dir}'
limit = filters.get('limit', None)
if limit is not None:
query_str += f' LIMIT {limit}'
offset = filters.get('offset', None)
if offset is not None:
query_str += f' OFFSET {offset}'
q = session.execute(query_str)
rv = []
for row in q:
identity = {
'address': row[0],
'label': row[1],
'num_sent_bids_successful': zeroIfNone(row[2]),
'num_recv_bids_successful': zeroIfNone(row[3]),
'num_sent_bids_rejected': zeroIfNone(row[4]),
'num_recv_bids_rejected': zeroIfNone(row[5]),
'num_sent_bids_failed': zeroIfNone(row[6]),
'num_recv_bids_failed': zeroIfNone(row[7]),
}
rv.append(identity)
return rv
finally:
self.closeSession(session)
def vacuumDB(self):
try:
session = self.openSession()
@@ -2127,8 +2192,9 @@ class BasicSwap(BaseApp):
session = scoped_session(self.session_factory)
identity = session.query(KnownIdentity).filter_by(address=address).first()
if identity is None:
identity = KnownIdentity(address=address)
identity = KnownIdentity(active_ind=1, address=address)
identity.label = label
identity.updated_at = int(time.time())
session.add(identity)
session.commit()
finally:
@@ -5434,7 +5500,6 @@ class BasicSwap(BaseApp):
settings_changed = True
if settings_changed:
settings_path = os.path.join(self.data_dir, cfg.CONFIG_FILENAME)
settings_path_new = settings_path + '.new'
shutil.copyfile(settings_path, settings_path + '.last')
@@ -5838,8 +5903,9 @@ class BasicSwap(BaseApp):
filter_coin_to = filters.get('coin_to', None)
if filter_coin_to and filter_coin_to > -1:
q = q.filter(Offer.coin_to == int(filter_coin_to))
filter_include_sent = filters.get('include_sent', None)
if filter_include_sent and filter_include_sent is not True:
if filter_include_sent is not None and filter_include_sent is not True:
q = q.filter(Offer.was_sent == False) # noqa: E712
order_dir = filters.get('sort_dir', 'desc')
@@ -5874,15 +5940,14 @@ class BasicSwap(BaseApp):
session.remove()
self.mxDB.release()
def listBids(self, sent=False, offer_id=None, for_html=False, filters={}, with_identity_info=False):
def listBids(self, sent=False, offer_id=None, for_html=False, filters={}):
self.mxDB.acquire()
try:
rv = []
now = int(time.time())
session = scoped_session(self.session_factory)
identity_fields = ''
query_str = 'SELECT bids.created_at, bids.expire_at, bids.bid_id, bids.offer_id, bids.amount, bids.state, bids.was_received, tx1.state, tx2.state, offers.coin_from, bids.rate, bids.bid_addr {} FROM bids '.format(identity_fields) + \
query_str = 'SELECT bids.created_at, bids.expire_at, bids.bid_id, bids.offer_id, bids.amount, bids.state, bids.was_received, tx1.state, tx2.state, offers.coin_from, bids.rate, bids.bid_addr FROM bids ' + \
'LEFT JOIN offers ON offers.offer_id = bids.offer_id ' + \
'LEFT JOIN transactions AS tx1 ON tx1.bid_id = bids.bid_id AND tx1.tx_type = {} '.format(TxTypes.ITX) + \
'LEFT JOIN transactions AS tx2 ON tx2.bid_id = bids.bid_id AND tx2.tx_type = {} '.format(TxTypes.PTX)
@@ -5901,9 +5966,14 @@ class BasicSwap(BaseApp):
bid_state_ind = filters.get('bid_state_ind', -1)
if bid_state_ind != -1:
query_str += 'AND bids.state = {} '.format(bid_state_ind)
with_available_or_active = filters.get('with_available_or_active', False)
with_expired = filters.get('with_expired', True)
if with_expired is not True:
query_str += 'AND bids.expire_at > {} '.format(now)
if with_available_or_active:
query_str += 'AND (bids.state NOT IN ({}, {}, {}, {}, {}) AND (bids.state > {} OR bids.expire_at > {})) '.format(BidStates.SWAP_COMPLETED, BidStates.BID_ERROR, BidStates.BID_REJECTED, BidStates.SWAP_TIMEDOUT, BidStates.BID_ABANDONED, BidStates.BID_RECEIVED, now)
else:
if with_expired is not True:
query_str += 'AND bids.expire_at > {} '.format(now)
sort_dir = filters.get('sort_dir', 'DESC').upper()
sort_by = filters.get('sort_by', 'created_at')

View File

@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2019-2022 tecnovert
# Copyright (c) 2019-2023 tecnovert
# Distributed under the MIT software license, see the accompanying
# file LICENSE or http://www.opensource.org/licenses/mit-license.php.
@@ -697,11 +697,8 @@ class HttpThread(threading.Thread, HTTPServer):
data = response.read()
conn.close()
def stopped(self):
return self.stop_event.is_set()
def serve_forever(self):
while not self.stopped():
while not self.stop_event.is_set():
self.handle_request()
self.socket.close()

View File

@@ -1237,7 +1237,7 @@ class BTCInterface(CoinInterface):
def describeTx(self, tx_hex: str):
return self.rpc_callback('decoderawtransaction', [tx_hex])
def getSpendableBalance(self):
def getSpendableBalance(self) -> int:
return self.make_int(self.rpc_callback('getbalances')['mine']['trusted'])
def createUTXO(self, value_sats: int):

View File

@@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2020-2022 tecnovert
# Copyright (c) 2020-2023 tecnovert
# Distributed under the MIT software license, see the accompanying
# file LICENSE or http://www.opensource.org/licenses/mit-license.php.
@@ -77,10 +77,10 @@ class PARTInterface(BTCInterface):
# TODO: Double check
return True
def getNewAddress(self, use_segwit, label='swap_receive'):
def getNewAddress(self, use_segwit, label='swap_receive') -> str:
return self.rpc_callback('getnewaddress', [label])
def getNewStealthAddress(self, label='swap_stealth'):
def getNewStealthAddress(self, label='swap_stealth') -> str:
return self.rpc_callback('getnewstealthaddress', [label])
def haveSpentIndex(self):
@@ -105,7 +105,7 @@ class PARTInterface(BTCInterface):
def getScriptForPubkeyHash(self, pkh):
return CScript([OP_DUP, OP_HASH160, pkh, OP_EQUALVERIFY, OP_CHECKSIG])
def formatStealthAddress(self, scan_pubkey, spend_pubkey):
def formatStealthAddress(self, scan_pubkey, spend_pubkey) -> str:
prefix_byte = chainparams[self.coin_type()][self._network]['stealth_key_prefix']
return encodeStealthAddress(prefix_byte, scan_pubkey, spend_pubkey)
@@ -116,7 +116,7 @@ class PARTInterface(BTCInterface):
length += getWitnessElementLen(len(e) // 2) # hex -> bytes
return length
def getWalletRestoreHeight(self):
def getWalletRestoreHeight(self) -> int:
start_time = self.rpc_callback('getwalletinfo')['keypoololdest']
blockchaininfo = self.rpc_callback('getblockchaininfo')
@@ -131,6 +131,15 @@ class PARTInterface(BTCInterface):
block_header = self.rpc_callback('getblockheader', [block_hash])
return block_header['height']
def isValidAddress(self, address: str) -> bool:
try:
rv = self.rpc_callback('validateaddress', [address])
if rv['isvalid'] is True:
return True
except Exception as ex:
self._log.debug('validateaddress failed: {}'.format(address))
return False
class PARTInterfaceBlind(PARTInterface):
@staticmethod
@@ -622,7 +631,7 @@ class PARTInterfaceBlind(PARTInterface):
return bytes.fromhex(lock_refund_swipe_tx_hex)
def getSpendableBalance(self):
def getSpendableBalance(self) -> int:
return self.make_int(self.rpc_callback('getbalances')['mine']['blind_trusted'])
def publishBLockTx(self, vkbv, Kbs, output_amount, feerate, delay_for: int = 10, unlock_time: int = 0) -> bytes:
@@ -840,7 +849,7 @@ class PARTInterfaceAnon(PARTInterface):
rv = self.rpc_callback('sendtypeto', params)
return bytes.fromhex(rv['txid'])
def findTxnByHash(self, txid_hex):
def findTxnByHash(self, txid_hex: str):
# txindex is enabled for Particl
try:
@@ -854,5 +863,5 @@ class PARTInterfaceAnon(PARTInterface):
return None
def getSpendableBalance(self):
def getSpendableBalance(self) -> int:
return self.make_int(self.rpc_callback('getbalances')['mine']['anon_trusted'])

View File

@@ -9,6 +9,7 @@ import random
import urllib.parse
from .util import (
ensure,
toBool,
)
from .basicswap_util import (
@@ -38,7 +39,7 @@ from .ui.page_offers import postNewOffer
from .protocols.xmr_swap_1 import recoverNoScriptTxnWithKey, getChainBSplitKey
def getFormData(post_string, is_json):
def getFormData(post_string: str, is_json: bool):
if post_string == '':
raise ValueError('No post data')
if is_json:
@@ -138,6 +139,7 @@ def js_offers(self, url_split, post_string, is_json, sent=False) -> bytes:
return bytes(json.dumps(rv), 'UTF-8')
offer_id = bytes.fromhex(url_split[3])
with_extra_info = False
filters = {
'coin_from': -1,
'coin_to': -1,
@@ -174,12 +176,15 @@ def js_offers(self, url_split, post_string, is_json, sent=False) -> bytes:
if have_data_entry(post_data, 'include_sent'):
filters['include_sent'] = toBool(get_data_entry(post_data, 'include_sent'))
if have_data_entry(post_data, 'with_extra_info'):
with_extra_info = toBool(get_data_entry(post_data, 'with_extra_info'))
offers = swap_client.listOffers(sent, filters)
rv = []
for o in offers:
ci_from = swap_client.ci(o.coin_from)
ci_to = swap_client.ci(o.coin_to)
rv.append({
offer_data = {
'swap_type': o.swap_type,
'addr_from': o.addr_from,
'addr_to': o.addr_to,
@@ -191,8 +196,11 @@ def js_offers(self, url_split, post_string, is_json, sent=False) -> bytes:
'amount_from': ci_from.format_amount(o.amount_from),
'amount_to': ci_to.format_amount((o.amount_from * o.rate) // ci_from.COIN()),
'rate': ci_to.format_amount(o.rate),
})
}
if with_extra_info:
offer_data['amount_negotiable'] = o.amount_negotiable
offer_data['rate_negotiable'] = o.rate_negotiable
rv.append(offer_data)
return bytes(json.dumps(rv), 'UTF-8')
@@ -200,7 +208,60 @@ def js_sentoffers(self, url_split, post_string, is_json) -> bytes:
return js_offers(self, url_split, post_string, is_json, True)
def js_bids(self, url_split, post_string, is_json) -> bytes:
def parseBidFilters(post_data):
offer_id = None
filters = {}
if have_data_entry(post_data, 'offer_id'):
offer_id = bytes.fromhex(get_data_entry(post_data, 'offer_id'))
assert (len(offer_id) == 28)
if have_data_entry(post_data, 'sort_by'):
sort_by = get_data_entry(post_data, 'sort_by')
assert (sort_by in ['created_at', ]), 'Invalid sort by'
filters['sort_by'] = sort_by
if have_data_entry(post_data, 'sort_dir'):
sort_dir = get_data_entry(post_data, 'sort_dir')
assert (sort_dir in ['asc', 'desc']), 'Invalid sort dir'
filters['sort_dir'] = sort_dir
if have_data_entry(post_data, 'offset'):
filters['offset'] = int(get_data_entry(post_data, 'offset'))
if have_data_entry(post_data, 'limit'):
filters['limit'] = int(get_data_entry(post_data, 'limit'))
assert (filters['limit'] > 0 and filters['limit'] <= PAGE_LIMIT), 'Invalid limit'
if have_data_entry(post_data, 'with_available_or_active'):
filters['with_available_or_active'] = toBool(get_data_entry(post_data, 'with_available_or_active'))
elif have_data_entry(post_data, 'with_expired'):
filters['with_expired'] = toBool(get_data_entry(post_data, 'with_expired'))
if have_data_entry(post_data, 'with_extra_info'):
filters['with_extra_info'] = toBool(get_data_entry(post_data, 'with_extra_info'))
return offer_id, filters
def formatBids(swap_client, bids, filters) -> bytes:
with_extra_info = filters.get('with_extra_info', False)
rv = []
for b in bids:
bid_data = {
'bid_id': b[2].hex(),
'offer_id': b[3].hex(),
'created_at': b[0],
'expire_at': b[1],
'coin_from': b[9],
'amount_from': swap_client.ci(b[9]).format_amount(b[4]),
'bid_state': strBidState(b[5])
}
if with_extra_info:
bid_data['addr_from'] = b[11]
rv.append(bid_data)
return bytes(json.dumps(rv), 'UTF-8')
def js_bids(self, url_split, post_string: str, is_json: bool) -> bytes:
swap_client = self.server.swap_client
swap_client.checkSystemStatus()
if len(url_split) > 3:
@@ -281,22 +342,21 @@ def js_bids(self, url_split, post_string, is_json) -> bytes:
data = describeBid(swap_client, bid, xmr_swap, offer, xmr_offer, events, edit_bid, show_txns, for_api=True)
return bytes(json.dumps(data), 'UTF-8')
bids = swap_client.listBids()
return bytes(json.dumps([{
'bid_id': b[2].hex(),
'offer_id': b[3].hex(),
'created_at': b[0],
'expire_at': b[1],
'coin_from': b[9],
'amount_from': swap_client.ci(b[9]).format_amount(b[4]),
'bid_state': strBidState(b[5])
} for b in bids]), 'UTF-8')
post_data = getFormData(post_string, is_json)
offer_id, filters = parseBidFilters(post_data)
bids = swap_client.listBids(offer_id=offer_id, filters=filters)
return formatBids(swap_client, bids, filters)
def js_sentbids(self, url_split, post_string, is_json) -> bytes:
swap_client = self.server.swap_client
swap_client.checkSystemStatus()
return bytes(json.dumps(swap_client.listBids(sent=True)), 'UTF-8')
post_data = getFormData(post_string, is_json)
offer_id, filters = parseBidFilters(post_data)
bids = swap_client.listBids(sent=True, offer_id=offer_id, filters=filters)
return formatBids(swap_client, bids, filters)
def js_network(self, url_split, post_string, is_json) -> bytes:
@@ -318,7 +378,7 @@ def js_smsgaddresses(self, url_split, post_string, is_json) -> bytes:
swap_client = self.server.swap_client
swap_client.checkSystemStatus()
if len(url_split) > 3:
post_data = getFormData(post_string, is_json)
post_data = {} if post_string == '' else getFormData(post_string, is_json)
if url_split[3] == 'new':
addressnote = get_data_entry_or(post_data, 'addressnote', '')
new_addr, pubkey = swap_client.newSMSGAddress(addressnote=addressnote)
@@ -417,11 +477,54 @@ def js_generatenotification(self, url_split, post_string, is_json) -> bytes:
def js_notifications(self, url_split, post_string, is_json) -> bytes:
swap_client = self.server.swap_client
swap_client.checkSystemStatus()
swap_client.getNotifications()
return bytes(json.dumps(swap_client.getNotifications()), 'UTF-8')
def js_identities(self, url_split, post_string: str, is_json: bool) -> bytes:
swap_client = self.server.swap_client
swap_client.checkSystemStatus()
filters = {
'page_no': 1,
'limit': PAGE_LIMIT,
'sort_by': 'created_at',
'sort_dir': 'desc',
}
if len(url_split) > 3:
address = url_split[3]
filters['address'] = address
if post_string != '':
post_data = getFormData(post_string, is_json)
if have_data_entry(post_data, 'sort_by'):
sort_by = get_data_entry(post_data, 'sort_by')
assert (sort_by in ['created_at', 'rate']), 'Invalid sort by'
filters['sort_by'] = sort_by
if have_data_entry(post_data, 'sort_dir'):
sort_dir = get_data_entry(post_data, 'sort_dir')
assert (sort_dir in ['asc', 'desc']), 'Invalid sort dir'
filters['sort_dir'] = sort_dir
if have_data_entry(post_data, 'offset'):
filters['offset'] = int(get_data_entry(post_data, 'offset'))
if have_data_entry(post_data, 'limit'):
filters['limit'] = int(get_data_entry(post_data, 'limit'))
assert (filters['limit'] > 0 and filters['limit'] <= PAGE_LIMIT), 'Invalid limit'
set_data = {}
if have_data_entry(post_data, 'set_label'):
set_data['label'] = get_data_entry(post_data, 'set_label')
if set_data:
ensure('address' in filters, 'Must provide an address to modify data')
swap_client.setIdentityData(filters, set_data)
return bytes(json.dumps(swap_client.listIdentities(filters)), 'UTF-8')
def js_vacuumdb(self, url_split, post_string, is_json) -> bytes:
swap_client = self.server.swap_client
swap_client.checkSystemStatus()
@@ -528,6 +631,7 @@ pages = {
'rateslist': js_rates_list,
'generatenotification': js_generatenotification,
'notifications': js_notifications,
'identities': js_identities,
'vacuumdb': js_vacuumdb,
'getcoinseed': js_getcoinseed,
'setpassword': js_setpassword,