#include "sock.hpp"
#include <sys/un.h>
#include <arpa/inet.h>
#include <net/if.h>


TEST(sin6_port) { // Checks #sockaddr_t
    return (member_offset(sockaddr_in, sin_port) == member_offset(sockaddr_in6, sin6_port)) &&
           (member_sizeof(sockaddr_in, sin_port) == member_sizeof(sockaddr_in6, sin6_port));
}


TEST(address_parsing) {
    const char* src = "1.2.3.4:1234";
    sockaddr_t addr;
    if (!str2addr(src, addr)) return false;
    char dst[ADDRSTRLEN];
    char* p = addr2str(addr, dst, false);
    if (p[0] || !p[-1]) return false;
    p = addr2str(addr, dst, true);
    if (p[0] || !p[-1]) return false;
    return (strcmp(src, dst) == 0);
}


CONF_DEF(config) {
    ConfigKey connect_tos;
};
CONF_INIT(config) {
    CONF_KEY_INIT(connect_tos, true, 0);
}


bool sockaddr_t::apply_mask_lower(const unsigned m) {
    if (family == AF_INET) {
        if (m > 32) {
            return false;
        } else if (m == 32) {
            return true;
        } else if (m == 0) {
            addr4.sin_addr.s_addr = 0;
            return true;
        } else {
            addr4.sin_addr.s_addr &= htonl(0xffffffffu << (32-m));
            return true;
        }
    } else if (family == AF_INET6) {
        /*if (m > 128) {
            return false;
        } else if (m == 128) {
            return true;
        } else if (m == 0) {
            for (unsigned i=0; i<4; ++i) {
                addr6.sin6_addr.s6_addr32[i] = 0;
            }
            return true;
        } else {
            for (unsigned i=0; i<m/32; ++i) {
                addr6.sin6_addr.s6_addr32[i] = 0;
            }
            addr6.sin6_addr.s6_addr32[m/32] &= 0xffffffffu << (m%32);
            return true;
        }*/
        return false; // TODO:
    } else {
        return false;
    }
}


bool sockaddr_t::apply_mask_upper(const unsigned m) {
    if (family == AF_INET) {
        if (m > 32) {
            return false;
        } else if (m == 32) {
            return true;
        } else if (m == 0) {
            addr4.sin_addr.s_addr = 0xffffffffu;
            return true;
        } else {
            addr4.sin_addr.s_addr |= htonl(0xffffffffu >> m);
            return true;
        }
    } else if (family == AF_INET6) {
        return false; // TODO:
    } else {
        return false;
    }
}


int uds_listen(const char* fn) {
    int fd = socket(AF_UNIX, SOCK_STREAM, 0);

    struct sockaddr_un addr;
    addr.sun_family = AF_UNIX;
    strcpy(addr.sun_path, fn);
    unlink(addr.sun_path);
    int len = strlen(addr.sun_path) + sizeof(addr.sun_family);

    if (bind(fd, (struct sockaddr*)&addr, len) == -1) {
        log_errno(error, "bind(%s)", fn);
        EINTR_RETRY(close(fd));
        return -1;
    }

    if (listen(fd, 5) == -1) {
        log_errno(error, "listen(%s)", fn);
        EINTR_RETRY(close(fd));
        return -1;
    }

    return fd;
}


int uds_connect(const char* p) {
    // new socket
    int fd = socket(AF_UNIX, SOCK_STREAM|SOCK_NONBLOCK, 0);
    if (fd == -1) {
        log_errno(error, "socket(UDS)");
        return -1;
    }

    // destination
    struct sockaddr_un remote;
    remote.sun_family = AF_UNIX;
    strcpy(remote.sun_path, p);
    int len = strlen(remote.sun_path) + sizeof(remote.sun_family);

    // connect
    if (connect(fd, (struct sockaddr*)&remote, len) == -1) {
        log_errno(error, "connect(%s)", p);
        EINTR_RETRY(close(fd));
        return -1;
    }

    log(info, "connected to %s", p);
    return fd;
}


