Implemented Lattice functionality.

- Ported over Lattice functionality from the derohe project.
- Ported over LatticeTest too, to verify that Lattice implementation is correct.
This commit is contained in:
Severiano Jaramillo 2024-05-12 16:39:06 -07:00
parent ef440c8a9f
commit 8cc60417a8
2 changed files with 142 additions and 0 deletions

View file

@ -0,0 +1,108 @@
package net.agorise.library.crypto.bn256
import java.math.BigInteger
val half: BigInteger = Constants.Order.shiftRight(1)
internal val curveLattice = Lattice(
vectors = arrayOf(
arrayOf(BigInteger("147946756881789319000765030803803410728"), BigInteger("147946756881789319010696353538189108491")),
arrayOf(BigInteger("147946756881789319020627676272574806254"), BigInteger("-147946756881789318990833708069417712965"))
),
inverse = arrayOf(
BigInteger("147946756881789318990833708069417712965"),
BigInteger("147946756881789319010696353538189108491")
),
det = BigInteger("43776485743678550444492811490514550177096728800832068687396408373151616991234")
)
internal val targetLattice = Lattice(
vectors = arrayOf(
arrayOf(BigInteger("9931322734385697761"), BigInteger("9931322734385697761"), BigInteger("9931322734385697763"), BigInteger("9931322734385697764")),
arrayOf(BigInteger("4965661367192848881"), BigInteger("4965661367192848881"), BigInteger("4965661367192848882"), BigInteger("-9931322734385697762")),
arrayOf(BigInteger("-9931322734385697762"), BigInteger("-4965661367192848881"), BigInteger("4965661367192848881"), BigInteger("-4965661367192848882")),
arrayOf(BigInteger("9931322734385697763"), BigInteger("-4965661367192848881"), BigInteger("-4965661367192848881"), BigInteger("-4965661367192848881"))
),
inverse = arrayOf(
BigInteger("734653495049373973658254490726798021314063399421879442165"),
BigInteger("147946756881789319000765030803803410728"),
BigInteger("-147946756881789319005730692170996259609"),
BigInteger("1469306990098747947464455738335385361643788813749140841702")
),
det = Constants.Order
)
/**
* Lattice implementation ported over from https://github.com/deroproject/derohe/blob/main/cryptography/bn256/lattice.go
*/
class Lattice(val vectors: Array<Array<BigInteger>>, val inverse: Array<BigInteger>, val det: BigInteger) {
/**
* takes a scalar mod Order as input and finds a short, positive decomposition of it
* wrt to the lattice basis.
*/
fun decompose(k: BigInteger): Array<BigInteger> {
val n = inverse.size
// Calculate closest vector in lattice to <k,0,0,...> with Babai's rounding.
val c = Array(n) { i ->
val ci = k.multiply(inverse[i])
round(ci, det)
}
// Transform vectors according to c and subtract <k,0,0,...>.
val out = Array(n) { BigInteger.ZERO }
for (i in 0 until n) {
for (j in 0 until n) {
val temp = c[j].multiply(vectors[j][i])
out[i] = out[i].add(temp)
}
out[i] = out[i].negate().add(vectors[0][i]).add(vectors[0][i])
}
out[0] = out[0].add(k)
return out
}
fun precompute(add: (UInt, UInt) -> Unit) {
val n = vectors.size
val total = 1u shl n
for (i in 0 until n) {
for (j in 0u until total) {
if (j shr i and 1u == 1u) {
add(i.toUInt(), j)
}
}
}
}
fun multi(scalar: BigInteger): ByteArray {
val decomp = decompose(scalar)
val maxLen = decomp.maxOf { it.bitLength() }
val out = ByteArray(maxLen)
for ((j, x) in decomp.withIndex()) {
for (i in 0 until maxLen) {
out[i] = (out[i] + (if (x.testBit(i)) 1 else 0) shl j).toByte()
}
}
return out
}
/**
* Returns num/denom rounded to the nearest integer.
*/
private fun round(num: BigInteger, denom: BigInteger): BigInteger {
val remainder = num.remainder(denom)
val quotient = num.divide(denom)
return if (remainder > half) {
quotient.add(BigInteger.ONE)
} else {
quotient
}
}
}

View file

@ -0,0 +1,34 @@
package net.agorise.library.crypto.bn256
import java.math.BigInteger
import java.security.SecureRandom
import kotlin.test.Test
import kotlin.test.fail
class LatticeTest {
@Test
fun `given a random number - when decompose is called on curveLattice - then reduction is small`() {
val random = SecureRandom()
val k = BigInteger(Constants.Order.bitLength(), random)
val ks = curveLattice.decompose(k)
if (ks[0].bitLength() > 130 || ks[1].bitLength() > 130) {
fail("reduction too large")
} else if (ks[0].signum() < 0 || ks[1].signum() < 0) {
fail("reduction must be positive")
}
}
@Test
fun `given a random number - when decompose is called on targetLattice - then reduction is small`() {
val random = SecureRandom()
val k = BigInteger(Constants.Order.bitLength(), random)
val ks = targetLattice.decompose(k)
if (ks.any { it.bitLength() > 66 }) {
fail("reduction too large")
} else if (ks.any { it.signum() < 0 }) {
fail("reduction must be positive")
}
}
}