import ecpy.curves as curves import secrets import pickle import Crypto.Random import Crypto.Cipher.AES import Crypto.Protocol.KDF import contextlib curve = curves.Curve.get_curve("NIST-P224") q = curve.size g = curve.generator def H(p: curves.Point, salt: bytes) -> bytes: secret = pickle.dumps((p.x, p.y), protocol=4) key = Crypto.Protocol.KDF.scrypt(secret, salt, 16, N=2**14, r=8, p=1) return key[:32] # first 32 bytes of generated key def E(key: bytes, message: bytes) -> tuple[bytes, bytes]: cipher = Crypto.Cipher.AES.new(key, Crypto.Cipher.AES.MODE_CTR) ct = cipher.encrypt(message) return (ct, cipher.nonce) def D(key: bytes, encrypted_with_nonce: tuple[bytes, bytes]) -> bytes: ct, nonce = encrypted_with_nonce cipher = Crypto.Cipher.AES.new(key, Crypto.Cipher.AES.MODE_CTR, nonce=nonce) return cipher.decrypt(ct) ######################################################## # Workers ######################################################## def alice(m0: bytes, m1: bytes): a = 1 + secrets.randbelow(q) A = curve.mul_point(a, g) B = yield A salt = Crypto.Random.get_random_bytes(16) k0 = H(curve.mul_point(a, B), salt) k1 = H(curve.mul_point(a, curve.sub_point(B, A)), salt) e0 = E(k0, m0) e1 = E(k1, m1) yield e0, e1, salt def bob(c: bool): b = 1 + secrets.randbelow(q) A = yield B = curve.mul_point(b, g) if c: B = curve.add_point(A, B) e0, e1, salt = yield B kc = H(curve.mul_point(b, A), salt) yield D(kc, e1 if c else e0) ######################################################## # Arrows ######################################################## def main(): with contextlib.suppress(StopIteration): a = alice(b"msg one", b"msg two") b = bob(True) A = a.send(None) b.send(None) B = b.send(A) encrypted = a.send(B) result = b.send(encrypted) print(result) if __name__ == "__main__": main()