package io.aether.crypto.sodium;

import com.goterl.lazysodium.LazySodiumJava;
import com.goterl.lazysodium.SodiumJava;
import com.goterl.lazysodium.exceptions.SodiumException;
import io.aether.crypto.*;
import io.aether.utils.HexUtils;
import io.aether.utils.RU;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;

public class SodiumCryptoProvider implements CryptoProvider {

    public static final SodiumCryptoProvider INSTANCE = new SodiumCryptoProvider();

    private final LazySodiumJava lazySodium;
    private final SodiumJava sodium;
    private static final byte[] SODIUM_KDF_CONTEXT = "_aether_".getBytes(StandardCharsets.UTF_8);

    @Override
    public AKey.SignPublic createSignPublicKey(byte[] data) {
        return new SodiumKey.SignPublic(data);
    }

    @Override
    public AKey.SignPrivate createSignPrivateKey(byte[] data) {
        return new SodiumKey.SignPrivate(data);
    }

    @Override
    public AKey.Symmetric createSymmetricKey(byte[] bytes) {
        return new SodiumKey.Symmetric(bytes);
    }

    private SodiumCryptoProvider() {
        this.sodium = new SodiumJava();
        this.lazySodium = new LazySodiumJava(this.sodium);
    }

    @Override
    public String getCryptoLibName() {
        return "SODIUM";
    }

    @Override
    public PairAsymKeys createAsymmetricKeys() {
        byte[] privateKey = new byte[KeySize.SODIUM_CURVE_PRIVATE];
        byte[] publicKey = new byte[KeySize.SODIUM_CURVE25519_PUBLIC];
        sodium.crypto_box_keypair(publicKey, privateKey);
        return new PairAsymKeys(new SodiumKey.AsymmetricPublic(publicKey), new SodiumKey.AsymmetricPrivate(privateKey));
    }

    @Override
    public AKey.Symmetric createSymmetricKey() {
        var key = new byte[KeySize.SODIUM_CHACHA20POLY1305];
        sodium.crypto_aead_chacha20poly1305_keygen(key);
        return new SodiumKey.Symmetric(key);
    }

    @Override
    public PairSignKeys createSignKeys() {
        try {
            var keys = lazySodium.cryptoSignKeypair();
            var publicKey = keys.getPublicKey().getAsBytes();
            var privateKey = keys.getSecretKey().getAsBytes();
            return new PairSignKeys(new SodiumKey.SignPublic(publicKey), new SodiumKey.SignPrivate(privateKey));
        } catch (SodiumException e) {
            throw new EncryptException("Failed to generate signing keys", e);
        }
    }

    @Override
    public Signer createSigner(PairSignKeys keys) {
        return createSigner(keys.publicKey, keys.privateKey);
    }

    @Override
    public Signer createSigner(AKey.SignPublic publicKey, AKey.SignPrivate privateKey) {
        if (!(publicKey instanceof SodiumKey.SignPublic) || !(privateKey instanceof SodiumKey.SignPrivate)) {
            throw new IllegalArgumentException("Keys must be instances of SodiumKey.SignPublic and SodiumKey.SignPrivate");
        }
        return new SodiumSigner(publicKey, privateKey, lazySodium);
    }

    @Override
    public Signer createSigner(AKey.SignPublic publicKey) {
        if (!(publicKey instanceof SodiumKey.SignPublic)) {
            throw new IllegalArgumentException("Public key must be an instance of SodiumKey.SignPublic");
        }
        return new SodiumSigner(publicKey, null, lazySodium);
    }

    @Override
    public CryptoEngine createSymmetricEngine(AKey.Symmetric key) {
        if (!(key instanceof SodiumKey.Symmetric)) {
            throw new IllegalArgumentException("Key must be a SodiumKey.Symmetric instance");
        }
        return new SodiumSymmetricEngine(key);
    }

    @Override
    public CryptoEngine createAsymmetricEngine(AKey.AsymmetricPublic key) {
        if (!(key instanceof SodiumKey.AsymmetricPublic)) {
            throw new IllegalArgumentException("Key must be a SodiumKey.AsymmetricPublic instance");
        }
        return new SodiumAsymmetricEngine(key);
    }

