Implement CurvePoint functionality

- Ported over the CurvePoint functionality to Kotlin from the derohe project.
- Made minor improvements to how contsants in different classes are defined.
This commit is contained in:
Severiano Jaramillo 2024-05-12 21:30:47 -07:00
parent 8cc60417a8
commit e42a2acc66
5 changed files with 283 additions and 34 deletions

View file

@ -26,4 +26,8 @@ object Constants {
// r3 is R^3 where R = 2^256 mod p. // r3 is R^3 where R = 2^256 mod p.
val r3 = GfP(ulongArrayOf(0xb1cd6dafda1530dfUL, 0x62f210e6a7283db6UL, 0xef7f0b0c0ada0afbUL, 0x20fd6e902d592544UL)) val r3 = GfP(ulongArrayOf(0xb1cd6dafda1530dfUL, 0x62f210e6a7283db6UL, 0xef7f0b0c0ada0afbUL, 0x20fd6e902d592544UL))
// xiTo2PSquaredMinus2Over3 is ξ^((2p²-2)/3) where ξ = i+9 (a cubic root of unity, mod p).
val xiTo2PSquaredMinus2Over3 = GfP(ulongArrayOf(0x71930c11d782e155u, 0xa6bb947cffbe3323u, 0xaa303344d4741444u, 0x2c3b3f0d26594943u))
} }

View file

