From 8cc60417a8b8d30ef36a00e3ece838a06e99d759 Mon Sep 17 00:00:00 2001 From: Severiano Jaramillo Date: Sun, 12 May 2024 16:39:06 -0700 Subject: [PATCH] Implemented Lattice functionality. - Ported over Lattice functionality from the derohe project. - Ported over LatticeTest too, to verify that Lattice implementation is correct. --- .../agorise/library/crypto/bn256/Lattice.kt | 108 ++++++++++++++++++ .../library/crypto/bn256/LatticeTest.kt | 34 ++++++ 2 files changed, 142 insertions(+) create mode 100644 library/crypto/src/main/kotlin/net/agorise/library/crypto/bn256/Lattice.kt create mode 100644 library/crypto/src/test/kotlin/net/agorise/library/crypto/bn256/LatticeTest.kt diff --git a/library/crypto/src/main/kotlin/net/agorise/library/crypto/bn256/Lattice.kt b/library/crypto/src/main/kotlin/net/agorise/library/crypto/bn256/Lattice.kt new file mode 100644 index 0000000..677aef0 --- /dev/null +++ b/library/crypto/src/main/kotlin/net/agorise/library/crypto/bn256/Lattice.kt @@ -0,0 +1,108 @@ +package net.agorise.library.crypto.bn256 + +import java.math.BigInteger + +val half: BigInteger = Constants.Order.shiftRight(1) + +internal val curveLattice = Lattice( + vectors = arrayOf( + arrayOf(BigInteger("147946756881789319000765030803803410728"), BigInteger("147946756881789319010696353538189108491")), + arrayOf(BigInteger("147946756881789319020627676272574806254"), BigInteger("-147946756881789318990833708069417712965")) + ), + inverse = arrayOf( + BigInteger("147946756881789318990833708069417712965"), + BigInteger("147946756881789319010696353538189108491") + ), + det = BigInteger("43776485743678550444492811490514550177096728800832068687396408373151616991234") +) + +internal val targetLattice = Lattice( + vectors = arrayOf( + arrayOf(BigInteger("9931322734385697761"), BigInteger("9931322734385697761"), BigInteger("9931322734385697763"), BigInteger("9931322734385697764")), + arrayOf(BigInteger("4965661367192848881"), BigInteger("4965661367192848881"), BigInteger("4965661367192848882"), BigInteger("-9931322734385697762")), + arrayOf(BigInteger("-9931322734385697762"), BigInteger("-4965661367192848881"), BigInteger("4965661367192848881"), BigInteger("-4965661367192848882")), + arrayOf(BigInteger("9931322734385697763"), BigInteger("-4965661367192848881"), BigInteger("-4965661367192848881"), BigInteger("-4965661367192848881")) + ), + inverse = arrayOf( + BigInteger("734653495049373973658254490726798021314063399421879442165"), + BigInteger("147946756881789319000765030803803410728"), + BigInteger("-147946756881789319005730692170996259609"), + BigInteger("1469306990098747947464455738335385361643788813749140841702") + ), + det = Constants.Order +) + +/** + * Lattice implementation ported over from https://github.com/deroproject/derohe/blob/main/cryptography/bn256/lattice.go + */ +class Lattice(val vectors: Array>, val inverse: Array, val det: BigInteger) { + + /** + * takes a scalar mod Order as input and finds a short, positive decomposition of it + * wrt to the lattice basis. + */ + fun decompose(k: BigInteger): Array { + val n = inverse.size + + // Calculate closest vector in lattice to with Babai's rounding. + val c = Array(n) { i -> + val ci = k.multiply(inverse[i]) + round(ci, det) + } + + // Transform vectors according to c and subtract . + val out = Array(n) { BigInteger.ZERO } + + for (i in 0 until n) { + for (j in 0 until n) { + val temp = c[j].multiply(vectors[j][i]) + out[i] = out[i].add(temp) + } + + out[i] = out[i].negate().add(vectors[0][i]).add(vectors[0][i]) + } + out[0] = out[0].add(k) + + return out + } + + fun precompute(add: (UInt, UInt) -> Unit) { + val n = vectors.size + val total = 1u shl n + + for (i in 0 until n) { + for (j in 0u until total) { + if (j shr i and 1u == 1u) { + add(i.toUInt(), j) + } + } + } + } + + fun multi(scalar: BigInteger): ByteArray { + val decomp = decompose(scalar) + val maxLen = decomp.maxOf { it.bitLength() } + + val out = ByteArray(maxLen) + for ((j, x) in decomp.withIndex()) { + for (i in 0 until maxLen) { + out[i] = (out[i] + (if (x.testBit(i)) 1 else 0) shl j).toByte() + } + } + return out + } + + /** + * Returns num/denom rounded to the nearest integer. + */ + private fun round(num: BigInteger, denom: BigInteger): BigInteger { + val remainder = num.remainder(denom) + val quotient = num.divide(denom) + + return if (remainder > half) { + quotient.add(BigInteger.ONE) + } else { + quotient + } + } +} diff --git a/library/crypto/src/test/kotlin/net/agorise/library/crypto/bn256/LatticeTest.kt b/library/crypto/src/test/kotlin/net/agorise/library/crypto/bn256/LatticeTest.kt new file mode 100644 index 0000000..19576aa --- /dev/null +++ b/library/crypto/src/test/kotlin/net/agorise/library/crypto/bn256/LatticeTest.kt @@ -0,0 +1,34 @@ +package net.agorise.library.crypto.bn256 + +import java.math.BigInteger +import java.security.SecureRandom +import kotlin.test.Test +import kotlin.test.fail + +class LatticeTest { + @Test + fun `given a random number - when decompose is called on curveLattice - then reduction is small`() { + val random = SecureRandom() + val k = BigInteger(Constants.Order.bitLength(), random) + val ks = curveLattice.decompose(k) + + if (ks[0].bitLength() > 130 || ks[1].bitLength() > 130) { + fail("reduction too large") + } else if (ks[0].signum() < 0 || ks[1].signum() < 0) { + fail("reduction must be positive") + } + } + + @Test + fun `given a random number - when decompose is called on targetLattice - then reduction is small`() { + val random = SecureRandom() + val k = BigInteger(Constants.Order.bitLength(), random) + val ks = targetLattice.decompose(k) + + if (ks.any { it.bitLength() > 66 }) { + fail("reduction too large") + } else if (ks.any { it.signum() < 0 }) { + fail("reduction must be positive") + } + } +}