#include "common.hpp"
#include "socks.hpp"
#include "sock.hpp"
#include "config.hpp"
#include "encode.hpp"
#include <signal.h>
#include <stdlib.h>


static ssize_t buf_write(int fd, char* buf, size_t& len) {
    ssize_t rv = write(fd, buf, len);
    if (rv == -1) {
        if (errno != EINTR && errno != EAGAIN && errno != EWOULDBLOCK) {
            LOG_ERRNO("write()");
            return -1;
        } else {
            return 0;
        }
    } else if (rv == 0) {
        return 0; // can happen?
    } else {
        if ((size_t)rv >= len) {
            len = 0;
        } else {
            len -= (size_t)rv;
            memmove(buf, buf+rv, len);
        }
        return rv;
    }
}


static ssize_t buf_read(int fd, char* buf, size_t len) {
    ssize_t rv = read(fd, buf, len);
    if (rv == -1) {
        if (errno != EINTR && errno != EAGAIN && errno != EWOULDBLOCK) {
            int ebup = errno;
            LOG_ERRNO("read()");
            errno = ebup;
            return -1;
        } else {
            return 0;
        }
    } else if (rv == 0) {
        errno = ENOPKG; // eof
        return -1;
    } else {
        return rv;
    }
}


static ssize_t strip_encoding(char* chunkbuf, size_t& chunkbuflen, char* outbuf, bool& in_chunk_state) { // outbuf needs at most chunkbuflen/2
    // find first crlf, if any
    if (!chunkbuflen) return 0;
    char* p = (char*)memchr(chunkbuf, '\r', chunkbuflen-1);
    if (p && p[1] != '\n') {
        LOG("bogus separator");
        return -1;
    }

    if (in_chunk_state) {
        // we're in actual data atm, check its length
        size_t chunklen;
        if (p) { // we know the end
            chunklen = p-chunkbuf;
            if (chunklen % 2 != 0) { // must be even, as we're expecting hex-encoded data only
                LOG("uneven stream len");
                return -1;
            }
            in_chunk_state = false; // next: chunk len state
        } else { // chunk not yet finished
            chunklen = chunkbuflen;
            if (chunklen % 2 != 0) {
                chunklen--;
            }
        }

        // convert from hex if needed
        ssize_t outlen;
        if (chunklen == 2 && strncmp(chunkbuf, "##", 2) == 0) { // noop/keep-alive
            outlen = 0;
        } else {
            if (!from_hex((unsigned char*)chunkbuf, (unsigned char*)outbuf, chunklen)) {
                LOG("cannot parse hex");
                return -1;
            }
            outlen = chunklen/2;
        }

        // pop converted data from buffer start if needed, and done
        if (p) chunklen += 2;
        if (chunklen == chunkbuflen) {
            chunkbuflen = 0;
        } else {
            chunkbuflen -= chunklen;
            memmove(chunkbuf, chunkbuf+chunklen, chunkbuflen); // TODO:
        }
        return outlen;
    } else {
        // we're in a chunk length, which we ignore
        if (!p) {
            // cannot pop yet
            return 0;
        } else {
            // strip it and recurse
            chunkbuflen -= (p-chunkbuf) + 2;
            memmove(chunkbuf, chunkbuf+(p-chunkbuf)+2, chunkbuflen);
            in_chunk_state = true;
            return strip_encoding(chunkbuf, chunkbuflen, outbuf, in_chunk_state);
        }
    }
}


static ssize_t strip_encoding(char* chunkbuf, size_t& chunkbuflen, char* outbuf) {
    static bool in_chunk_state = false;
    ssize_t rv = 0;
    while (true) {
        ssize_t r = strip_encoding(chunkbuf, chunkbuflen, outbuf, in_chunk_state);
        if (r == -1) {
            return -1;
        } else if (r == 0) {
            break;
        } else {
            rv += r;
        }
    }
    return rv;
}