void read_empty(int fd) {
    if (shutdown(fd, SHUT_RD) == -1 && errno != ENOTCONN) {
        log_errno(info, "shutdown(%d)", fd);
    }
    static char foo[BUF_SIZE];
    while (read(fd, foo, sizeof(foo)) > 0) {}
}


ssize_t peek(int fd, char** buf) {
    static char peekbuf[BUF_SIZE]; // shared is not a problem, we're single-threaded
    *buf = peekbuf;
    ssize_t rv = recv(fd, peekbuf, sizeof(peekbuf), MSG_DONTWAIT|MSG_PEEK);
    if (!rv) {
        log(io, "recv(%d): EOF", fd);
        return 0;
    } else if (rv == -1) {
        if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
            return 0;
        }
        log_errno(debug, "recv(%d)", fd);
        return -1;
    } else {
        log(io, "recv(%d): %zd", fd, rv);
        return rv;
    }
}


int socket_listen(in_port_t port) {
    int fd;
    int tmp;

    sa_family_t family;
    if ((fd = socket(AF_INET6, SOCK_STREAM|SOCK_NONBLOCK, 0)) != -1) {
        family = AF_INET6;
    } else if ((errno == EAFNOSUPPORT || errno == EPFNOSUPPORT) && ((fd = socket(AF_INET, SOCK_STREAM|SOCK_NONBLOCK, 0)) != -1)) {
        log(notice, "falling back to IPv4");
        family = AF_INET;
    } else {
        log_errno(error, "socket()");
        return -1;
    }

    if (family == AF_INET6) {
        tmp = 0;
        if (setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, &tmp, sizeof(tmp)) != 0) {
            log_errno(error, "setsockopt(IPV6_V6ONLY)");
            close(fd);
            return -1;
        }
    }

    tmp = 1;
    if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &tmp, sizeof(tmp)) != 0) {
        log_errno(error, "setsockopt(SO_REUSEADDR)");
        close(fd);
        return -1;
    }

    sockaddr_t serv_addr;
    bzero((char*)&serv_addr, sizeof(serv_addr)); // so addr is any
    serv_addr.family = family;
    serv_addr.addr6.sin6_port = htons(port);

    if (bind(fd, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) == -1) {
        log_errno(error, "bind()");
        close(fd);
        return -1;
    }

    if (listen(fd, SOMAXCONN) == -1) {
        log_errno(error, "listen()");
        close(fd);
        return -1;
    }

    #ifdef SO_ACCEPTFILTER
        struct accept_filter_arg af;
        strcpy(af.af_name, "https"); // httpready dataready
        strcpy(af.af_arg, "");
        tmp = 1;
        if (setsockopt(fd, SOL_SOCKET, SO_ACCEPTFILTER, &af, sizeof(af)) == -1) {
            log_errno(error, "setsockopt(SO_ACCEPTFILTER)");
        } else if (setsockopt(fd, IPPROTO_TCP, TCP_DEFER_ACCEPT, &tmp, sizeof(tmp)) == -1) {
            log_errno(error, "setsockopt(TCP_DEFER_ACCEPT)");
        } else {
            log(debug, "using SO_ACCEPTFILTER/TCP_DEFER_ACCEPT");
        }
    #else
        log(debug, "not using SO_ACCEPTFILTER/TCP_DEFER_ACCEPT");
    #endif

    log(info, "listening on %d #%d...", (int)port, fd);
    return fd;
}


char* addr2str(const sockaddr_t& addr, char* buf, bool port) {
    if (!inet_ntop(addr.family, (addr.family == AF_INET)? (void*)&addr.addr4.sin_addr: (void*)&addr.addr6.sin6_addr, buf, ADDRSTRLEN)) {
        *buf = '\0';
        return buf;
    }
    int len = strlen(buf);
    if (!port) {
        return buf + len;
    } else {
        return buf + len + sprintf(buf + len, ":%d", htons((addr.family == AF_INET)? (int)addr.addr4.sin_port: (int)addr.addr6.sin6_port));
    }
}