@ -0,0 +1,232 @@
package net.agorise.library.crypto.bn256
import java.math.BigInteger
/**
* CurvePoint implements the elliptic curve =+3. Points are kept in Jacobian
* form and t= when valid. G₁ is the set of points of this curve on GF(p).
* Ported over from https://github.com/deroproject/derohe/blob/main/cryptography/bn256/curve.go
*/
internal class CurvePoint(
private var x: GfP,
private var y: GfP,
private var z: GfP,
private var t: GfP,
) {
private constructor() : this(GfP(0UL), GfP(0UL), GfP(0UL), GfP(0UL))
override fun toString(): String {
makeAffine()
val x = GfP()
val y = GfP()
montDecode(x, this.x)
montDecode(y, this.y)
return "($x, $y)"
}
fun set(a: CurvePoint) {
this.x.set(a.x)
this.y.set(a.y)
this.z.set(a.z)
this.t.set(a.t)
}
/**
* Returns true if curve point is on the curve.
*/
fun isOnCurve(): Boolean {
makeAffine()
if (isInfinity()) return true
val y2 = GfP()
val x3 = GfP()
gfpMul(y2, this.y, this.y)
gfpMul(x3, this.x, this.x)
gfpMul(x3, x3, this.x)
gfpAdd(x3, x3, curveB)
return y2 == x3
}
fun setInfinity() {
this.x = GfP(0UL)
this.y = GfP.newGfP(1)
this.z = GfP(0UL)
this.t = GfP(0UL)
}
fun isInfinity(): Boolean {
return this.z == GfP(0UL)
}
fun add(a: CurvePoint, b: CurvePoint) {
if (a.isInfinity()) {
set(b)
return
}
if (b.isInfinity()) {
set(a)
return
}
// See http://hyperelliptic.org/EFD/g1p/auto-code/shortw/jacobian-0/addition/add-2007-bl.op3
// Normalize the points by replacing a = [x1:y1:z1] and b = [x2:y2:z2]
// by [u1:s1:z1·z2] and [u2:s2:z1·z2]
// where u1 = x1·z2², s1 = y1·z2³ and u1 = x2·z1², s2 = y2·z1³
val z12 = GfP().apply { gfpMul(this, a.z, a.z) }
val z22 = GfP().apply { gfpMul(this, b.z, b.z) }
val u1 = GfP().apply { gfpMul(this, a.x, z22) }
val u2 = GfP().apply { gfpMul(this, b.x, z12) }
val t = GfP().apply { gfpMul(this, b.z, z22) }
val s1 = GfP().apply { gfpMul(this, a.y, t) }
gfpMul(t, a.z, z12)
val s2 = GfP().apply { gfpMul(this, b.y, t) }
// Compute x = (2h)²(s²-u1-u2)
// where s = (s2-s1)/(u2-u1) is the slope of the line through
// (u1,s1) and (u2,s2). The extra factor 2h = 2(u2-u1) comes from the value of z below.
// This is also:
// 4(s2-s1)² - 4h²(u1+u2) = 4(s2-s1)² - 4h³ - 4h²(2u1)
// = r² - j - 2v
// with the notations below.
val h = GfP().apply { gfpSub(this, u2, u1) }
val xEqual = h == GfP(0UL)
gfpAdd(t, h, h)
// i = 4h²
val i = GfP().apply { gfpMul(this, t, t) }
// j = 4h³
val j = GfP().apply { gfpMul(this, h, i) }
gfpSub(t, s2, s1)
val yEqual = t == GfP(0UL)
if (xEqual && yEqual) {
double(a)
return
}
val r = GfP().apply { gfpAdd(this, t, t) }
val v = GfP().apply { gfpMul(this, u1, i) }
// t4 = 4(s2-s1)²
val t4 = GfP().apply { gfpMul(this, r, r) }
val t6 = GfP().apply { gfpSub(this, t4, j) }
gfpAdd(t, v, v)
gfpSub(this.x, t6, t)
// Set y = -(2h)³(s1 + s*(x/4h²-u1))
// This is also
// y = - 2·s1·j - (s2-s1)(2x - 2i·u1) = r(v-x) - 2·s1·j
gfpSub(t, v, this.x) // t7
gfpMul(t4, s1, j) // t8
gfpAdd(t6, t4, t4) // t9
gfpMul(t4, r, t) // t10
gfpSub(this.y, t4, t6)
// Set z = 2(u2-u1)·z1·z2 = 2h·z1·z2
gfpAdd(t, a.z, b.z) // t11
gfpMul(t4, t, t) // t12
gfpSub(t, t4, z12) // t13
gfpSub(t4, t, z22) // t14
gfpMul(this.z, t4, h)
}
private fun double(a: CurvePoint) {
// See http://hyperelliptic.org/EFD/g1p/auto-code/shortw/jacobian-0/doubling/dbl-2009-l.op3
val A = GfP().apply { gfpMul(this, a.x, a.x) }
val B = GfP().apply { gfpMul(this, a.y, a.y) }
val C = GfP().apply { gfpMul(this, B, B) }
val t = GfP().apply { gfpAdd(this, a.x, B) }
val t2 = GfP().apply { gfpMul(this, t, t) }
gfpSub(t, t2, A)
gfpSub(t2, t, C)
val d = GfP().apply { gfpAdd(this, t2, t2) }
gfpAdd(t, A, A)
val e = GfP().apply { gfpAdd(this, t, A) }
val f = GfP().apply { gfpMul(this, e, e) }
gfpAdd(t, d, d)
gfpSub(this.x, f, t)
gfpMul(this.z, a.y, a.z)
gfpAdd(this.z, this.z, this.z)
gfpAdd(t, C, C)
gfpAdd(t2, t, t)
gfpAdd(t, t2, t2)
gfpSub(this.y, d, this.x)
gfpMul(t2, e, this.y)
gfpSub(this.y, t2, t)
}
fun mul(a: CurvePoint, scalar: BigInteger) {
val precomp = Array(4) { CurvePoint() }
precomp[1].set(a)
precomp[2].set(a)
gfpMul(precomp[2].x, precomp[2].x, Constants.xiTo2PSquaredMinus2Over3)
precomp[3].add(precomp[1], precomp[2])
val multiScalar = Lattice.curveLattice.multi(scalar)
val sum = CurvePoint().apply { setInfinity() }
val t = CurvePoint()
for (i in multiScalar.size - 1 downTo 0) {
t.double(sum)
if (multiScalar[i].toInt() == 0) {
sum.set(t)
} else {
sum.add(t, precomp[multiScalar[i].toInt()])
}
}
set(sum)
}
fun makeAffine() {
if (this.z == GfP.newGfP(1)) {
return
} else if (this.z == GfP.newGfP(0)) {
this.x = GfP(0UL)
this.y = GfP.newGfP(1)
this.t = GfP(0UL)
return
}
val zInv = GfP().apply { invert(this@CurvePoint.z) }
val t = GfP().apply { gfpMul(this, this@CurvePoint.y, zInv) }
val zInv2 = GfP().apply { gfpMul(this, zInv, zInv) }
gfpMul(this.x, this.x, zInv2)
gfpMul(this.y, t, zInv2)
this.z = GfP.newGfP(1)
this.t = GfP.newGfP(1)
}
fun neg(a: CurvePoint) {
this.x.set(a.x)
gfpNeg(this.y, a.y)
this.z.set(a.z)
this.t = GfP(0UL)
}
companion object {
internal val curveB = GfP.newGfP(3)
// curveGen is the generator of G₁.
internal val curveGen = CurvePoint(
GfP.newGfP(1),
GfP.newGfP(2),
GfP.newGfP(1),
GfP.newGfP(1),
)
}
}