static bool loop(int cfd, int ufd, const char* session, size_t session_len, const char* preread, size_t preread_len) {
    // work on client, persistent upstream, and per-request upstream fds
    int rfd = -1;
    if (!sock_nonblock(cfd, true) || !sock_nonblock(ufd, true)) {
        return false;
    }
    bool want_read[3]; // cfd, ufd, rfd
    bool want_write[3];
    bool eof = false;

    // don't care about resource usage atm
    size_t reqbuflen = 0;
    size_t repbuflen = 0;
    size_t chunkbuflen = 0;
    static char reqbuf[16384];
    static char repbuf[16384];
    static char chunkbuf[16384];

    while (true) {
        // track if/why we have to block
        bool noop = true;
        memset(want_read, 0, sizeof(want_read));
        memset(want_write, 0, sizeof(want_write));

        // can write to client?
        if (repbuflen) {
            ssize_t rv = buf_write(cfd, repbuf, repbuflen);
            if (rv == -1) {
                return false;
            } else if (rv == 0) {
                want_write[0] = true;
            } else {
                noop = false;
                LOG("> %zd", rv);
            }
        }

        // can write pre-created request to upstream?
        if (reqbuflen) {
            int rv = buf_write(rfd, reqbuf, reqbuflen);
            if (rv == -1) {
                return false;
            } else if (rv == 0) {
                want_write[2] = true;
            } else {
                noop = false;
                LOG(">> %d", rv);
            }
        }

        // can read request status from upstream?
        if (!reqbuflen && rfd != -1) {
            static char tmp[1024];
            ssize_t rv = buf_read(rfd, tmp, sizeof(tmp));
            if (rv == -1) {
                return false;
            } else if (rv == 0) {
                want_read[2] = true;
            } else {
                noop = false;
                close(rfd); // TODO: keep-alive
                rfd = -1;
                if (rv < 12 || tmp[8] != ' ' || tmp[12] != ' ' || atoi(tmp+9) != 200) {
                    LOG("error write response");
                    return false;
                }
            }
        }

        // can read and craft new request?
        if (!reqbuflen && !eof) {
            static char tmp[(sizeof(reqbuf)-1024)/2]; // leave space for header and encoding overhead
            ssize_t rv = buf_read(cfd, tmp, sizeof(tmp));
            if (rv == -1) {
                if (errno == ENOPKG) {
                    eof = true;
                } else {
                    return false;
                }
            } else if (rv == 0) {
                want_read[0] = true;
            } else {
                noop = false;
                LOG("< %zd", rv);

                rfd = sock_connect(&config.dst, true);
                if (rfd == -1) {
                    return false;
                }
                reqbuflen = snprintf(reqbuf, sizeof(reqbuf),
                    "POST /%s?a=write HTTP/1.1\r\n"
                    "Host: %s\r\n"
                    "Cookie: sid=%s\r\n"
                    "Content-Length: %zu\r\n"
                    "Connection: close\r\n"
                    "\r\n", config.script, config.host, session, (size_t)rv*2
                );
                to_hex((unsigned char*)tmp, (unsigned char*)reqbuf+reqbuflen, (size_t)rv);
                reqbuflen += (size_t)rv*2;
            }
        }

        // can read reply data?
        if (chunkbuflen < sizeof(chunkbuf)) {
            ssize_t rv = buf_read(ufd, chunkbuf+chunkbuflen, sizeof(chunkbuf)-chunkbuflen);
            if (rv == -1) {
                return false;
            } else if (rv == 0) {
                want_read[1] = true;
            } else {
                noop = false;
                LOG("<< %zd", rv);
                chunkbuflen += (size_t)rv;
            }
        }

        // can convert reply data?
        if (sizeof(repbuf)-repbuflen > chunkbuflen/2) {
            ssize_t rv = strip_encoding(chunkbuf, chunkbuflen, repbuf+repbuflen);
            if (rv == -1) {
                return false;
            } else {
                repbuflen += (size_t)rv;
            }
        }

        // have to sleep?
        if (noop) {
            if (eof) {
                return true; // TODO: proper half-closed handling
            }

            fd_set rfds, wfds;
            FD_ZERO(&rfds);
            FD_ZERO(&wfds);
            int max = 0;
            if (want_read[0]) FD_SET(cfd, &rfds);
            if (want_read[1]) FD_SET(ufd, &rfds);
            if (want_read[2]) FD_SET(rfd, &rfds);
            if (want_write[0]) FD_SET(cfd, &wfds);
            if (want_write[1]) FD_SET(ufd, &wfds);
            if (want_write[2]) FD_SET(rfd, &wfds);
            if (want_read[0] || want_write[0]) max = MAX(max, cfd);
            if (want_read[1] || want_write[1]) max = MAX(max, ufd);
            if (want_read[2] || want_write[2]) max = MAX(max, rfd);

            int rv = select(max+1, &rfds, &wfds, NULL, NULL);
            if (rv == -1) {
                LOG_ERRNO("select()");
                return false;
            }
        }
    }
}


