From 5d84d54e6f6967242efa0409780a10ba894ea5d8 Mon Sep 17 00:00:00 2001 From: tecnovert Date: Sat, 31 Oct 2020 22:08:30 +0200 Subject: [PATCH] Replace makeInt with make_int --- basicswap/basicswap.py | 38 +- basicswap/chainparams.py | 24 +- basicswap/contrib/MoneroPy/__init__.py | 0 basicswap/contrib/MoneroPy/base58.py | 168 ++ basicswap/contrib/ellipticcurve.py | 486 +++++ basicswap/contrib/test_framework/__init__.py | 0 basicswap/contrib/test_framework/address.py | 158 ++ basicswap/contrib/test_framework/authproxy.py | 204 ++ basicswap/contrib/test_framework/coverage.py | 109 + basicswap/contrib/test_framework/key.py | 393 ++++ basicswap/contrib/test_framework/messages.py | 1756 +++++++++++++++++ basicswap/contrib/test_framework/script.py | 740 +++++++ .../contrib/test_framework/segwit_addr.py | 107 + basicswap/contrib/test_framework/siphash.py | 63 + basicswap/contrib/test_framework/util.py | 619 ++++++ .../contrib/test_framework/wallet_util.py | 131 ++ basicswap/db.py | 2 +- basicswap/ecc_util.py | 222 +++ basicswap/http_server.py | 4 +- basicswap/interface_btc.py | 805 ++++++++ basicswap/interface_ltc.py | 12 + basicswap/interface_part.py | 28 + basicswap/interface_xmr.py | 230 +++ basicswap/rpc.py | 15 +- basicswap/rpc_xmr.py | 85 + basicswap/util.py | 109 +- basicswap/util_xmr.py | 17 + setup.py | 1 + tests/basicswap/__init__.py | 2 + tests/basicswap/common.py | 14 + tests/basicswap/test_other.py | 67 +- tests/basicswap/test_run.py | 12 +- tests/basicswap/test_xmr.py | 246 +++ 33 files changed, 6838 insertions(+), 29 deletions(-) create mode 100644 basicswap/contrib/MoneroPy/__init__.py create mode 100644 basicswap/contrib/MoneroPy/base58.py create mode 100644 basicswap/contrib/ellipticcurve.py create mode 100644 basicswap/contrib/test_framework/__init__.py create mode 100644 basicswap/contrib/test_framework/address.py create mode 100644 basicswap/contrib/test_framework/authproxy.py create mode 100644 basicswap/contrib/test_framework/coverage.py create mode 100644 basicswap/contrib/test_framework/key.py create mode 100755 basicswap/contrib/test_framework/messages.py create mode 100644 basicswap/contrib/test_framework/script.py create mode 100644 basicswap/contrib/test_framework/segwit_addr.py create mode 100644 basicswap/contrib/test_framework/siphash.py create mode 100644 basicswap/contrib/test_framework/util.py create mode 100755 basicswap/contrib/test_framework/wallet_util.py create mode 100644 basicswap/ecc_util.py create mode 100644 basicswap/interface_btc.py create mode 100644 basicswap/interface_ltc.py create mode 100644 basicswap/interface_part.py create mode 100644 basicswap/interface_xmr.py create mode 100644 basicswap/rpc_xmr.py create mode 100644 basicswap/util_xmr.py create mode 100644 tests/basicswap/common.py create mode 100644 tests/basicswap/test_xmr.py diff --git a/basicswap/basicswap.py b/basicswap/basicswap.py index 3fa699f..dc1284b 100644 --- a/basicswap/basicswap.py +++ b/basicswap/basicswap.py @@ -19,6 +19,11 @@ import secrets from sqlalchemy.orm import sessionmaker, scoped_session from enum import IntEnum, auto +from .interface_part import PARTInterface +from .interface_btc import BTCInterface +from .interface_ltc import LTCInterface +from .interface_xmr import XMRInterface + from . import __version__ from .util import ( COIN, @@ -31,7 +36,7 @@ from .util import ( decodeWif, toWIF, getKeyID, - makeInt, + make_int, ) from .chainparams import ( chainparams, @@ -417,6 +422,27 @@ class BasicSwap(BaseApp): 'chain_lookups': chain_client_settings.get('chain_lookups', 'local'), } + if self.coin_clients[coin]['connection_type'] == 'rpc': + if coin == Coins.XMR: + self.coin_clients[coin]['walletrpcport'] = chain_client_settings.get('walletrpcport', chainparams[coin][self.chain]['walletrpcport']) + if 'walletrpcpassword' in chain_client_settings: + self.coin_clients[coin]['walletrpcauth'] = chain_client_settings['walletrpcuser'] + ':' + chain_client_settings['walletrpcpassword'] + else: + raise ValueError('Missing XMR wallet rpc credentials.') + self.coin_clients[coin]['interface'] = self.createInterface(coin) + + def createInterface(self, coin): + if coin == Coins.PART: + return PARTInterface(self.coin_clients[coin]) + elif coin == Coins.BTC: + return BTCInterface(self.coin_clients[coin]) + elif coin == Coins.LTC: + return LTCInterface(self.coin_clients[coin]) + elif coin == Coins.XMR: + return XMRInterface(self.coin_clients[coin]) + else: + raise ValueError('Unknown coin type') + def setCoinRunParams(self, coin): cc = self.coin_clients[coin] if cc['connection_type'] == 'rpc' and cc['rpcauth'] is None: @@ -1699,7 +1725,7 @@ class BasicSwap(BaseApp): continue # Verify amount if assert_amount: - assert(makeInt(o['amount']) == int(assert_amount)), 'Incorrect output amount in txn {}: {} != {}.'.format(assert_txid, makeInt(o['amount']), int(assert_amount)) + assert(make_int(o['amount']) == int(assert_amount)), 'Incorrect output amount in txn {}: {} != {}.'.format(assert_txid, make_int(o['amount']), int(assert_amount)) if not sum_output: if o['height'] > 0: @@ -1711,7 +1737,7 @@ class BasicSwap(BaseApp): 'index': o['vout'], 'height': o['height'], 'n_conf': n_conf, - 'value': makeInt(o['amount']), + 'value': make_int(o['amount']), } else: sum_unspent += o['amount'] * COIN @@ -1744,7 +1770,7 @@ class BasicSwap(BaseApp): # Verify amount vout = getVoutByAddress(initiate_txn, p2sh) - out_value = makeInt(initiate_txn['vout'][vout]['value']) + out_value = make_int(initiate_txn['vout'][vout]['value']) assert(out_value == int(bid.amount)), 'Incorrect output amount in initiate txn {}: {} != {}.'.format(initiate_txnid_hex, out_value, int(bid.amount)) bid.initiate_tx.conf = initiate_txn['confirmations'] @@ -2442,8 +2468,8 @@ class BasicSwap(BaseApp): 'deposit_address': self.getCachedAddressForCoin(coin), 'name': chainparams[coin]['name'].capitalize(), 'blocks': blockchaininfo['blocks'], - 'balance': format8(makeInt(walletinfo['balance'])), - 'unconfirmed': format8(makeInt(walletinfo.get('unconfirmed_balance'))), + 'balance': format8(make_int(walletinfo['balance'])), + 'unconfirmed': format8(make_int(walletinfo.get('unconfirmed_balance'))), 'synced': '{0:.2f}'.format(round(blockchaininfo['verificationprogress'], 2)), } return rv diff --git a/basicswap/chainparams.py b/basicswap/chainparams.py index 2adf2ee..d778f10 100644 --- a/basicswap/chainparams.py +++ b/basicswap/chainparams.py @@ -14,8 +14,9 @@ class Coins(IntEnum): PART = 1 BTC = 2 LTC = 3 - # DCR = 4 + #DCR = 4 NMC = 5 + XMR = 6 chainparams = { @@ -156,5 +157,26 @@ chainparams = { 'min_amount': 1000, 'max_amount': 100000 * COIN, } + }, + Coins.XMR: { + 'name': 'monero', + 'ticker': 'XMR', + 'client': 'xmr', + 'mainnet': { + 'rpcport': 18081, + 'walletrpcport': 18082, + }, + 'testnet': { + 'rpcport': 28081, + 'walletrpcport': 28082, + }, + 'regtest': { + 'rpcport': 18081, + 'walletrpcport': 18082, + } } } + +class CoinInterface: + pass + diff --git a/basicswap/contrib/MoneroPy/__init__.py b/basicswap/contrib/MoneroPy/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/basicswap/contrib/MoneroPy/base58.py b/basicswap/contrib/MoneroPy/base58.py new file mode 100644 index 0000000..83424de --- /dev/null +++ b/basicswap/contrib/MoneroPy/base58.py @@ -0,0 +1,168 @@ +# MoneroPy - A python toolbox for Monero +# Copyright (C) 2016 The MoneroPy Developers. +# +# MoneroPy is released under the BSD 3-Clause license. Use and redistribution of +# this software is subject to the license terms in the LICENSE file found in the +# top-level directory of this distribution. + +__alphabet = [ord(s) for s in '123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz'] +__b58base = 58 +__UINT64MAX = 2**64 +__encodedBlockSizes = [0, 2, 3, 5, 6, 7, 9, 10, 11] +__fullBlockSize = 8 +__fullEncodedBlockSize = 11 + +def _hexToBin(hex): + if len(hex) % 2 != 0: + return "Hex string has invalid length!" + return [int(hex[i*2:i*2+2], 16) for i in range(len(hex)//2)] + +def _binToHex(bin): + return "".join([("0" + hex(int(bin[i])).split('x')[1])[-2:] for i in range(len(bin))]) + +def _strToBin(a): + return [ord(s) for s in a] + +def _binToStr(bin): + return ''.join([chr(bin[i]) for i in range(len(bin))]) + +def _uint8be_to_64(data): + l_data = len(data) + + if l_data < 1 or l_data > 8: + return "Invalid input length" + + res = 0 + switch = 9 - l_data + for i in range(l_data): + if switch == 1: + res = res << 8 | data[i] + elif switch == 2: + res = res << 8 | data[i] + elif switch == 3: + res = res << 8 | data[i] + elif switch == 4: + res = res << 8 | data[i] + elif switch == 5: + res = res << 8 | data[i] + elif switch == 6: + res = res << 8 | data[i] + elif switch == 7: + res = res << 8 | data[i] + elif switch == 8: + res = res << 8 | data[i] + else: + return "Impossible condition" + return res + +def _uint64_to_8be(num, size): + res = [0] * size; + if size < 1 or size > 8: + return "Invalid input length" + + twopow8 = 2**8 + for i in range(size-1,-1,-1): + res[i] = num % twopow8 + num = num // twopow8 + + return res + +def encode_block(data, buf, index): + l_data = len(data) + + if l_data < 1 or l_data > __fullEncodedBlockSize: + return "Invalid block length: " + str(l_data) + + num = _uint8be_to_64(data) + i = __encodedBlockSizes[l_data] - 1 + + while num > 0: + remainder = num % __b58base + num = num // __b58base + buf[index+i] = __alphabet[remainder]; + i -= 1 + + return buf + +def encode(hex): + '''Encode hexadecimal string as base58 (ex: encoding a Monero address).''' + data = _hexToBin(hex) + l_data = len(data) + + if l_data == 0: + return "" + + full_block_count = l_data // __fullBlockSize + last_block_size = l_data % __fullBlockSize + res_size = full_block_count * __fullEncodedBlockSize + __encodedBlockSizes[last_block_size] + + res = [0] * res_size + for i in range(res_size): + res[i] = __alphabet[0] + + for i in range(full_block_count): + res = encode_block(data[(i*__fullBlockSize):(i*__fullBlockSize+__fullBlockSize)], res, i * __fullEncodedBlockSize) + + if last_block_size > 0: + res = encode_block(data[(full_block_count*__fullBlockSize):(full_block_count*__fullBlockSize+last_block_size)], res, full_block_count * __fullEncodedBlockSize) + + return _binToStr(res) + +def decode_block(data, buf, index): + l_data = len(data) + + if l_data < 1 or l_data > __fullEncodedBlockSize: + return "Invalid block length: " + l_data + + res_size = __encodedBlockSizes.index(l_data) + if res_size <= 0: + return "Invalid block size" + + res_num = 0 + order = 1 + for i in range(l_data-1, -1, -1): + digit = __alphabet.index(data[i]) + if digit < 0: + return "Invalid symbol" + + product = order * digit + res_num + if product > __UINT64MAX: + return "Overflow" + + res_num = product + order = order * __b58base + + if res_size < __fullBlockSize and 2**(8 * res_size) <= res_num: + return "Overflow 2" + + tmp_buf = _uint64_to_8be(res_num, res_size) + for i in range(len(tmp_buf)): + buf[i+index] = tmp_buf[i] + + return buf + +def decode(enc): + '''Decode a base58 string (ex: a Monero address) into hexidecimal form.''' + enc = _strToBin(enc) + l_enc = len(enc) + + if l_enc == 0: + return "" + + full_block_count = l_enc // __fullEncodedBlockSize + last_block_size = l_enc % __fullEncodedBlockSize + last_block_decoded_size = __encodedBlockSizes.index(last_block_size) + + if last_block_decoded_size < 0: + return "Invalid encoded length" + + data_size = full_block_count * __fullBlockSize + last_block_decoded_size + + data = [0] * data_size + for i in range(full_block_count): + data = decode_block(enc[(i*__fullEncodedBlockSize):(i*__fullEncodedBlockSize+__fullEncodedBlockSize)], data, i * __fullBlockSize) + + if last_block_size > 0: + data = decode_block(enc[(full_block_count*__fullEncodedBlockSize):(full_block_count*__fullEncodedBlockSize+last_block_size)], data, full_block_count * __fullBlockSize) + + return _binToHex(data) diff --git a/basicswap/contrib/ellipticcurve.py b/basicswap/contrib/ellipticcurve.py new file mode 100644 index 0000000..8a58166 --- /dev/null +++ b/basicswap/contrib/ellipticcurve.py @@ -0,0 +1,486 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# +# Implementation of elliptic curves, for cryptographic applications. +# +# This module doesn't provide any way to choose a random elliptic +# curve, nor to verify that an elliptic curve was chosen randomly, +# because one can simply use NIST's standard curves. +# +# Notes from X9.62-1998 (draft): +# Nomenclature: +# - Q is a public key. +# The "Elliptic Curve Domain Parameters" include: +# - q is the "field size", which in our case equals p. +# - p is a big prime. +# - G is a point of prime order (5.1.1.1). +# - n is the order of G (5.1.1.1). +# Public-key validation (5.2.2): +# - Verify that Q is not the point at infinity. +# - Verify that X_Q and Y_Q are in [0,p-1]. +# - Verify that Q is on the curve. +# - Verify that nQ is the point at infinity. +# Signature generation (5.3): +# - Pick random k from [1,n-1]. +# Signature checking (5.4.2): +# - Verify that r and s are in [1,n-1]. +# +# Version of 2008.11.25. +# +# Revision history: +# 2005.12.31 - Initial version. +# 2008.11.25 - Change CurveFp.is_on to contains_point. +# +# Written in 2005 by Peter Pearson and placed in the public domain. + +def inverse_mod(a, m): + """Inverse of a mod m.""" + + if a < 0 or m <= a: + a = a % m + + # From Ferguson and Schneier, roughly: + + c, d = a, m + uc, vc, ud, vd = 1, 0, 0, 1 + while c != 0: + q, c, d = divmod(d, c) + (c,) + uc, vc, ud, vd = ud - q * uc, vd - q * vc, uc, vc + + # At this point, d is the GCD, and ud*a+vd*m = d. + # If d == 1, this means that ud is a inverse. + + assert d == 1 + if ud > 0: + return ud + else: + return ud + m + + +def modular_sqrt(a, p): + # from http://eli.thegreenplace.net/2009/03/07/computing-modular-square-roots-in-python/ + """ Find a quadratic residue (mod p) of 'a'. p + must be an odd prime. + + Solve the congruence of the form: + x^2 = a (mod p) + And returns x. Note that p - x is also a root. + + 0 is returned is no square root exists for + these a and p. + + The Tonelli-Shanks algorithm is used (except + for some simple cases in which the solution + is known from an identity). This algorithm + runs in polynomial time (unless the + generalized Riemann hypothesis is false). + """ + # Simple cases + # + if legendre_symbol(a, p) != 1: + return 0 + elif a == 0: + return 0 + elif p == 2: + return p + elif p % 4 == 3: + return pow(a, (p + 1) // 4, p) + + # Partition p-1 to s * 2^e for an odd s (i.e. + # reduce all the powers of 2 from p-1) + # + s = p - 1 + e = 0 + while s % 2 == 0: + s /= 2 + e += 1 + + # Find some 'n' with a legendre symbol n|p = -1. + # Shouldn't take long. + # + n = 2 + while legendre_symbol(n, p) != -1: + n += 1 + + # Here be dragons! + # Read the paper "Square roots from 1; 24, 51, + # 10 to Dan Shanks" by Ezra Brown for more + # information + # + + # x is a guess of the square root that gets better + # with each iteration. + # b is the "fudge factor" - by how much we're off + # with the guess. The invariant x^2 = ab (mod p) + # is maintained throughout the loop. + # g is used for successive powers of n to update + # both a and b + # r is the exponent - decreases with each update + # + x = pow(a, (s + 1) // 2, p) + b = pow(a, s, p) + g = pow(n, s, p) + r = e + + while True: + t = b + m = 0 + for m in range(r): + if t == 1: + break + t = pow(t, 2, p) + + if m == 0: + return x + + gs = pow(g, 2 ** (r - m - 1), p) + g = (gs * gs) % p + x = (x * gs) % p + b = (b * g) % p + r = m + + +def legendre_symbol(a, p): + """ Compute the Legendre symbol a|p using + Euler's criterion. p is a prime, a is + relatively prime to p (if p divides + a, then a|p = 0) + + Returns 1 if a has a square root modulo + p, -1 otherwise. + """ + ls = pow(a, (p - 1) // 2, p) + return -1 if ls == p - 1 else ls + + +def jacobi_symbol(n, k): + """Compute the Jacobi symbol of n modulo k + + See http://en.wikipedia.org/wiki/Jacobi_symbol + + For our application k is always prime, so this is the same as the Legendre symbol.""" + assert k > 0 and k & 1, "jacobi symbol is only defined for positive odd k" + n %= k + t = 0 + while n != 0: + while n & 1 == 0: + n >>= 1 + r = k & 7 + t ^= (r == 3 or r == 5) + n, k = k, n + t ^= (n & k & 3 == 3) + n = n % k + if k == 1: + return -1 if t else 1 + return 0 + + +class CurveFp(object): + """Elliptic Curve over the field of integers modulo a prime.""" + def __init__(self, p, a, b): + """The curve of points satisfying y^2 = x^3 + a*x + b (mod p).""" + self.__p = p + self.__a = a + self.__b = b + + def p(self): + return self.__p + + def a(self): + return self.__a + + def b(self): + return self.__b + + def contains_point(self, x, y): + """Is the point (x,y) on this curve?""" + return (y * y - (x * x * x + self.__a * x + self.__b)) % self.__p == 0 + + +class Point(object): + """ A point on an elliptic curve. Altering x and y is forbidding, + but they can be read by the x() and y() methods.""" + def __init__(self, curve, x, y, order=None): + """curve, x, y, order; order (optional) is the order of this point.""" + self.__curve = curve + self.__x = x + self.__y = y + self.__order = order + # self.curve is allowed to be None only for INFINITY: + if self.__curve: + assert self.__curve.contains_point(x, y) + if order: + assert self * order == INFINITY + + def __eq__(self, other): + """Return 1 if the points are identical, 0 otherwise.""" + if self.__curve == other.__curve \ + and self.__x == other.__x \ + and self.__y == other.__y: + return 1 + else: + return 0 + + def __add__(self, other): + """Add one point to another point.""" + + # X9.62 B.3: + if other == INFINITY: + return self + if self == INFINITY: + return other + assert self.__curve == other.__curve + if self.__x == other.__x: + if (self.__y + other.__y) % self.__curve.p() == 0: + return INFINITY + else: + return self.double() + + p = self.__curve.p() + + l = ((other.__y - self.__y) * inverse_mod(other.__x - self.__x, p)) % p + + x3 = (l * l - self.__x - other.__x) % p + y3 = (l * (self.__x - x3) - self.__y) % p + + return Point(self.__curve, x3, y3) + + def __sub__(self, other): + #The inverse of a point P=(xP,yP) is its reflexion across the x-axis : P′=(xP,−yP). + #If you want to compute Q−P, just replace yP by −yP in the usual formula for point addition. + + # X9.62 B.3: + if other == INFINITY: + return self + if self == INFINITY: + return other + assert self.__curve == other.__curve + + p = self.__curve.p() + #opi = inverse_mod(other.__y, p) + opi = -other.__y % p + #print(opi) + #print(-other.__y % p) + + if self.__x == other.__x: + if (self.__y + opi) % self.__curve.p() == 0: + return INFINITY + else: + return self.double + + l = ((opi - self.__y) * inverse_mod(other.__x - self.__x, p)) % p + + x3 = (l * l - self.__x - other.__x) % p + y3 = (l * (self.__x - x3) - self.__y) % p + + return Point(self.__curve, x3, y3) + + def __mul__(self, e): + if self.__order: + e %= self.__order + if e == 0 or self == INFINITY: + return INFINITY + result, q = INFINITY, self + while e: + if e & 1: + result += q + e, q = e >> 1, q.double() + return result + + """ + def __mul__(self, other): + #Multiply a point by an integer. + + def leftmost_bit( x ): + assert x > 0 + result = 1 + while result <= x: result = 2 * result + return result // 2 + + e = other + if self.__order: e = e % self.__order + if e == 0: return INFINITY + if self == INFINITY: return INFINITY + assert e > 0 + + # From X9.62 D.3.2: + + e3 = 3 * e + negative_self = Point( self.__curve, self.__x, -self.__y, self.__order ) + i = leftmost_bit( e3 ) // 2 + result = self + # print "Multiplying %s by %d (e3 = %d):" % ( self, other, e3 ) + while i > 1: + result = result.double() + if ( e3 & i ) != 0 and ( e & i ) == 0: result = result + self + if ( e3 & i ) == 0 and ( e & i ) != 0: result = result + negative_self + # print ". . . i = %d, result = %s" % ( i, result ) + i = i // 2 + + return result + """ + + def __rmul__(self, other): + """Multiply a point by an integer.""" + + return self * other + + def __str__(self): + if self == INFINITY: + return "infinity" + return "(%d, %d)" % (self.__x, self.__y) + + def inverse(self): + return Point(self.__curve, self.__x, -self.__y % self.__curve.p()) + + def double(self): + """Return a new point that is twice the old.""" + + if self == INFINITY: + return INFINITY + + # X9.62 B.3: + + p = self.__curve.p() + a = self.__curve.a() + + l = ((3 * self.__x * self.__x + a) * inverse_mod(2 * self.__y, p)) % p + + x3 = (l * l - 2 * self.__x) % p + y3 = (l * (self.__x - x3) - self.__y) % p + + return Point(self.__curve, x3, y3) + + def x(self): + return self.__x + + def y(self): + return self.__y + + def pair(self): + return (self.__x, self.__y) + + def curve(self): + return self.__curve + + def order(self): + return self.__order + + +# This one point is the Point At Infinity for all purposes: +INFINITY = Point(None, None, None) + + +def __main__(): + + class FailedTest(Exception): + pass + + def test_add(c, x1, y1, x2, y2, x3, y3): + """We expect that on curve c, (x1,y1) + (x2, y2 ) = (x3, y3).""" + p1 = Point(c, x1, y1) + p2 = Point(c, x2, y2) + p3 = p1 + p2 + print("%s + %s = %s" % (p1, p2, p3)) + if p3.x() != x3 or p3.y() != y3: + raise FailedTest("Failure: should give (%d,%d)." % (x3, y3)) + else: + print(" Good.") + + def test_double(c, x1, y1, x3, y3): + """We expect that on curve c, 2*(x1,y1) = (x3, y3).""" + p1 = Point(c, x1, y1) + p3 = p1.double() + print("%s doubled = %s" % (p1, p3)) + if p3.x() != x3 or p3.y() != y3: + raise FailedTest("Failure: should give (%d,%d)." % (x3, y3)) + else: + print(" Good.") + + def test_double_infinity(c): + """We expect that on curve c, 2*INFINITY = INFINITY.""" + p1 = INFINITY + p3 = p1.double() + print("%s doubled = %s" % (p1, p3)) + if p3.x() != INFINITY.x() or p3.y() != INFINITY.y(): + raise FailedTest("Failure: should give (%d,%d)." % (INFINITY.x(), INFINITY.y())) + else: + print(" Good.") + + def test_multiply(c, x1, y1, m, x3, y3): + """We expect that on curve c, m*(x1,y1) = (x3,y3).""" + p1 = Point(c, x1, y1) + p3 = p1 * m + print("%s * %d = %s" % (p1, m, p3)) + if p3.x() != x3 or p3.y() != y3: + raise FailedTest("Failure: should give (%d,%d)." % (x3, y3)) + else: + print(" Good.") + + # A few tests from X9.62 B.3: + + c = CurveFp(23, 1, 1) + test_add(c, 3, 10, 9, 7, 17, 20) + test_double(c, 3, 10, 7, 12) + test_add(c, 3, 10, 3, 10, 7, 12) # (Should just invoke double.) + test_multiply(c, 3, 10, 2, 7, 12) + + test_double_infinity(c) + + # From X9.62 I.1 (p. 96): + + g = Point(c, 13, 7, 7) + + check = INFINITY + for i in range(7 + 1): + p = (i % 7) * g + print("%s * %d = %s, expected %s . . ." % (g, i, p, check)) + if p == check: + print(" Good.") + else: + raise FailedTest("Bad.") + check = check + g + + # NIST Curve P-192: + p = 6277101735386680763835789423207666416083908700390324961279 + r = 6277101735386680763835789423176059013767194773182842284081 + #s = 0x3045ae6fc8422f64ed579528d38120eae12196d5L + c = 0x3099d2bbbfcb2538542dcd5fb078b6ef5f3d6fe2c745de65 + b = 0x64210519e59c80e70fa7e9ab72243049feb8deecc146b9b1 + Gx = 0x188da80eb03090f67cbf20eb43a18800f4ff0afd82ff1012 + Gy = 0x07192b95ffc8da78631011ed6b24cdd573f977a11e794811 + + c192 = CurveFp(p, -3, b) + p192 = Point(c192, Gx, Gy, r) + + # Checking against some sample computations presented + # in X9.62: + + d = 651056770906015076056810763456358567190100156695615665659 + Q = d * p192 + if Q.x() != 0x62B12D60690CDCF330BABAB6E69763B471F994DD702D16A5: + raise FailedTest("p192 * d came out wrong.") + else: + print("p192 * d came out right.") + + k = 6140507067065001063065065565667405560006161556565665656654 + R = k * p192 + if R.x() != 0x885052380FF147B734C330C43D39B2C4A89F29B0F749FEAD \ + or R.y() != 0x9CF9FA1CBEFEFB917747A3BB29C072B9289C2547884FD835: + raise FailedTest("k * p192 came out wrong.") + else: + print("k * p192 came out right.") + + u1 = 2563697409189434185194736134579731015366492496392189760599 + u2 = 6266643813348617967186477710235785849136406323338782220568 + temp = u1 * p192 + u2 * Q + if temp.x() != 0x885052380FF147B734C330C43D39B2C4A89F29B0F749FEAD \ + or temp.y() != 0x9CF9FA1CBEFEFB917747A3BB29C072B9289C2547884FD835: + raise FailedTest("u1 * p192 + u2 * Q came out wrong.") + else: + print("u1 * p192 + u2 * Q came out right.") + + +if __name__ == "__main__": + __main__() diff --git a/basicswap/contrib/test_framework/__init__.py b/basicswap/contrib/test_framework/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/basicswap/contrib/test_framework/address.py b/basicswap/contrib/test_framework/address.py new file mode 100644 index 0000000..7d15167 --- /dev/null +++ b/basicswap/contrib/test_framework/address.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 +# Copyright (c) 2016-2020 The Bitcoin Core developers +# Distributed under the MIT software license, see the accompanying +# file COPYING or http://www.opensource.org/licenses/mit-license.php. +"""Encode and decode BASE58, P2PKH and P2SH addresses.""" + +import enum +import unittest + +from .script import hash256, hash160, sha256, CScript, OP_0 +from .util import hex_str_to_bytes + +from . import segwit_addr + +from .util import assert_equal + +ADDRESS_BCRT1_UNSPENDABLE = 'bcrt1qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqq3xueyj' +ADDRESS_BCRT1_UNSPENDABLE_DESCRIPTOR = 'addr(bcrt1qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqq3xueyj)#juyq9d97' +# Coins sent to this address can be spent with a witness stack of just OP_TRUE +ADDRESS_BCRT1_P2WSH_OP_TRUE = 'bcrt1qft5p2uhsdcdc3l2ua4ap5qqfg4pjaqlp250x7us7a8qqhrxrxfsqseac85' + + +class AddressType(enum.Enum): + bech32 = 'bech32' + p2sh_segwit = 'p2sh-segwit' + legacy = 'legacy' # P2PKH + + +chars = '123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz' + + +def byte_to_base58(b, version): + result = '' + str = b.hex() + str = chr(version).encode('latin-1').hex() + str + checksum = hash256(hex_str_to_bytes(str)).hex() + str += checksum[:8] + value = int('0x'+str,0) + while value > 0: + result = chars[value % 58] + result + value //= 58 + while (str[:2] == '00'): + result = chars[0] + result + str = str[2:] + return result + + +def base58_to_byte(s, verify_checksum=True): + if not s: + return b'' + n = 0 + for c in s: + n *= 58 + assert c in chars + digit = chars.index(c) + n += digit + h = '%x' % n + if len(h) % 2: + h = '0' + h + res = n.to_bytes((n.bit_length() + 7) // 8, 'big') + pad = 0 + for c in s: + if c == chars[0]: + pad += 1 + else: + break + res = b'\x00' * pad + res + if verify_checksum: + assert_equal(hash256(res[:-4])[:4], res[-4:]) + + return res[1:-4], int(res[0]) + + +def keyhash_to_p2pkh(hash, main = False, btc = True): + assert (len(hash) == 20 or len(hash) == 32) + if len(hash) == 20: + if btc: + version = 0 if main else 111 + else: + version = 56 if main else 118 + return byte_to_base58(hash, version) + version = 57 if main else 119 + return byte_to_base58(hash, version) + +def scripthash_to_p2sh(hash, main = False, btc = True): + assert (len(hash) == 20) + if btc: + version = 5 if main else 196 + else: + version = 60 if main else 122 + return byte_to_base58(hash, version) + +def key_to_p2pkh(key, main = False): + key = check_key(key) + return keyhash_to_p2pkh(hash160(key), main) + +def script_to_p2sh(script, main = False, btc = True): + script = check_script(script) + return scripthash_to_p2sh(hash160(script), main, btc) + +def key_to_p2sh_p2wpkh(key, main = False): + key = check_key(key) + p2shscript = CScript([OP_0, hash160(key)]) + return script_to_p2sh(p2shscript, main) + +def program_to_witness(version, program, main = False): + if (type(program) is str): + program = hex_str_to_bytes(program) + assert 0 <= version <= 16 + assert 2 <= len(program) <= 40 + assert version > 0 or len(program) in [20, 32] + return segwit_addr.encode("bc" if main else "bcrt", version, program) + +def script_to_p2wsh(script, main = False): + script = check_script(script) + return program_to_witness(0, sha256(script), main) + +def key_to_p2wpkh(key, main = False): + key = check_key(key) + return program_to_witness(0, hash160(key), main) + +def script_to_p2sh_p2wsh(script, main = False): + script = check_script(script) + p2shscript = CScript([OP_0, sha256(script)]) + return script_to_p2sh(p2shscript, main) + +def check_key(key): + if (type(key) is str): + key = hex_str_to_bytes(key) # Assuming this is hex string + if (type(key) is bytes and (len(key) == 33 or len(key) == 65)): + return key + assert False + +def check_script(script): + if (type(script) is str): + script = hex_str_to_bytes(script) # Assuming this is hex string + if (type(script) is bytes or type(script) is CScript): + return script + assert False + + +class TestFrameworkScript(unittest.TestCase): + def test_base58encodedecode(self): + def check_base58(data, version): + self.assertEqual(base58_to_byte(byte_to_base58(data, version)), (data, version)) + + check_base58(b'\x1f\x8e\xa1p*{\xd4\x94\x1b\xca\tA\xb8R\xc4\xbb\xfe\xdb.\x05', 111) + check_base58(b':\x0b\x05\xf4\xd7\xf6l;\xa7\x00\x9fE50)l\x84\\\xc9\xcf', 111) + check_base58(b'A\xc1\xea\xf1\x11\x80%Y\xba\xd6\x1b`\xd6+\x1f\x89|c\x92\x8a', 111) + check_base58(b'\0A\xc1\xea\xf1\x11\x80%Y\xba\xd6\x1b`\xd6+\x1f\x89|c\x92\x8a', 111) + check_base58(b'\0\0A\xc1\xea\xf1\x11\x80%Y\xba\xd6\x1b`\xd6+\x1f\x89|c\x92\x8a', 111) + check_base58(b'\0\0\0A\xc1\xea\xf1\x11\x80%Y\xba\xd6\x1b`\xd6+\x1f\x89|c\x92\x8a', 111) + check_base58(b'\x1f\x8e\xa1p*{\xd4\x94\x1b\xca\tA\xb8R\xc4\xbb\xfe\xdb.\x05', 0) + check_base58(b':\x0b\x05\xf4\xd7\xf6l;\xa7\x00\x9fE50)l\x84\\\xc9\xcf', 0) + check_base58(b'A\xc1\xea\xf1\x11\x80%Y\xba\xd6\x1b`\xd6+\x1f\x89|c\x92\x8a', 0) + check_base58(b'\0A\xc1\xea\xf1\x11\x80%Y\xba\xd6\x1b`\xd6+\x1f\x89|c\x92\x8a', 0) + check_base58(b'\0\0A\xc1\xea\xf1\x11\x80%Y\xba\xd6\x1b`\xd6+\x1f\x89|c\x92\x8a', 0) + check_base58(b'\0\0\0A\xc1\xea\xf1\x11\x80%Y\xba\xd6\x1b`\xd6+\x1f\x89|c\x92\x8a', 0) diff --git a/basicswap/contrib/test_framework/authproxy.py b/basicswap/contrib/test_framework/authproxy.py new file mode 100644 index 0000000..0530893 --- /dev/null +++ b/basicswap/contrib/test_framework/authproxy.py @@ -0,0 +1,204 @@ +# Copyright (c) 2011 Jeff Garzik +# +# Previous copyright, from python-jsonrpc/jsonrpc/proxy.py: +# +# Copyright (c) 2007 Jan-Klaas Kollhof +# +# This file is part of jsonrpc. +# +# jsonrpc is free software; you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation; either version 2.1 of the License, or +# (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this software; if not, write to the Free Software +# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA +"""HTTP proxy for opening RPC connection to bitcoind. + +AuthServiceProxy has the following improvements over python-jsonrpc's +ServiceProxy class: + +- HTTP connections persist for the life of the AuthServiceProxy object + (if server supports HTTP/1.1) +- sends protocol 'version', per JSON-RPC 1.1 +- sends proper, incrementing 'id' +- sends Basic HTTP authentication headers +- parses all JSON numbers that look like floats as Decimal +- uses standard Python json lib +""" + +import base64 +import decimal +from http import HTTPStatus +import http.client +import json +import logging +import os +import socket +import time +import urllib.parse + +HTTP_TIMEOUT = 30 +USER_AGENT = "AuthServiceProxy/0.1" + +log = logging.getLogger("BitcoinRPC") + +class JSONRPCException(Exception): + def __init__(self, rpc_error, http_status=None): + try: + errmsg = '%(message)s (%(code)i)' % rpc_error + except (KeyError, TypeError): + errmsg = '' + super().__init__(errmsg) + self.error = rpc_error + self.http_status = http_status + + +def EncodeDecimal(o): + if isinstance(o, decimal.Decimal): + return str(o) + raise TypeError(repr(o) + " is not JSON serializable") + +class AuthServiceProxy(): + __id_count = 0 + + # ensure_ascii: escape unicode as \uXXXX, passed to json.dumps + def __init__(self, service_url, service_name=None, timeout=HTTP_TIMEOUT, connection=None, ensure_ascii=True): + self.__service_url = service_url + self._service_name = service_name + self.ensure_ascii = ensure_ascii # can be toggled on the fly by tests + self.__url = urllib.parse.urlparse(service_url) + user = None if self.__url.username is None else self.__url.username.encode('utf8') + passwd = None if self.__url.password is None else self.__url.password.encode('utf8') + authpair = user + b':' + passwd + self.__auth_header = b'Basic ' + base64.b64encode(authpair) + self.timeout = timeout + self._set_conn(connection) + + def __getattr__(self, name): + if name.startswith('__') and name.endswith('__'): + # Python internal stuff + raise AttributeError + if self._service_name is not None: + name = "%s.%s" % (self._service_name, name) + return AuthServiceProxy(self.__service_url, name, connection=self.__conn) + + def _request(self, method, path, postdata): + ''' + Do a HTTP request, with retry if we get disconnected (e.g. due to a timeout). + This is a workaround for https://bugs.python.org/issue3566 which is fixed in Python 3.5. + ''' + headers = {'Host': self.__url.hostname, + 'User-Agent': USER_AGENT, + 'Authorization': self.__auth_header, + 'Content-type': 'application/json'} + if os.name == 'nt': + # Windows somehow does not like to re-use connections + # TODO: Find out why the connection would disconnect occasionally and make it reusable on Windows + # Avoid "ConnectionAbortedError: [WinError 10053] An established connection was aborted by the software in your host machine" + self._set_conn() + try: + self.__conn.request(method, path, postdata, headers) + return self._get_response() + except (BrokenPipeError, ConnectionResetError): + # Python 3.5+ raises BrokenPipeError when the connection was reset + # ConnectionResetError happens on FreeBSD + self.__conn.close() + self.__conn.request(method, path, postdata, headers) + return self._get_response() + except OSError as e: + retry = ( + '[WinError 10053] An established connection was aborted by the software in your host machine' in str(e)) + if retry: + self.__conn.close() + self.__conn.request(method, path, postdata, headers) + return self._get_response() + else: + raise + + def get_request(self, *args, **argsn): + AuthServiceProxy.__id_count += 1 + + log.debug("-{}-> {} {}".format( + AuthServiceProxy.__id_count, + self._service_name, + json.dumps(args or argsn, default=EncodeDecimal, ensure_ascii=self.ensure_ascii), + )) + if args and argsn: + raise ValueError('Cannot handle both named and positional arguments') + return {'version': '1.1', + 'method': self._service_name, + 'params': args or argsn, + 'id': AuthServiceProxy.__id_count} + + def __call__(self, *args, **argsn): + postdata = json.dumps(self.get_request(*args, **argsn), default=EncodeDecimal, ensure_ascii=self.ensure_ascii) + response, status = self._request('POST', self.__url.path, postdata.encode('utf-8')) + if response['error'] is not None: + raise JSONRPCException(response['error'], status) + elif 'result' not in response: + raise JSONRPCException({ + 'code': -343, 'message': 'missing JSON-RPC result'}, status) + elif status != HTTPStatus.OK: + raise JSONRPCException({ + 'code': -342, 'message': 'non-200 HTTP status code but no JSON-RPC error'}, status) + else: + return response['result'] + + def batch(self, rpc_call_list): + postdata = json.dumps(list(rpc_call_list), default=EncodeDecimal, ensure_ascii=self.ensure_ascii) + log.debug("--> " + postdata) + response, status = self._request('POST', self.__url.path, postdata.encode('utf-8')) + if status != HTTPStatus.OK: + raise JSONRPCException({ + 'code': -342, 'message': 'non-200 HTTP status code but no JSON-RPC error'}, status) + return response + + def _get_response(self): + req_start_time = time.time() + try: + http_response = self.__conn.getresponse() + except socket.timeout: + raise JSONRPCException({ + 'code': -344, + 'message': '%r RPC took longer than %f seconds. Consider ' + 'using larger timeout for calls that take ' + 'longer to return.' % (self._service_name, + self.__conn.timeout)}) + if http_response is None: + raise JSONRPCException({ + 'code': -342, 'message': 'missing HTTP response from server'}) + + content_type = http_response.getheader('Content-Type') + if content_type != 'application/json': + raise JSONRPCException( + {'code': -342, 'message': 'non-JSON HTTP response with \'%i %s\' from server' % (http_response.status, http_response.reason)}, + http_response.status) + + responsedata = http_response.read().decode('utf8') + response = json.loads(responsedata, parse_float=decimal.Decimal) + elapsed = time.time() - req_start_time + if "error" in response and response["error"] is None: + log.debug("<-%s- [%.6f] %s" % (response["id"], elapsed, json.dumps(response["result"], default=EncodeDecimal, ensure_ascii=self.ensure_ascii))) + else: + log.debug("<-- [%.6f] %s" % (elapsed, responsedata)) + return response, http_response.status + + def __truediv__(self, relative_uri): + return AuthServiceProxy("{}/{}".format(self.__service_url, relative_uri), self._service_name, connection=self.__conn) + + def _set_conn(self, connection=None): + port = 80 if self.__url.port is None else self.__url.port + if connection: + self.__conn = connection + self.timeout = connection.timeout + elif self.__url.scheme == 'https': + self.__conn = http.client.HTTPSConnection(self.__url.hostname, port, timeout=self.timeout) + else: + self.__conn = http.client.HTTPConnection(self.__url.hostname, port, timeout=self.timeout) diff --git a/basicswap/contrib/test_framework/coverage.py b/basicswap/contrib/test_framework/coverage.py new file mode 100644 index 0000000..7705dd3 --- /dev/null +++ b/basicswap/contrib/test_framework/coverage.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 +# Copyright (c) 2015-2018 The Bitcoin Core developers +# Distributed under the MIT software license, see the accompanying +# file COPYING or http://www.opensource.org/licenses/mit-license.php. +"""Utilities for doing coverage analysis on the RPC interface. + +Provides a way to track which RPC commands are exercised during +testing. +""" + +import os + + +REFERENCE_FILENAME = 'rpc_interface.txt' + + +class AuthServiceProxyWrapper(): + """ + An object that wraps AuthServiceProxy to record specific RPC calls. + + """ + def __init__(self, auth_service_proxy_instance, coverage_logfile=None): + """ + Kwargs: + auth_service_proxy_instance (AuthServiceProxy): the instance + being wrapped. + coverage_logfile (str): if specified, write each service_name + out to a file when called. + + """ + self.auth_service_proxy_instance = auth_service_proxy_instance + self.coverage_logfile = coverage_logfile + + def __getattr__(self, name): + return_val = getattr(self.auth_service_proxy_instance, name) + if not isinstance(return_val, type(self.auth_service_proxy_instance)): + # If proxy getattr returned an unwrapped value, do the same here. + return return_val + return AuthServiceProxyWrapper(return_val, self.coverage_logfile) + + def __call__(self, *args, **kwargs): + """ + Delegates to AuthServiceProxy, then writes the particular RPC method + called to a file. + + """ + return_val = self.auth_service_proxy_instance.__call__(*args, **kwargs) + self._log_call() + return return_val + + def _log_call(self): + rpc_method = self.auth_service_proxy_instance._service_name + + if self.coverage_logfile: + with open(self.coverage_logfile, 'a+', encoding='utf8') as f: + f.write("%s\n" % rpc_method) + + def __truediv__(self, relative_uri): + return AuthServiceProxyWrapper(self.auth_service_proxy_instance / relative_uri, + self.coverage_logfile) + + def get_request(self, *args, **kwargs): + self._log_call() + return self.auth_service_proxy_instance.get_request(*args, **kwargs) + +def get_filename(dirname, n_node): + """ + Get a filename unique to the test process ID and node. + + This file will contain a list of RPC commands covered. + """ + pid = str(os.getpid()) + return os.path.join( + dirname, "coverage.pid%s.node%s.txt" % (pid, str(n_node))) + + +def write_all_rpc_commands(dirname, node): + """ + Write out a list of all RPC functions available in `bitcoin-cli` for + coverage comparison. This will only happen once per coverage + directory. + + Args: + dirname (str): temporary test dir + node (AuthServiceProxy): client + + Returns: + bool. if the RPC interface file was written. + + """ + filename = os.path.join(dirname, REFERENCE_FILENAME) + + if os.path.isfile(filename): + return False + + help_output = node.help().split('\n') + commands = set() + + for line in help_output: + line = line.strip() + + # Ignore blanks and headers + if line and not line.startswith('='): + commands.add("%s\n" % line.split()[0]) + + with open(filename, 'w', encoding='utf8') as f: + f.writelines(list(commands)) + + return True diff --git a/basicswap/contrib/test_framework/key.py b/basicswap/contrib/test_framework/key.py new file mode 100644 index 0000000..55e2de1 --- /dev/null +++ b/basicswap/contrib/test_framework/key.py @@ -0,0 +1,393 @@ +# Copyright (c) 2019 Pieter Wuille +# Distributed under the MIT software license, see the accompanying +# file COPYING or http://www.opensource.org/licenses/mit-license.php. +"""Test-only secp256k1 elliptic curve implementation + +WARNING: This code is slow, uses bad randomness, does not properly protect +keys, and is trivially vulnerable to side channel attacks. Do not use for +anything but tests.""" +import random + +def modinv(a, n): + """Compute the modular inverse of a modulo n + + See https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm#Modular_integers. + """ + t1, t2 = 0, 1 + r1, r2 = n, a + while r2 != 0: + q = r1 // r2 + t1, t2 = t2, t1 - q * t2 + r1, r2 = r2, r1 - q * r2 + if r1 > 1: + return None + if t1 < 0: + t1 += n + return t1 + +def jacobi_symbol(n, k): + """Compute the Jacobi symbol of n modulo k + + See http://en.wikipedia.org/wiki/Jacobi_symbol + + For our application k is always prime, so this is the same as the Legendre symbol.""" + assert k > 0 and k & 1, "jacobi symbol is only defined for positive odd k" + n %= k + t = 0 + while n != 0: + while n & 1 == 0: + n >>= 1 + r = k & 7 + t ^= (r == 3 or r == 5) + n, k = k, n + t ^= (n & k & 3 == 3) + n = n % k + if k == 1: + return -1 if t else 1 + return 0 + +def modsqrt(a, p): + """Compute the square root of a modulo p when p % 4 = 3. + + The Tonelli-Shanks algorithm can be used. See https://en.wikipedia.org/wiki/Tonelli-Shanks_algorithm + + Limiting this function to only work for p % 4 = 3 means we don't need to + iterate through the loop. The highest n such that p - 1 = 2^n Q with Q odd + is n = 1. Therefore Q = (p-1)/2 and sqrt = a^((Q+1)/2) = a^((p+1)/4) + + secp256k1's is defined over field of size 2**256 - 2**32 - 977, which is 3 mod 4. + """ + if p % 4 != 3: + raise NotImplementedError("modsqrt only implemented for p % 4 = 3") + sqrt = pow(a, (p + 1)//4, p) + if pow(sqrt, 2, p) == a % p: + return sqrt + return None + +class EllipticCurve: + def __init__(self, p, a, b): + """Initialize elliptic curve y^2 = x^3 + a*x + b over GF(p).""" + self.p = p + self.a = a % p + self.b = b % p + + def affine(self, p1): + """Convert a Jacobian point tuple p1 to affine form, or None if at infinity. + + An affine point is represented as the Jacobian (x, y, 1)""" + x1, y1, z1 = p1 + if z1 == 0: + return None + inv = modinv(z1, self.p) + inv_2 = (inv**2) % self.p + inv_3 = (inv_2 * inv) % self.p + return ((inv_2 * x1) % self.p, (inv_3 * y1) % self.p, 1) + + def negate(self, p1): + """Negate a Jacobian point tuple p1.""" + x1, y1, z1 = p1 + return (x1, (self.p - y1) % self.p, z1) + + def on_curve(self, p1): + """Determine whether a Jacobian tuple p is on the curve (and not infinity)""" + x1, y1, z1 = p1 + z2 = pow(z1, 2, self.p) + z4 = pow(z2, 2, self.p) + return z1 != 0 and (pow(x1, 3, self.p) + self.a * x1 * z4 + self.b * z2 * z4 - pow(y1, 2, self.p)) % self.p == 0 + + def is_x_coord(self, x): + """Test whether x is a valid X coordinate on the curve.""" + x_3 = pow(x, 3, self.p) + return jacobi_symbol(x_3 + self.a * x + self.b, self.p) != -1 + + def lift_x(self, x): + """Given an X coordinate on the curve, return a corresponding affine point.""" + x_3 = pow(x, 3, self.p) + v = x_3 + self.a * x + self.b + y = modsqrt(v, self.p) + if y is None: + return None + return (x, y, 1) + + def double(self, p1): + """Double a Jacobian tuple p1 + + See https://en.wikibooks.org/wiki/Cryptography/Prime_Curve/Jacobian_Coordinates - Point Doubling""" + x1, y1, z1 = p1 + if z1 == 0: + return (0, 1, 0) + y1_2 = (y1**2) % self.p + y1_4 = (y1_2**2) % self.p + x1_2 = (x1**2) % self.p + s = (4*x1*y1_2) % self.p + m = 3*x1_2 + if self.a: + m += self.a * pow(z1, 4, self.p) + m = m % self.p + x2 = (m**2 - 2*s) % self.p + y2 = (m*(s - x2) - 8*y1_4) % self.p + z2 = (2*y1*z1) % self.p + return (x2, y2, z2) + + def add_mixed(self, p1, p2): + """Add a Jacobian tuple p1 and an affine tuple p2 + + See https://en.wikibooks.org/wiki/Cryptography/Prime_Curve/Jacobian_Coordinates - Point Addition (with affine point)""" + x1, y1, z1 = p1 + x2, y2, z2 = p2 + assert(z2 == 1) + # Adding to the point at infinity is a no-op + if z1 == 0: + return p2 + z1_2 = (z1**2) % self.p + z1_3 = (z1_2 * z1) % self.p + u2 = (x2 * z1_2) % self.p + s2 = (y2 * z1_3) % self.p + if x1 == u2: + if (y1 != s2): + # p1 and p2 are inverses. Return the point at infinity. + return (0, 1, 0) + # p1 == p2. The formulas below fail when the two points are equal. + return self.double(p1) + h = u2 - x1 + r = s2 - y1 + h_2 = (h**2) % self.p + h_3 = (h_2 * h) % self.p + u1_h_2 = (x1 * h_2) % self.p + x3 = (r**2 - h_3 - 2*u1_h_2) % self.p + y3 = (r*(u1_h_2 - x3) - y1*h_3) % self.p + z3 = (h*z1) % self.p + return (x3, y3, z3) + + def add(self, p1, p2): + """Add two Jacobian tuples p1 and p2 + + See https://en.wikibooks.org/wiki/Cryptography/Prime_Curve/Jacobian_Coordinates - Point Addition""" + x1, y1, z1 = p1 + x2, y2, z2 = p2 + # Adding the point at infinity is a no-op + if z1 == 0: + return p2 + if z2 == 0: + return p1 + # Adding an Affine to a Jacobian is more efficient since we save field multiplications and squarings when z = 1 + if z1 == 1: + return self.add_mixed(p2, p1) + if z2 == 1: + return self.add_mixed(p1, p2) + z1_2 = (z1**2) % self.p + z1_3 = (z1_2 * z1) % self.p + z2_2 = (z2**2) % self.p + z2_3 = (z2_2 * z2) % self.p + u1 = (x1 * z2_2) % self.p + u2 = (x2 * z1_2) % self.p + s1 = (y1 * z2_3) % self.p + s2 = (y2 * z1_3) % self.p + if u1 == u2: + if (s1 != s2): + # p1 and p2 are inverses. Return the point at infinity. + return (0, 1, 0) + # p1 == p2. The formulas below fail when the two points are equal. + return self.double(p1) + h = u2 - u1 + r = s2 - s1 + h_2 = (h**2) % self.p + h_3 = (h_2 * h) % self.p + u1_h_2 = (u1 * h_2) % self.p + x3 = (r**2 - h_3 - 2*u1_h_2) % self.p + y3 = (r*(u1_h_2 - x3) - s1*h_3) % self.p + z3 = (h*z1*z2) % self.p + return (x3, y3, z3) + + def mul(self, ps): + """Compute a (multi) point multiplication + + ps is a list of (Jacobian tuple, scalar) pairs. + """ + r = (0, 1, 0) + for i in range(255, -1, -1): + r = self.double(r) + for (p, n) in ps: + if ((n >> i) & 1): + r = self.add(r, p) + return r + +SECP256K1 = EllipticCurve(2**256 - 2**32 - 977, 0, 7) +SECP256K1_G = (0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798, 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8, 1) +SECP256K1_ORDER = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141 +SECP256K1_ORDER_HALF = SECP256K1_ORDER // 2 + +class ECPubKey(): + """A secp256k1 public key""" + + def __init__(self): + """Construct an uninitialized public key""" + self.valid = False + + def set_int(self, x, y): + p = (x, y, 1) + self.valid = SECP256K1.on_curve(p) + if self.valid: + self.p = p + self.compressed = False + + def set(self, data): + """Construct a public key from a serialization in compressed or uncompressed format""" + if (len(data) == 65 and data[0] == 0x04): + p = (int.from_bytes(data[1:33], 'big'), int.from_bytes(data[33:65], 'big'), 1) + self.valid = SECP256K1.on_curve(p) + if self.valid: + self.p = p + self.compressed = False + elif (len(data) == 33 and (data[0] == 0x02 or data[0] == 0x03)): + x = int.from_bytes(data[1:33], 'big') + if SECP256K1.is_x_coord(x): + p = SECP256K1.lift_x(x) + # if the oddness of the y co-ord isn't correct, find the other + # valid y + if (p[1] & 1) != (data[0] & 1): + p = SECP256K1.negate(p) + self.p = p + self.valid = True + self.compressed = True + else: + self.valid = False + else: + self.valid = False + + @property + def is_compressed(self): + return self.compressed + + @property + def is_valid(self): + return self.valid + + def get_bytes(self): + assert(self.valid) + p = SECP256K1.affine(self.p) + if p is None: + return None + if self.compressed: + return bytes([0x02 + (p[1] & 1)]) + p[0].to_bytes(32, 'big') + else: + return bytes([0x04]) + p[0].to_bytes(32, 'big') + p[1].to_bytes(32, 'big') + + def verify_ecdsa(self, sig, msg, low_s=True): + """Verify a strictly DER-encoded ECDSA signature against this pubkey. + + See https://en.wikipedia.org/wiki/Elliptic_Curve_Digital_Signature_Algorithm for the + ECDSA verifier algorithm""" + assert(self.valid) + + # Extract r and s from the DER formatted signature. Return false for + # any DER encoding errors. + if (sig[1] + 2 != len(sig)): + return False + if (len(sig) < 4): + return False + if (sig[0] != 0x30): + return False + if (sig[2] != 0x02): + return False + rlen = sig[3] + if (len(sig) < 6 + rlen): + return False + if rlen < 1 or rlen > 33: + return False + if sig[4] >= 0x80: + return False + if (rlen > 1 and (sig[4] == 0) and not (sig[5] & 0x80)): + return False + r = int.from_bytes(sig[4:4+rlen], 'big') + if (sig[4+rlen] != 0x02): + return False + slen = sig[5+rlen] + if slen < 1 or slen > 33: + return False + if (len(sig) != 6 + rlen + slen): + return False + if sig[6+rlen] >= 0x80: + return False + if (slen > 1 and (sig[6+rlen] == 0) and not (sig[7+rlen] & 0x80)): + return False + s = int.from_bytes(sig[6+rlen:6+rlen+slen], 'big') + + # Verify that r and s are within the group order + if r < 1 or s < 1 or r >= SECP256K1_ORDER or s >= SECP256K1_ORDER: + return False + if low_s and s >= SECP256K1_ORDER_HALF: + return False + z = int.from_bytes(msg, 'big') + + # Run verifier algorithm on r, s + w = modinv(s, SECP256K1_ORDER) + u1 = z*w % SECP256K1_ORDER + u2 = r*w % SECP256K1_ORDER + R = SECP256K1.affine(SECP256K1.mul([(SECP256K1_G, u1), (self.p, u2)])) + if R is None or R[0] != r: + return False + return True + +class ECKey(): + """A secp256k1 private key""" + + def __init__(self): + self.valid = False + + def set(self, secret, compressed): + """Construct a private key object with given 32-byte secret and compressed flag.""" + assert(len(secret) == 32) + secret = int.from_bytes(secret, 'big') + self.valid = (secret > 0 and secret < SECP256K1_ORDER) + if self.valid: + self.secret = secret + self.compressed = compressed + + def generate(self, compressed=True): + """Generate a random private key (compressed or uncompressed).""" + self.set(random.randrange(1, SECP256K1_ORDER).to_bytes(32, 'big'), compressed) + + def get_bytes(self): + """Retrieve the 32-byte representation of this key.""" + assert(self.valid) + return self.secret.to_bytes(32, 'big') + + @property + def is_valid(self): + return self.valid + + @property + def is_compressed(self): + return self.compressed + + def get_pubkey(self): + """Compute an ECPubKey object for this secret key.""" + assert(self.valid) + ret = ECPubKey() + p = SECP256K1.mul([(SECP256K1_G, self.secret)]) + ret.p = p + ret.valid = True + ret.compressed = self.compressed + return ret + + def sign_ecdsa(self, msg, low_s=True): + """Construct a DER-encoded ECDSA signature with this key. + + See https://en.wikipedia.org/wiki/Elliptic_Curve_Digital_Signature_Algorithm for the + ECDSA signer algorithm.""" + assert(self.valid) + z = int.from_bytes(msg, 'big') + # Note: no RFC6979, but a simple random nonce (some tests rely on distinct transactions for the same operation) + k = random.randrange(1, SECP256K1_ORDER) + R = SECP256K1.affine(SECP256K1.mul([(SECP256K1_G, k)])) + r = R[0] % SECP256K1_ORDER + s = (modinv(k, SECP256K1_ORDER) * (z + self.secret * r)) % SECP256K1_ORDER + if low_s and s > SECP256K1_ORDER_HALF: + s = SECP256K1_ORDER - s + # Represent in DER format. The byte representations of r and s have + # length rounded up (255 bits becomes 32 bytes and 256 bits becomes 33 + # bytes). + rb = r.to_bytes((r.bit_length() + 8) // 8, 'big') + sb = s.to_bytes((s.bit_length() + 8) // 8, 'big') + return b'\x30' + bytes([4 + len(rb) + len(sb), 2, len(rb)]) + rb + bytes([2, len(sb)]) + sb diff --git a/basicswap/contrib/test_framework/messages.py b/basicswap/contrib/test_framework/messages.py new file mode 100755 index 0000000..70e353e --- /dev/null +++ b/basicswap/contrib/test_framework/messages.py @@ -0,0 +1,1756 @@ +#!/usr/bin/env python3 +# Copyright (c) 2010 ArtForz -- public domain half-a-node +# Copyright (c) 2012 Jeff Garzik +# Copyright (c) 2010-2020 The Bitcoin Core developers +# Distributed under the MIT software license, see the accompanying +# file COPYING or http://www.opensource.org/licenses/mit-license.php. +"""Bitcoin test framework primitive and message structures + +CBlock, CTransaction, CBlockHeader, CTxIn, CTxOut, etc....: + data structures that should map to corresponding structures in + bitcoin/primitives + +msg_block, msg_tx, msg_headers, etc.: + data structures that represent network messages + +ser_*, deser_*: functions that handle serialization/deserialization. + +Classes use __slots__ to ensure extraneous attributes aren't accidentally added +by tests, compromising their intended effect. +""" +from codecs import encode +import copy +import hashlib +from io import BytesIO +import random +import socket +import struct +import time + +from .siphash import siphash256 +from .util import hex_str_to_bytes, assert_equal + +MIN_VERSION_SUPPORTED = 60001 +#MY_VERSION = 70014 # past bip-31 for ping/pong +MY_VERSION = 90009 +MY_SUBVERSION = b"/python-mininode-tester:0.0.3/" +MY_RELAY = 1 # from version 70001 onwards, fRelay should be appended to version messages (BIP37) + +MAX_LOCATOR_SZ = 101 +MAX_BLOCK_BASE_SIZE = 1000000 +MAX_BLOOM_FILTER_SIZE = 36000 +MAX_BLOOM_HASH_FUNCS = 50 + +COIN = 100000000 # 1 btc in satoshis +MAX_MONEY = 21000000 * COIN + +BIP125_SEQUENCE_NUMBER = 0xfffffffd # Sequence number that is BIP 125 opt-in and BIP 68-opt-out + +NODE_NETWORK = (1 << 0) +NODE_GETUTXO = (1 << 1) +NODE_BLOOM = (1 << 2) +NODE_WITNESS = (1 << 3) +NODE_NETWORK_LIMITED = (1 << 10) + +MSG_TX = 1 +MSG_BLOCK = 2 +MSG_FILTERED_BLOCK = 3 +MSG_CMPCT_BLOCK = 4 +MSG_WITNESS_FLAG = 1 << 30 +MSG_TYPE_MASK = 0xffffffff >> 2 + +FILTER_TYPE_BASIC = 0 + +PARTICL_TX_VERSION = 0xa0 +PARTICL_TX_ANON_MARKER = 0xffffffa0 +OUTPUT_TYPE_STANDARD = 1 +OUTPUT_TYPE_CT = 2 +OUTPUT_TYPE_RINGCT = 3 +OUTPUT_TYPE_DATA = 4 + + +# Serialization/deserialization tools +def sha256(s): + return hashlib.new('sha256', s).digest() + +def hash256(s): + return sha256(sha256(s)) + +def ser_compact_size(l): + r = b"" + if l < 253: + r = struct.pack("B", l) + elif l < 0x10000: + r = struct.pack(">= 32 + return rs + + +def uint256_from_str(s): + r = 0 + t = struct.unpack("> 24) & 0xFF + v = (c & 0xFFFFFF) << (8 * (nbytes - 3)) + return v + + +def deser_vector(f, c): + nit = deser_compact_size(f) + r = [] + for i in range(nit): + t = c() + t.deserialize(f) + r.append(t) + return r + + +# ser_function_name: Allow for an alternate serialization function on the +# entries in the vector (we use this for serializing the vector of transactions +# for a witness block). +def ser_vector(l, ser_function_name=None): + r = ser_compact_size(len(l)) + for i in l: + if ser_function_name: + r += getattr(i, ser_function_name)() + else: + r += i.serialize() + return r + + +def deser_uint256_vector(f): + nit = deser_compact_size(f) + r = [] + for i in range(nit): + t = deser_uint256(f) + r.append(t) + return r + + +def ser_uint256_vector(l): + r = ser_compact_size(len(l)) + for i in l: + r += ser_uint256(i) + return r + + +def deser_string_vector(f): + nit = deser_compact_size(f) + r = [] + for i in range(nit): + t = deser_string(f) + r.append(t) + return r + + +def ser_string_vector(l): + r = ser_compact_size(len(l)) + for sv in l: + r += ser_string(sv) + return r + + +# Deserialize from a hex string representation (eg from RPC) +def FromHex(obj, hex_string): + obj.deserialize(BytesIO(hex_str_to_bytes(hex_string))) + return obj + +# Convert a binary-serializable object to hex (eg for submission via RPC) +def ToHex(obj): + return obj.serialize().hex() + +# Objects that map to bitcoind objects, which can be serialized/deserialized + + +class CAddress: + __slots__ = ("ip", "nServices", "pchReserved", "port", "time") + + def __init__(self): + self.time = 0 + self.nServices = 1 + self.pchReserved = b"\x00" * 10 + b"\xff" * 2 + self.ip = "0.0.0.0" + self.port = 0 + + def deserialize(self, f, with_time=True): + if with_time: + self.time = struct.unpack("H", f.read(2))[0] + + def serialize(self, with_time=True): + r = b"" + if with_time: + r += struct.pack("H", self.port) + return r + + def __repr__(self): + return "CAddress(nServices=%i ip=%s port=%i)" % (self.nServices, + self.ip, self.port) + + +class CInv: + __slots__ = ("hash", "type") + + typemap = { + 0: "Error", + MSG_TX: "TX", + MSG_BLOCK: "Block", + MSG_TX | MSG_WITNESS_FLAG: "WitnessTx", + MSG_BLOCK | MSG_WITNESS_FLAG: "WitnessBlock", + MSG_FILTERED_BLOCK: "filtered Block", + 4: "CompactBlock" + } + + def __init__(self, t=0, h=0): + self.type = t + self.hash = h + + def deserialize(self, f): + self.type = struct.unpack(" 21000000 * COIN: + return False + return True + + def __repr__(self): + return "CTransaction(nVersion=%i vin=%s vout=%s wit=%s nLockTime=%i)" \ + % (self.nVersion, repr(self.vin), repr(self.vout), repr(self.wit), self.nLockTime) + + +class CBlockHeader: + __slots__ = ("hash", "hashMerkleRoot", "hashPrevBlock", "nBits", "nNonce", + "nTime", "nVersion", "sha256", + "is_part", "hashWitnessMerkleRoot") + + def __init__(self, header=None, is_part=False): + self.is_part = is_part + if header is None: + self.set_null() + else: + self.is_part = header.is_part + self.nVersion = header.nVersion + self.hashPrevBlock = header.hashPrevBlock + self.hashMerkleRoot = header.hashMerkleRoot + if self.is_part: + self.hashWitnessMerkleRoot = header.hashWitnessMerkleRoot + self.nTime = header.nTime + self.nBits = header.nBits + self.nNonce = header.nNonce + self.sha256 = header.sha256 + self.hash = header.hash + self.calc_sha256() + + def set_null(self): + self.nVersion = 1 + self.hashPrevBlock = 0 + self.hashMerkleRoot = 0 + if self.is_part: + self.hashWitnessMerkleRoot = 0 + self.nTime = 0 + self.nBits = 0 + self.nNonce = 0 + self.sha256 = None + self.hash = None + + def deserialize(self, f): + self.nVersion = struct.unpack(" 1: + newhashes = [] + for i in range(0, len(hashes), 2): + i2 = min(i+1, len(hashes)-1) + newhashes.append(hash256(hashes[i] + hashes[i2])) + hashes = newhashes + return uint256_from_str(hashes[0]) + + def calc_merkle_root(self): + hashes = [] + for tx in self.vtx: + tx.calc_sha256() + hashes.append(ser_uint256(tx.sha256)) + return self.get_merkle_root(hashes) + + def calc_witness_merkle_root(self): + # For witness root purposes, the hash of the + # coinbase, with witness, is defined to be 0...0 + hashes = [ser_uint256(0)] + + for tx in self.vtx[1:]: + # Calculate the hashes with witness data + hashes.append(ser_uint256(tx.calc_sha256(True))) + + return self.get_merkle_root(hashes) + + def is_valid(self): + self.calc_sha256() + target = uint256_from_compact(self.nBits) + if self.sha256 > target: + return False + for tx in self.vtx: + if not tx.is_valid(): + return False + if self.calc_merkle_root() != self.hashMerkleRoot: + return False + return True + + def solve(self): + self.rehash() + target = uint256_from_compact(self.nBits) + while self.sha256 > target: + self.nNonce += 1 + self.rehash() + + def __repr__(self): + return "CBlock(nVersion=%i hashPrevBlock=%064x hashMerkleRoot=%064x nTime=%s nBits=%08x nNonce=%08x vtx=%s)" \ + % (self.nVersion, self.hashPrevBlock, self.hashMerkleRoot, + time.ctime(self.nTime), self.nBits, self.nNonce, repr(self.vtx)) + + +class PrefilledTransaction: + __slots__ = ("index", "tx") + + def __init__(self, index=0, tx = None): + self.index = index + self.tx = tx + + def deserialize(self, f): + self.index = deser_compact_size(f) + self.tx = CTransaction() + self.tx.deserialize(f) + + def serialize(self, with_witness=True): + r = b"" + r += ser_compact_size(self.index) + if with_witness: + r += self.tx.serialize_with_witness() + else: + r += self.tx.serialize_without_witness() + return r + + def serialize_without_witness(self): + return self.serialize(with_witness=False) + + def serialize_with_witness(self): + return self.serialize(with_witness=True) + + def __repr__(self): + return "PrefilledTransaction(index=%d, tx=%s)" % (self.index, repr(self.tx)) + + +# This is what we send on the wire, in a cmpctblock message. +class P2PHeaderAndShortIDs: + __slots__ = ("header", "nonce", "prefilled_txn", "prefilled_txn_length", + "shortids", "shortids_length") + + def __init__(self): + self.header = CBlockHeader() + self.nonce = 0 + self.shortids_length = 0 + self.shortids = [] + self.prefilled_txn_length = 0 + self.prefilled_txn = [] + + def deserialize(self, f): + self.header.deserialize(f) + self.nonce = struct.unpack("= 70001: + # Relay field is optional for version 70001 onwards + try: + self.nRelay = struct.unpack(" +class msg_headers: + __slots__ = ("headers",) + msgtype = b"headers" + + def __init__(self, headers=None): + self.headers = headers if headers is not None else [] + + def deserialize(self, f): + # comment in bitcoind indicates these should be deserialized as blocks + blocks = deser_vector(f, CBlock) + for x in blocks: + self.headers.append(CBlockHeader(x)) + + def serialize(self): + blocks = [CBlock(x) for x in self.headers] + return ser_vector(blocks) + + def __repr__(self): + return "msg_headers(headers=%s)" % repr(self.headers) + + +class msg_merkleblock: + __slots__ = ("merkleblock",) + msgtype = b"merkleblock" + + def __init__(self, merkleblock=None): + if merkleblock is None: + self.merkleblock = CMerkleBlock() + else: + self.merkleblock = merkleblock + + def deserialize(self, f): + self.merkleblock.deserialize(f) + + def serialize(self): + return self.merkleblock.serialize() + + def __repr__(self): + return "msg_merkleblock(merkleblock=%s)" % (repr(self.merkleblock)) + + +class msg_filterload: + __slots__ = ("data", "nHashFuncs", "nTweak", "nFlags") + msgtype = b"filterload" + + def __init__(self, data=b'00', nHashFuncs=0, nTweak=0, nFlags=0): + self.data = data + self.nHashFuncs = nHashFuncs + self.nTweak = nTweak + self.nFlags = nFlags + + def deserialize(self, f): + self.data = deser_string(f) + self.nHashFuncs = struct.unpack(">= 8 + if r[-1] & 0x80: + r.append(0x80 if neg else 0) + elif neg: + r[-1] |= 0x80 + return bytes([len(r)]) + r + + @staticmethod + def decode(vch): + result = 0 + # We assume valid push_size and minimal encoding + value = vch[1:] + if len(value) == 0: + return result + for i, byte in enumerate(value): + result |= int(byte) << 8 * i + if value[-1] >= 0x80: + # Mask for all but the highest result bit + num_mask = (2**(len(value) * 8) - 1) >> 1 + result &= num_mask + result *= -1 + return result + + +class CScript(bytes): + """Serialized script + + A bytes subclass, so you can use this directly whenever bytes are accepted. + Note that this means that indexing does *not* work - you'll get an index by + byte rather than opcode. This format was chosen for efficiency so that the + general case would not require creating a lot of little CScriptOP objects. + + iter(script) however does iterate by opcode. + """ + __slots__ = () + + @classmethod + def __coerce_instance(cls, other): + # Coerce other into bytes + if isinstance(other, CScriptOp): + other = bytes([other]) + elif isinstance(other, CScriptNum): + if (other.value == 0): + other = bytes([CScriptOp(OP_0)]) + else: + other = CScriptNum.encode(other) + elif isinstance(other, int): + if 0 <= other <= 16: + other = bytes([CScriptOp.encode_op_n(other)]) + elif other == -1: + other = bytes([OP_1NEGATE]) + else: + other = CScriptOp.encode_op_pushdata(bn2vch(other)) + elif isinstance(other, (bytes, bytearray)): + other = CScriptOp.encode_op_pushdata(other) + return other + + def __add__(self, other): + # add makes no sense for a CScript() + raise NotImplementedError + + def join(self, iterable): + # join makes no sense for a CScript() + raise NotImplementedError + + def __new__(cls, value=b''): + if isinstance(value, bytes) or isinstance(value, bytearray): + return super().__new__(cls, value) + else: + def coerce_iterable(iterable): + for instance in iterable: + yield cls.__coerce_instance(instance) + # Annoyingly on both python2 and python3 bytes.join() always + # returns a bytes instance even when subclassed. + return super().__new__(cls, b''.join(coerce_iterable(value))) + + def raw_iter(self): + """Raw iteration + + Yields tuples of (opcode, data, sop_idx) so that the different possible + PUSHDATA encodings can be accurately distinguished, as well as + determining the exact opcode byte indexes. (sop_idx) + """ + i = 0 + while i < len(self): + sop_idx = i + opcode = self[i] + i += 1 + + if opcode > OP_PUSHDATA4: + yield (opcode, None, sop_idx) + else: + datasize = None + pushdata_type = None + if opcode < OP_PUSHDATA1: + pushdata_type = 'PUSHDATA(%d)' % opcode + datasize = opcode + + elif opcode == OP_PUSHDATA1: + pushdata_type = 'PUSHDATA1' + if i >= len(self): + raise CScriptInvalidError('PUSHDATA1: missing data length') + datasize = self[i] + i += 1 + + elif opcode == OP_PUSHDATA2: + pushdata_type = 'PUSHDATA2' + if i + 1 >= len(self): + raise CScriptInvalidError('PUSHDATA2: missing data length') + datasize = self[i] + (self[i + 1] << 8) + i += 2 + + elif opcode == OP_PUSHDATA4: + pushdata_type = 'PUSHDATA4' + if i + 3 >= len(self): + raise CScriptInvalidError('PUSHDATA4: missing data length') + datasize = self[i] + (self[i + 1] << 8) + (self[i + 2] << 16) + (self[i + 3] << 24) + i += 4 + + else: + assert False # shouldn't happen + + data = bytes(self[i:i + datasize]) + + # Check for truncation + if len(data) < datasize: + raise CScriptTruncatedPushDataError('%s: truncated data' % pushdata_type, data) + + i += datasize + + yield (opcode, data, sop_idx) + + def __iter__(self): + """'Cooked' iteration + + Returns either a CScriptOP instance, an integer, or bytes, as + appropriate. + + See raw_iter() if you need to distinguish the different possible + PUSHDATA encodings. + """ + for (opcode, data, sop_idx) in self.raw_iter(): + if data is not None: + yield data + else: + opcode = CScriptOp(opcode) + + if opcode.is_small_int(): + yield opcode.decode_op_n() + else: + yield CScriptOp(opcode) + + def __repr__(self): + def _repr(o): + if isinstance(o, bytes): + return "x('%s')" % o.hex() + else: + return repr(o) + + ops = [] + i = iter(self) + while True: + op = None + try: + op = _repr(next(i)) + except CScriptTruncatedPushDataError as err: + op = '%s...' % (_repr(err.data), err) + break + except CScriptInvalidError as err: + op = '' % err + break + except StopIteration: + break + finally: + if op is not None: + ops.append(op) + + return "CScript([%s])" % ', '.join(ops) + + def GetSigOpCount(self, fAccurate): + """Get the SigOp count. + + fAccurate - Accurately count CHECKMULTISIG, see BIP16 for details. + + Note that this is consensus-critical. + """ + n = 0 + lastOpcode = OP_INVALIDOPCODE + for (opcode, data, sop_idx) in self.raw_iter(): + if opcode in (OP_CHECKSIG, OP_CHECKSIGVERIFY): + n += 1 + elif opcode in (OP_CHECKMULTISIG, OP_CHECKMULTISIGVERIFY): + if fAccurate and (OP_1 <= lastOpcode <= OP_16): + n += opcode.decode_op_n() + else: + n += 20 + lastOpcode = opcode + return n + + +SIGHASH_ALL = 1 +SIGHASH_NONE = 2 +SIGHASH_SINGLE = 3 +SIGHASH_ANYONECANPAY = 0x80 + +def FindAndDelete(script, sig): + """Consensus critical, see FindAndDelete() in Satoshi codebase""" + r = b'' + last_sop_idx = sop_idx = 0 + skip = True + for (opcode, data, sop_idx) in script.raw_iter(): + if not skip: + r += script[last_sop_idx:sop_idx] + last_sop_idx = sop_idx + if script[sop_idx:sop_idx + len(sig)] == sig: + skip = True + else: + skip = False + if not skip: + r += script[last_sop_idx:] + return CScript(r) + + +def LegacySignatureHash(script, txTo, inIdx, hashtype): + """Consensus-correct SignatureHash + + Returns (hash, err) to precisely match the consensus-critical behavior of + the SIGHASH_SINGLE bug. (inIdx is *not* checked for validity) + """ + HASH_ONE = b'\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' + + if inIdx >= len(txTo.vin): + return (HASH_ONE, "inIdx %d out of range (%d)" % (inIdx, len(txTo.vin))) + txtmp = CTransaction(txTo) + + for txin in txtmp.vin: + txin.scriptSig = b'' + txtmp.vin[inIdx].scriptSig = FindAndDelete(script, CScript([OP_CODESEPARATOR])) + + if (hashtype & 0x1f) == SIGHASH_NONE: + txtmp.vout = [] + + for i in range(len(txtmp.vin)): + if i != inIdx: + txtmp.vin[i].nSequence = 0 + + elif (hashtype & 0x1f) == SIGHASH_SINGLE: + outIdx = inIdx + if outIdx >= len(txtmp.vout): + return (HASH_ONE, "outIdx %d out of range (%d)" % (outIdx, len(txtmp.vout))) + + tmp = txtmp.vout[outIdx] + txtmp.vout = [] + for i in range(outIdx): + txtmp.vout.append(CTxOut(-1)) + txtmp.vout.append(tmp) + + for i in range(len(txtmp.vin)): + if i != inIdx: + txtmp.vin[i].nSequence = 0 + + if hashtype & SIGHASH_ANYONECANPAY: + tmp = txtmp.vin[inIdx] + txtmp.vin = [] + txtmp.vin.append(tmp) + + s = txtmp.serialize_without_witness() + s += struct.pack(b"> 25 + chk = (chk & 0x1ffffff) << 5 ^ value + for i in range(5): + chk ^= generator[i] if ((top >> i) & 1) else 0 + return chk + + +def bech32_hrp_expand(hrp): + """Expand the HRP into values for checksum computation.""" + return [ord(x) >> 5 for x in hrp] + [0] + [ord(x) & 31 for x in hrp] + + +def bech32_verify_checksum(hrp, data): + """Verify a checksum given HRP and converted data characters.""" + return bech32_polymod(bech32_hrp_expand(hrp) + data) == 1 + + +def bech32_create_checksum(hrp, data): + """Compute the checksum values given HRP and data.""" + values = bech32_hrp_expand(hrp) + data + polymod = bech32_polymod(values + [0, 0, 0, 0, 0, 0]) ^ 1 + return [(polymod >> 5 * (5 - i)) & 31 for i in range(6)] + + +def bech32_encode(hrp, data): + """Compute a Bech32 string given HRP and data values.""" + combined = data + bech32_create_checksum(hrp, data) + return hrp + '1' + ''.join([CHARSET[d] for d in combined]) + + +def bech32_decode(bech): + """Validate a Bech32 string, and determine HRP and data.""" + if ((any(ord(x) < 33 or ord(x) > 126 for x in bech)) or + (bech.lower() != bech and bech.upper() != bech)): + return (None, None) + bech = bech.lower() + pos = bech.rfind('1') + if pos < 1 or pos + 7 > len(bech) or len(bech) > 90: + return (None, None) + if not all(x in CHARSET for x in bech[pos+1:]): + return (None, None) + hrp = bech[:pos] + data = [CHARSET.find(x) for x in bech[pos+1:]] + if not bech32_verify_checksum(hrp, data): + return (None, None) + return (hrp, data[:-6]) + + +def convertbits(data, frombits, tobits, pad=True): + """General power-of-2 base conversion.""" + acc = 0 + bits = 0 + ret = [] + maxv = (1 << tobits) - 1 + max_acc = (1 << (frombits + tobits - 1)) - 1 + for value in data: + if value < 0 or (value >> frombits): + return None + acc = ((acc << frombits) | value) & max_acc + bits += frombits + while bits >= tobits: + bits -= tobits + ret.append((acc >> bits) & maxv) + if pad: + if bits: + ret.append((acc << (tobits - bits)) & maxv) + elif bits >= frombits or ((acc << (tobits - bits)) & maxv): + return None + return ret + + +def decode(hrp, addr): + """Decode a segwit address.""" + hrpgot, data = bech32_decode(addr) + if hrpgot != hrp: + return (None, None) + decoded = convertbits(data[1:], 5, 8, False) + if decoded is None or len(decoded) < 2 or len(decoded) > 40: + return (None, None) + if data[0] > 16: + return (None, None) + if data[0] == 0 and len(decoded) != 20 and len(decoded) != 32: + return (None, None) + return (data[0], decoded) + + +def encode(hrp, witver, witprog): + """Encode a segwit address.""" + ret = bech32_encode(hrp, [witver] + convertbits(witprog, 8, 5)) + if decode(hrp, ret) == (None, None): + return None + return ret diff --git a/basicswap/contrib/test_framework/siphash.py b/basicswap/contrib/test_framework/siphash.py new file mode 100644 index 0000000..8583684 --- /dev/null +++ b/basicswap/contrib/test_framework/siphash.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +# Copyright (c) 2016-2018 The Bitcoin Core developers +# Distributed under the MIT software license, see the accompanying +# file COPYING or http://www.opensource.org/licenses/mit-license.php. +"""Specialized SipHash-2-4 implementations. + +This implements SipHash-2-4 for 256-bit integers. +""" + +def rotl64(n, b): + return n >> (64 - b) | (n & ((1 << (64 - b)) - 1)) << b + +def siphash_round(v0, v1, v2, v3): + v0 = (v0 + v1) & ((1 << 64) - 1) + v1 = rotl64(v1, 13) + v1 ^= v0 + v0 = rotl64(v0, 32) + v2 = (v2 + v3) & ((1 << 64) - 1) + v3 = rotl64(v3, 16) + v3 ^= v2 + v0 = (v0 + v3) & ((1 << 64) - 1) + v3 = rotl64(v3, 21) + v3 ^= v0 + v2 = (v2 + v1) & ((1 << 64) - 1) + v1 = rotl64(v1, 17) + v1 ^= v2 + v2 = rotl64(v2, 32) + return (v0, v1, v2, v3) + +def siphash256(k0, k1, h): + n0 = h & ((1 << 64) - 1) + n1 = (h >> 64) & ((1 << 64) - 1) + n2 = (h >> 128) & ((1 << 64) - 1) + n3 = (h >> 192) & ((1 << 64) - 1) + v0 = 0x736f6d6570736575 ^ k0 + v1 = 0x646f72616e646f6d ^ k1 + v2 = 0x6c7967656e657261 ^ k0 + v3 = 0x7465646279746573 ^ k1 ^ n0 + v0, v1, v2, v3 = siphash_round(v0, v1, v2, v3) + v0, v1, v2, v3 = siphash_round(v0, v1, v2, v3) + v0 ^= n0 + v3 ^= n1 + v0, v1, v2, v3 = siphash_round(v0, v1, v2, v3) + v0, v1, v2, v3 = siphash_round(v0, v1, v2, v3) + v0 ^= n1 + v3 ^= n2 + v0, v1, v2, v3 = siphash_round(v0, v1, v2, v3) + v0, v1, v2, v3 = siphash_round(v0, v1, v2, v3) + v0 ^= n2 + v3 ^= n3 + v0, v1, v2, v3 = siphash_round(v0, v1, v2, v3) + v0, v1, v2, v3 = siphash_round(v0, v1, v2, v3) + v0 ^= n3 + v3 ^= 0x2000000000000000 + v0, v1, v2, v3 = siphash_round(v0, v1, v2, v3) + v0, v1, v2, v3 = siphash_round(v0, v1, v2, v3) + v0 ^= 0x2000000000000000 + v2 ^= 0xFF + v0, v1, v2, v3 = siphash_round(v0, v1, v2, v3) + v0, v1, v2, v3 = siphash_round(v0, v1, v2, v3) + v0, v1, v2, v3 = siphash_round(v0, v1, v2, v3) + v0, v1, v2, v3 = siphash_round(v0, v1, v2, v3) + return v0 ^ v1 ^ v2 ^ v3 diff --git a/basicswap/contrib/test_framework/util.py b/basicswap/contrib/test_framework/util.py new file mode 100644 index 0000000..c9f55e8 --- /dev/null +++ b/basicswap/contrib/test_framework/util.py @@ -0,0 +1,619 @@ +#!/usr/bin/env python3 +# Copyright (c) 2014-2020 The Bitcoin Core developers +# Distributed under the MIT software license, see the accompanying +# file COPYING or http://www.opensource.org/licenses/mit-license.php. +"""Helpful routines for regression testing.""" + +from base64 import b64encode +from binascii import unhexlify +from decimal import Decimal, ROUND_DOWN +from subprocess import CalledProcessError +import inspect +import json +import logging +import os +import random +import re +import time + +from . import coverage +from .authproxy import AuthServiceProxy, JSONRPCException +from io import BytesIO + +logger = logging.getLogger("TestFramework.utils") + +# Assert functions +################## + + +def assert_approx(v, vexp, vspan=0.00001): + """Assert that `v` is within `vspan` of `vexp`""" + if v < vexp - vspan: + raise AssertionError("%s < [%s..%s]" % (str(v), str(vexp - vspan), str(vexp + vspan))) + if v > vexp + vspan: + raise AssertionError("%s > [%s..%s]" % (str(v), str(vexp - vspan), str(vexp + vspan))) + + +def assert_fee_amount(fee, tx_size, fee_per_kB): + """Assert the fee was in range""" + target_fee = round(tx_size * fee_per_kB / 1000, 8) + if fee < target_fee: + raise AssertionError("Fee of %s BTC too low! (Should be %s BTC)" % (str(fee), str(target_fee))) + # allow the wallet's estimation to be at most 2 bytes off + if fee > (tx_size + 2) * fee_per_kB / 1000: + raise AssertionError("Fee of %s BTC too high! (Should be %s BTC)" % (str(fee), str(target_fee))) + + +def assert_equal(thing1, thing2, *args): + if thing1 != thing2 or any(thing1 != arg for arg in args): + raise AssertionError("not(%s)" % " == ".join(str(arg) for arg in (thing1, thing2) + args)) + + +def assert_greater_than(thing1, thing2): + if thing1 <= thing2: + raise AssertionError("%s <= %s" % (str(thing1), str(thing2))) + + +def assert_greater_than_or_equal(thing1, thing2): + if thing1 < thing2: + raise AssertionError("%s < %s" % (str(thing1), str(thing2))) + + +def assert_raises(exc, fun, *args, **kwds): + assert_raises_message(exc, None, fun, *args, **kwds) + + +def assert_raises_message(exc, message, fun, *args, **kwds): + try: + fun(*args, **kwds) + except JSONRPCException: + raise AssertionError("Use assert_raises_rpc_error() to test RPC failures") + except exc as e: + if message is not None and message not in e.error['message']: + raise AssertionError( + "Expected substring not found in error message:\nsubstring: '{}'\nerror message: '{}'.".format( + message, e.error['message'])) + except Exception as e: + raise AssertionError("Unexpected exception raised: " + type(e).__name__) + else: + raise AssertionError("No exception raised") + + +def assert_raises_process_error(returncode, output, fun, *args, **kwds): + """Execute a process and asserts the process return code and output. + + Calls function `fun` with arguments `args` and `kwds`. Catches a CalledProcessError + and verifies that the return code and output are as expected. Throws AssertionError if + no CalledProcessError was raised or if the return code and output are not as expected. + + Args: + returncode (int): the process return code. + output (string): [a substring of] the process output. + fun (function): the function to call. This should execute a process. + args*: positional arguments for the function. + kwds**: named arguments for the function. + """ + try: + fun(*args, **kwds) + except CalledProcessError as e: + if returncode != e.returncode: + raise AssertionError("Unexpected returncode %i" % e.returncode) + if output not in e.output: + raise AssertionError("Expected substring not found:" + e.output) + else: + raise AssertionError("No exception raised") + + +def assert_raises_rpc_error(code, message, fun, *args, **kwds): + """Run an RPC and verify that a specific JSONRPC exception code and message is raised. + + Calls function `fun` with arguments `args` and `kwds`. Catches a JSONRPCException + and verifies that the error code and message are as expected. Throws AssertionError if + no JSONRPCException was raised or if the error code/message are not as expected. + + Args: + code (int), optional: the error code returned by the RPC call (defined + in src/rpc/protocol.h). Set to None if checking the error code is not required. + message (string), optional: [a substring of] the error string returned by the + RPC call. Set to None if checking the error string is not required. + fun (function): the function to call. This should be the name of an RPC. + args*: positional arguments for the function. + kwds**: named arguments for the function. + """ + assert try_rpc(code, message, fun, *args, **kwds), "No exception raised" + + +def try_rpc(code, message, fun, *args, **kwds): + """Tries to run an rpc command. + + Test against error code and message if the rpc fails. + Returns whether a JSONRPCException was raised.""" + try: + fun(*args, **kwds) + except JSONRPCException as e: + # JSONRPCException was thrown as expected. Check the code and message values are correct. + if (code is not None) and (code != e.error["code"]): + raise AssertionError("Unexpected JSONRPC error code %i" % e.error["code"]) + if (message is not None) and (message not in e.error['message']): + raise AssertionError( + "Expected substring not found in error message:\nsubstring: '{}'\nerror message: '{}'.".format( + message, e.error['message'])) + return True + except Exception as e: + raise AssertionError("Unexpected exception raised: " + type(e).__name__) + else: + return False + + +def assert_is_hex_string(string): + try: + int(string, 16) + except Exception as e: + raise AssertionError("Couldn't interpret %r as hexadecimal; raised: %s" % (string, e)) + + +def assert_is_hash_string(string, length=64): + if not isinstance(string, str): + raise AssertionError("Expected a string, got type %r" % type(string)) + elif length and len(string) != length: + raise AssertionError("String of length %d expected; got %d" % (length, len(string))) + elif not re.match('[abcdef0-9]+$', string): + raise AssertionError("String %r contains invalid characters for a hash." % string) + + +def assert_array_result(object_array, to_match, expected, should_not_find=False): + """ + Pass in array of JSON objects, a dictionary with key/value pairs + to match against, and another dictionary with expected key/value + pairs. + If the should_not_find flag is true, to_match should not be found + in object_array + """ + if should_not_find: + assert_equal(expected, {}) + num_matched = 0 + for item in object_array: + all_match = True + for key, value in to_match.items(): + if item[key] != value: + all_match = False + if not all_match: + continue + elif should_not_find: + num_matched = num_matched + 1 + for key, value in expected.items(): + if item[key] != value: + raise AssertionError("%s : expected %s=%s" % (str(item), str(key), str(value))) + num_matched = num_matched + 1 + if num_matched == 0 and not should_not_find: + raise AssertionError("No objects matched %s" % (str(to_match))) + if num_matched > 0 and should_not_find: + raise AssertionError("Objects were found %s" % (str(to_match))) + + +# Utility functions +################### + + +def check_json_precision(): + """Make sure json library being used does not lose precision converting BTC values""" + n = Decimal("20000000.00000003") + satoshis = int(json.loads(json.dumps(float(n))) * 1.0e8) + if satoshis != 2000000000000003: + raise RuntimeError("JSON encode/decode loses precision") + + +def EncodeDecimal(o): + if isinstance(o, Decimal): + return str(o) + raise TypeError(repr(o) + " is not JSON serializable") + + +def count_bytes(hex_string): + return len(bytearray.fromhex(hex_string)) + + +def hex_str_to_bytes(hex_str): + return unhexlify(hex_str.encode('ascii')) + + +def str_to_b64str(string): + return b64encode(string.encode('utf-8')).decode('ascii') + + +def satoshi_round(amount): + return Decimal(amount).quantize(Decimal('0.00000001'), rounding=ROUND_DOWN) + + +def wait_until(predicate, *, attempts=float('inf'), timeout=float('inf'), lock=None, timeout_factor=1.0): + if attempts == float('inf') and timeout == float('inf'): + timeout = 60 + timeout = timeout * timeout_factor + attempt = 0 + time_end = time.time() + timeout + + while attempt < attempts and time.time() < time_end: + if lock: + with lock: + if predicate(): + return + else: + if predicate(): + return + attempt += 1 + time.sleep(0.05) + + # Print the cause of the timeout + predicate_source = "''''\n" + inspect.getsource(predicate) + "'''" + logger.error("wait_until() failed. Predicate: {}".format(predicate_source)) + if attempt >= attempts: + raise AssertionError("Predicate {} not true after {} attempts".format(predicate_source, attempts)) + elif time.time() >= time_end: + raise AssertionError("Predicate {} not true after {} seconds".format(predicate_source, timeout)) + raise RuntimeError('Unreachable') + + +# RPC/P2P connection constants and functions +############################################ + +# The maximum number of nodes a single test can spawn +MAX_NODES = 12 +# Don't assign rpc or p2p ports lower than this +PORT_MIN = int(os.getenv('TEST_RUNNER_PORT_MIN', default=11000)) +# The number of ports to "reserve" for p2p and rpc, each +PORT_RANGE = 5000 + + +class PortSeed: + # Must be initialized with a unique integer for each process + n = None + + +def get_rpc_proxy(url, node_number, *, timeout=None, coveragedir=None): + """ + Args: + url (str): URL of the RPC server to call + node_number (int): the node number (or id) that this calls to + + Kwargs: + timeout (int): HTTP timeout in seconds + coveragedir (str): Directory + + Returns: + AuthServiceProxy. convenience object for making RPC calls. + + """ + proxy_kwargs = {} + if timeout is not None: + proxy_kwargs['timeout'] = int(timeout) + + proxy = AuthServiceProxy(url, **proxy_kwargs) + proxy.url = url # store URL on proxy for info + + coverage_logfile = coverage.get_filename(coveragedir, node_number) if coveragedir else None + + return coverage.AuthServiceProxyWrapper(proxy, coverage_logfile) + + +def p2p_port(n): + assert n <= MAX_NODES + return PORT_MIN + n + (MAX_NODES * PortSeed.n) % (PORT_RANGE - 1 - MAX_NODES) + + +def rpc_port(n): + return PORT_MIN + PORT_RANGE + n + (MAX_NODES * PortSeed.n) % (PORT_RANGE - 1 - MAX_NODES) + + +def rpc_url(datadir, i, chain, rpchost): + rpc_u, rpc_p = get_auth_cookie(datadir, chain) + host = '127.0.0.1' + port = rpc_port(i) + if rpchost: + parts = rpchost.split(':') + if len(parts) == 2: + host, port = parts + else: + host = rpchost + return "http://%s:%s@%s:%d" % (rpc_u, rpc_p, host, int(port)) + + +# Node functions +################ + + +def initialize_datadir(dirname, n, chain): + datadir = get_datadir_path(dirname, n) + if not os.path.isdir(datadir): + os.makedirs(datadir) + # Translate chain name to config name + if chain == 'testnet3': + chain_name_conf_arg = 'testnet' + chain_name_conf_section = 'test' + else: + chain_name_conf_arg = chain + chain_name_conf_section = chain + with open(os.path.join(datadir, "particl.conf"), 'w', encoding='utf8') as f: + f.write("{}=1\n".format(chain_name_conf_arg)) + f.write("[{}]\n".format(chain_name_conf_section)) + f.write("port=" + str(p2p_port(n)) + "\n") + f.write("rpcport=" + str(rpc_port(n)) + "\n") + f.write("fallbackfee=0.0002\n") + f.write("server=1\n") + f.write("keypool=1\n") + f.write("discover=0\n") + f.write("dnsseed=0\n") + f.write("listenonion=0\n") + f.write("printtoconsole=0\n") + f.write("upnp=0\n") + f.write("shrinkdebugfile=0\n") + os.makedirs(os.path.join(datadir, 'stderr'), exist_ok=True) + os.makedirs(os.path.join(datadir, 'stdout'), exist_ok=True) + return datadir + + +def get_datadir_path(dirname, n): + return os.path.join(dirname, "node" + str(n)) + + +def append_config(datadir, options): + with open(os.path.join(datadir, "particl.conf"), 'a', encoding='utf8') as f: + for option in options: + f.write(option + "\n") + + +def get_auth_cookie(datadir, chain): + user = None + password = None + if os.path.isfile(os.path.join(datadir, "particl.conf")): + with open(os.path.join(datadir, "particl.conf"), 'r', encoding='utf8') as f: + for line in f: + if line.startswith("rpcuser="): + assert user is None # Ensure that there is only one rpcuser line + user = line.split("=")[1].strip("\n") + if line.startswith("rpcpassword="): + assert password is None # Ensure that there is only one rpcpassword line + password = line.split("=")[1].strip("\n") + try: + with open(os.path.join(datadir, chain, ".cookie"), 'r', encoding="ascii") as f: + userpass = f.read() + split_userpass = userpass.split(':') + user = split_userpass[0] + password = split_userpass[1] + except OSError: + pass + if user is None or password is None: + raise ValueError("No RPC credentials") + return user, password + + +# If a cookie file exists in the given datadir, delete it. +def delete_cookie_file(datadir, chain): + if os.path.isfile(os.path.join(datadir, chain, ".cookie")): + logger.debug("Deleting leftover cookie file") + os.remove(os.path.join(datadir, chain, ".cookie")) + + +def softfork_active(node, key): + """Return whether a softfork is active.""" + return node.getblockchaininfo()['softforks'][key]['active'] + + +def set_node_times(nodes, t): + for node in nodes: + node.setmocktime(t) + + +def disconnect_nodes(from_connection, node_num): + def get_peer_ids(): + result = [] + for peer in from_connection.getpeerinfo(): + if "testnode{}".format(node_num) in peer['subver']: + result.append(peer['id']) + return result + + peer_ids = get_peer_ids() + if not peer_ids: + logger.warning("disconnect_nodes: {} and {} were not connected".format( + from_connection.index, + node_num, + )) + return + for peer_id in peer_ids: + try: + from_connection.disconnectnode(nodeid=peer_id) + except JSONRPCException as e: + # If this node is disconnected between calculating the peer id + # and issuing the disconnect, don't worry about it. + # This avoids a race condition if we're mass-disconnecting peers. + if e.error['code'] != -29: # RPC_CLIENT_NODE_NOT_CONNECTED + raise + + # wait to disconnect + wait_until(lambda: not get_peer_ids(), timeout=5) + + +def connect_nodes(from_connection, node_num): + ip_port = "127.0.0.1:" + str(p2p_port(node_num)) + from_connection.addnode(ip_port, "onetry") + # poll until version handshake complete to avoid race conditions + # with transaction relaying + # See comments in net_processing: + # * Must have a version message before anything else + # * Must have a verack message before anything else + wait_until(lambda: all(peer['version'] != 0 for peer in from_connection.getpeerinfo())) + wait_until(lambda: all(peer['bytesrecv_per_msg'].pop('verack', 0) == 24 for peer in from_connection.getpeerinfo())) + + +# Transaction/Block functions +############################# + + +def find_output(node, txid, amount, *, blockhash=None): + """ + Return index to output of txid with value amount + Raises exception if there is none. + """ + txdata = node.getrawtransaction(txid, 1, blockhash) + for i in range(len(txdata["vout"])): + if txdata["vout"][i]["value"] == amount: + return i + raise RuntimeError("find_output txid %s : %s not found" % (txid, str(amount))) + + +def gather_inputs(from_node, amount_needed, confirmations_required=1): + """ + Return a random set of unspent txouts that are enough to pay amount_needed + """ + assert confirmations_required >= 0 + utxo = from_node.listunspent(confirmations_required) + random.shuffle(utxo) + inputs = [] + total_in = Decimal("0.00000000") + while total_in < amount_needed and len(utxo) > 0: + t = utxo.pop() + total_in += t["amount"] + inputs.append({"txid": t["txid"], "vout": t["vout"], "address": t["address"]}) + if total_in < amount_needed: + raise RuntimeError("Insufficient funds: need %d, have %d" % (amount_needed, total_in)) + return (total_in, inputs) + + +def make_change(from_node, amount_in, amount_out, fee): + """ + Create change output(s), return them + """ + outputs = {} + amount = amount_out + fee + change = amount_in - amount + if change > amount * 2: + # Create an extra change output to break up big inputs + change_address = from_node.getnewaddress() + # Split change in two, being careful of rounding: + outputs[change_address] = Decimal(change / 2).quantize(Decimal('0.00000001'), rounding=ROUND_DOWN) + change = amount_in - amount - outputs[change_address] + if change > 0: + outputs[from_node.getnewaddress()] = change + return outputs + + +def random_transaction(nodes, amount, min_fee, fee_increment, fee_variants): + """ + Create a random transaction. + Returns (txid, hex-encoded-transaction-data, fee) + """ + from_node = random.choice(nodes) + to_node = random.choice(nodes) + fee = min_fee + fee_increment * random.randint(0, fee_variants) + + (total_in, inputs) = gather_inputs(from_node, amount + fee) + outputs = make_change(from_node, total_in, amount, fee) + outputs[to_node.getnewaddress()] = float(amount) + + rawtx = from_node.createrawtransaction(inputs, outputs) + signresult = from_node.signrawtransactionwithwallet(rawtx) + txid = from_node.sendrawtransaction(signresult["hex"], 0) + + return (txid, signresult["hex"], fee) + + +# Helper to create at least "count" utxos +# Pass in a fee that is sufficient for relay and mining new transactions. +def create_confirmed_utxos(fee, node, count): + to_generate = int(0.5 * count) + 101 + while to_generate > 0: + node.generate(min(25, to_generate)) + to_generate -= 25 + utxos = node.listunspent() + iterations = count - len(utxos) + addr1 = node.getnewaddress() + addr2 = node.getnewaddress() + if iterations <= 0: + return utxos + for i in range(iterations): + t = utxos.pop() + inputs = [] + inputs.append({"txid": t["txid"], "vout": t["vout"]}) + outputs = {} + send_value = t['amount'] - fee + outputs[addr1] = satoshi_round(send_value / 2) + outputs[addr2] = satoshi_round(send_value / 2) + raw_tx = node.createrawtransaction(inputs, outputs) + signed_tx = node.signrawtransactionwithwallet(raw_tx)["hex"] + node.sendrawtransaction(signed_tx) + + while (node.getmempoolinfo()['size'] > 0): + node.generate(1) + + utxos = node.listunspent() + assert len(utxos) >= count + return utxos + + +# Create large OP_RETURN txouts that can be appended to a transaction +# to make it large (helper for constructing large transactions). +def gen_return_txouts(): + # Some pre-processing to create a bunch of OP_RETURN txouts to insert into transactions we create + # So we have big transactions (and therefore can't fit very many into each block) + # create one script_pubkey + script_pubkey = "6a4d0200" # OP_RETURN OP_PUSH2 512 bytes + for i in range(512): + script_pubkey = script_pubkey + "01" + # concatenate 128 txouts of above script_pubkey which we'll insert before the txout for change + txouts = [] + from .messages import CTxOut + txout = CTxOut() + txout.nValue = 0 + txout.scriptPubKey = hex_str_to_bytes(script_pubkey) + for k in range(128): + txouts.append(txout) + return txouts + + +# Create a spend of each passed-in utxo, splicing in "txouts" to each raw +# transaction to make it large. See gen_return_txouts() above. +def create_lots_of_big_transactions(node, txouts, utxos, num, fee): + addr = node.getnewaddress() + txids = [] + from .messages import CTransaction + for _ in range(num): + t = utxos.pop() + inputs = [{"txid": t["txid"], "vout": t["vout"]}] + outputs = {} + change = t['amount'] - fee + outputs[addr] = satoshi_round(change) + rawtx = node.createrawtransaction(inputs, outputs) + tx = CTransaction() + tx.deserialize(BytesIO(hex_str_to_bytes(rawtx))) + for txout in txouts: + tx.vout.append(txout) + newtx = tx.serialize().hex() + signresult = node.signrawtransactionwithwallet(newtx, None, "NONE") + txid = node.sendrawtransaction(signresult["hex"], 0) + txids.append(txid) + return txids + + +def mine_large_block(node, utxos=None): + # generate a 66k transaction, + # and 14 of them is close to the 1MB block limit + num = 14 + txouts = gen_return_txouts() + utxos = utxos if utxos is not None else [] + if len(utxos) < num: + utxos.clear() + utxos.extend(node.listunspent()) + fee = 100 * node.getnetworkinfo()["relayfee"] + create_lots_of_big_transactions(node, txouts, utxos, num, fee=fee) + node.generate(1) + + +def find_vout_for_address(node, txid, addr): + """ + Locate the vout index of the given transaction sending to the + given address. Raises runtime error exception if not found. + """ + tx = node.getrawtransaction(txid, True) + for i in range(len(tx["vout"])): + if any([addr == a for a in tx["vout"][i]["scriptPubKey"]["addresses"]]): + return i + raise RuntimeError("Vout not found for address: txid=%s, addr=%s" % (txid, addr)) diff --git a/basicswap/contrib/test_framework/wallet_util.py b/basicswap/contrib/test_framework/wallet_util.py new file mode 100755 index 0000000..688d565 --- /dev/null +++ b/basicswap/contrib/test_framework/wallet_util.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python3 +# Copyright (c) 2018-2020 The Bitcoin Core developers +# Distributed under the MIT software license, see the accompanying +# file COPYING or http://www.opensource.org/licenses/mit-license.php. +"""Useful util functions for testing the wallet""" +from collections import namedtuple + +from .address import ( + byte_to_base58, + key_to_p2pkh, + key_to_p2sh_p2wpkh, + key_to_p2wpkh, + script_to_p2sh, + script_to_p2sh_p2wsh, + script_to_p2wsh, +) +from .key import ECKey +from .script import ( + CScript, + OP_0, + OP_2, + OP_3, + OP_CHECKMULTISIG, + OP_CHECKSIG, + OP_DUP, + OP_EQUAL, + OP_EQUALVERIFY, + OP_HASH160, + hash160, + sha256, +) +from .util import hex_str_to_bytes + +Key = namedtuple('Key', ['privkey', + 'pubkey', + 'p2pkh_script', + 'p2pkh_addr', + 'p2wpkh_script', + 'p2wpkh_addr', + 'p2sh_p2wpkh_script', + 'p2sh_p2wpkh_redeem_script', + 'p2sh_p2wpkh_addr']) + +Multisig = namedtuple('Multisig', ['privkeys', + 'pubkeys', + 'p2sh_script', + 'p2sh_addr', + 'redeem_script', + 'p2wsh_script', + 'p2wsh_addr', + 'p2sh_p2wsh_script', + 'p2sh_p2wsh_addr']) + +def get_key(node): + """Generate a fresh key on node + + Returns a named tuple of privkey, pubkey and all address and scripts.""" + addr = node.getnewaddress() + pubkey = node.getaddressinfo(addr)['pubkey'] + pkh = hash160(hex_str_to_bytes(pubkey)) + return Key(privkey=node.dumpprivkey(addr), + pubkey=pubkey, + p2pkh_script=CScript([OP_DUP, OP_HASH160, pkh, OP_EQUALVERIFY, OP_CHECKSIG]).hex(), + p2pkh_addr=key_to_p2pkh(pubkey), + p2wpkh_script=CScript([OP_0, pkh]).hex(), + p2wpkh_addr=key_to_p2wpkh(pubkey), + p2sh_p2wpkh_script=CScript([OP_HASH160, hash160(CScript([OP_0, pkh])), OP_EQUAL]).hex(), + p2sh_p2wpkh_redeem_script=CScript([OP_0, pkh]).hex(), + p2sh_p2wpkh_addr=key_to_p2sh_p2wpkh(pubkey)) + +def get_generate_key(): + """Generate a fresh key + + Returns a named tuple of privkey, pubkey and all address and scripts.""" + eckey = ECKey() + eckey.generate() + privkey = bytes_to_wif(eckey.get_bytes()) + pubkey = eckey.get_pubkey().get_bytes().hex() + pkh = hash160(hex_str_to_bytes(pubkey)) + return Key(privkey=privkey, + pubkey=pubkey, + p2pkh_script=CScript([OP_DUP, OP_HASH160, pkh, OP_EQUALVERIFY, OP_CHECKSIG]).hex(), + p2pkh_addr=key_to_p2pkh(pubkey), + p2wpkh_script=CScript([OP_0, pkh]).hex(), + p2wpkh_addr=key_to_p2wpkh(pubkey), + p2sh_p2wpkh_script=CScript([OP_HASH160, hash160(CScript([OP_0, pkh])), OP_EQUAL]).hex(), + p2sh_p2wpkh_redeem_script=CScript([OP_0, pkh]).hex(), + p2sh_p2wpkh_addr=key_to_p2sh_p2wpkh(pubkey)) + +def get_multisig(node): + """Generate a fresh 2-of-3 multisig on node + + Returns a named tuple of privkeys, pubkeys and all address and scripts.""" + addrs = [] + pubkeys = [] + for _ in range(3): + addr = node.getaddressinfo(node.getnewaddress()) + addrs.append(addr['address']) + pubkeys.append(addr['pubkey']) + script_code = CScript([OP_2] + [hex_str_to_bytes(pubkey) for pubkey in pubkeys] + [OP_3, OP_CHECKMULTISIG]) + witness_script = CScript([OP_0, sha256(script_code)]) + return Multisig(privkeys=[node.dumpprivkey(addr) for addr in addrs], + pubkeys=pubkeys, + p2sh_script=CScript([OP_HASH160, hash160(script_code), OP_EQUAL]).hex(), + p2sh_addr=script_to_p2sh(script_code), + redeem_script=script_code.hex(), + p2wsh_script=witness_script.hex(), + p2wsh_addr=script_to_p2wsh(script_code), + p2sh_p2wsh_script=CScript([OP_HASH160, witness_script, OP_EQUAL]).hex(), + p2sh_p2wsh_addr=script_to_p2sh_p2wsh(script_code)) + +def test_address(node, address, **kwargs): + """Get address info for `address` and test whether the returned values are as expected.""" + addr_info = node.getaddressinfo(address) + for key, value in kwargs.items(): + if value is None: + if key in addr_info.keys(): + raise AssertionError("key {} unexpectedly returned in getaddressinfo.".format(key)) + elif addr_info[key] != value: + raise AssertionError("key {} value {} did not match expected value {}".format(key, addr_info[key], value)) + +def bytes_to_wif(b, compressed=True, prefix=239): + if compressed: + b += b'\x01' + return byte_to_base58(b, prefix) + +def generate_wif_key(): + # Makes a WIF privkey for imports + k = ECKey() + k.generate() + return bytes_to_wif(k.get_bytes(), k.is_compressed) diff --git a/basicswap/db.py b/basicswap/db.py index ffdc322..02f32a8 100644 --- a/basicswap/db.py +++ b/basicswap/db.py @@ -97,7 +97,7 @@ class Bid(Base): participate_txn_refund = sa.Column(sa.LargeBinary) state = sa.Column(sa.Integer) - state_time = sa.Column(sa.BigInteger) # timestamp of last state change + state_time = sa.Column(sa.BigInteger) # Timestamp of last state change states = sa.Column(sa.LargeBinary) # Packed states and times state_note = sa.Column(sa.String) diff --git a/basicswap/ecc_util.py b/basicswap/ecc_util.py new file mode 100644 index 0000000..e88a010 --- /dev/null +++ b/basicswap/ecc_util.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import os +import codecs +import hashlib +import secrets + +from .contrib.ellipticcurve import CurveFp, Point, INFINITY, jacobi_symbol + + +class ECCParameters(): + def __init__(self, p, a, b, Gx, Gy, o): + self.p = p + self.a = a + self.b = b + self.Gx = Gx + self.Gy = Gy + self.o = o + + +ep = ECCParameters( \ + p = 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f, \ + a = 0x0, \ + b = 0x7, \ + Gx = 0x79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798, \ + Gy = 0x483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8, \ + o = 0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141) # noqa: E221,E251,E502 + +curve_secp256k1 = CurveFp(ep.p, ep.a, ep.b) +G = Point(curve_secp256k1, ep.Gx, ep.Gy, ep.o) +SECP256K1_ORDER_HALF = ep.o // 2 + + +def ToDER(P): + return bytes((4, )) + int(P.x()).to_bytes(32, byteorder='big') + int(P.y()).to_bytes(32, byteorder='big') + + +def bytes32ToInt(b): + return int.from_bytes(b, byteorder='big') + + +def intToBytes32(i): + return i.to_bytes(32, byteorder='big') + + +def intToBytes32_le(i): + return i.to_bytes(32, byteorder='little') + + +def bytesToHexStr(b): + return codecs.encode(b, 'hex').decode('utf-8') + + +def hexStrToBytes(h): + if h.startswith('0x'): + h = h[2:] + return bytes.fromhex(h) + + +def getSecretBytes(): + i = 1 + secrets.randbelow(ep.o - 1) + return intToBytes32(i) + + +def getSecretInt(): + return 1 + secrets.randbelow(ep.o - 1) + + +def getInsecureBytes(): + while True: + s = os.urandom(32) + + s_test = int.from_bytes(s, byteorder='big') + if s_test > 1 and s_test < ep.o: + return s + + +def getInsecureInt(): + while True: + s = os.urandom(32) + + s_test = int.from_bytes(s, byteorder='big') + if s_test > 1 and s_test < ep.o: + return s_test + + +def powMod(x, y, z): + # Calculate (x ** y) % z efficiently. + number = 1 + while y: + if y & 1: + number = number * x % z + y >>= 1 # y //= 2 + + x = x * x % z + return number + + +def ExpandPoint(xb, sign): + x = int.from_bytes(xb, byteorder='big') + a = (powMod(x, 3, ep.p) + 7) % ep.p + y = powMod(a, (ep.p + 1) // 4, ep.p) + + if sign: + y = ep.p - y + return Point(curve_secp256k1, x, y, ep.o) + + +def CPKToPoint(cpk): + y_parity = cpk[0] - 2 + + x = int.from_bytes(cpk[1:], byteorder='big') + a = (powMod(x, 3, ep.p) + 7) % ep.p + y = powMod(a, (ep.p + 1) // 4, ep.p) + + if y % 2 != y_parity: + y = ep.p - y + + return Point(curve_secp256k1, x, y, ep.o) + + +def pointToCPK2(point, ind=0x09): + # The function is_square(x), where x is an integer, returns whether or not x is a quadratic residue modulo p. Since p is prime, it is equivalent to the Legendre symbol (x / p) = x(p-1)/2 mod p being equal to 1[8]. + ind = bytes((ind ^ (1 if jacobi_symbol(point.y(), ep.p) == 1 else 0),)) + return ind + point.x().to_bytes(32, byteorder='big') + + +def pointToCPK(point): + + y = point.y().to_bytes(32, byteorder='big') + ind = bytes((0x03,)) if y[31] % 2 else bytes((0x02,)) + + cpk = ind + point.x().to_bytes(32, byteorder='big') + return cpk + + +def secretToCPK(secret): + secretInt = secret if isinstance(secret, int) \ + else int.from_bytes(secret, byteorder='big') + + R = G * secretInt + + Y = R.y().to_bytes(32, byteorder='big') + ind = bytes((0x03,)) if Y[31] % 2 else bytes((0x02,)) + + pubkey = ind + R.x().to_bytes(32, byteorder='big') + + return pubkey + + +def getKeypair(): + secretBytes = getSecretBytes() + return secretBytes, secretToCPK(secretBytes) + + +def hashToCurve(pubkey): + + xBytes = hashlib.sha256(pubkey).digest() + x = int.from_bytes(xBytes, byteorder='big') + + for k in range(0, 100): + # get matching y element for point + y_parity = 0 # always pick 0, + a = (powMod(x, 3, ep.p) + 7) % ep.p + y = powMod(a, (ep.p + 1) // 4, ep.p) + + # print("before parity %x" % (y)) + if y % 2 != y_parity: + y = ep.p - y + + # If x is always mod P, can R ever not be on the curve? + try: + R = Point(curve_secp256k1, x, y, ep.o) + except Exception: + x = (x + 1) % ep.p # % P? + continue + + if R == INFINITY or R * ep.o != INFINITY: # is R * O != INFINITY check necessary? Validation of Elliptic Curve Public Keys says no if cofactor = 1 + x = (x + 1) % ep.p # % P? + continue + return R + + raise ValueError('hashToCurve failed for 100 tries') + + +def hash256(inb): + return hashlib.sha256(inb).digest() + + +i2b = intToBytes32 +b2i = bytes32ToInt +b2h = bytesToHexStr +h2b = hexStrToBytes + + +def i2h(x): + return b2h(i2b(x)) + + +def testEccUtils(): + print('testEccUtils()') + + G_enc = ToDER(G) + assert(G_enc.hex() == '0479be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8') + + G_enc = pointToCPK(G) + assert(G_enc.hex() == '0279be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798') + G_dec = CPKToPoint(G_enc) + assert(G_dec == G) + + G_enc = pointToCPK2(G) + assert(G_enc.hex() == '0879be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798') + + H = hashToCurve(ToDER(G)) + assert(pointToCPK(H).hex() == '0250929b74c1a04954b78b4b6035e97a5e078a5a0f28ec96d547bfee9ace803ac0') + + print('Passed.') + + +if __name__ == "__main__": + testEccUtils() diff --git a/basicswap/http_server.py b/basicswap/http_server.py index 1eeac72..1731a78 100644 --- a/basicswap/http_server.py +++ b/basicswap/http_server.py @@ -19,7 +19,7 @@ from . import __version__ from .util import ( COIN, format8, - makeInt, + make_int, dumpj, ) from .chainparams import ( @@ -129,7 +129,7 @@ def validateAmountString(amount): def inputAmount(amount_str): validateAmountString(amount_str) - return makeInt(amount_str) + return make_int(amount_str) def setCoinFilter(form_data, field_name): diff --git a/basicswap/interface_btc.py b/basicswap/interface_btc.py new file mode 100644 index 0000000..7c989e8 --- /dev/null +++ b/basicswap/interface_btc.py @@ -0,0 +1,805 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# Copyright (c) 2020 tecnovert +# Distributed under the MIT software license, see the accompanying +# file LICENSE or http://www.opensource.org/licenses/mit-license.php. + +import time +import hashlib +import logging +from io import BytesIO + +from .util import ( + decodeScriptNum, + getCompactSizeLen, + dumpj, + format_amount, + make_int +) + +from .ecc_util import ( + G, ep, + pointToCPK, CPKToPoint, + getSecretInt, + b2h, i2b, b2i, i2h) + +from .contrib.test_framework.messages import ( + COIN, + COutPoint, + CTransaction, + CTxIn, + CTxInWitness, + CTxOut, + FromHex, + ToHex) + +from .contrib.test_framework.script import ( + CScript, + CScriptOp, + CScriptNum, + OP_IF, OP_ELSE, OP_ENDIF, + OP_0, + OP_2, + OP_16, + OP_EQUALVERIFY, + OP_CHECKSIG, + OP_SIZE, + OP_SHA256, + OP_CHECKMULTISIG, + OP_CHECKSEQUENCEVERIFY, + OP_DROP, + SIGHASH_ALL, + SegwitV0SignatureHash, + hash160) + +from .contrib.test_framework.key import ECKey, ECPubKey + +from .chainparams import CoinInterface +from .rpc import make_rpc_func +from .util import assert_cond + + +def findOutput(tx, script_pk): + for i in range(len(tx.vout)): + if tx.vout[i].scriptPubKey == script_pk: + return i + return None + + +class BTCInterface(CoinInterface): + @staticmethod + def exp(): + return 8 + + @staticmethod + def nbk(): + return 32 + + @staticmethod + def nbK(): # No. of bytes requires to encode a public key + return 33 + + @staticmethod + def witnessScaleFactor(): + return 4 + + @staticmethod + def txVersion(): + return 2 + + @staticmethod + def getTxOutputValue(tx): + rv = 0 + for output in tx.vout: + rv += output.nValue + return rv + + def compareFeeRates(self, a, b): + return abs(a - b) < 20 + + def __init__(self, coin_settings): + self.rpc_callback = make_rpc_func(coin_settings['rpcport'], coin_settings['rpcauth']) + self.txoType = CTxOut + + def getNewSecretKey(self): + return getSecretInt() + + def pubkey(self, key): + return G * key + + def encodePubkey(self, pk): + return pointToCPK(pk) + + def decodePubkey(self, pke): + return CPKToPoint(pke) + + def decodeKey(self, k): + i = b2i(k) + assert(i < ep.o) + return i + + def sumKeys(self, ka, kb): + return (ka + kb) % ep.o + + def sumPubkeys(self, Ka, Kb): + return Ka + Kb + + def extractScriptLockScriptValues(self, script_bytes): + script_len = len(script_bytes) + assert_cond(script_len > 112, 'Bad script length') + assert_cond(script_bytes[0] == OP_IF) + assert_cond(script_bytes[1] == OP_SIZE) + assert_cond(script_bytes[2:4] == bytes((1, 32))) # 0120, CScriptNum length, then data + assert_cond(script_bytes[4] == OP_EQUALVERIFY) + assert_cond(script_bytes[5] == OP_SHA256) + assert_cond(script_bytes[6] == 32) + secret_hash = script_bytes[7: 7 + 32] + assert_cond(script_bytes[39] == OP_EQUALVERIFY) + assert_cond(script_bytes[40] == OP_2) + assert_cond(script_bytes[41] == 33) + pk1 = script_bytes[42: 42 + 33] + assert_cond(script_bytes[75] == 33) + pk2 = script_bytes[76: 76 + 33] + assert_cond(script_bytes[109] == OP_2) + assert_cond(script_bytes[110] == OP_CHECKMULTISIG) + assert_cond(script_bytes[111] == OP_ELSE) + o = 112 + + # Decode script num + csv_val, nb = decodeScriptNum(script_bytes, o) + o += nb + + assert_cond(script_len == o + 8 + 66, 'Bad script length') # Fails if script too long + assert_cond(script_bytes[o] == OP_CHECKSEQUENCEVERIFY) + o += 1 + assert_cond(script_bytes[o] == OP_DROP) + o += 1 + assert_cond(script_bytes[o] == OP_2) + o += 1 + assert_cond(script_bytes[o] == 33) + o += 1 + pk3 = script_bytes[o: o + 33] + o += 33 + assert_cond(script_bytes[o] == 33) + o += 1 + pk4 = script_bytes[o: o + 33] + o += 33 + assert_cond(script_bytes[o] == OP_2) + o += 1 + assert_cond(script_bytes[o] == OP_CHECKMULTISIG) + o += 1 + assert_cond(script_bytes[o] == OP_ENDIF) + + return secret_hash, pk1, pk2, csv_val, pk3, pk4 + + def genScriptLockTxScript(self, sh, Kal, Kaf, lock_blocks, Karl, Karf): + return CScript([ + CScriptOp(OP_IF), + CScriptOp(OP_SIZE), 32, CScriptOp(OP_EQUALVERIFY), + CScriptOp(OP_SHA256), sh, CScriptOp(OP_EQUALVERIFY), + 2, self.encodePubkey(Kal), self.encodePubkey(Kaf), 2, CScriptOp(OP_CHECKMULTISIG), + CScriptOp(OP_ELSE), + lock_blocks, CScriptOp(OP_CHECKSEQUENCEVERIFY), CScriptOp(OP_DROP), + 2, self.encodePubkey(Karl), self.encodePubkey(Karf), 2, CScriptOp(OP_CHECKMULTISIG), + CScriptOp(OP_ENDIF)]) + + def createScriptLockTx(self, value, sh, Kal, Kaf, lock_blocks, Karl, Karf): + + script = self.genScriptLockTxScript(sh, Kal, Kaf, lock_blocks, Karl, Karf) + tx = CTransaction() + tx.nVersion = self.txVersion() + tx.vout.append(self.txoType(value, CScript([OP_0, hashlib.sha256(script).digest()]))) + + return tx, script + + def extractScriptLockRefundScriptValues(self, script_bytes): + script_len = len(script_bytes) + assert_cond(script_len > 73, 'Bad script length') + assert_cond(script_bytes[0] == OP_IF) + assert_cond(script_bytes[1] == OP_2) + assert_cond(script_bytes[2] == 33) + pk1 = script_bytes[3: 3 + 33] + assert_cond(script_bytes[36] == 33) + pk2 = script_bytes[37: 37 + 33] + assert_cond(script_bytes[70] == OP_2) + assert_cond(script_bytes[71] == OP_CHECKMULTISIG) + assert_cond(script_bytes[72] == OP_ELSE) + o = 73 + csv_val, nb = decodeScriptNum(script_bytes, o) + o += nb + + assert_cond(script_len == o + 5 + 33, 'Bad script length') # Fails if script too long + assert_cond(script_bytes[o] == OP_CHECKSEQUENCEVERIFY) + o += 1 + assert_cond(script_bytes[o] == OP_DROP) + o += 1 + assert_cond(script_bytes[o] == 33) + o += 1 + pk3 = script_bytes[o: o + 33] + o += 33 + assert_cond(script_bytes[o] == OP_CHECKSIG) + o += 1 + assert_cond(script_bytes[o] == OP_ENDIF) + + return pk1, pk2, csv_val, pk3 + + def genScriptLockRefundTxScript(self, Karl, Karf, csv_val, Kaf): + return CScript([ + CScriptOp(OP_IF), + 2, self.encodePubkey(Karl), self.encodePubkey(Karf), 2, CScriptOp(OP_CHECKMULTISIG), + CScriptOp(OP_ELSE), + csv_val, CScriptOp(OP_CHECKSEQUENCEVERIFY), CScriptOp(OP_DROP), + self.encodePubkey(Kaf), CScriptOp(OP_CHECKSIG), + CScriptOp(OP_ENDIF)]) + + def createScriptLockRefundTx(self, tx_lock, script_lock, Karl, Karf, csv_val, Kaf, tx_fee_rate): + + output_script = CScript([OP_0, hashlib.sha256(script_lock).digest()]) + locked_n = findOutput(tx_lock, output_script) + assert_cond(locked_n is not None, 'Output not found in tx') + locked_coin = tx_lock.vout[locked_n].nValue + + tx_lock.rehash() + tx_lock_hash_int = tx_lock.sha256 + + sh, A, B, lock1_value, C, D = self.extractScriptLockScriptValues(script_lock) + + refund_script = self.genScriptLockRefundTxScript(Karl, Karf, csv_val, Kaf) + tx = CTransaction() + tx.nVersion = self.txVersion() + tx.vin.append(CTxIn(COutPoint(tx_lock_hash_int, locked_n), nSequence=lock1_value)) + tx.vout.append(self.txoType(locked_coin, CScript([OP_0, hashlib.sha256(refund_script).digest()]))) + + witness_bytes = len(script_lock) + witness_bytes += 73 * 2 # 2 signatures (72 + 1 byts size) + witness_bytes += 2 # 2 empty witness stack values + witness_bytes += getCompactSizeLen(witness_bytes) + vsize = self.getTxVSize(tx, add_witness_bytes=witness_bytes) + pay_fee = int(tx_fee_rate * vsize / 1000) + tx.vout[0].nValue = locked_coin - pay_fee + + tx.rehash() + logging.info('createScriptLockRefundTx %s:\n fee_rate, vsize, fee: %ld, %ld, %ld.', + i2h(tx.sha256), tx_fee_rate, vsize, pay_fee) + + return tx, refund_script, tx.vout[0].nValue + + def createScriptLockRefundSpendTx(self, tx_lock_refund, script_lock_refund, Kal, tx_fee_rate): + # Returns the coinA locked coin to the leader + # The follower will sign the multisig path with a signature encumbered by the leader's coinB spend pubkey + # When the leader publishes the decrypted signature the leader's coinB spend privatekey will be revealed to the follower + + output_script = CScript([OP_0, hashlib.sha256(script_lock_refund).digest()]) + locked_n = findOutput(tx_lock_refund, output_script) + assert_cond(locked_n is not None, 'Output not found in tx') + locked_coin = tx_lock_refund.vout[locked_n].nValue + + tx_lock_refund.rehash() + tx_lock_refund_hash_int = tx_lock_refund.sha256 + + tx = CTransaction() + tx.nVersion = self.txVersion() + tx.vin.append(CTxIn(COutPoint(tx_lock_refund_hash_int, locked_n), nSequence=0)) + + pubkeyhash = hash160(self.encodePubkey(Kal)) + tx.vout.append(self.txoType(locked_coin, CScript([OP_0, pubkeyhash]))) + + witness_bytes = len(script_lock_refund) + witness_bytes += 73 * 2 # 2 signatures (72 + 1 byte size) + witness_bytes += 4 # 1 empty, 1 true witness stack values + witness_bytes += getCompactSizeLen(witness_bytes) + vsize = self.getTxVSize(tx, add_witness_bytes=witness_bytes) + pay_fee = int(tx_fee_rate * vsize / 1000) + tx.vout[0].nValue = locked_coin - pay_fee + + tx.rehash() + logging.info('createScriptLockRefundSpendTx %s:\n fee_rate, vsize, fee: %ld, %ld, %ld.', + i2h(tx.sha256), tx_fee_rate, vsize, pay_fee) + + return tx + + def createScriptLockRefundSpendToFTx(self, tx_lock_refund, script_lock_refund, pkh_dest, tx_fee_rate): + # Sends the coinA locked coin to the follower + output_script = CScript([OP_0, hashlib.sha256(script_lock_refund).digest()]) + locked_n = findOutput(tx_lock_refund, output_script) + assert_cond(locked_n is not None, 'Output not found in tx') + locked_coin = tx_lock_refund.vout[locked_n].nValue + + A, B, lock2_value, C = self.extractScriptLockRefundScriptValues(script_lock_refund) + + tx_lock_refund.rehash() + tx_lock_refund_hash_int = tx_lock_refund.sha256 + + tx = CTransaction() + tx.nVersion = self.txVersion() + tx.vin.append(CTxIn(COutPoint(tx_lock_refund_hash_int, locked_n), nSequence=lock2_value)) + + tx.vout.append(self.txoType(locked_coin, CScript([OP_0, pkh_dest]))) + + witness_bytes = len(script_lock_refund) + witness_bytes += 73 # signature (72 + 1 byte size) + witness_bytes += 1 # 1 empty stack value + witness_bytes += getCompactSizeLen(witness_bytes) + vsize = self.getTxVSize(tx, add_witness_bytes=witness_bytes) + pay_fee = int(tx_fee_rate * vsize / 1000) + tx.vout[0].nValue = locked_coin - pay_fee + + tx.rehash() + logging.info('createScriptLockRefundSpendToFTx %s:\n fee_rate, vsize, fee: %ld, %ld, %ld.', + i2h(tx.sha256), tx_fee_rate, vsize, pay_fee) + + return tx + + def createScriptLockSpendTx(self, tx_lock, script_lock, pkh_dest, tx_fee_rate): + + output_script = CScript([OP_0, hashlib.sha256(script_lock).digest()]) + locked_n = findOutput(tx_lock, output_script) + assert_cond(locked_n is not None, 'Output not found in tx') + locked_coin = tx_lock.vout[locked_n].nValue + + tx_lock.rehash() + tx_lock_hash_int = tx_lock.sha256 + + tx = CTransaction() + tx.nVersion = self.txVersion() + tx.vin.append(CTxIn(COutPoint(tx_lock_hash_int, locked_n))) + + p2wpkh = CScript([OP_0, pkh_dest]) + tx.vout.append(self.txoType(locked_coin, p2wpkh)) + + witness_bytes = len(script_lock) + witness_bytes += 33 # sv, size + witness_bytes += 73 * 2 # 2 signatures (72 + 1 byts size) + witness_bytes += 4 # 1 empty, 1 true witness stack values + witness_bytes += getCompactSizeLen(witness_bytes) + vsize = self.getTxVSize(tx, add_witness_bytes=witness_bytes) + pay_fee = int(tx_fee_rate * vsize / 1000) + tx.vout[0].nValue = locked_coin - pay_fee + + tx.rehash() + logging.info('createScriptLockSpendTx %s:\n fee_rate, vsize, fee: %ld, %ld, %ld.', + i2h(tx.sha256), tx_fee_rate, vsize, pay_fee) + + return tx + + def verifyLockTx(self, tx, script_out, + swap_value, + sh, + Kal, Kaf, + lock_value, feerate, + Karl, Karf, + check_lock_tx_inputs): + # Verify: + # + + # Not necessary to check the lock txn is mineable, as protocol will wait for it to confirm + # However by checking early we can avoid wasting time processing unmineable txns + # Check fee is reasonable + + tx_hash = self.getTxHash(tx) + logging.info('Verifying lock tx: {}.'.format(b2h(tx_hash))) + + assert_cond(tx.nVersion == self.txVersion(), 'Bad version') + assert_cond(tx.nLockTime == 0, 'Bad nLockTime') + + script_pk = CScript([OP_0, hashlib.sha256(script_out).digest()]) + locked_n = findOutput(tx, script_pk) + assert_cond(locked_n is not None, 'Output not found in tx') + locked_coin = tx.vout[locked_n].nValue + + assert_cond(locked_coin == swap_value, 'Bad locked value') + + # Check script and values + shv, A, B, csv_val, C, D = self.extractScriptLockScriptValues(script_out) + assert_cond(shv == sh, 'Bad hash lock') + assert_cond(A == self.encodePubkey(Kal), 'Bad script pubkey') + assert_cond(B == self.encodePubkey(Kaf), 'Bad script pubkey') + assert_cond(csv_val == lock_value, 'Bad script csv value') + assert_cond(C == self.encodePubkey(Karl), 'Bad script pubkey') + assert_cond(D == self.encodePubkey(Karf), 'Bad script pubkey') + + if check_lock_tx_inputs: + # Check that inputs are unspent and verify fee rate + inputs_value = 0 + add_bytes = 0 + add_witness_bytes = getCompactSizeLen(len(tx.vin)) + for pi in tx.vin: + ptx = self.rpc_callback('getrawtransaction', [i2h(pi.prevout.hash), True]) + print('ptx', dumpj(ptx)) + prevout = ptx['vout'][pi.prevout.n] + inputs_value += make_int(prevout['value']) + + prevout_type = prevout['scriptPubKey']['type'] + if prevout_type == 'witness_v0_keyhash': + add_witness_bytes += 107 # sig 72, pk 33 and 2 size bytes + add_witness_bytes += getCompactSizeLen(107) + else: + # Assume P2PKH, TODO more types + add_bytes += 107 # OP_PUSH72 OP_PUSH33 + + outputs_value = 0 + for txo in tx.vout: + outputs_value += txo.nValue + fee_paid = inputs_value - outputs_value + assert(fee_paid > 0) + + vsize = self.getTxVSize(tx, add_bytes, add_witness_bytes) + fee_rate_paid = fee_paid * 1000 / vsize + + logging.info('tx amount, vsize, feerate: %ld, %ld, %ld', locked_coin, vsize, fee_rate_paid) + + if not self.compareFeeRates(fee_rate_paid, feerate): + logging.warning('feerate paid doesn\'t match expected: %ld, %ld', fee_rate_paid, feerate) + # TODO: Display warning to user + + return tx_hash, locked_n + + def verifyLockRefundTx(self, tx, script_out, + prevout_id, prevout_n, prevout_seq, prevout_script, + Karl, Karf, csv_val_expect, Kaf, swap_value, feerate): + # Verify: + # Must have only one input with correct prevout and sequence + # Must have only one output to the p2wsh of the lock refund script + # Output value must be locked_coin - lock tx fee + + tx_hash = self.getTxHash(tx) + logging.info('Verifying lock refund tx: {}.'.format(b2h(tx_hash))) + + assert_cond(tx.nVersion == self.txVersion(), 'Bad version') + assert_cond(tx.nLockTime == 0, 'nLockTime not 0') + assert_cond(len(tx.vin) == 1, 'tx doesn\'t have one input') + + assert_cond(tx.vin[0].nSequence == prevout_seq, 'Bad input nSequence') + assert_cond(len(tx.vin[0].scriptSig) == 0, 'Input scriptsig not empty') + assert_cond(tx.vin[0].prevout.hash == b2i(prevout_id) and tx.vin[0].prevout.n == prevout_n, 'Input prevout mismatch') + + assert_cond(len(tx.vout) == 1, 'tx doesn\'t have one output') + + script_pk = CScript([OP_0, hashlib.sha256(script_out).digest()]) + locked_n = findOutput(tx, script_pk) + assert_cond(locked_n is not None, 'Output not found in tx') + locked_coin = tx.vout[locked_n].nValue + + # Check script and values + A, B, csv_val, C = self.extractScriptLockRefundScriptValues(script_out) + assert_cond(A == self.encodePubkey(Karl), 'Bad script pubkey') + assert_cond(B == self.encodePubkey(Karf), 'Bad script pubkey') + assert_cond(csv_val == csv_val_expect, 'Bad script csv value') + assert_cond(C == self.encodePubkey(Kaf), 'Bad script pubkey') + + fee_paid = swap_value - locked_coin + assert(fee_paid > 0) + + witness_bytes = len(prevout_script) + witness_bytes += 73 * 2 # 2 signatures (72 + 1 byts size) + witness_bytes += 2 # 2 empty witness stack values + witness_bytes += getCompactSizeLen(witness_bytes) + vsize = self.getTxVSize(tx, add_witness_bytes=witness_bytes) + fee_rate_paid = fee_paid * 1000 / vsize + + logging.info('tx amount, vsize, feerate: %ld, %ld, %ld', locked_coin, vsize, fee_rate_paid) + + if not self.compareFeeRates(fee_rate_paid, feerate): + raise ValueError('Bad fee rate') + + return tx_hash, locked_coin + + def verifyLockRefundSpendTx(self, tx, + lock_refund_tx_id, prevout_script, + Kal, + prevout_value, feerate): + # Verify: + # Must have only one input with correct prevout (n is always 0) and sequence + # Must have only one output sending lock refund tx value - fee to leader's address, TODO: follower shouldn't need to verify destination addr + tx_hash = self.getTxHash(tx) + logging.info('Verifying lock refund spend tx: {}.'.format(b2h(tx_hash))) + + assert_cond(tx.nVersion == self.txVersion(), 'Bad version') + assert_cond(tx.nLockTime == 0, 'nLockTime not 0') + assert_cond(len(tx.vin) == 1, 'tx doesn\'t have one input') + + assert_cond(tx.vin[0].nSequence == 0, 'Bad input nSequence') + assert_cond(len(tx.vin[0].scriptSig) == 0, 'Input scriptsig not empty') + assert_cond(tx.vin[0].prevout.hash == b2i(lock_refund_tx_id) and tx.vin[0].prevout.n == 0, 'Input prevout mismatch') + + assert_cond(len(tx.vout) == 1, 'tx doesn\'t have one output') + + p2wpkh = CScript([OP_0, hash160(self.encodePubkey(Kal))]) + locked_n = findOutput(tx, p2wpkh) + assert_cond(locked_n is not None, 'Output not found in lock refund spend tx') + tx_value = tx.vout[locked_n].nValue + + fee_paid = prevout_value - tx_value + assert(fee_paid > 0) + + witness_bytes = len(prevout_script) + witness_bytes += 73 * 2 # 2 signatures (72 + 1 byts size) + witness_bytes += 4 # 1 empty, 1 true witness stack values + witness_bytes += getCompactSizeLen(witness_bytes) + vsize = self.getTxVSize(tx, add_witness_bytes=witness_bytes) + fee_rate_paid = fee_paid * 1000 / vsize + + logging.info('tx amount, vsize, feerate: %ld, %ld, %ld', tx_value, vsize, fee_rate_paid) + + if not self.compareFeeRates(fee_rate_paid, feerate): + raise ValueError('Bad fee rate') + + return True + + def verifyLockSpendTx(self, tx, + lock_tx, lock_tx_script, + a_pkhash_f, feerate): + # Verify: + # Must have only one input with correct prevout (n is always 0) and sequence + # Must have only one output with destination and amount + + tx_hash = self.getTxHash(tx) + logging.info('Verifying lock spend tx: {}.'.format(b2h(tx_hash))) + + assert_cond(tx.nVersion == self.txVersion(), 'Bad version') + assert_cond(tx.nLockTime == 0, 'nLockTime not 0') + assert_cond(len(tx.vin) == 1, 'tx doesn\'t have one input') + + lock_tx_id = self.getTxHash(lock_tx) + + output_script = CScript([OP_0, hashlib.sha256(lock_tx_script).digest()]) + locked_n = findOutput(lock_tx, output_script) + assert_cond(locked_n is not None, 'Output not found in tx') + locked_coin = lock_tx.vout[locked_n].nValue + + assert_cond(tx.vin[0].nSequence == 0, 'Bad input nSequence') + assert_cond(len(tx.vin[0].scriptSig) == 0, 'Input scriptsig not empty') + assert_cond(tx.vin[0].prevout.hash == b2i(lock_tx_id) and tx.vin[0].prevout.n == locked_n, 'Input prevout mismatch') + + assert_cond(len(tx.vout) == 1, 'tx doesn\'t have one output') + p2wpkh = CScript([OP_0, a_pkhash_f]) + assert_cond(tx.vout[0].scriptPubKey == p2wpkh, 'Bad output destination') + + fee_paid = locked_coin - tx.vout[0].nValue + assert(fee_paid > 0) + + witness_bytes = len(lock_tx_script) + witness_bytes += 33 # sv, size + witness_bytes += 73 * 2 # 2 signatures (72 + 1 byts size) + witness_bytes += 4 # 1 empty, 1 true witness stack values + witness_bytes += getCompactSizeLen(witness_bytes) + vsize = self.getTxVSize(tx, add_witness_bytes=witness_bytes) + fee_rate_paid = fee_paid * 1000 / vsize + + logging.info('tx amount, vsize, feerate: %ld, %ld, %ld', tx.vout[0].nValue, vsize, fee_rate_paid) + + if not self.compareFeeRates(fee_rate_paid, feerate): + raise ValueError('Bad fee rate') + + return True + + def signTx(self, key_int, tx, prevout_n, prevout_script, prevout_value): + sig_hash = SegwitV0SignatureHash(prevout_script, tx, prevout_n, SIGHASH_ALL, prevout_value) + + eck = ECKey() + eck.set(i2b(key_int), compressed=True) + + return eck.sign_ecdsa(sig_hash) + b'\x01' # 0x1 is SIGHASH_ALL + + def signTxOtVES(self, key_sign, key_encrypt, tx, prevout_n, prevout_script, prevout_value): + sig_hash = SegwitV0SignatureHash(prevout_script, tx, prevout_n, SIGHASH_ALL, prevout_value) + return otves.EncSign(key_sign, key_encrypt, sig_hash) + + def verifyTxOtVES(self, tx, sig, Ks, Ke, prevout_n, prevout_script, prevout_value): + sig_hash = SegwitV0SignatureHash(prevout_script, tx, prevout_n, SIGHASH_ALL, prevout_value) + return otves.EncVrfy(Ks, Ke, sig_hash, sig) + + def decryptOtVES(self, k, esig): + return otves.DecSig(k, esig) + b'\x01' # 0x1 is SIGHASH_ALL + + def verifyTxSig(self, tx, sig, K, prevout_n, prevout_script, prevout_value): + sig_hash = SegwitV0SignatureHash(prevout_script, tx, prevout_n, SIGHASH_ALL, prevout_value) + + ecK = ECPubKey() + ecK.set_int(K.x(), K.y()) + return ecK.verify_ecdsa(sig[: -1], sig_hash) # Pop the hashtype byte + + def fundTx(self, tx, feerate): + feerate_str = format_amount(feerate, self.exp()) + rv = self.rpc_callback('fundrawtransaction', [ToHex(tx), {'feeRate': feerate_str}]) + return FromHex(tx, rv['hex']) + + def signTxWithWallet(self, tx): + rv = self.rpc_callback('signrawtransactionwithwallet', [ToHex(tx)]) + + return FromHex(tx, rv['hex']) + + def publishTx(self, tx): + return self.rpc_callback('sendrawtransaction', [ToHex(tx)]) + + def encodeTx(self, tx): + return tx.serialize() + + def loadTx(self, tx_bytes): + # Load tx from bytes to internal representation + tx = CTransaction() + tx.deserialize(BytesIO(tx_bytes)) + return tx + + def getTxHash(self, tx): + tx.rehash() + return i2b(tx.sha256) + + def getPubkeyHash(self, K): + return hash160(self.encodePubkey(K)) + + def getScriptDest(self, script): + return CScript([OP_0, hashlib.sha256(script).digest()]) + + def getPkDest(self, K): + return CScript([OP_0, self.getPubkeyHash(K)]) + + def scanTxOutset(self, dest): + return self.rpc_callback('scantxoutset', ['start', ['raw({})'.format(dest.hex())]]) + + def getTransaction(self, txid): + try: + return self.rpc_callback('getrawtransaction', [txid.hex()]) + except Exception as ex: + # TODO: filter errors + return None + + def setTxSignature(self, tx, stack): + tx.wit.vtxinwit.clear() + tx.wit.vtxinwit.append(CTxInWitness()) + tx.wit.vtxinwit[0].scriptWitness.stack = stack + return True + + def extractLeaderSig(self, tx): + return tx.wit.vtxinwit[0].scriptWitness.stack[1] + + def extractFollowerSig(self, tx): + return tx.wit.vtxinwit[0].scriptWitness.stack[2] + + def createBLockTx(self, Kbs, output_amount): + tx = CTransaction() + tx.nVersion = self.txVersion() + p2wpkh = self.getPkDest(Kbs) + tx.vout.append(self.txoType(output_amount, p2wpkh)) + return tx + + def publishBLockTx(self, Kbv, Kbs, output_amount, feerate): + b_lock_tx = self.createBLockTx(Kbs, output_amount) + + b_lock_tx = self.fundTx(b_lock_tx, feerate) + b_lock_tx_id = self.getTxHash(b_lock_tx) + b_lock_tx = self.signTxWithWallet(b_lock_tx) + + return self.publishTx(b_lock_tx) + + def recoverEncKey(self, esig, sig, K): + return otves.RecoverEncKey(esig, sig[:-1], K) # Strip sighash type + + def getTxVSize(self, tx, add_bytes=0, add_witness_bytes=0): + wsf = self.witnessScaleFactor() + len_full = len(tx.serialize_with_witness()) + add_bytes + add_witness_bytes + len_nwit = len(tx.serialize_without_witness()) + add_bytes + weight = len_nwit * (wsf - 1) + len_full + return (weight + wsf - 1) // wsf + + def findTxB(self, kbv, Kbs, cb_swap_value, cb_block_confirmed, restore_height): + raw_dest = self.getPkDest(Kbs) + + rv = self.scanTxOutset(raw_dest) + print('scanTxOutset', dumpj(rv)) + + for utxo in rv['unspents']: + if 'height' in utxo and utxo['height'] > 0 and rv['height'] - utxo['height'] > cb_block_confirmed: + if utxo['amount'] * COIN != cb_swap_value: + logging.warning('Found output to lock tx pubkey of incorrect value: %s', str(utxo['amount'])) + else: + return True + return False + + def waitForLockTxB(self, kbv, Kbs, cb_swap_value, cb_block_confirmed): + + raw_dest = self.getPkDest(Kbs) + + for i in range(20): + time.sleep(1) + rv = self.scanTxOutset(raw_dest) + print('scanTxOutset', dumpj(rv)) + + for utxo in rv['unspents']: + if 'height' in utxo and utxo['height'] > 0 and rv['height'] - utxo['height'] > cb_block_confirmed: + + if utxo['amount'] * COIN != cb_swap_value: + logging.warning('Found output to lock tx pubkey of incorrect value: %s', str(utxo['amount'])) + else: + return True + return False + + def spendBLockTx(self, address_to, kbv, kbs, cb_swap_value, b_fee, restore_height): + print('TODO: spendBLockTx') + + +def testBTCInterface(): + print('testBTCInterface') + script_bytes = bytes.fromhex('6382012088a820aaf125ff9a34a74c7a17f5e7ee9d07d17cc5e53a539f345d5f73baa7e79b65e28852210224019219ad43c47288c937ae508f26998dd81ec066827773db128fd5e262c04f21039a0fd752bd1a2234820707852e7a30253620052ecd162948a06532a817710b5952ae670114b2755221038689deba25c5578e5457ddadbaf8aeb8badf438dc22f540503dbd4ae10e14f512103c9c5d5acc996216d10852a72cd67c701bfd4b9137a4076350fd32f08db39575552ae68') + i = BTCInterface(None) + sh, a, b, csv_val, c, d = i.extractScriptLockScriptValues(script_bytes) + assert(csv_val == 20) + + script_bytes_t = script_bytes + bytes((0x00,)) + try: + sh, a, b, csv_val, c, d = i.extractScriptLockScriptValues(script_bytes_t) + assert(False), 'Should fail' + except Exception as e: + assert(str(e) == 'Bad script length') + + script_bytes_t = script_bytes[:-1] + try: + sh, a, b, csv_val, c, d = i.extractScriptLockScriptValues(script_bytes_t) + assert(False), 'Should fail' + except Exception as e: + assert(str(e) == 'Bad script length') + + script_bytes_t = bytes((0x00,)) + script_bytes[1:] + try: + sh, a, b, csv_val, c, d = i.extractScriptLockScriptValues(script_bytes_t) + assert(False), 'Should fail' + except Exception as e: + assert(str(e) == 'Bad opcode') + + # Remove the csv value + script_part_a = script_bytes[:112] + script_part_b = script_bytes[114:] + + script_bytes_t = script_part_a + bytes((0x00,)) + script_part_b + sh, a, b, csv_val, c, d = i.extractScriptLockScriptValues(script_bytes_t) + assert(csv_val == 0) + + script_bytes_t = script_part_a + bytes((OP_16,)) + script_part_b + sh, a, b, csv_val, c, d = i.extractScriptLockScriptValues(script_bytes_t) + assert(csv_val == 16) + + script_bytes_t = script_part_a + CScriptNum.encode(CScriptNum(17)) + script_part_b + sh, a, b, csv_val, c, d = i.extractScriptLockScriptValues(script_bytes_t) + assert(csv_val == 17) + + script_bytes_t = script_part_a + CScriptNum.encode(CScriptNum(-15)) + script_part_b + sh, a, b, csv_val, c, d = i.extractScriptLockScriptValues(script_bytes_t) + assert(csv_val == -15) + + script_bytes_t = script_part_a + CScriptNum.encode(CScriptNum(4000)) + script_part_b + sh, a, b, csv_val, c, d = i.extractScriptLockScriptValues(script_bytes_t) + assert(csv_val == 4000) + + max_pos = 0x7FFFFFFF + script_bytes_t = script_part_a + CScriptNum.encode(CScriptNum(max_pos)) + script_part_b + sh, a, b, csv_val, c, d = i.extractScriptLockScriptValues(script_bytes_t) + assert(csv_val == max_pos) + script_bytes_t = script_part_a + CScriptNum.encode(CScriptNum(max_pos - 1)) + script_part_b + sh, a, b, csv_val, c, d = i.extractScriptLockScriptValues(script_bytes_t) + assert(csv_val == max_pos - 1) + + script_bytes_t = script_part_a + CScriptNum.encode(CScriptNum(max_pos + 1)) + script_part_b + try: + sh, a, b, csv_val, c, d = i.extractScriptLockScriptValues(script_bytes_t) + assert(False), 'Should fail' + except Exception as e: + assert(str(e) == 'Bad scriptnum length') + + min_neg = -2147483647 + script_bytes_t = script_part_a + CScriptNum.encode(CScriptNum(min_neg)) + script_part_b + sh, a, b, csv_val, c, d = i.extractScriptLockScriptValues(script_bytes_t) + assert(csv_val == min_neg) + + script_bytes_t = script_part_a + CScriptNum.encode(CScriptNum(min_neg - 1)) + script_part_b + try: + sh, a, b, csv_val, c, d = i.extractScriptLockScriptValues(script_bytes_t) + assert(False), 'Should fail' + except Exception as e: + assert(str(e) == 'Bad scriptnum length') + + print('Passed.') + + +if __name__ == "__main__": + testBTCInterface() diff --git a/basicswap/interface_ltc.py b/basicswap/interface_ltc.py new file mode 100644 index 0000000..c052766 --- /dev/null +++ b/basicswap/interface_ltc.py @@ -0,0 +1,12 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# Copyright (c) 2020 tecnovert +# Distributed under the MIT software license, see the accompanying +# file LICENSE or http://www.opensource.org/licenses/mit-license.php. + +from .interface_btc import BTCInterface + + +class LTCInterface(BTCInterface): + pass diff --git a/basicswap/interface_part.py b/basicswap/interface_part.py new file mode 100644 index 0000000..a12ce97 --- /dev/null +++ b/basicswap/interface_part.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# Copyright (c) 2020 tecnovert +# Distributed under the MIT software license, see the accompanying +# file LICENSE or http://www.opensource.org/licenses/mit-license.php. + +from .contrib.test_framework.messages import ( + CTxOutPart, +) + +from .interface_btc import BTCInterface +from .chainparams import CoinInterface +from .rpc import make_rpc_func + + +class PARTInterface(BTCInterface): + @staticmethod + def witnessScaleFactor(): + return 2 + + @staticmethod + def txVersion(): + return 0xa0 + + def __init__(self, coin_settings): + self.rpc_callback = make_rpc_func(coin_settings['rpcport'], coin_settings['rpcauth']) + self.txoType = CTxOutPart diff --git a/basicswap/interface_xmr.py b/basicswap/interface_xmr.py new file mode 100644 index 0000000..59289e1 --- /dev/null +++ b/basicswap/interface_xmr.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# Copyright (c) 2020 tecnovert +# Distributed under the MIT software license, see the accompanying +# file LICENSE or http://www.opensource.org/licenses/mit-license.php. + +import time +import logging + +from .chainparams import CoinInterface +from .rpc_xmr import make_xmr_rpc_func, make_xmr_wallet_rpc_func + +XMR_COIN = 10 ** 12 + + +class XMRInterface(CoinInterface): + @staticmethod + def exp(): + return 12 + + @staticmethod + def nbk(): + return 32 + + @staticmethod + def nbK(): # No. of bytes requires to encode a public key + return 32 + + def __init__(self, coin_settings): + rpc_cb = make_xmr_rpc_func(coin_settings['rpcport']) + rpc_wallet_cb = make_xmr_wallet_rpc_func(coin_settings['walletrpcport'], coin_settings['walletrpcauth']) + + self.rpc_cb = rpc_cb # Not essential + self.rpc_wallet_cb = rpc_wallet_cb + + def getNewSecretKey(self): + return edu.get_secret() + + def pubkey(self, key): + return edf.scalarmult_B(key) + + def encodePubkey(self, pk): + return edu.encodepoint(pk) + + def decodePubkey(self, pke): + return edf.decodepoint(pke) + + def decodeKey(self, k): + i = b2i(k) + assert(i < edf.l and i > 8) + return i + + def sumKeys(self, ka, kb): + return (ka + kb) % edf.l + + def sumPubkeys(self, Ka, Kb): + return edf.edwards_add(Ka, Kb) + + def publishBLockTx(self, Kbv, Kbs, output_amount, feerate): + + shared_addr = xmr_util.encode_address(self.encodePubkey(Kbv), self.encodePubkey(Kbs)) + + # TODO: How to set feerate? + params = {'destinations': [{'amount': output_amount, 'address': shared_addr}]} + rv = self.rpc_wallet_cb('transfer', params) + logging.info('publishBLockTx %s to address_b58 %s', rv['tx_hash'], shared_addr) + + return rv['tx_hash'] + + def findTxB(self, kbv, Kbs, cb_swap_value, cb_block_confirmed, restore_height): + Kbv_enc = self.encodePubkey(self.pubkey(kbv)) + address_b58 = xmr_util.encode_address(Kbv_enc, self.encodePubkey(Kbs)) + + try: + self.rpc_wallet_cb('close_wallet') + except Exception as e: + logging.warning('close_wallet failed %s', str(e)) + + params = { + 'restore_height': restore_height, + 'filename': address_b58, + 'address': address_b58, + 'viewkey': b2h(intToBytes32_le(kbv)), + } + + try: + rv = self.rpc_wallet_cb('open_wallet', {'filename': address_b58}) + except Exception as e: + rv = self.rpc_wallet_cb('generate_from_keys', params) + logging.info('generate_from_keys %s', dumpj(rv)) + rv = self.rpc_wallet_cb('open_wallet', {'filename': address_b58}) + + # Debug + try: + current_height = self.rpc_cb('get_block_count')['count'] + logging.info('findTxB XMR current_height %d\nAddress: %s', current_height, address_b58) + except Exception as e: + logging.info('rpc_cb failed %s', str(e)) + current_height = None # If the transfer is available it will be deep enough + + # For a while after opening the wallet rpc cmds return empty data + for i in range(5): + params = {'transfer_type': 'available'} + rv = self.rpc_wallet_cb('incoming_transfers', params) + if 'transfers' in rv: + for transfer in rv['transfers']: + if transfer['amount'] == cb_swap_value \ + and (current_height is None or current_height - transfer['block_height'] > cb_block_confirmed): + return True + time.sleep(1 + i) + + return False + + def waitForLockTxB(self, kbv, Kbs, cb_swap_value, cb_block_confirmed, restore_height): + + Kbv_enc = self.encodePubkey(self.pubkey(kbv)) + address_b58 = xmr_util.encode_address(Kbv_enc, self.encodePubkey(Kbs)) + + try: + self.rpc_wallet_cb('close_wallet') + except Exception as e: + logging.warning('close_wallet failed %s', str(e)) + + params = { + 'filename': address_b58, + 'address': address_b58, + 'viewkey': b2h(intToBytes32_le(kbv)), + 'restore_height': restore_height, + } + self.rpc_wallet_cb('generate_from_keys', params) + + self.rpc_wallet_cb('open_wallet', {'filename': address_b58}) + # For a while after opening the wallet rpc cmds return empty data + + num_tries = 40 + for i in range(num_tries + 1): + try: + current_height = self.rpc_cb('get_block_count')['count'] + print('current_height', current_height) + except Exception as e: + logging.warning('rpc_cb failed %s', str(e)) + current_height = None # If the transfer is available it will be deep enough + + # TODO: Make accepting current_height == None a user selectable option + # Or look for all transfers and check height + + params = {'transfer_type': 'available'} + rv = self.rpc_wallet_cb('incoming_transfers', params) + print('rv', rv) + + if 'transfers' in rv: + for transfer in rv['transfers']: + if transfer['amount'] == cb_swap_value \ + and (current_height is None or current_height - transfer['block_height'] > cb_block_confirmed): + return True + + # TODO: Is it necessary to check the address? + + ''' + rv = self.rpc_wallet_cb('get_balance') + print('get_balance', rv) + + if 'per_subaddress' in rv: + for sub_addr in rv['per_subaddress']: + if sub_addr['address'] == address_b58: + + ''' + + if i >= num_tries: + raise ValueError('Balance not confirming on node') + time.sleep(1) + + return False + + def spendBLockTx(self, address_to, kbv, kbs, cb_swap_value, b_fee_rate, restore_height): + + Kbv_enc = self.encodePubkey(self.pubkey(kbv)) + Kbs_enc = self.encodePubkey(self.pubkey(kbs)) + address_b58 = xmr_util.encode_address(Kbv_enc, Kbs_enc) + + try: + self.rpc_wallet_cb('close_wallet') + except Exception as e: + logging.warning('close_wallet failed %s', str(e)) + + wallet_filename = address_b58 + '_spend' + + params = { + 'filename': wallet_filename, + 'address': address_b58, + 'viewkey': b2h(intToBytes32_le(kbv)), + 'spendkey': b2h(intToBytes32_le(kbs)), + 'restore_height': restore_height, + } + + try: + self.rpc_wallet_cb('open_wallet', {'filename': wallet_filename}) + except Exception as e: + rv = self.rpc_wallet_cb('generate_from_keys', params) + logging.info('generate_from_keys %s', dumpj(rv)) + self.rpc_wallet_cb('open_wallet', {'filename': wallet_filename}) + + # For a while after opening the wallet rpc cmds return empty data + for i in range(10): + rv = self.rpc_wallet_cb('get_balance') + print('get_balance', rv) + if rv['balance'] >= cb_swap_value: + break + + time.sleep(1 + i) + + # TODO: need a subfee from output option + b_fee = b_fee_rate * 10 # Guess + + num_tries = 20 + for i in range(1 + num_tries): + try: + params = {'destinations': [{'amount': cb_swap_value - b_fee, 'address': address_to}]} + rv = self.rpc_wallet_cb('transfer', params) + print('transfer', rv) + break + except Exception as e: + print('str(e)', str(e)) + if i >= num_tries: + raise ValueError('transfer failed.') + b_fee += b_fee_rate + logging.info('Raising fee to %d', b_fee) + + return rv['tx_hash'] diff --git a/basicswap/rpc.py b/basicswap/rpc.py index 3cf87cf..bcd7b59 100644 --- a/basicswap/rpc.py +++ b/basicswap/rpc.py @@ -93,8 +93,8 @@ class Jsonrpc(): def callrpc(rpc_port, auth, method, params=[], wallet=None): try: url = 'http://%s@127.0.0.1:%d/' % (auth, rpc_port) - if wallet: - url += 'wallet/' + wallet + if wallet is not None: + url += 'wallet/' + urllib.parse.quote(wallet) x = Jsonrpc(url) v = x.json_request(method, params) @@ -126,3 +126,14 @@ def callrpc_cli(bindir, datadir, chain, cmd, cli_bin='particl-cli'): except Exception: pass return r + + +def make_rpc_func(port, auth, wallet=None): + port = port + auth = auth + wallet = wallet + + def rpc_func(method, params=None, wallet_override=None): + nonlocal port, auth, wallet + return callrpc(port, auth, method, params, wallet if wallet_override is None else wallet_override) + return rpc_func diff --git a/basicswap/rpc_xmr.py b/basicswap/rpc_xmr.py new file mode 100644 index 0000000..9f15650 --- /dev/null +++ b/basicswap/rpc_xmr.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- + +import json +import requests + + +def callrpc_xmr(rpc_port, auth, method, params=[], path='json_rpc'): + # auth is a tuple: (username, password) + try: + url = 'http://127.0.0.1:{}/{}'.format(rpc_port, path) + request_body = { + 'method': method, + 'params': params, + 'id': 2, + 'jsonrpc': '2.0' + } + headers = { + 'content-type': 'application/json' + } + p = requests.post(url, data=json.dumps(request_body), auth=requests.auth.HTTPDigestAuth(auth[0], auth[1]), headers=headers) + r = json.loads(p.text) + except Exception as ex: + raise ValueError('RPC Server Error: {}'.format(str(ex))) + + if 'error' in r and r['error'] is not None: + raise ValueError('RPC error ' + str(r['error'])) + + return r['result'] + + +def callrpc_xmr_na(rpc_port, method, params=[], path='json_rpc'): + try: + url = 'http://127.0.0.1:{}/{}'.format(rpc_port, path) + request_body = { + 'method': method, + 'params': params, + 'id': 2, + 'jsonrpc': '2.0' + } + headers = { + 'content-type': 'application/json' + } + p = requests.post(url, data=json.dumps(request_body), headers=headers) + r = json.loads(p.text) + except Exception as ex: + raise ValueError('RPC Server Error: {}'.format(str(ex))) + + if 'error' in r and r['error'] is not None: + raise ValueError('RPC error ' + str(r['error'])) + + return r['result'] + + +def callrpc_xmr2(rpc_port, method, params=[]): + try: + url = 'http://127.0.0.1:{}/{}'.format(rpc_port, method) + headers = { + 'content-type': 'application/json' + } + p = requests.post(url, data=json.dumps(params), headers=headers) + r = json.loads(p.text) + except Exception as ex: + raise ValueError('RPC Server Error: {}'.format(str(ex))) + + return r + + +def make_xmr_rpc_func(port): + port = port + + def rpc_func(method, params=None, wallet=None): + nonlocal port + return callrpc_xmr_na(port, method, params) + return rpc_func + + +def make_xmr_wallet_rpc_func(port, auth): + port = port + auth = auth + + def rpc_func(method, params=None, wallet=None): + nonlocal port, auth + return callrpc_xmr(port, auth, method, params) + return rpc_func + diff --git a/basicswap/util.py b/basicswap/util.py index c700869..024a864 100644 --- a/basicswap/util.py +++ b/basicswap/util.py @@ -9,12 +9,15 @@ import json import hashlib from .contrib.segwit_addr import bech32_decode, convertbits, bech32_encode +OP_1 = 0x51 +OP_16 = 0x60 COIN = 100000000 DCOIN = decimal.Decimal(COIN) -def makeInt(v): - return int(dquantize(decimal.Decimal(v) * DCOIN).quantize(decimal.Decimal(1))) +def assert_cond(v, err='Bad opcode'): + if not v: + raise ValueError(err) def format8(i): @@ -188,3 +191,105 @@ def DeserialiseNum(b, o=0): if b[o + nb - 1] & 0x80: return -(v & ~(0x80 << (8 * (nb - 1)))) return v + + +def decodeScriptNum(script_bytes, o): + v = 0 + num_len = script_bytes[o] + if num_len >= OP_1 and num_len <= OP_16: + return((num_len - OP_1) + 1, 1) + + if num_len > 4: + raise ValueError('Bad scriptnum length') # Max 4 bytes + if num_len + o >= len(script_bytes): + raise ValueError('Bad script length') + o += 1 + for i in range(num_len): + b = script_bytes[o + i] + # Negative flag set in last byte, if num is positive and > 0x80 an extra 0x00 byte will be appended + if i == num_len - 1 and b & 0x80: + b &= (~(0x80) & 0xFF) + v += int(b) << 8 * i + v *= -1 + else: + v += int(b) << 8 * i + return(v, 1 + num_len) + + +def getCompactSizeLen(v): + # Compact Size + if v < 253: + return 1 + if v < 0xffff: # USHRT_MAX + return 3 + if v < 0xffffffff: # UINT_MAX + return 5 + if v < 0xffffffffffffffff: # UINT_MAX + return 9 + raise ValueError('Value too large') + + +def make_int(v, precision=8, r=0): # r = 0, no rounding, fail, r > 0 round up, r < 0 floor + if type(v) == float: + v = str(v) + elif type(v) == int: + return v * 10 ** precision + + ep = 10 ** precision + have_dp = False + rv = 0 + for c in v: + if c == '.': + rv *= ep + have_dp = True + continue + if not c.isdigit(): + raise ValueError('Invalid char') + if have_dp: + ep //= 10 + if ep <= 0: + if r == 0: + raise ValueError('Mantissa too long') + if r > 0: + # Round up + if int(c) > 4: + rv += 1 + break + + rv += ep * int(c) + else: + rv = rv * 10 + int(c) + if not have_dp: + rv *= ep + return rv + + +def validate_amount(amount, precision=8): + str_amount = str(amount) + has_decimal = False + for c in str_amount: + if c == '.' and not has_decimal: + has_decimal = True + continue + if not c.isdigit(): + raise ValueError('Invalid amount') + + ar = str_amount.split('.') + if len(ar) > 1 and len(ar[1]) > precision: + raise ValueError('Too many decimal places in amount {}'.format(str_amount)) + return True + + +def format_amount(i, display_precision, precision=None): + if precision is None: + precision = display_precision + ep = 10 ** precision + n = abs(i) + quotient = n // ep + remainder = n % ep + if display_precision != precision: + remainder %= (10 ** display_precision) + rv = '{}.{:0>{prec}}'.format(quotient, remainder, prec=display_precision) + if i < 0: + rv = '-' + rv + return rv diff --git a/basicswap/util_xmr.py b/basicswap/util_xmr.py new file mode 100644 index 0000000..75f5185 --- /dev/null +++ b/basicswap/util_xmr.py @@ -0,0 +1,17 @@ +# -*- coding: utf-8 -*- + +import xmrswap.contrib.Keccak as Keccak +from .contrib.MoneroPy.base58 import encode as xmr_b58encode + + +def cn_fast_hash(s): + k = Keccak.Keccak() + return k.Keccak((len(s) * 8, s.hex()), 1088, 512, 32 * 8, False).lower() # r = bitrate = 1088, c = capacity, n = output length in bits + + +def encode_address(view_point, spend_point, version=18): + buf = bytes((version,)) + spend_point + view_point + h = cn_fast_hash(buf) + buf = buf + bytes.fromhex(h[0: 8]) + + return xmr_b58encode(buf.hex()) diff --git a/setup.py b/setup.py index 65e89dc..deed653 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,7 @@ setuptools.setup( "sqlalchemy", "python-gnupg", "Jinja2", + "requests", ], entry_points={ "console_scripts": [ diff --git a/tests/basicswap/__init__.py b/tests/basicswap/__init__.py index 0c173cf..30104e0 100644 --- a/tests/basicswap/__init__.py +++ b/tests/basicswap/__init__.py @@ -4,6 +4,7 @@ import tests.basicswap.test_other as test_other import tests.basicswap.test_prepare as test_prepare import tests.basicswap.test_run as test_run import tests.basicswap.test_reload as test_reload +import tests.basicswap.test_xmr as test_xmr def test_suite(): @@ -12,5 +13,6 @@ def test_suite(): suite.addTests(loader.loadTestsFromModule(test_prepare)) suite.addTests(loader.loadTestsFromModule(test_run)) suite.addTests(loader.loadTestsFromModule(test_reload)) + suite.addTests(loader.loadTestsFromModule(test_xmr)) return suite diff --git a/tests/basicswap/common.py b/tests/basicswap/common.py new file mode 100644 index 0000000..3139587 --- /dev/null +++ b/tests/basicswap/common.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright (c) 2020 tecnovert +# Distributed under the MIT software license, see the accompanying +# file LICENSE.txt or http://www.opensource.org/licenses/mit-license.php. + +def checkForks(ro): + if 'bip9_softforks' in ro: + assert(ro['bip9_softforks']['csv']['status'] == 'active') + assert(ro['bip9_softforks']['segwit']['status'] == 'active') + else: + assert(ro['softforks']['csv']['active']) + assert(ro['softforks']['segwit']['active']) diff --git a/tests/basicswap/test_other.py b/tests/basicswap/test_other.py index 08971ea..5d1c92b 100644 --- a/tests/basicswap/test_other.py +++ b/tests/basicswap/test_other.py @@ -9,7 +9,7 @@ import unittest from basicswap.util import ( SerialiseNum, DeserialiseNum, - makeInt, + make_int, format8, ) from basicswap.basicswap import ( @@ -57,19 +57,28 @@ class Test(unittest.TestCase): decoded = decodeSequence(encoded) assert(decoded == blocks_val) - def test_makeInt(self): + def test_make_int(self): def test_case(vs, vf, expect_int): - assert(makeInt(vs) == expect_int) - assert(makeInt(vf) == expect_int) - vs_out = format8(makeInt(vs)) + i = make_int(vs) + assert(i == expect_int and isinstance(i, int)) + i = make_int(vf) + assert(i == expect_int and isinstance(i, int)) + vs_out = format_amount(i, 8) # Strip for i in range(7): if vs_out[-1] == '0': vs_out = vs_out[:-1] - assert(vs_out == vs) + if '.' in vs: + assert(vs_out == vs) + else: + assert(vs_out[:-2] == vs) + test_case('0', 0, 0) + test_case('1', 1, 100000000) + test_case('10', 10, 1000000000) test_case('0.00899999', 0.00899999, 899999) test_case('899999.0', 899999.0, 89999900000000) test_case('899999.00899999', 899999.00899999, 89999900899999) + test_case('0.0', 0.0, 0) test_case('1.0', 1.0, 100000000) test_case('1.1', 1.1, 110000000) test_case('1.2', 1.2, 120000000) @@ -79,6 +88,52 @@ class Test(unittest.TestCase): test_case('0.123', 0.123, 12300000) test_case('123000.000123', 123000.000123, 12300000012300) + try: + make_int('0.123456789') + assert(False) + except Exception as e: + assert(str(e) == 'Mantissa too long') + validate_amount('0.12345678') + + # floor + assert(make_int('0.123456789', r=-1) == 12345678) + # Round up + assert(make_int('0.123456789', r=1) == 12345679) + + def test_make_int12(self): + def test_case(vs, vf, expect_int): + i = make_int(vs, 12) + assert(i == expect_int and isinstance(i, int)) + i = make_int(vf, 12) + assert(i == expect_int and isinstance(i, int)) + vs_out = format_amount(i, 12) + # Strip + for i in range(7): + if vs_out[-1] == '0': + vs_out = vs_out[:-1] + if '.' in vs: + assert(vs_out == vs) + else: + assert(vs_out[:-2] == vs) + test_case('0.123456789', 0.123456789, 123456789000) + test_case('0.123456789123', 0.123456789123, 123456789123) + try: + make_int('0.1234567891234', 12) + assert(False) + except Exception as e: + assert(str(e) == 'Mantissa too long') + validate_amount('0.123456789123', 12) + try: + validate_amount('0.1234567891234', 12) + assert(False) + except Exception as e: + assert('Too many decimal places' in str(e)) + try: + validate_amount(0.1234567891234, 12) + assert(False) + except Exception as e: + assert('Too many decimal places' in str(e)) + if __name__ == '__main__': unittest.main() diff --git a/tests/basicswap/test_run.py b/tests/basicswap/test_run.py index d130870..43a0426 100644 --- a/tests/basicswap/test_run.py +++ b/tests/basicswap/test_run.py @@ -48,6 +48,9 @@ from basicswap.contrib.key import ( from basicswap.http_server import ( HttpThread, ) +from tests.basicswap.common import ( + checkForks, +) from bin.basicswap_run import startDaemon logger = logging.getLogger() @@ -205,15 +208,6 @@ def run_loop(self): btcRpc('generatetoaddress 1 {}'.format(self.btc_addr)) -def checkForks(ro): - if 'bip9_softforks' in ro: - assert(ro['bip9_softforks']['csv']['status'] == 'active') - assert(ro['bip9_softforks']['segwit']['status'] == 'active') - else: - assert(ro['softforks']['csv']['active']) - assert(ro['softforks']['segwit']['active']) - - class Test(unittest.TestCase): @classmethod diff --git a/tests/basicswap/test_xmr.py b/tests/basicswap/test_xmr.py new file mode 100644 index 0000000..06933cc --- /dev/null +++ b/tests/basicswap/test_xmr.py @@ -0,0 +1,246 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright (c) 2020 tecnovert +# Distributed under the MIT software license, see the accompanying +# file LICENSE or http://www.opensource.org/licenses/mit-license.php. + +import os +import sys +import unittest +import json +import logging +import shutil +import time +import signal +import threading +from urllib.request import urlopen +from coincurve.ecdsaotves import ( + ecdsaotves_enc_sign, + ecdsaotves_enc_verify, + ecdsaotves_dec_sig, + ecdsaotves_rec_enc_key) +from coincurve.dleag import ( + dleag_prove, + dleag_verify) + +import basicswap.config as cfg +from basicswap.basicswap import ( + BasicSwap, + Coins, + SwapTypes, + BidStates, + TxStates, + SEQUENCE_LOCK_BLOCKS, +) +from basicswap.util import ( + COIN, + toWIF, + dumpje, +) +from basicswap.rpc import ( + callrpc_cli, + waitForRPC, +) +from basicswap.contrib.key import ( + ECKey, +) +from basicswap.http_server import ( + HttpThread, +) +from bin.basicswap_run import startDaemon + +logger = logging.getLogger() +logger.level = logging.DEBUG +if not len(logger.handlers): + logger.addHandler(logging.StreamHandler(sys.stdout)) + +NUM_NODES = 3 +BASE_PORT = 14792 +BASE_RPC_PORT = 19792 +BASE_ZMQ_PORT = 20792 +PREFIX_SECRET_KEY_REGTEST = 0x2e +TEST_HTML_PORT = 1800 +stop_test = False + + + +def prepareOtherDir(datadir, nodeId, conf_file='litecoin.conf'): + node_dir = os.path.join(datadir, str(nodeId)) + if not os.path.exists(node_dir): + os.makedirs(node_dir) + filePath = os.path.join(node_dir, conf_file) + + with open(filePath, 'w+') as fp: + fp.write('regtest=1\n') + fp.write('[regtest]\n') + fp.write('port=' + str(BASE_PORT + nodeId) + '\n') + fp.write('rpcport=' + str(BASE_RPC_PORT + nodeId) + '\n') + + fp.write('daemon=0\n') + fp.write('printtoconsole=0\n') + fp.write('server=1\n') + fp.write('discover=0\n') + fp.write('listenonion=0\n') + fp.write('bind=127.0.0.1\n') + fp.write('findpeers=0\n') + fp.write('debug=1\n') + fp.write('debugexclude=libevent\n') + fp.write('fallbackfee=0.0002\n') + + fp.write('acceptnonstdtxn=0\n') + + +def prepareDir(datadir, nodeId, network_key, network_pubkey): + node_dir = os.path.join(datadir, str(nodeId)) + if not os.path.exists(node_dir): + os.makedirs(node_dir) + filePath = os.path.join(node_dir, 'particl.conf') + + with open(filePath, 'w+') as fp: + fp.write('regtest=1\n') + fp.write('[regtest]\n') + fp.write('port=' + str(BASE_PORT + nodeId) + '\n') + fp.write('rpcport=' + str(BASE_RPC_PORT + nodeId) + '\n') + + fp.write('daemon=0\n') + fp.write('printtoconsole=0\n') + fp.write('server=1\n') + fp.write('discover=0\n') + fp.write('listenonion=0\n') + fp.write('bind=127.0.0.1\n') + fp.write('findpeers=0\n') + fp.write('debug=1\n') + fp.write('debugexclude=libevent\n') + fp.write('zmqpubsmsg=tcp://127.0.0.1:' + str(BASE_ZMQ_PORT + nodeId) + '\n') + + fp.write('acceptnonstdtxn=0\n') + fp.write('minstakeinterval=5\n') + + for i in range(0, NUM_NODES): + if nodeId == i: + continue + fp.write('addnode=127.0.0.1:%d\n' % (BASE_PORT + i)) + + if nodeId < 2: + fp.write('spentindex=1\n') + fp.write('txindex=1\n') + + basicswap_dir = os.path.join(datadir, str(nodeId), 'basicswap') + if not os.path.exists(basicswap_dir): + os.makedirs(basicswap_dir) + + ltcdatadir = os.path.join(datadir, str(LTC_NODE)) + btcdatadir = os.path.join(datadir, str(BTC_NODE)) + settings_path = os.path.join(basicswap_dir, cfg.CONFIG_FILENAME) + settings = { + 'zmqhost': 'tcp://127.0.0.1', + 'zmqport': BASE_ZMQ_PORT + nodeId, + 'htmlhost': 'localhost', + 'htmlport': 12700 + nodeId, + 'network_key': network_key, + 'network_pubkey': network_pubkey, + 'chainclients': { + 'particl': { + 'connection_type': 'rpc', + 'manage_daemon': False, + 'rpcport': BASE_RPC_PORT + nodeId, + 'datadir': node_dir, + 'bindir': cfg.PARTICL_BINDIR, + 'blocks_confirmed': 2, # Faster testing + }, + 'litecoin': { + 'connection_type': 'rpc', + 'manage_daemon': False, + 'rpcport': BASE_RPC_PORT + LTC_NODE, + 'datadir': ltcdatadir, + 'bindir': cfg.LITECOIN_BINDIR, + # 'use_segwit': True, + }, + 'bitcoin': { + 'connection_type': 'rpc', + 'manage_daemon': False, + 'rpcport': BASE_RPC_PORT + BTC_NODE, + 'datadir': btcdatadir, + 'bindir': cfg.BITCOIN_BINDIR, + 'use_segwit': True, + } + }, + 'check_progress_seconds': 2, + 'check_watched_seconds': 4, + 'check_expired_seconds': 60, + 'check_events_seconds': 1, + 'min_delay_auto_accept': 1, + 'max_delay_auto_accept': 5 + } + with open(settings_path, 'w') as fp: + json.dump(settings, fp, indent=4) + + +def partRpc(cmd, node_id=0): + return callrpc_cli(cfg.PARTICL_BINDIR, os.path.join(cfg.TEST_DATADIRS, str(node_id)), 'regtest', cmd, cfg.PARTICL_CLI) + + +def btcRpc(cmd): + return callrpc_cli(cfg.BITCOIN_BINDIR, os.path.join(cfg.TEST_DATADIRS, str(BTC_NODE)), 'regtest', cmd, cfg.BITCOIN_CLI) + + +def signal_handler(sig, frame): + global stop_test + print('signal {} detected.'.format(sig)) + stop_test = True + + +def run_loop(self): + while not stop_test: + time.sleep(1) + for c in self.swap_clients: + c.update() + btcRpc('generatetoaddress 1 {}'.format(self.btc_addr)) + + +def checkForks(ro): + if 'bip9_softforks' in ro: + assert(ro['bip9_softforks']['csv']['status'] == 'active') + assert(ro['bip9_softforks']['segwit']['status'] == 'active') + else: + assert(ro['softforks']['csv']['active']) + assert(ro['softforks']['segwit']['active']) + + +class Test(unittest.TestCase): + + @classmethod + def setUpClass(cls): + super(Test, cls).setUpClass() + + cls.swap_clients = [] + cls.xmr_daemons = [] + cls.xmr_wallet_auth = [] + + cls.part_stakelimit = 0 + cls.xmr_addr = None + + signal.signal(signal.SIGINT, signal_handler) + cls.update_thread = threading.Thread(target=run_loop, args=(cls,)) + cls.update_thread.start() + + @classmethod + def tearDownClass(cls): + global stop_test + logging.info('Finalising') + stop_test = True + cls.update_thread.join() + + super(Test, cls).tearDownClass() + + def test_01_part_xmr(self): + logging.info('---------- Test PART to XMR') + #swap_clients = self.swap_clients + + #offer_id = swap_clients[0].postOffer(Coins.PART, Coins.XMR, 100 * COIN, 0.5 * COIN, 100 * COIN, SwapTypes.SELLER_FIRST) + + + +if __name__ == '__main__': + unittest.main()