package io.aether.utils.streams;

import io.aether.logger.LNode;
import io.aether.logger.Log;
import io.aether.utils.AString;
import io.aether.utils.RU;
import io.aether.utils.dataio.DataInOutStatic;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Base64;

public class WebSocketNode implements Node<byte[], byte[], byte[], byte[]> {
    private final FGate<byte[], byte[]> up = FGate.of(new WebSocketUpHandler(this));
    private final FGate<byte[], byte[]> down = FGate.of(new WebSocketDownHandler(this));
    private final LNode log = Log.createContext();

    private enum State {
        HANDSHAKE,
        CONNECTED,
        CLOSED
    }

    private volatile State state = State.HANDSHAKE;
    private final String websocketKey;
    private ByteBuffer partialFrame;
    private int expectedLength = -1;
    private boolean isFinalFragment;
    private int opcode;

    public WebSocketNode() {
        byte[] random = new byte[16];
        RU.SECURE_RANDOM.nextBytes(random);
        this.websocketKey = Base64.getEncoder().encodeToString(random);
    }

    @Override
    public FGate<byte[], byte[]> gUp() {
        return up;
    }

    @Override
    public FGate<byte[], byte[]> gDown() {
        return down;
    }

    @Override
    public void toString(AString sb) {
        sb.add("WebSocketNode(state=").add(state.name()).add(")");
    }

    private class WebSocketUpHandler extends FGate.Pair<byte[], byte[], byte[]> {
        public WebSocketUpHandler(Object owner) {
            super(owner);
        }

        @Override
        public FGate<?, byte[]>.InsideGate pair() {
            return down.inSide;
        }

        @Override
        public void send(FGate<byte[], byte[]> fGate, Value<byte[]> value) {
            try {
                switch (state) {
                    case HANDSHAKE:
                        sendHandshakeRequest();
                        break;
                    case CONNECTED:
                        if (value.isData()) {
                            sendDataFrame(value.data());
                        } else if (value.isClose()) {
                            sendCloseFrame();
                            state = State.CLOSED;
                        }
                        break;
                    case CLOSED:
                        Log.warn("Attempt to send data on closed WebSocket");
                        break;
                }
            } catch (IOException e) {
                Log.error("WebSocket send error", e);
                value.reject(this);
            }
        }

        private void sendHandshakeRequest() throws IOException {
            String request = "GET / HTTP/1.1\r\n" +
                             "Host: example.com\r\n" +
                             "Upgrade: websocket\r\n" +
                             "Connection: Upgrade\r\n" +
                             "Sec-WebSocket-Key: " + websocketKey + "\r\n" +
                             "Sec-WebSocket-Version: 13\r\n\r\n";

            pair().send(Value.of(request.getBytes(StandardCharsets.UTF_8)));
        }

        private void sendDataFrame(byte[] data) throws IOException {
            var out = new DataInOutStatic(2 + (data.length < 126 ? 0 : data.length < 65536 ? 2 : 8) + data.length);

            // FIN + BINARY_FRAME
            out.writeByte((byte) 0x82);

            // Длина payload
            if (data.length < 126) {
                out.writeByte((byte) data.length);
            } else if (data.length < 65536) {
                out.writeByte((byte) 126);
                out.writeShort((short) data.length);
            } else {
                out.writeByte((byte) 127);
                out.writeLong(data.length);
            }

            out.write(data);
            pair().send(Value.of(out.toArray()));
        }

        private void sendCloseFrame() throws IOException {
            byte[] closeFrame = new byte[]{(byte) 0x88, 0x00};
            pair().send(Value.of(closeFrame));
        }
    }

    private class WebSocketDownHandler extends FGate.Pair<byte[], byte[], byte[]> {
        public WebSocketDownHandler(Object owner) {
            super(owner);
        }

        @Override
        public FGate<?, byte[]>.InsideGate pair() {
            return up.inSide;
        }

