#include "ac.hpp"


static void ac_upstream_unconfig(void*& ctx) {
    safe_free(ctx);
}
static bool ac_upstream_config(void*& ctx, char* value) {
    ctx = (void*)tcalloc(sockaddr_t);
    if (!str2addr(value, *(sockaddr_t*)ctx)) {
        safe_free(ctx);
        return false;
    }
    return true;
}

static bool ac_acl_config(void*& ctx, char* value) {
    if (!ctx) {
        ctx = new Acl<bool>(&AclMatch::parser);
    }
    return ((Acl<bool>*)ctx)->parse_line(value);
}
static void ac_acl_unconfig(void*& ctx) {
    if (ctx) {
        delete (Acl<bool>*)ctx;
        ctx = NULL;
    }
}

typedef HashCache<AccessControl::key_t, AccessControl::val_t> ACCache;

CONF_DEF(config) {
    ConfigKey ac_upstream;
    ConfigKey ac_acl;
    ConfigKey ac_cache; ///< lookup result or pending request
};
CONF_INIT(config) {
    CONF_KEY_INIT(ac_upstream, true, false, &ac_upstream_config, &ac_upstream_unconfig);
    CONF_KEY_INIT(ac_acl, true, true, &ac_acl_config, &ac_acl_unconfig);
    CONF_KEY_INST(ac_cache, true, ACCache, size_t);
}
HOOK(CACHE_CLEAN, HOOK_PRIO_MID) {
    if (config) {
        ACCache* cache = config->ac_cache.ctx_as<ACCache>();
        if (cache) {
            cache->invalidate();
        }
    }
}


//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//


AccessControlResult::AccessControlResult(int _fd, AccessControlResult** _handle): fd(_fd), result(TRI_NONE), handle(_handle) {
    *handle = this;
}


AccessControlResult::AccessControlResult(int _fd, bool _result): fd(_fd), result(bi2tri(_result)), handle(NULL) {
}


void AccessControlResult::set(bool _result) {
    assert(result == TRI_NONE);
    assert(handle && *handle == this);
    result = bi2tri(_result);
    handle = NULL;
    Poll::getInst()->wakeup(fd);
}


AccessControlResult::~AccessControlResult() { ///< abort or frees unfetched result
    if (handle) {
        assert(*handle == this);
        *handle = NULL;
    }
}


//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//


CPool<AccessControl::work_t> AccessControl::pool;


AccessControlResult* AccessControl::getInst(int fd, const char* host, size_t host_len, const sockaddr_t& src, const sockaddr_t& dst) {
    // check local acls first. XXX: Caching here, too?
    const Acl<bool>* acls = config? config->ac_acl.ctx_as<Acl<bool> >(): NULL;
    if (acls) {
        bool rv = false;
        switch (acls->match(host, src, dst, rv)) {
            case TRI_TRUE:
            case TRI_FALSE:
                log(debug, "ac: acl match: %d", rv);
                return new AccessControlResult(fd, rv);
                break;
            case TRI_NONE:
                log(debug, "ac: no matching acls");
                break;
        }
    }

    // build cache-key
    ACCache* cache = config? config->ac_cache.ctx_as<ACCache>(): NULL;
    key_t key;
    if (cache) {
        key.src = src;
        if (host) {
            assert(host_len < sizeof(key.dst));
            strncpy(key.dst, host, sizeof(key.dst));
        } else {
            bzero(key.dst, sizeof(key.dst));
        }
    }

    // check cache for result or already pending request
    val_t* hit = cache? cache->at(key): NULL;
    if (hit) {
        if (hit->upstream_fd != -1) {
            log(debug, "enqueuing to pending ac request");
            work_t* w = (work_t*)Poll::getInst()->ctx_get(hit->upstream_fd);
            assert(w);
            AccessControlResult** handle = w->pending.push_get();
            return new AccessControlResult(fd, handle);
        } else if (!AC_CACHE_TTL || hit->last_result >= NOW-AC_CACHE_TTL) {
            log(debug, "ac cache-hit");
            return new AccessControlResult(fd, hit->result); // TODO: pooling?
        } else {
            log(debug, "ac cache-miss");
            hit = NULL;
        }
    }

    // create new non-blocking socket
    sockaddr_t* ac = config? config->ac_upstream.ctx_as<sockaddr_t>(): NULL;
    int sfd = ac? socket_connect(NULL, ac): -1;
    if (sfd == -1) {
        log(debug, "no ac connection");
        return new AccessControlResult(fd, true);
    }

    // build request
    static char buf[1024];
    char* p = buf;
    p = strccpy(p, "GET https://");
    if (host) {
        p = strncpy(p, host, host_len) + host_len;
        p += sprintf(p, ":%u", dst.addr4.sin_port);
    } else {
        p = addr2str(dst, p, true);
    }
    p = strccpy(p, "/ HTTP/1.0\nX-Forwarded-For: ");
    p = addr2str(src, p, false);
    p = strccpy(p, "\nCache-Control: only-if-cached\n\n");
    const size_t len = p-buf;

    // simply write directly, don't use AioFile as its handler is threaded, and its so small that the socket should buffer it for us
    ssize_t rv = write(sfd, buf, len);
    if (rv != (ssize_t)len) {
        log(error, "cannot write ac request");
        EINTR_RETRY(close(sfd));
        return new AccessControlResult(fd, true);
    }

    // create new handle
    work_t* work = pool.pop();
    work->key = key;
    assert(work->pending.empty());
    AccessControlResult** handle = work->pending.push_get();

    // register in Poll (should be room for that as we don't splice etc yet), with short timeout
    Poll::getInst()->add(sfd, cb, EVENT_IN, true, work);

    // we're the first miss
    assert(!hit);
    if (cache) {
        hit = cache->get(key);
        hit->upstream_fd = sfd;
    }

    // return handle
    log(debug, "enqueuing new ac request");
    return new AccessControlResult(fd, handle);
}


void AccessControl::cb(int fd, event_t ev, unsigned, void* w) {
    // try to read in any case, close and remove from poll
    char buf[1024];
    ssize_t rv = read(fd, buf, sizeof(buf)-1);
    Poll::getInst()->del(fd);
    EINTR_RETRY(close(fd));
    if (rv <= 0) {
        log(error, "cannot read ac response");
        rv = 0;
    }
    buf[(size_t)rv] = '\0';

    // parse result
    bool result;
    int code;
    if (sscanf(buf, "HTTP/1.%*c %d ", &code) != 1) {
        log(error, "unexpected ac response");
        result = true;
    } else if (code == 504) {
        log(debug, "positive ac response");
        result = true;
    } else if (code == 403) {
        log(debug, "negative ac response");
        result = false;
    } else {
        log(error, "unexpected ac response code %d", code);
        result = true;
    }

    // put into cache
    work_t* work = (work_t*)w;
    assert(work);
    ACCache* cache = config? config->ac_cache.ctx_as<ACCache>(): NULL;
    val_t* hit = cache? cache->at(work->key): NULL;
    if (hit) {
        hit->last_result = NOW;
        hit->result = result;
        hit->upstream_fd = -1;
    }

    // wakeup clients w/ result
    while (!work->pending.empty()) {
        AccessControlResult* handle = work->pending.pop();
        if (handle) {
            handle->set(result);
        }
    }
    pool.push(work);
}