const char* addr2str(const sockaddr_t& addr, bool port) {
    static char ip[ADDRSTRLEN];
    if (addr2str(addr, ip, port) == ip) {
        return "-";
    }
    return ip;
}


bool str2addr(const char* dst, sockaddr_t& addr) {
    addr.family = AF_INET; // TODO: need to support AF_INET6?
    char* ip = strdupa(dst);
    char* port = strchr(ip, ':');
    if (port) {
        *port = '\0'; ++port;
    }
    if (inet_pton(AF_INET, ip, &addr.addr4.sin_addr) != 1) {
        addr.family = 0;
        return false;
    }
    addr.addr4.sin_port = port? ntohs(atoi(port)): 0;
    assert(addr.family != 0);
    return true;
}


sockaddr_t str2addr(const char* dst) {
    sockaddr_t rv = {};
    (void)str2addr(dst, rv);
    return rv;
}


static void unmapv4(sockaddr_t& addr) {
    if (addr.family == AF_INET6 && IN6_IS_ADDR_V4MAPPED(&addr.addr6.sin6_addr)) {
        addr.family = AF_INET;
        memcpy(&addr.addr4.sin_addr, ((char*)&addr.addr6.sin6_addr) + sizeof(addr.addr6.sin6_addr) - sizeof(addr.addr4.sin_addr), sizeof(addr.addr4.sin_addr)); // should not overlap
    }
}


static bool original_dst(int fd, sockaddr_t& addr) {
    #ifdef ORIGINAL_DST
        #ifndef DEBUG
            MESSAGE(overriding original dst)
        #endif
        static sockaddr_t orig = {};
        if (!orig.family) {
            if (!str2addr(ORIGINAL_DST, orig)) {
                die("ORIGINAL_DST parsing failed");
                return false;
            }
        }
        memcpy(&addr, &orig, sizeof(addr));
        return true;
    #endif

    bzero(&addr, sizeof(addr));
    socklen_t addrlen = sizeof(addr);
    if (getsockopt(fd, SOL_IP, SO_ORIGINAL_DST, (void*)&addr, &addrlen) != 0) {
        if (errno == ENOPROTOOPT) { // usually for direct/non-transparent conections
            log_errno(debug, "SO_ORIGINAL_DST failed");
        } else {
            log_errno(info, "SO_ORIGINAL_DST failed");
        }
        return false;
    }
    unmapv4(addr);

    #ifdef ORIGINAL_DST_PORT
        #ifndef DEBUG
            MESSAGE(overriding original dst port)
        #endif
        addr.addr4.sin_port = htons(ORIGINAL_DST_PORT);
    #endif

    return true;
}


int socket_accept(int fd, sockaddr_t& src, sockaddr_t& dst) {
    int len = sizeof(src);
    bzero(&src, sizeof(src));
    int newfd = accept4(fd, (struct sockaddr*)&src, (socklen_t*)&len, SOCK_NONBLOCK);
    if (newfd == -1) return -1;
    unmapv4(src);
    src.addr4.sin_port = 0; // should not be needed and might break some hashing

    if (!original_dst(newfd, dst)) {
        #if (REQUIRE_ORIGINAL_DST)
            EINTR_RETRY(close(newfd));
            return -1;
        #else
            bzero(&dst, sizeof(dst));
        #endif
    }

    log(debug, "accepted fd %d: %s -> %s", newfd, strdupa(addr2str(src)), strdupa(addr2str(dst)));
    return newfd;
}