        @Override
        public void send(FGate<byte[], byte[]> fGate, Value<byte[]> value) {
            try {
                if (!value.isData()) {
                    pair().send(value);
                    return;
                }

                byte[] data = value.data();
                if (partialFrame != null) {
                    data = mergePartialData(data);
                }

                DataInOutStatic in = new DataInOutStatic(data);
                while (in.isReadable()) {
                    if (state == State.HANDSHAKE) {
                        if (!processHandshakeResponse(in)) {
                            return;
                        }
                        continue;
                    }

                    if (expectedLength == -1) {
                        if (!parseFrameHeader(in)) {
                            savePartialData(in);
                            return;
                        }
                    }

                    if (in.getSizeForRead() >= expectedLength) {
                        processCompleteFrame(in);
                    } else {
                        savePartialData(in);
                        return;
                    }
                }
            } catch (IOException e) {
                Log.error("WebSocket protocol error", e);
                value.reject(this);
            }
        }

        private boolean processHandshakeResponse(DataInOutStatic in) throws IOException {
            String response = in.readString1();

            if (!response.contains("HTTP/1.1 101") ||
                !response.contains("Upgrade: websocket") ||
                !response.contains("Connection: Upgrade")) {
                throw new IOException("Invalid WebSocket handshake response");
            }

            String acceptKey = websocketKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
            String expectedAccept = Base64.getEncoder().encodeToString(
                    RU.sha1(acceptKey.getBytes(StandardCharsets.UTF_8)));

            if (!response.contains("Sec-WebSocket-Accept: " + expectedAccept)) {
                throw new IOException("WebSocket accept key mismatch");
            }

            state = State.CONNECTED;
            pair().send(Value.ofRequest()); // Уведомляем об успешном подключении
            return true;
        }

        private boolean parseFrameHeader(DataInOutStatic in) throws IOException {
            if (in.getSizeForRead() < 2) return false;

            byte b1 = in.readByte();
            byte b2 = in.readByte();

            isFinalFragment = (b1 & 0x80) != 0;
            opcode = b1 & 0x0F;

            boolean masked = (b2 & 0x80) != 0;
            int payloadLength = b2 & 0x7F;

            if (payloadLength == 126) {
                if (in.getSizeForRead() < 2) return false;
                payloadLength = in.readUShort();
            } else if (payloadLength == 127) {
                if (in.getSizeForRead() < 8) return false;
                payloadLength = (int) in.readLong();
            }

            if (masked) {
                if (in.getSizeForRead() < 4) return false;
                in.skipBytes(4); // Пропускаем маскирующий ключ
            }

            expectedLength = payloadLength;
            return true;
        }

        private void processCompleteFrame(DataInOutStatic in) throws IOException {
            byte[] payload = in.readBytes(expectedLength);
            expectedLength = -1;

            switch (opcode) {
                case 0x02: // BINARY_FRAME
                    if (isFinalFragment) {
                        pair().send(Value.of(payload));
                    }
                    break;
                case 0x08: // CLOSE_FRAME
                    sendCloseFrame();
                    state = State.CLOSED;
                    pair().send(Value.ofClose());
                    break;
                case 0x09: // PING_FRAME
                    sendPongFrame(payload);
                    break;
                default:
                    Log.warn("Unsupported WebSocket opcode: " + opcode);
            }
        }

        private void sendCloseFrame() throws IOException {
            byte[] closeFrame = new byte[]{(byte) 0x88, 0x00};
            down.inSide.send(Value.of(closeFrame));
        }

        private void sendPongFrame(byte[] payload) throws IOException {
            var out = new DataInOutStatic(2 + payload.length);
            out.writeByte((byte) 0x8A); // PONG_FRAME
            out.writeByte((byte) payload.length);
            out.write(payload);
            down.inSide.send(Value.of(out.toArray()));
        }

        private void savePartialData(DataInOutStatic in) {

            partialFrame = ByteBuffer.wrap(in.readBytes(in.getSizeForRead()));
        }

        private byte[] mergePartialData(byte[] newData) {
            ByteBuffer merged = ByteBuffer.allocate(
                    partialFrame.remaining() + newData.length);
            merged.put(partialFrame);
            merged.put(newData);
            partialFrame = null;
            return merged.array();
        }
    }
}