#include "ssl.hpp"
#include <arpa/inet.h> // ntoa


#define buf2len1(b,o) ((size_t)b[o])
#define buf2len2(b,o) ((((size_t)b[o]) << 8) + (size_t)b[o+1])
#define buf2len3(b,o) ((((size_t)b[o]) << 16) + (((size_t)b[o+1]) << 8) + (size_t)b[o+2])
#define buf2len4(b,o) ((((size_t)b[o]) << 24) + (((size_t)b[o+1]) << 16) + (((size_t)b[o+2]) << 8) + (size_t)b[o+3])

static bool buf_skip(const unsigned char*& buf, size_t& len, size_t slen) {
    if (slen > len) return false;
    buf += slen;
    len -= slen;
    return true;
}

static bool buf_skip_len(const unsigned char*& buf, size_t& len, size_t slen, size_t& rv) {
    if (slen > len) return false;
    switch (slen) {
        case 1:
            rv = buf2len1(buf, 0); break;
        case 2:
            rv = buf2len2(buf, 0); break;
        case 3:
            rv = buf2len3(buf, 0); break;
        case 4:
            rv = buf2len4(buf, 0); break;
        default:
            return false;
    }
    buf += slen;
    len -= slen;
    return true;
}

static bool buf_skip_arr(const unsigned char*& buf, size_t& len, size_t slen) {
    size_t alen;
    if (!buf_skip_len(buf, len, slen, alen)) return false;
    if (!buf_skip(buf, len, alen)) return false;
    return true;
}


static const char* parse_client_helo(const unsigned char* buf, size_t len) {
    // skip version, random
    if (!buf_skip(buf, len, 2+32)) return NULL;
    // skip sid, ciphers, compressions
    if (!buf_skip_arr(buf, len, 1)) return NULL;
    if (!buf_skip_arr(buf, len, 2)) return NULL;
    if (!buf_skip_arr(buf, len, 1)) return NULL;

    // extensions length
    size_t el;
    if (!buf_skip_len(buf, len, 2, el)) return NULL;
    if (el > len) return NULL;
    len = el;
    while (len) {
        // extension type 0: SNI
        size_t et;
        if (!buf_skip_len(buf, len, 2, et)) return NULL;
        if (et != 0) {
            if (!buf_skip_arr(buf, len, 2)) return NULL;
            continue;
        }

        // skip list/item len
        if (!buf_skip(buf, len, 2+2)) return NULL;

        // item type: 0
        if (!buf_skip_len(buf, len, 1, et) || et) return NULL;

        // len
        size_t slen;
        if (!buf_skip_len(buf, len, 2, slen)) return NULL;
        if (!slen || slen > len) return NULL;

        // found it
        static char rv[256];
        if (slen >= sizeof(rv) - 5) return NULL;
        strcpy(rv, "sni: ");
        for (size_t i=0; i<slen; ++i) {
            if (buf[i] <= ' ' || buf[i] > '~') return NULL;
        }
        memcpy(rv + 5, buf, slen);
        rv[5 + slen] = '\0';
        return rv;
    }

    // no server_name found
    return "-";
}

static const char* parse_server_helo(const unsigned char* buf, size_t len) {
    // skip version, random
    if (!buf_skip(buf, len, 2+32)) return NULL;

    // skip sid len + sid
    size_t sid_len = buf2len1(buf, 0);
    if ((sid_len != 0 && sid_len != 32) || sid_len + 1 >= len) return NULL;
    buf += sid_len + 1;
    len -= sid_len + 1;

    // return chosen cipher suite & compression method
    if (len < 3) return NULL;
    unsigned cipher_suites = (unsigned)buf2len2(buf, 0);
    unsigned compr_meth = (unsigned)buf2len1(buf, 2);
    static char rv[4+4+2+6+2+1];
    sprintf(rv, "cs: %04x, comp: %02x", cipher_suites, compr_meth);
    return rv;
}


SSLSession::SSLSession() {
    memset(parser_state, 0, sizeof(parser_state));
}


ssize_t SSLSession::push(bool client_side, const conn_t& conn, const unsigned char* buf, size_t len) {
    // content type + len
    static const size_t hlen = 5;
    if (len < hlen) return 0;
    uint8_t content_type = buf[0];
    const unsigned char ver_maj = buf[1];
    const unsigned char ver_min = buf[2];
    uint16_t content_len = buf2len2(buf, 3);

    // tls only
    if (ver_maj != 3 || (ver_min != 1 && ver_min != 3)) {
        LOG("unknown/unsupported version for ct %u: %u/%u", content_type, ver_maj, ver_min);
        return -1;
    }

    // process only this handshake
    buf += hlen;
    len -= hlen;
    if (len < content_len) return 0;
    len = content_len;

    static char src[INET_ADDRSTRLEN], dst[INET_ADDRSTRLEN];
    strcpy(src, inet_ntoa(conn.src));
    strcpy(dst, inet_ntoa(conn.dst));
    OUT(">  %s:%d -> %s:%d %" PRIu64 " ct: %u, len: %u", src, conn.src_port, dst, conn.dst_port, conn.ts, content_type, content_len);

    // handshake and previous one was no change cipher spec? (would contain an encrypted handshake msg)
    if (content_type == 0x16 && parser_state[client_side?1:0].last_content_type != 0x14) {
        while (len) {
            uint8_t handshake_type = buf[0];
            size_t handshake_len = buf2len3(buf, 1) + 4; // includes these 4 bytes
            if (handshake_len > len) {
                LOG("handshake len exceeded");
                return -1;
            }

            const char* helo_info = NULL;
            if (handshake_type == 1) {
                helo_info = parse_client_helo(buf + 4, handshake_len - 4);
                if (!helo_info) {
                    LOG("cannot parse client helo");
                    return -1;
                }
            } else if (handshake_type == 2) {
                helo_info = parse_server_helo(buf + 4, handshake_len - 4);
                if (!helo_info) {
                    LOG("cannot parse server helo");
                    return -1;
                }
            }

            OUT(" > %s:%d -> %s:%d %" PRIu64 " ht: %u, len: %zu, ver: %u/%u%s%s",
                src, conn.src_port, dst, conn.dst_port, conn.ts,
                handshake_type, handshake_len-4,
                ver_maj, ver_min,
                helo_info? ", ": "", helo_info ?: ""
            );

            buf += handshake_len;
            len -= handshake_len;
        }
    }
    parser_state[client_side?1:0].last_content_type = content_type;

    return hlen + content_len;
}