#include "stream.hpp"
#include <arpa/inet.h> // ntoa


#define STREAM_TTL 120


// http://www.isthe.com/chongo/tech/comp/fnv/#FNV-1a
static uint32_t hash(const unsigned char* buf, size_t len, uint32_t h=2166136261u) {
    while (len--) {
        h ^= *buf;
        h *= 16777619;
        ++buf;
    }
    return h;
}


SSLStream::Buf::Buf(): buf(NULL), len(0) {
}

SSLStream::Buf::~Buf() {
    free((void*)buf);
}

void SSLStream::Buf::push(const unsigned char* b, size_t l) {
    assert(l > 0);
    buf = (const unsigned char*)realloc((void*)buf, len + l);
    memcpy((void*)(buf + len), b, l);
    len += l;
}

void SSLStream::Buf::pop(size_t l) {
    assert(l > 0);
    assert(l <= len);
    if (l == len) {
        free((void*)buf);
        buf = NULL;
        len = 0;
    } else {
        len -= l;
        memmove((void*)buf, buf+l, len);
        buf = (const unsigned char*)realloc((void*)buf, len);
    }
}


bool SSLStream::push(const packet_t& p) {
    if (!p.len) {
        return true;
    }

    const int client_side = (p.conn.dst_port == 443)? 1: 0;

    const unsigned char* buf = p.buf;
    size_t len = p.len;
    Buf* blog = &backlog[client_side];
    if (blog->len) {
        blog->push(buf, len);
        buf = blog->buf;
        len = blog->len;
    }

    while (len) {
        ssize_t rv = sess.push(client_side? true: false, p.conn, buf, len); // the actual work
        if (rv < 0) {
            return false;
        } else if (rv == 0) {
            break;
        } else {
            buf += rv;
            len -= rv;
        }
    }

    if (blog->len) { // we read from it
        if (len != blog->len) {
            assert(len < blog->len);
            blog->pop(blog->len-len);
        }
    } else { // backlog wasn't used
        if (len) { // push rest, if any
            blog->push(buf, len);
        }
    }

    return true;
}


void Stream::push(const packet_t& p) {
    // only interested in https port atm, regardless of filter
    if ((p.conn.src_port != 443 && p.conn.dst_port != 443) || p.conn.src_port == p.conn.dst_port) {
        LOG("no https");
        return;
    }
    const int client_side = (p.conn.dst_port == 443);
    const bool first = p.syn && !p.ack;
    const bool seq_inc = p.syn || p.fin; // the presence of the SYN or FIN flag in a received packet triggers an increase of 1 in the sequence.
    const uint32_t next_seq = p.seq + p.len + (seq_inc? 1: 0);

    // find in map
    const uint32_t hs = hash((const unsigned char*)&p.conn.src, sizeof(p.conn.src), hash((const unsigned char*)&p.conn.src_port, sizeof(p.conn.src_port)));
    const uint32_t hd = hash((const unsigned char*)&p.conn.dst, sizeof(p.conn.dst), hash((const unsigned char*)&p.conn.dst_port, sizeof(p.conn.dst_port)));
    const uint32_t h = hs ^ hd;
    std::map<uint32_t, entry_t>::iterator it = streams.find(h);

    // found, but valid?
    if (it != streams.end()) {
        bool valid = true;
        const conn_t& pkt = it->second.pkt;
        if (first) {
            LOG("stream exists for initial packet");
            valid = false;
        } else if (pkt.ts < p.conn.ts - (STREAM_TTL*1000*1000)) {
            LOG("stream timeout/collision (%" PRIu64 "/%" PRIu64 ")", pkt.ts, p.conn.ts);
            valid = false;
        } else if ((pkt.src != p.conn.src || pkt.src_port != p.conn.src_port || pkt.dst != p.conn.dst || pkt.dst_port != p.conn.dst_port) &&
                   (pkt.src != p.conn.dst || pkt.src_port != p.conn.dst_port || pkt.dst != p.conn.src || pkt.dst_port != p.conn.src_port)) {
            LOG("stream collision");
            valid = false;
        } else if (it->second.next_seq[client_side] && p.seq != it->second.next_seq[client_side]) {
            LOG("unexpected seqnum: %" PRIu32 "/%" PRIu32, p.seq, it->second.next_seq[client_side]);
            valid = false; // TODO: buffer until matching gap is found
        }
        if (!valid) {
            delete it->second.stream;
            streams.erase(it);
            it = streams.end();
        }
    }

    // not found/need to create?
    if (it == streams.end()) {
        if (!first) {
            LOG("won't insert from follow-up packet");
            return;
        }
        it = streams.insert(std::pair<uint32_t, entry_t>(h, (entry_t){p.conn, {0, 0}, new SSLStream()})).first;
    }

    // found our entry - push the payload
    it->second.next_seq[client_side] = next_seq;
    if (!it->second.stream->push(p)) {
        LOG("parsing failed, removing stream");
        streams.erase(it);
    }
}

Stream::~Stream() {
    while (!streams.empty()) {
        delete streams.begin()->second.stream;
        streams.erase(streams.begin());
    }
}