static bool handle(int fd) {
    static char buf[4096];

    // read desired destination from socks request
    sockaddr_in dst;
    if (!socks_req(fd, &dst)) {
        (void)socks_rep(fd, false);
        return false;
    }

    // obfuscate it for upstream
    uint64_t nonce = get_nonce();
    uint64_t val = obfuscate_dst(nonce, dst);
    char sess[(sizeof(nonce) + sizeof(val)) * 2 + 1] = {};
    to_hex((unsigned char*)&nonce, (unsigned char*)sess, sizeof(nonce));
    to_hex((unsigned char*)&val, (unsigned char*)sess + (sizeof(nonce)*2), sizeof(val));
    LOG("sending request for: %s", sess);

    // connect to upstream
    int ufd = sock_connect(&config.dst, false);
    if (ufd == -1) {
        (void)socks_rep(fd, false);
        return false;
    }

    // craft request and blocking write
    size_t len = snprintf(buf, sizeof(buf),
        "POST /%s?a=bind HTTP/1.1\r\n"
        "Host: %s\r\n"
        "If-None-Match: %s\r\n"
        "\r\n", config.script, config.host, sess
    );
    ssize_t rv = write(ufd, buf, len);
    if (rv != (ssize_t)len) {
        (rv == -1)? LOG_ERRNO("write(upstream)"): LOG("incomplete write to upstream");
        (void)socks_rep(fd, false);
        return false;
    }

    // get response and check reply
    rv = read(ufd, buf, sizeof(buf)-1);
    if (rv <= 0) {
        LOG_ERRNO("read(upstream)");
        (void)socks_rep(fd, false);
        return false;
    }
    buf[rv] = '\0';
    if (rv < 12 || buf[8] != ' ' || buf[12] != ' ' || atoi(buf+9) != 200) {
        LOG("unsuccessful reply");
        (void)socks_rep(fd, false);
        return false;
    }
    if (!strstr(buf, "Transfer-Encoding: chunked")) {
        LOG("no chunked encoding");
        (void)socks_rep(fd, false);
        return false;
    }

    // extract session cookie
    char* cookie_s = strstr(buf, "Set-Cookie: sid=");
    if (!cookie_s) {
        LOG("no session cookie");
        (void)socks_rep(fd, false);
        return false;
    }
    cookie_s += sizeof("Set-Cookie: sid=")-1;
    char* cookie_e = strchr(cookie_s, '\r');
    if (!cookie_e) {
        LOG("cannot parse session cookie");
        (void)socks_rep(fd, false);
        return false;
    }
    *cookie_e = '\0';
    if ((intptr_t)strspn(cookie_s, "0123456789abcdef") != cookie_e-cookie_s) {
        LOG("cannot parse session cookie value");
        (void)socks_rep(fd, false);
        return false;
    }
    LOG("got session '%s'", cookie_s);

    // skip to body
    char* start = strstr(cookie_e+1, "\r\n\r\n");
    if (!start) {
        LOG("no header end");
        (void)socks_rep(fd, false);
        return false;
    }
    start += 4;

    // done, start to proxy
    if (!socks_rep(fd, true)) {
        return false;
    }
    return loop(fd, ufd, cookie_s, cookie_e-cookie_s, start, buf+rv-start);
}


static bool detach(int fd) {
    pid_t pid = fork();
    if (pid == -1) {
        LOG_ERRNO("fork");
        return false;
    } else if (pid != 0) {
        return true;
    } else {
        for (int i=0; i<1024; ++i) {
            if (i != fd && i != STDERR_FILENO) close(i);
        }
        bool rv = handle(fd);
        LOG("done: %s", rv? "success": "error");
        exit(rv? 0: 1);
    }
}


int main(int argc, char** argv) {
    if (!config_parse(argc, argv)) {
        return 1;
    }

    signal(SIGCHLD, SIG_IGN);

    int fd = sock_listen(config.port);
    if (fd == -1) return 1;

    while (true) {
        int conn = sock_accept(fd, false);
        if (conn == -1) break;
        if (!detach(conn)) return 1;
        close(conn);
    }

    return 0;
}