From e5a4f72eb7ea1e2c784673fe2d56c519e9f9aada Mon Sep 17 00:00:00 2001 From: Miguel Mota Date: Wed, 30 Nov 2022 20:22:03 -0800 Subject: [PATCH] Add MerkleSumTree --- src/MerkleSumTree.ts | 161 +++++++++++++++++++++++++++++++++++++ src/index.ts | 1 + test/MerkleSumTree.test.js | 29 +++++++ 3 files changed, 191 insertions(+) create mode 100644 src/MerkleSumTree.ts create mode 100644 test/MerkleSumTree.test.js diff --git a/src/MerkleSumTree.ts b/src/MerkleSumTree.ts new file mode 100644 index 0000000..e5bbff8 --- /dev/null +++ b/src/MerkleSumTree.ts @@ -0,0 +1,161 @@ +import { Base } from './Base' + +// @credit: https://github.com/finalitylabs/pymst + +type TValue = Buffer | BigInt | string | number | null | undefined +type THashFn = (value: TValue) => Buffer + +export class Bucket { + size: BigInt + hashed: Buffer + parent: Bucket | null + left: Bucket | null + right: Bucket | null + + constructor (size: BigInt | number, hashed: Buffer) { + this.size = BigInt(size) + this.hashed = hashed + + // each node in the tree can have a parent, and a left or right sibling + this.parent = null + this.left = null + this.right = null + } +} + +export class Leaf { + hashFn: THashFn + rng: BigInt[] + data: Buffer | null + + constructor (hashFn: THashFn, rng: (number | BigInt)[], data: Buffer | null) { + this.hashFn = hashFn + this.rng = rng.map(x => BigInt(x)) + this.data = data + } + + getBucket () { + let hashed : Buffer + if (this.data) { + hashed = this.hashFn(this.data) + } else { + hashed = Buffer.alloc(32) + } + return new Bucket(BigInt(this.rng[1]) - BigInt(this.rng[0]), hashed) + } +} + +export class ProofStep { + bucket: Bucket + right: boolean + + constructor (bucket: Bucket, right: boolean) { + this.bucket = bucket + this.right = right // whether the bucket hash should be appeded on the right side in this step (default is left + } +} + +export class MerkleSumTree extends Base { + hashFn: THashFn + leaves: Leaf[] + buckets: Bucket[] + root: Bucket + + constructor (leaves: Leaf[], hashFn: THashFn) { + super() + this.leaves = leaves + this.hashFn = hashFn + + MerkleSumTree.checkConsecutive(leaves) + + this.buckets = [] + for (const l of leaves) { + this.buckets.push(l.getBucket()) + } + + let buckets = [] + for (const bucket of this.buckets) { + buckets.push(bucket) + } + + while (buckets.length !== 1) { + const newBuckets = [] + while (buckets.length) { + if (buckets.length >= 2) { + const b1 = buckets.shift() + const b2 = buckets.shift() + const size = b1.size + b2.size + const hashed = this.hashFn(Buffer.concat([this.sizeToBuffer(b1.size), this.bufferify(b1.hashed), this.sizeToBuffer(b2.size), this.bufferify(b2.hashed)])) + const b = new Bucket(size, hashed) + b2.parent = b + b1.parent = b2.parent + b1.right = b2 + b2.left = b1 + newBuckets.push(b) + } else { + newBuckets.push(buckets.shift()) + } + } + buckets = newBuckets + } + this.root = buckets[0] + } + + sizeToBuffer (size: BigInt) { + const buf = Buffer.alloc(8) + const view = new DataView(buf.buffer) + view.setBigInt64(0, BigInt(size), false) // true when little endian + return buf + } + + static checkConsecutive (leaves: Leaf[]) { + let curr = BigInt(0) + for (const leaf of leaves) { + if (leaf.rng[0] !== curr) { + throw new Error('leaf ranges are invalid') + } + curr = BigInt(leaf.rng[1]) + } + } + + // gets inclusion/exclusion proof of a bucket in the specified index + getProof (index: number | BigInt) { + let curr = this.buckets[Number(index)] + const proof = [] + while (curr && curr.parent) { + const right = !!curr.right + const bucket = curr.right ? curr.right : curr.left + curr = curr.parent + proof.push(new ProofStep(bucket, right)) + } + return proof + } + + sum (arr: BigInt[]) { + let total = BigInt(0) + for (const value of arr) { + total += BigInt(value) + } + return total + } + + // validates the suppplied proof for a specified leaf according to the root bucket + verifyProof (root: Bucket, leaf: Leaf, proof: ProofStep[]) { + const rng = [this.sum(proof.filter(x => !x.right).map(x => x.bucket.size)), BigInt(root.size) - this.sum(proof.filter(x => x.right).map(x => x.bucket.size))] + if (!(rng[0] === leaf.rng[0] && rng[1] === leaf.rng[1])) { + // supplied steps are not routing to the range specified + return false + } + let curr = leaf.getBucket() + let hashed :Buffer + for (const step of proof) { + if (step.right) { + hashed = this.hashFn(Buffer.concat([this.sizeToBuffer(curr.size), this.bufferify(curr.hashed), this.sizeToBuffer(step.bucket.size), this.bufferify(step.bucket.hashed)])) + } else { + hashed = this.hashFn(Buffer.concat([this.sizeToBuffer(step.bucket.size), this.bufferify(step.bucket.hashed), this.sizeToBuffer(curr.size), this.bufferify(curr.hashed)])) + } + curr = new Bucket(BigInt(curr.size) + BigInt(step.bucket.size), hashed) + } + return curr.size === root.size && curr.hashed.toString('hex') === root.hashed.toString('hex') + } +} diff --git a/src/index.ts b/src/index.ts index d58c21f..c922c43 100644 --- a/src/index.ts +++ b/src/index.ts @@ -2,4 +2,5 @@ import MerkleTree from './MerkleTree' export { MerkleTree } export { MerkleMountainRange } from './MerkleMountainRange' export { IncrementalMerkleTree } from './IncrementalMerkleTree' +export { MerkleSumTree } from './MerkleSumTree' export default MerkleTree diff --git a/test/MerkleSumTree.test.js b/test/MerkleSumTree.test.js new file mode 100644 index 0000000..8346ab0 --- /dev/null +++ b/test/MerkleSumTree.test.js @@ -0,0 +1,29 @@ +const test = require('tape') +const crypto = require('crypto') +const { MerkleSumTree, Leaf } = require('../dist/MerkleSumTree') + +const sha256 = (data) => crypto.createHash('sha256').update(data).digest() + +test('MerkleSumTree', t => { + t.plan(4) + + const treeSize = 2n ** 64n + // const treeSize = 18446744073709551616n + const leaves = [ + new Leaf(sha256, [0, 4], null), + new Leaf(sha256, [4, 10], Buffer.from('tx1')), + new Leaf(sha256, [10, 15], null), + new Leaf(sha256, [15, 20], Buffer.from('tx2')), + new Leaf(sha256, [20, 90], Buffer.from('tx4')), + new Leaf(sha256, [90, treeSize], null) + ] + const tree = new MerkleSumTree(leaves, sha256) + const root = tree.root.hashed.toString('hex') + t.equal(root, 'b12575680dad581d8b70dcb517f8a2e0c547ffb0eedb5eb59f4db03da3fa1c6d') + + const proof = tree.getProof(3) + t.deepEqual(proof.length, 3) + + t.true(tree.verifyProof(tree.root, leaves[3], proof)) + t.false(tree.verifyProof(tree.root, leaves[2], proof)) +})