From e42a2acc66dac2f68a3f4f9a1d3f5e9dea27f823 Mon Sep 17 00:00:00 2001 From: Severiano Jaramillo Date: Sun, 12 May 2024 21:30:47 -0700 Subject: [PATCH] Implement CurvePoint functionality - Ported over the CurvePoint functionality to Kotlin from the derohe project. - Made minor improvements to how contsants in different classes are defined. --- .../agorise/library/crypto/bn256/Constants.kt | 4 + .../library/crypto/bn256/CurvePoint.kt | 232 ++++++++++++++++++ .../net/agorise/library/crypto/bn256/GfP.kt | 8 + .../agorise/library/crypto/bn256/Lattice.kt | 69 +++--- .../library/crypto/bn256/LatticeTest.kt | 4 +- 5 files changed, 283 insertions(+), 34 deletions(-) create mode 100644 library/crypto/src/main/kotlin/net/agorise/library/crypto/bn256/CurvePoint.kt diff --git a/library/crypto/src/main/kotlin/net/agorise/library/crypto/bn256/Constants.kt b/library/crypto/src/main/kotlin/net/agorise/library/crypto/bn256/Constants.kt index dbd15a2..ee059e1 100644 --- a/library/crypto/src/main/kotlin/net/agorise/library/crypto/bn256/Constants.kt +++ b/library/crypto/src/main/kotlin/net/agorise/library/crypto/bn256/Constants.kt @@ -26,4 +26,8 @@ object Constants { // r3 is R^3 where R = 2^256 mod p. val r3 = GfP(ulongArrayOf(0xb1cd6dafda1530dfUL, 0x62f210e6a7283db6UL, 0xef7f0b0c0ada0afbUL, 0x20fd6e902d592544UL)) + + // xiTo2PSquaredMinus2Over3 is ξ^((2p²-2)/3) where ξ = i+9 (a cubic root of unity, mod p). + val xiTo2PSquaredMinus2Over3 = GfP(ulongArrayOf(0x71930c11d782e155u, 0xa6bb947cffbe3323u, 0xaa303344d4741444u, 0x2c3b3f0d26594943u)) + } diff --git a/library/crypto/src/main/kotlin/net/agorise/library/crypto/bn256/CurvePoint.kt b/library/crypto/src/main/kotlin/net/agorise/library/crypto/bn256/CurvePoint.kt new file mode 100644 index 0000000..3376823 --- /dev/null +++ b/library/crypto/src/main/kotlin/net/agorise/library/crypto/bn256/CurvePoint.kt @@ -0,0 +1,232 @@ +package net.agorise.library.crypto.bn256 + +import java.math.BigInteger + +/** + * CurvePoint implements the elliptic curve y²=x³+3. Points are kept in Jacobian + * form and t=z² when valid. G₁ is the set of points of this curve on GF(p). + * Ported over from https://github.com/deroproject/derohe/blob/main/cryptography/bn256/curve.go + */ +internal class CurvePoint( + private var x: GfP, + private var y: GfP, + private var z: GfP, + private var t: GfP, +) { + private constructor() : this(GfP(0UL), GfP(0UL), GfP(0UL), GfP(0UL)) + + override fun toString(): String { + makeAffine() + val x = GfP() + val y = GfP() + montDecode(x, this.x) + montDecode(y, this.y) + return "($x, $y)" + } + + fun set(a: CurvePoint) { + this.x.set(a.x) + this.y.set(a.y) + this.z.set(a.z) + this.t.set(a.t) + } + + /** + * Returns true if curve point is on the curve. + */ + fun isOnCurve(): Boolean { + makeAffine() + if (isInfinity()) return true + + val y2 = GfP() + val x3 = GfP() + gfpMul(y2, this.y, this.y) + gfpMul(x3, this.x, this.x) + gfpMul(x3, x3, this.x) + gfpAdd(x3, x3, curveB) + + return y2 == x3 + } + + fun setInfinity() { + this.x = GfP(0UL) + this.y = GfP.newGfP(1) + this.z = GfP(0UL) + this.t = GfP(0UL) + } + + fun isInfinity(): Boolean { + return this.z == GfP(0UL) + } + + fun add(a: CurvePoint, b: CurvePoint) { + if (a.isInfinity()) { + set(b) + return + } + if (b.isInfinity()) { + set(a) + return + } + + // See http://hyperelliptic.org/EFD/g1p/auto-code/shortw/jacobian-0/addition/add-2007-bl.op3 + + // Normalize the points by replacing a = [x1:y1:z1] and b = [x2:y2:z2] + // by [u1:s1:z1·z2] and [u2:s2:z1·z2] + // where u1 = x1·z2², s1 = y1·z2³ and u1 = x2·z1², s2 = y2·z1³ + val z12 = GfP().apply { gfpMul(this, a.z, a.z) } + val z22 = GfP().apply { gfpMul(this, b.z, b.z) } + + val u1 = GfP().apply { gfpMul(this, a.x, z22) } + val u2 = GfP().apply { gfpMul(this, b.x, z12) } + + val t = GfP().apply { gfpMul(this, b.z, z22) } + val s1 = GfP().apply { gfpMul(this, a.y, t) } + + gfpMul(t, a.z, z12) + val s2 = GfP().apply { gfpMul(this, b.y, t) } + + // Compute x = (2h)²(s²-u1-u2) + // where s = (s2-s1)/(u2-u1) is the slope of the line through + // (u1,s1) and (u2,s2). The extra factor 2h = 2(u2-u1) comes from the value of z below. + // This is also: + // 4(s2-s1)² - 4h²(u1+u2) = 4(s2-s1)² - 4h³ - 4h²(2u1) + // = r² - j - 2v + // with the notations below. + val h = GfP().apply { gfpSub(this, u2, u1) } + val xEqual = h == GfP(0UL) + + gfpAdd(t, h, h) + // i = 4h² + val i = GfP().apply { gfpMul(this, t, t) } + // j = 4h³ + val j = GfP().apply { gfpMul(this, h, i) } + + gfpSub(t, s2, s1) + val yEqual = t == GfP(0UL) + if (xEqual && yEqual) { + double(a) + return + } + val r = GfP().apply { gfpAdd(this, t, t) } + + val v = GfP().apply { gfpMul(this, u1, i) } + + // t4 = 4(s2-s1)² + val t4 = GfP().apply { gfpMul(this, r, r) } + val t6 = GfP().apply { gfpSub(this, t4, j) } + + gfpAdd(t, v, v) + gfpSub(this.x, t6, t) + + // Set y = -(2h)³(s1 + s*(x/4h²-u1)) + // This is also + // y = - 2·s1·j - (s2-s1)(2x - 2i·u1) = r(v-x) - 2·s1·j + gfpSub(t, v, this.x) // t7 + gfpMul(t4, s1, j) // t8 + gfpAdd(t6, t4, t4) // t9 + gfpMul(t4, r, t) // t10 + gfpSub(this.y, t4, t6) + + // Set z = 2(u2-u1)·z1·z2 = 2h·z1·z2 + gfpAdd(t, a.z, b.z) // t11 + gfpMul(t4, t, t) // t12 + gfpSub(t, t4, z12) // t13 + gfpSub(t4, t, z22) // t14 + gfpMul(this.z, t4, h) + } + + private fun double(a: CurvePoint) { + // See http://hyperelliptic.org/EFD/g1p/auto-code/shortw/jacobian-0/doubling/dbl-2009-l.op3 + val A = GfP().apply { gfpMul(this, a.x, a.x) } + val B = GfP().apply { gfpMul(this, a.y, a.y) } + val C = GfP().apply { gfpMul(this, B, B) } + + val t = GfP().apply { gfpAdd(this, a.x, B) } + val t2 = GfP().apply { gfpMul(this, t, t) } + gfpSub(t, t2, A) + gfpSub(t2, t, C) + + val d = GfP().apply { gfpAdd(this, t2, t2) } + gfpAdd(t, A, A) + val e = GfP().apply { gfpAdd(this, t, A) } + val f = GfP().apply { gfpMul(this, e, e) } + + gfpAdd(t, d, d) + gfpSub(this.x, f, t) + + gfpMul(this.z, a.y, a.z) + gfpAdd(this.z, this.z, this.z) + + gfpAdd(t, C, C) + gfpAdd(t2, t, t) + gfpAdd(t, t2, t2) + gfpSub(this.y, d, this.x) + gfpMul(t2, e, this.y) + gfpSub(this.y, t2, t) + } + + fun mul(a: CurvePoint, scalar: BigInteger) { + val precomp = Array(4) { CurvePoint() } + precomp[1].set(a) + precomp[2].set(a) + gfpMul(precomp[2].x, precomp[2].x, Constants.xiTo2PSquaredMinus2Over3) + precomp[3].add(precomp[1], precomp[2]) + + val multiScalar = Lattice.curveLattice.multi(scalar) + + val sum = CurvePoint().apply { setInfinity() } + val t = CurvePoint() + + for (i in multiScalar.size - 1 downTo 0) { + t.double(sum) + if (multiScalar[i].toInt() == 0) { + sum.set(t) + } else { + sum.add(t, precomp[multiScalar[i].toInt()]) + } + } + set(sum) + } + + fun makeAffine() { + if (this.z == GfP.newGfP(1)) { + return + } else if (this.z == GfP.newGfP(0)) { + this.x = GfP(0UL) + this.y = GfP.newGfP(1) + this.t = GfP(0UL) + return + } + + val zInv = GfP().apply { invert(this@CurvePoint.z) } + + val t = GfP().apply { gfpMul(this, this@CurvePoint.y, zInv) } + val zInv2 = GfP().apply { gfpMul(this, zInv, zInv) } + + gfpMul(this.x, this.x, zInv2) + gfpMul(this.y, t, zInv2) + + this.z = GfP.newGfP(1) + this.t = GfP.newGfP(1) + } + + fun neg(a: CurvePoint) { + this.x.set(a.x) + gfpNeg(this.y, a.y) + this.z.set(a.z) + this.t = GfP(0UL) + } + + companion object { + internal val curveB = GfP.newGfP(3) + + // curveGen is the generator of G₁. + internal val curveGen = CurvePoint( + GfP.newGfP(1), + GfP.newGfP(2), + GfP.newGfP(1), + GfP.newGfP(1), + ) + } +} diff --git a/library/crypto/src/main/kotlin/net/agorise/library/crypto/bn256/GfP.kt b/library/crypto/src/main/kotlin/net/agorise/library/crypto/bn256/GfP.kt index 8483c7b..2277b65 100644 --- a/library/crypto/src/main/kotlin/net/agorise/library/crypto/bn256/GfP.kt +++ b/library/crypto/src/main/kotlin/net/agorise/library/crypto/bn256/GfP.kt @@ -21,6 +21,14 @@ class GfP(val data: ULongArray) { return String.format("%016x %016x %016x %016x", data[3].toLong(), data[2].toLong(), data[1].toLong(), data[0].toLong()) } + override fun equals(other: Any?): Boolean { + if (other !is GfP) return false + repeat(4) { if (data[it] != other.data[it]) return false } + return true + } + + override fun hashCode(): Int = data.hashCode() + fun set(f: GfP) { data[0] = f.data[0] data[1] = f.data[1] diff --git a/library/crypto/src/main/kotlin/net/agorise/library/crypto/bn256/Lattice.kt b/library/crypto/src/main/kotlin/net/agorise/library/crypto/bn256/Lattice.kt index 677aef0..1167d3d 100644 --- a/library/crypto/src/main/kotlin/net/agorise/library/crypto/bn256/Lattice.kt +++ b/library/crypto/src/main/kotlin/net/agorise/library/crypto/bn256/Lattice.kt @@ -2,41 +2,14 @@ package net.agorise.library.crypto.bn256 import java.math.BigInteger -val half: BigInteger = Constants.Order.shiftRight(1) - -internal val curveLattice = Lattice( - vectors = arrayOf( - arrayOf(BigInteger("147946756881789319000765030803803410728"), BigInteger("147946756881789319010696353538189108491")), - arrayOf(BigInteger("147946756881789319020627676272574806254"), BigInteger("-147946756881789318990833708069417712965")) - ), - inverse = arrayOf( - BigInteger("147946756881789318990833708069417712965"), - BigInteger("147946756881789319010696353538189108491") - ), - det = BigInteger("43776485743678550444492811490514550177096728800832068687396408373151616991234") -) - -internal val targetLattice = Lattice( - vectors = arrayOf( - arrayOf(BigInteger("9931322734385697761"), BigInteger("9931322734385697761"), BigInteger("9931322734385697763"), BigInteger("9931322734385697764")), - arrayOf(BigInteger("4965661367192848881"), BigInteger("4965661367192848881"), BigInteger("4965661367192848882"), BigInteger("-9931322734385697762")), - arrayOf(BigInteger("-9931322734385697762"), BigInteger("-4965661367192848881"), BigInteger("4965661367192848881"), BigInteger("-4965661367192848882")), - arrayOf(BigInteger("9931322734385697763"), BigInteger("-4965661367192848881"), BigInteger("-4965661367192848881"), BigInteger("-4965661367192848881")) - ), - inverse = arrayOf( - BigInteger("734653495049373973658254490726798021314063399421879442165"), - BigInteger("147946756881789319000765030803803410728"), - BigInteger("-147946756881789319005730692170996259609"), - BigInteger("1469306990098747947464455738335385361643788813749140841702") - ), - det = Constants.Order -) - /** * Lattice implementation ported over from https://github.com/deroproject/derohe/blob/main/cryptography/bn256/lattice.go */ -class Lattice(val vectors: Array>, val inverse: Array, val det: BigInteger) { - +internal class Lattice( + val vectors: Array>, + val inverse: Array, + val det: BigInteger, +) { /** * takes a scalar mod Order as input and finds a short, positive decomposition of it * wrt to the lattice basis. @@ -105,4 +78,36 @@ class Lattice(val vectors: Array>, val inverse: Array 130 || ks[1].bitLength() > 130) { fail("reduction too large") @@ -23,7 +23,7 @@ class LatticeTest { fun `given a random number - when decompose is called on targetLattice - then reduction is small`() { val random = SecureRandom() val k = BigInteger(Constants.Order.bitLength(), random) - val ks = targetLattice.decompose(k) + val ks = Lattice.targetLattice.decompose(k) if (ks.any { it.bitLength() > 66 }) { fail("reduction too large")