#include "proxy.hpp"


typedef struct {
    uint8_t vn; // VN is the SOCKS protocol version number and should be 4
    uint8_t cd; // CD is the SOCKS command code and should be 1 for CONNECT request
    uint16_t dstport;
    uint32_t dstip;
    uint8_t null; // don't support any USERID
} PACKED socks_req_t;

typedef struct {
    uint8_t vn; // VN is the version of the reply code and should be 0
    uint8_t cd; // CD is the result code with one of the following values: 90: request granted; 91: request rejected or failed
    uint8_t pad[6];
} PACKED socks_res_t;


static const socks_res_t socks_res_ok = { 0, 90, {0} };
static const socks_res_t socks_res_fail = { 0, 91, {0} };
static const char* connect_res_ok = "HTTP/1.0 200 Connection Established\r\n\r\n";
static const char* connect_res_fail = "HTTP/1.0 400 Bad Request\r\n\r\n";


static bool handle_socks_request(int fd, char* buf, size_t len, sockaddr_t& dst, char*& host) {
    if (len != sizeof(socks_req_t)) { // assume we will get this little bit all at once
        log(debug, "not a SOCKS helo w/o userid: #%zu", len);
        return false;
    }
    const socks_req_t* r = (socks_req_t*)buf;

    if (r->vn != 4 || r->cd != 1 || r->null != 0) {
        log(debug, "not a SOCKS 4 connect: %u/%u/%u", r->vn, r->cd, r->null);
        return false;
    }

    if (!r->dstip || !r->dstport) { // TODO: check for other invalid or bad dst addresses/ports
        log(debug, "not returning SOCKS for %u:%u", htonl(r->dstip), htons(r->dstport));
        (void)write(fd, &socks_res_fail, sizeof(socks_res_fail));
        return false;
    }

    if (write(fd, &socks_res_ok, sizeof(socks_res_ok)) != sizeof(socks_res_ok)) { // assume we can write this little bit all at once
        log_errno(io, "cannot write SOCKS reply");
        return false;
    }

    host = NULL;
    bzero(&dst, sizeof(dst));
    dst.family = AF_INET;
    dst.addr4.sin_addr.s_addr = r->dstip;
    dst.addr4.sin_port = r->dstport;

    log(info, "found SOCKS to %s", addr2str(dst, true));
    return true;
}


static bool handle_connect_request(int fd, char* buf, size_t len, sockaddr_t& dst, char*& host) {
    if (len < cstrlen("CONNECT X:N HTTP/1.X__")) {
        log(debug, "not a CONNECT request: #%zu", len);
        return false;
    }
    if (memcmp("CONNECT ", buf, cstrlen("CONNECT ")) != 0) {
        log(debug, "no CONNECT method");
        return false;
    }
    if (memcmp("\r\n\r\n", buf+len-4, 4) != 0 && memcmp("\n\n", buf+len-2, 2) != 0) {
        log(debug, "truncated CONNECT request");
        return false;
    }
    buf += cstrlen("CONNECT ");

    char* conn_port = strchr(buf, ':');
    unless (conn_port) {
        log(debug, "no CONNECT port given");
        return false;
    }
    *conn_port = '\0'; ++conn_port;
    in_port_t sin_port = ntohs(atoi(conn_port));
    if (!sin_port) {
        log(debug, "cannot parse CONNECT port");
        return false;
    }

    char* p = strchr(conn_port, ' ');
    unless (p) {
        log(debug, "cannot find CONNECT version");
        return false;
    }
    if ((strncmp(" HTTP/1.", p, cstrlen(" HTTP/1.")) != 0) || (p[cstrlen(" HTTP/1.")] != '0' && p[cstrlen(" HTTP/1.")] != '1')) {
        log(debug, "cannot parse CONNECT version");
        return false;
    }

    if (str2addr(buf, dst)) {
        dst.addr4.sin_port = sin_port;
        log(info, "found CONNECT to ip %s", addr2str(dst, true));
    } else if (validate_host(buf, conn_port-buf-1)) {
        host = strdup(buf);
        dst.addr4.sin_port = sin_port;
        log(info, "found CONNECT to host %s:%u", host, htons(sin_port));
    } else { // XXX: won't look for Host header instead, not trusted
        log(debug, "cannot parse CONNECT destination");
        (void)write(fd, connect_res_fail, strlen(connect_res_fail));
        return false;
    }

    if (write(fd, connect_res_ok, strlen(connect_res_ok)) != (ssize_t)strlen(connect_res_ok)) {
        log_errno(io, "cannot write CONNECT reply");
        return false;
    }

    return true;
}


bool handle_proxy_request(int fd, sockaddr_t& dst, char*& host) { // TODO: return event_t instead?
    static char buf[BUF_SIZE];
    ssize_t rv = read(fd, buf, sizeof(buf)-1);
    if (rv == -1) {
        log_errno(io, "proxy read(%d)", fd);
        return false;
    } else if (rv == 0) {
        log(io, "proxy read(%d) eof", fd);
        return false;
    } else {
        log(io, "proxy read(%d): %zd", fd, rv);
    }
    const size_t len = (size_t)rv;
    buf[len] = '\0';

    if (*buf == 'C') {
        return handle_connect_request(fd, buf, len, dst, host);
    } else if (*buf == 4) {
        return handle_socks_request(fd, buf, len, dst, host);
    } else {
        log(debug, "does not look like SOCKS or CONNECT: %x #%zu", *buf, len);
        return false;
    }
}