cryptoSL2024/ot.py

82 lines
2.0 KiB
Python

import ecpy.curves as curves
import secrets
import pickle
import Crypto.Random
import Crypto.Cipher.AES
import Crypto.Protocol.KDF
import contextlib
curve = curves.Curve.get_curve("NIST-P224")
q = curve.size
g = curve.generator
def H(p: curves.Point, salt: bytes) -> bytes:
secret = pickle.dumps((p.x, p.y), protocol=4)
key = Crypto.Protocol.KDF.scrypt(secret, salt, 16, N=2**14, r=8, p=1)
return key[:32] # first 32 bytes of generated key
def E(key: bytes, message: bytes) -> tuple[bytes, bytes]:
cipher = Crypto.Cipher.AES.new(key, Crypto.Cipher.AES.MODE_CTR)
ct = cipher.encrypt(message)
return (ct, cipher.nonce)
def D(key: bytes, encrypted_with_nonce: tuple[bytes, bytes]) -> bytes:
ct, nonce = encrypted_with_nonce
cipher = Crypto.Cipher.AES.new(key, Crypto.Cipher.AES.MODE_CTR, nonce=nonce)
return cipher.decrypt(ct)
########################################################
# Workers
########################################################
def alice(m0: bytes, m1: bytes):
a = 1 + secrets.randbelow(q)
A = curve.mul_point(a, g)
B = yield A
salt = Crypto.Random.get_random_bytes(16)
k0 = H(curve.mul_point(a, B), salt)
k1 = H(curve.mul_point(a, curve.sub_point(B, A)), salt)
e0 = E(k0, m0)
e1 = E(k1, m1)
yield e0, e1, salt
def bob(c: bool):
b = 1 + secrets.randbelow(q)
A = yield
B = curve.mul_point(b, g)
if c:
B = curve.add_point(A, B)
e0, e1, salt = yield B
kc = H(curve.mul_point(b, A), salt)
yield D(kc, e1 if c else e0)
########################################################
# Arrows
########################################################
def main():
with contextlib.suppress(StopIteration):
a = alice(b"msg one", b"msg two")
b = bob(True)
A = a.send(None)
b.send(None)
B = b.send(A)
encrypted = a.send(B)
result = b.send(encrypted)
print(result)
if __name__ == "__main__":
main()