#include "acl.hpp"
#include <fnmatch.h>


bool domain_match(const char* str, const char* pat) {
    if (pat[0] == '*' && pat[1] == '\0') {
        return true;
    } else if (!str) {
        return false;
    } else if (fnmatch(pat, str, 0) == 0) {
        return true;
    } else if (pat[0] == '*' && pat[1] == '.' && fnmatch(pat+2, str, 0) == 0) {
        return true; // s.t. '*.foo.com' matches 'foo.com' as well
    } else {
        return false;
    }
}


bool AclMatch::addr_range_t::parse(const sockaddr_t& a, unsigned m) {
    from = a;
    to = a;
    return from.apply_mask_lower(m) && to.apply_mask_upper(m);
}


bool AclMatch::addr_range_t::parse(char* i) {
    char* m = strchr(i, '/');
    if (!m) return false;
    *m = '\0'; ++m;

    char* p = strchr(m, ':');
    if (p) {
        *p = '\0'; ++p;
    }

    unsigned mask = atoi(m);
    if (!mask && strcmp(m, "0") != 0) return false;

    unsigned port = p? htons(atoi(p)): 0;
    if (!port && p && strcmp(p, "0") != 0) return false;

    sockaddr_t a;
    if (!str2addr(i, a)) return false;
    a.addr4.sin_port = port;
    return parse(a, mask);
}


bool AclMatch::addr_range_t::match(const sockaddr_t& a) const {
    //log(io, "match %s vs %s-%s", strdupa(addr2str(a, true)), strdupa(addr2str(from, true)), strdupa(addr2str(to, true)));
    assert(from.addr4.sin_port == to.addr6.sin6_port);
    return a >= from && a <= to && (!from.addr4.sin_port || a.addr4.sin_port == from.addr4.sin_port);
}


char* AclMatch::parse(char* s) {
    char* a = strchr(s, ' ');
    if (!a) return NULL;
    *a = '\0'; ++a;

    char* b = strchr(a, ' ');
    if (!b) return NULL;
    *b = '\0'; ++b;

    char* c = strchr(b, ' ');
    if (!c) return NULL;
    *c = '\0'; ++c;

    if (!src.parse(s)) return NULL;
    if (!dst.parse(a)) return NULL;
    if (strlen(b) >= sizeof(host)) return NULL;
    strcpy(host, b);

    return c;
}


bool AclMatch::parser(char* p, bool& rv) {
    if (!strcmp(p, "1") || !strcmp(p, "true") || !strcmp(p, "yes") || !strcmp(p, "allow")) {
        rv = true;
        return true;
    } else if (!strcmp(p, "0") || !strcmp(p, "false") || !strcmp(p, "no") || !strcmp(p, "deny")) {
        rv = false;
        return true;
    } else {
        return false;
    }
}


bool AclMatch::parser(char* p, tristate_t& rv) {
    if (!strcmp(p, "true") || !strcmp(p, "yes")) {
        rv = TRI_TRUE;
        return true;
    } else if (!strcmp(p, "false") || !strcmp(p, "no")) {
        rv = TRI_FALSE;
        return true;
    } else if (!strcmp(p, "none") || !strcmp(p, "noop")) {
        rv = TRI_NONE;
        return true;
    } else {
        return false;
    }
}


TEST(acl_match_parse) {
    const sockaddr_t src = str2addr("127.0.0.1:1337");
    const sockaddr_t dst = str2addr("1.2.3.4:443");
    AclMatch acl;

    char* p;
    char* acl_any = strdup("127.0.0.1/24 0.0.0.0/0 * foo");
    p = acl.parse(acl_any);
    TEST_ASSERT(p && !strcmp(p, "foo"));
    TEST_ASSERT(acl.match(NULL, src, dst));
    TEST_ASSERT(!acl.match(NULL, dst, dst));
    TEST_ASSERT(acl.match("bar.baz", src, dst));

    free(acl_any);
    return true;
}