    @Override
    public CryptoEngine createAsymmetricEngine(AKey.AsymmetricPrivate privateKey, AKey.AsymmetricPublic publicKey) {
        if (!(privateKey instanceof SodiumKey.AsymmetricPrivate) || !(publicKey instanceof SodiumKey.AsymmetricPublic)) {
            throw new IllegalArgumentException("Keys must be instances of SodiumKey.AsymmetricPrivate and SodiumKey.AsymmetricPublic");
        }
        return new SodiumAsymmetricEngine(privateKey, publicKey);
    }

    @Override
    public CryptoEngine createAsymmetricEngine(PairAsymKeys keys) {
        return createAsymmetricEngine(keys.getPrivateKey(), keys.getPublicKey());
    }

    @Override
    public <T extends AKey> T createKey(KeyType keyType, byte[] data) {
        AKey result;
        switch (keyType) {
            case SYMMETRIC:
                result= new SodiumKey.Symmetric(data);
                break;
            case ASYMMETRIC_PUBLIC:
                result= new SodiumKey.AsymmetricPublic(data);
                break;
            case ASYMMETRIC_PRIVATE:
                result= new SodiumKey.AsymmetricPrivate(data);
                break;
            case SIGN_PUBLIC:
                result= new SodiumKey.SignPublic(data);
                break;
            case SIGN_PRIVATE:
                result= new SodiumKey.SignPrivate(data);
                break;
            default:
                throw new UnsupportedOperationException();
        }
        return RU.cast(result);
    }

    @Override
    public <T extends AKey> T createKey(String data) {
        var parts = data.split(":");
        if (parts.length != 3 || !parts[0].equals(getCryptoLibName())) {
            throw new IllegalArgumentException("Invalid key string for this provider.");
        }
        KeyType keyType = KeyType.valueOf(parts[1]);
        byte[] bytes = HexUtils.hexToBytes(parts[2]);
        return RU.cast(createKey(keyType, bytes));
    }


    @Override
    public Sign createSign(String data) {
        var parts = data.split(":");
        if (parts.length != 2 || !parts[0].equals(getCryptoLibName())) {
            throw new IllegalArgumentException("Invalid sign string for this provider.");
        }
        return new SodiumSign(HexUtils.hexToBytes(parts[1]));
    }

    /**
     * Derives a pair of symmetric keys (for client-to-server and server-to-client communication)
     * using Key Derivation Function (KDF) from a master key and session/key identifiers.
     *
     * @param masterKey The master symmetric key.
     * @param serverId The server identifier (32-bit).
     * @param keyNumber The key number/index (32-bit).
     * @return A PairSymmetricKeys object containing client and server keys.
     */
    @Override
    public PairSymKeys deriveSymmetricKeys(AKey.Symmetric masterKey, int serverId, int keyNumber) {
        if (!(masterKey instanceof SodiumKey.Symmetric)) {
            throw new IllegalArgumentException("Key must be a SodiumKey.Symmetric instance");
        }

        // 32-byte key size for crypto_aead_chacha20poly1305_ietf_keygen
        int keySize = KeySize.SODIUM_CHACHA20POLY1305;
        byte[] derivedKey = new byte[keySize * 2];

        // C++: std::uint64_t const subkey_id = (static_cast<std::uint64_t>(server_id) << 32) | key_number;
        long subkeyId = (((long) serverId) << 32) | (keyNumber & 0xFFFFFFFFL);

        // Call libsodium's crypto_kdf_derive_from_key (via SodiumJava)
        // int crypto_kdf_derive_from_key(byte[] subkey, long subkey_len, long subkey_id, byte[] ctx, byte[] key)
        int result = this.sodium.crypto_kdf_derive_from_key(
            derivedKey, derivedKey.length, subkeyId, SODIUM_KDF_CONTEXT, masterKey.getData()
        );

        if (result != 0) {
            throw new RuntimeException("Sodium KDF derivation failed with error code: " + result);
        }

        // Split 64-byte key into two 32-byte keys
        byte[] clientToServerKeyBytes = Arrays.copyOfRange(derivedKey, 0, keySize);
        byte[] serverToClientKeyBytes = Arrays.copyOfRange(derivedKey, keySize, derivedKey.length);

        AKey.Symmetric clientToServerKey = new SodiumKey.Symmetric(clientToServerKeyBytes);
        AKey.Symmetric serverToClientKey = new SodiumKey.Symmetric(serverToClientKeyBytes);

        return new PairSymKeys(clientToServerKey, serverToClientKey);
    }

    @Override
    public Sign createSign(byte[] data) {
        return new SodiumSign(data);
    }
}