int socket_connect(const sockaddr_t* src, const sockaddr_t* dst) {
    // new nonblocking socket
    int fd;
    if ((fd = socket(dst->family, SOCK_STREAM|SOCK_NONBLOCK, 0)) == -1) {
        log_errno(error, "socket(%d)", dst->family);
        return -1;
    }

    // SO_REUSEADDR
    int tmp = 1;
    if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &tmp, sizeof(tmp)) != 0) {
        log_errno(error, "setsockopt(SO_REUSEADDR)");
        EINTR_RETRY(close(fd));
        return -1;
    }

    // spoof?
    #if (TRANSPARENT)
        if (src) {
            if (src->family == dst->family) {
                sockaddr_t tmp = *src;
                tmp.addr6.sin6_port = 0;
                if (bind(fd, (struct sockaddr*)&tmp, sizeof(tmp)) == -1) {
                    log_errno(error, "bind()");
                    EINTR_RETRY(close(fd));
                    return -1;
                }
            } else {
                // TODO: map one of them to IPv6?
                log(error, "cannot spoof from family %d to %d", src->family, dst->family);
                EINTR_RETRY(close(fd));
                return -1;
            }
        }
    #else
        #ifndef DEBUG
            MESSAGE(spoofing disabled)
        #endif
    #endif

    // tos?
    int tos = config? config->connect_tos.val.num: 0;
    if (tos) {
        if (setsockopt(fd, IPPROTO_IP, IP_TOS, &tos, sizeof(tos)) == -1) {
            log_errno(error, "setsockopt(tos)");
        }
    }

    // connect
    if ((connect(fd, (struct sockaddr*)dst, sizeof(*dst)) != 0) && (errno != EINPROGRESS)) {
        log_errno(info, "connect()");
        EINTR_RETRY(close(fd));
        return -1;
    }

    // done
    #ifdef DEBUG
        log(debug, "connecting to '%s'...", addr2str(*dst));
    #else
        log(debug, "connecting...");
    #endif
    return fd;
}


#if 0
bool find_mac(const char* macbuf) {
    unsigned char mac[6] = {};
    if (sscanf(macbuf, "%02X:%02X:%02X:%02X:%02X:%02X", (unsigned*)&mac[0], (unsigned*)&mac[1], (unsigned*)&mac[2], (unsigned*)&mac[3], (unsigned*)&mac[4], (unsigned*)&mac[5]) != 6) {
        return false;
    }

    // get arbitrary socket handle
    int fd = socket(AF_INET, SOCK_STREAM, 0);
    if (fd  == -1) {
        log_errno(error, "socket");
        return false;
    }

    // query available interfaces
    char buf[BUF_SIZE];
    struct ifconf ifc = {};
    ifc.ifc_len = sizeof(buf);
    ifc.ifc_buf = buf;
    if(ioctl(fd, SIOCGIFCONF, &ifc) == -1) {
        log_errno(error, "ioctl");
        close(fd);
        return false;
    }

    // iterate through the list of interfaces
    for (int i = 0; i < ifc.ifc_len/(int)sizeof(struct ifreq); ++i) {
        const char* iface = ifc.ifc_ifcu.ifcu_req[i].ifr_ifrn.ifrn_name;

        // get software mac
        struct ifreq ifr = {};
        strcpy(ifr.ifr_name, iface);
        if (ioctl(fd, SIOCGIFHWADDR, &ifr) == -1) {
            log_errno(error, "ioctl");
            continue;
        }
        if (memcmp(mac, ifr.ifr_hwaddr.sa_data, sizeof(mac)) != 0) {
            continue;
        }

        // upon match, get hardware mac
        const size_t addr_len = 32;
        struct ethtool_perm_addr* edata = (struct ethtool_perm_addr*)malloc(sizeof(struct ethtool_perm_addr) + addr_len);
        edata->cmd = ETHTOOL_GPERMADDR;
        edata->size = addr_len;
        ifr.ifr_data = (caddr_t)edata;
        if (ioctl(fd, SIOCETHTOOL, &ifr) == -1) {
            log_errno(error, "ioctl");
            free(edata);
            continue;
        }
        if (memcmp(mac, edata->data, sizeof(mac)) != 0) {
            free(edata);
            continue;
        }

        // done (both match)
        free(edata);
        close(fd);
        return true;
    }

    close(fd);
    return false;
}
#endif