diff --git a/main.py b/main.py index 519c5ca..26ed21e 100644 --- a/main.py +++ b/main.py @@ -1,44 +1,48 @@ -from utils import generate_secret, coin_toss, H -from utils import prime -from utils import encrypt, decrypt -from Crypto.Math.Numbers import Integer +import ecpy.curves as curves +import secrets +import pickle +import Crypto.Random +import Crypto.Cipher.AES +import Crypto.Protocol.KDF -def alice_1(): - g = Integer(2) - a = generate_secret() - A_secret = pow(g, Integer(int.from_bytes(a)), prime) - return A_secret, a +m0 = b"alice" +m1 = b"bob" +c = False -def bob_1(A_secret): - g = Integer(2) - b = generate_secret() - B_secret = pow(g, Integer.from_bytes(b), prime) - coin = coin_toss() - if coin: - B_secret = A_secret * B_secret - return B_secret, b, coin +q = 224 +curve = curves.Curve.get_curve('NIST-P224') +g = curve.generator -def alice_2(A_secret, B_secret, a, m0, m1): - a = int.from_bytes(a) - k0 = H(str(pow(int(B_secret), a, prime)).encode()) - k1 = H(str(pow((B_secret // A_secret), a, prime)).encode()) - # Nonce at this point is generated by AES built-in method, not by scrypt - e0, e1 = encrypt(k0, m0.encode()), encrypt(k1, m1.encode()) - return *e0, *e1 +def H(p: curves.Point) -> bytes: + secret = pickle.dumps((p.x, p.y), protocol=4) + salt = Crypto.Random.get_random_bytes(16) + key = Crypto.Protocol.KDF.scrypt(secret, salt, 16, N=2**14, r=8, p=1) + return key[:32] # first 32 bytes of generated key -def bob_2(A_secret, b, c, e0, e1, n0, n1): - kc = H(str(pow(int(A_secret), int.from_bytes(b), prime)).encode()) - if c: - return decrypt(e1, n1, kc) - else: - return decrypt(e0, n0, kc) +def E(key: bytes, message: bytes) -> tuple[bytes, bytes]: + cipher = Crypto.Cipher.AES.new(key, Crypto.Cipher.AES.MODE_CTR) + ct = cipher.encrypt(message) + return (ct, cipher.nonce) -m0, m1 = "alice", "bob" +def D(key: bytes, encrypted_with_nonce: tuple[bytes, bytes]) -> bytes: + ct, nonce = encrypted_with_nonce + cipher = Crypto.Cipher.AES.new(key, Crypto.Cipher.AES.MODE_CTR, nonce=nonce) + return cipher.decrypt(ct) -A, a = alice_1() -B, b, c = bob_1(A) -e0, n0, e1, n1 = alice_2(A, B, a, m0, m1) -result = bob_2(A, b, c, e0, e1, n0, n1) +a = 1 + secrets.randbelow(q) +b = 1 + secrets.randbelow(q) +A = curve.mul_point(a, g) + +B = curve.mul_point(b, g) +if c: + B = curve.add_point(A, B) + +k0 = H(curve.mul_point(a, B)) +k1 = H(curve.mul_point(a, curve.sub_point(B, A))) +e0 = E(k0, m0) +e1 = E(k1, m1) + +kc = H(curve.mul_point(b, A)) +print(D(kc, e1 if c else e0)) -print(result.decode()) diff --git a/utils.py b/utils.py deleted file mode 100644 index 4ba2bf0..0000000 --- a/utils.py +++ /dev/null @@ -1,35 +0,0 @@ -from Crypto.PublicKey import ECC -from Crypto.Protocol.KDF import scrypt -from Crypto.Random import get_random_bytes -from Crypto.Cipher import AES -from secrets import randbelow - -prime = 0xffffffffffffffffffffffffffffffff000000000000000000000001 - -def generate_secret(c = 0): - # 'DER' format for byte output - return ECC.generate(curve='NIST P-224').export_key(format='DER') - -def H(secret): - # secret should be bytearray[], pref. from generate_secret() function - salt = get_random_bytes(16) - key = scrypt(bytes(secret), salt, 16, N=2**14, r=8, p=1) - return key[:32] # first 32 bytes of generated key - -def coin_toss(): - x = randbelow(2 ** 64) - if x & 1: - return False - else: - return True - -def encrypt(key, data): - cipher = AES.new(key, AES.MODE_CTR) - ct = cipher.encrypt(data) - nonce = cipher.nonce - return ct, nonce - -def decrypt(ct, nonce, key): - cipher = AES.new(key, AES.MODE_CTR, nonce=nonce) - pt = cipher.decrypt(ct) - return pt