#include "proxy.hpp"
typedef struct {
uint8_t vn; // VN is the SOCKS protocol version number and should be 4
uint8_t cd; // CD is the SOCKS command code and should be 1 for CONNECT request
uint16_t dstport;
uint32_t dstip;
uint8_t null; // don't support any USERID
} PACKED socks_req_t;
typedef struct {
uint8_t vn; // VN is the version of the reply code and should be 0
uint8_t cd; // CD is the result code with one of the following values: 90: request granted; 91: request rejected or failed
uint8_t pad[6];
} PACKED socks_res_t;
static const socks_res_t socks_res_ok = { 0, 90, {0} };
static const socks_res_t socks_res_fail = { 0, 91, {0} };
static const char* connect_res_ok = "HTTP/1.0 200 Connection Established\r\n\r\n";
static const char* connect_res_fail = "HTTP/1.0 400 Bad Request\r\n\r\n";
static bool handle_socks_request(int fd, char* buf, size_t len, sockaddr_t& dst, char*& host) {
if (len != sizeof(socks_req_t)) { // assume we will get this little bit all at once
log(debug, "not a SOCKS helo w/o userid: #%zu", len);
return false;
}
const socks_req_t* r = (socks_req_t*)buf;
if (r->vn != 4 || r->cd != 1 || r->null != 0) {
log(debug, "not a SOCKS 4 connect: %u/%u/%u", r->vn, r->cd, r->null);
return false;
}
if (!r->dstip || !r->dstport) { // TODO: check for other invalid or bad dst addresses/ports
log(debug, "not returning SOCKS for %u:%u", htonl(r->dstip), htons(r->dstport));
(void)write(fd, &socks_res_fail, sizeof(socks_res_fail));
return false;
}
if (write(fd, &socks_res_ok, sizeof(socks_res_ok)) != sizeof(socks_res_ok)) { // assume we can write this little bit all at once
log_errno(io, "cannot write SOCKS reply");
return false;
}
host = NULL;
bzero(&dst, sizeof(dst));
dst.family = AF_INET;
dst.addr4.sin_addr.s_addr = r->dstip;
dst.addr4.sin_port = r->dstport;
log(info, "found SOCKS to %s", addr2str(dst, true));
return true;
}
static bool handle_connect_request(int fd, char* buf, size_t len, sockaddr_t& dst, char*& host) {
if (len < cstrlen("CONNECT X:N HTTP/1.X__")) {
log(debug, "not a CONNECT request: #%zu", len);
return false;
}
if (memcmp("CONNECT ", buf, cstrlen("CONNECT ")) != 0) {
log(debug, "no CONNECT method");
return false;
}
if (memcmp("\r\n\r\n", buf+len-4, 4) != 0 && memcmp("\n\n", buf+len-2, 2) != 0) {
log(debug, "truncated CONNECT request");
return false;
}
buf += cstrlen("CONNECT ");
char* conn_port = strchr(buf, ':');
unless (conn_port) {
log(debug, "no CONNECT port given");
return false;
}
*conn_port = '\0'; ++conn_port;
in_port_t sin_port = ntohs(atoi(conn_port));
if (!sin_port) {
log(debug, "cannot parse CONNECT port");
return false;
}
char* p = strchr(conn_port, ' ');
unless (p) {
log(debug, "cannot find CONNECT version");
return false;
}
if ((strncmp(" HTTP/1.", p, cstrlen(" HTTP/1.")) != 0) || (p[cstrlen(" HTTP/1.")] != '0' && p[cstrlen(" HTTP/1.")] != '1')) {
log(debug, "cannot parse CONNECT version");
return false;
}
if (str2addr(buf, dst)) {
dst.addr4.sin_port = sin_port;
log(info, "found CONNECT to ip %s", addr2str(dst, true));
} else if (validate_host(buf, conn_port-buf-1)) {
host = strdup(buf);
dst.addr4.sin_port = sin_port;
log(info, "found CONNECT to host %s:%u", host, htons(sin_port));
} else { // XXX: won't look for Host header instead, not trusted
log(debug, "cannot parse CONNECT destination");
(void)write(fd, connect_res_fail, strlen(connect_res_fail));
return false;
}
if (write(fd, connect_res_ok, strlen(connect_res_ok)) != (ssize_t)strlen(connect_res_ok)) {
log_errno(io, "cannot write CONNECT reply");
return false;
}
return true;
}
bool handle_proxy_request(int fd, sockaddr_t& dst, char*& host) { // TODO: return event_t instead?
static char buf[BUF_SIZE];
ssize_t rv = read(fd, buf, sizeof(buf)-1);
if (rv == -1) {
log_errno(io, "proxy read(%d)", fd);
return false;
} else if (rv == 0) {
log(io, "proxy read(%d) eof", fd);
return false;
} else {
log(io, "proxy read(%d): %zd", fd, rv);
}
const size_t len = (size_t)rv;
buf[len] = '\0';
if (*buf == 'C') {
return handle_connect_request(fd, buf, len, dst, host);
} else if (*buf == 4) {
return handle_socks_request(fd, buf, len, dst, host);
} else {
log(debug, "does not look like SOCKS or CONNECT: %x #%zu", *buf, len);
return false;
}
}