View file

@ -21,6 +21,14 @@ class GfP(val data: ULongArray) {
return String.format("%016x %016x %016x %016x", data[3].toLong(), data[2].toLong(), data[1].toLong(), data[0].toLong()) return String.format("%016x %016x %016x %016x", data[3].toLong(), data[2].toLong(), data[1].toLong(), data[0].toLong())
} }
override fun equals(other: Any?): Boolean {
if (other !is GfP) return false
repeat(4) { if (data[it] != other.data[it]) return false }
return true
}
override fun hashCode(): Int = data.hashCode()
fun set(f: GfP) { fun set(f: GfP) {
data[0] = f.data[0] data[0] = f.data[0]
data[1] = f.data[1] data[1] = f.data[1]

View file

@ -2,41 +2,14 @@ package net.agorise.library.crypto.bn256
import java.math.BigInteger import java.math.BigInteger
val half: BigInteger = Constants.Order.shiftRight(1)
internal val curveLattice = Lattice(
vectors = arrayOf(
arrayOf(BigInteger("147946756881789319000765030803803410728"), BigInteger("147946756881789319010696353538189108491")),
arrayOf(BigInteger("147946756881789319020627676272574806254"), BigInteger("-147946756881789318990833708069417712965"))
),
inverse = arrayOf(
BigInteger("147946756881789318990833708069417712965"),
BigInteger("147946756881789319010696353538189108491")
),
det = BigInteger("43776485743678550444492811490514550177096728800832068687396408373151616991234")
)
internal val targetLattice = Lattice(
vectors = arrayOf(
arrayOf(BigInteger("9931322734385697761"), BigInteger("9931322734385697761"), BigInteger("9931322734385697763"), BigInteger("9931322734385697764")),
arrayOf(BigInteger("4965661367192848881"), BigInteger("4965661367192848881"), BigInteger("4965661367192848882"), BigInteger("-9931322734385697762")),
arrayOf(BigInteger("-9931322734385697762"), BigInteger("-4965661367192848881"), BigInteger("4965661367192848881"), BigInteger("-4965661367192848882")),
arrayOf(BigInteger("9931322734385697763"), BigInteger("-4965661367192848881"), BigInteger("-4965661367192848881"), BigInteger("-4965661367192848881"))
),
inverse = arrayOf(
BigInteger("734653495049373973658254490726798021314063399421879442165"),
BigInteger("147946756881789319000765030803803410728"),
BigInteger("-147946756881789319005730692170996259609"),
BigInteger("1469306990098747947464455738335385361643788813749140841702")
),
det = Constants.Order
)
/** /**
* Lattice implementation ported over from https://github.com/deroproject/derohe/blob/main/cryptography/bn256/lattice.go * Lattice implementation ported over from https://github.com/deroproject/derohe/blob/main/cryptography/bn256/lattice.go
*/ */
class Lattice(val vectors: Array<Array<BigInteger>>, val inverse: Array<BigInteger>, val det: BigInteger) { internal class Lattice(
val vectors: Array<Array<BigInteger>>,
val inverse: Array<BigInteger>,
val det: BigInteger,
) {
/** /**
* takes a scalar mod Order as input and finds a short, positive decomposition of it * takes a scalar mod Order as input and finds a short, positive decomposition of it
* wrt to the lattice basis. * wrt to the lattice basis.
@ -105,4 +78,36 @@ class Lattice(val vectors: Array<Array<BigInteger>>, val inverse: Array<BigInteg
quotient quotient
} }
} }
companion object {
private val half: BigInteger = Constants.Order.shiftRight(1)
internal val curveLattice = Lattice(
vectors = arrayOf(
arrayOf(BigInteger("147946756881789319000765030803803410728"), BigInteger("147946756881789319010696353538189108491")),
arrayOf(BigInteger("147946756881789319020627676272574806254"), BigInteger("-147946756881789318990833708069417712965"))
),
inverse = arrayOf(
BigInteger("147946756881789318990833708069417712965"),
BigInteger("147946756881789319010696353538189108491")
),
det = BigInteger("43776485743678550444492811490514550177096728800832068687396408373151616991234")
)
internal val targetLattice = Lattice(
vectors = arrayOf(
arrayOf(BigInteger("9931322734385697761"), BigInteger("9931322734385697761"), BigInteger("9931322734385697763"), BigInteger("9931322734385697764")),
arrayOf(BigInteger("4965661367192848881"), BigInteger("4965661367192848881"), BigInteger("4965661367192848882"), BigInteger("-9931322734385697762")),
arrayOf(BigInteger("-9931322734385697762"), BigInteger("-4965661367192848881"), BigInteger("4965661367192848881"), BigInteger("-4965661367192848882")),
arrayOf(BigInteger("9931322734385697763"), BigInteger("-4965661367192848881"), BigInteger("-4965661367192848881"), BigInteger("-4965661367192848881"))
),
inverse = arrayOf(
BigInteger("734653495049373973658254490726798021314063399421879442165"),
BigInteger("147946756881789319000765030803803410728"),
BigInteger("-147946756881789319005730692170996259609"),
BigInteger("1469306990098747947464455738335385361643788813749140841702")
),
det = Constants.Order
)
}
} }

View file

@ -10,7 +10,7 @@ class LatticeTest {
fun `given a random number - when decompose is called on curveLattice - then reduction is small`() { fun `given a random number - when decompose is called on curveLattice - then reduction is small`() {
val random = SecureRandom() val random = SecureRandom()
val k = BigInteger(Constants.Order.bitLength(), random) val k = BigInteger(Constants.Order.bitLength(), random)
val ks = curveLattice.decompose(k) val ks = Lattice.curveLattice.decompose(k)
if (ks[0].bitLength() > 130 || ks[1].bitLength() > 130) { if (ks[0].bitLength() > 130 || ks[1].bitLength() > 130) {
fail("reduction too large") fail("reduction too large")
@ -23,7 +23,7 @@ class LatticeTest {
fun `given a random number - when decompose is called on targetLattice - then reduction is small`() { fun `given a random number - when decompose is called on targetLattice - then reduction is small`() {
val random = SecureRandom() val random = SecureRandom()
val k = BigInteger(Constants.Order.bitLength(), random) val k = BigInteger(Constants.Order.bitLength(), random)
val ks = targetLattice.decompose(k) val ks = Lattice.targetLattice.decompose(k)
if (ks.any { it.bitLength() > 66 }) { if (ks.any { it.bitLength() > 66 }) {
fail("reduction too large") fail("reduction too large")