#include "certs.hpp"
#include <dirent.h>
#include <fcntl.h>


TEST(ssl_ctx) {
    SslCtx sslctx(1000);

    const char* cn = "foo.com";
    char *key=NULL, *cert=NULL;
    size_t keylen, certlen;

    unless (mkcert(cn, cert, certlen, key, keylen)) return false;
    ssl_ctx_t* ctx1 = sslctx.get(cn, cert, certlen, key, keylen);
    free(key);
    free(cert);

    unless (ctx1) return false;
    ssl_ctx_t* ctx2 = sslctx.get(cn);
    unless (ctx2) return false;
    unless (!sslctx.get("bar.com")) return false;

    ssl_ctx_del(ctx1);
    ssl_ctx_del(ctx2);

    return true;
}


SslCtx::SslCtx(size_t size): cache(size) {
}


SslCtx::~SslCtx() {
    for (node_t* it = cache.begin(); it != cache.end(); ++it) {
        if (it->ctx) {
            ssl_ctx_del(it->ctx);
        }
    }
}


ssl_ctx_t* SslCtx::get(const char* host, const char* cert, size_t certlen, const char* key, size_t keylen) {
    ssl_ctx_t* ctx = ssl_ctx_get(cert, certlen, key, keylen);
    log(debug, "sslctx: '%s': %p", host, ctx);
    unless (ctx) return NULL;

    node_t* node = cache.at(hash(host));
    if (node->ctx) {
        ssl_ctx_del(node->ctx); // refcount will be decremented s.t. this or the last ssl_del() will actually free it.
    }
    node->ctx = ctx;
    ssl_ctx_up_ref(ctx); // we increase the refcount if we cache it s.t. it stays valid even if it gets "freed" before/after the actual ssl has been derived.
    assert(strlen(host) < sizeof(node->host));
    strcpy(node->host, host);

    return ctx;
}


ssl_ctx_t* SslCtx::get(const char* host) {
    ssl_ctx_t* ctx = NULL;
    node_t* node = cache.at(hash(host));
    if (node->ctx && strcmp(host, node->host) == 0) {
        ctx = node->ctx;
        ssl_ctx_up_ref(ctx);
    }
    log(debug, "sslctx: '%s': cache-%s", host, ctx? "hit": "miss");
    return ctx;
}


//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//


Wallet::Wallet(const char* path): data(1000) {
    dirfd = open(path, O_RDONLY|O_DIRECTORY);
    if (dirfd == -1) {
        log_errno(error, "open(%s, O_DIRECTORY)", path);
        return;
    }

    int dirpfd = dup(dirfd);
    DIR* dirp = fdopendir(dirpfd);
    unless (dirp) {
        log_errno(error, "opendir(%s)", path);
        close(dirpfd);
        close(dirfd);
        dirfd = -1;
        return;
    }

    struct dirent* dent;
    while ((dent = readdir(dirp)) != NULL) {
        if (*dent->d_name != '.') {
            log(io, "found wallet file '%s'", dent->d_name);
            load(dent->d_name);
        }
    }
    closedir(dirp);

    log(info, "crawled wallet dir '%s'", path);
}


Wallet::~Wallet() {
    for (data_t* it = data.begin(); it != data.end(); ++it) {
        if (it->aio_result) it->aio_result->delInst(); // aborts and/or frees buf
    }
    if (dirfd != -1) close(dirfd);
}

void Wallet::load(const char* name) {
    data_t* node = data[hash(name)];
    if (*node->host) {
        log(notice, "wallet cache collision: '%s', '%s'", node->host, name);
        return;
    }
    if (strlen(name) >= sizeof(node->host)) {
        return;
    }
    strcpy(node->host, name);

    assert(!node->aio_result);
    node->aio_result = AioFileIn::enqueue(name, dirfd);
    if (node->aio_result) {
        log(debug, "loading wallet cert from '%s'", name);
    }
}


bool Wallet::parse(data_t* node) {
    assert(node->aio_result->len > 0); // available
    assert(!node->cert); // not yet parsed
    node->cert = node->aio_result->buf;
    char* sep = strnchr(node->aio_result->buf, '~', node->aio_result->len); // TODO: support plain pem boundaries
    unless (sep) {
        return false;
    }
    *sep = '\0';
    node->certlen = sep-node->cert;
    node->key = sep+1;
    node->keylen = node->aio_result->len - node->certlen - 1;
    return true;
}


tristate_t Wallet::get(const char* cn, const char*& cert, size_t& certlen, const char*& key, size_t& keylen) {
    data_t* node = data[hash(cn)];
    if (strcmp(cn, node->host) != 0) { // not responsible for this host
        return TRI_NONE;
    }
    if (!node->aio_result) { // some previous error
        return TRI_FALSE;
    }
    if (node->aio_result->len == 0) { // still pending
        return TRI_FALSE;
    }
    if (node->aio_result->len == -1) { // aio error
        node->aio_result->delInst();
        node->aio_result = NULL;
        return TRI_FALSE;
    }
    if (!node->cert) { // available but not yet parsed
        if (!parse(node)) {
            node->aio_result->delInst();
            node->aio_result = NULL;
            return TRI_FALSE;
        }
        assert(node->cert);
    }

    log(debug, "found wallet cert '%s'", cn);
    cert = node->cert;
    certlen = node->certlen;
    key = node->key;
    keylen = node->keylen;
    return TRI_TRUE;
}


tristate_t Wallet::get(const char* cn, SslCtx* sslctx, ssl_ctx_t*& ctx) {
    const char* cert;
    size_t certlen;
    const char* key;
    size_t keylen;
    tristate_t rv = get(cn, cert, certlen, key, keylen);
    if (rv != TRI_TRUE) {
        return rv;
    }

    ctx = sslctx?
        sslctx->get(cn, cert, certlen, key, keylen):
        ssl_ctx_get(cert, certlen, key, keylen);

    unless (ctx) {
        log(notice, "reset wallet cert for '%s'", cn);
        data_t* node = data[hash(cn)];
        assert(node->aio_result && strcmp(cn, node->host) == 0);
        node->aio_result->delInst();
        node->aio_result = NULL;
        return TRI_FALSE;
    }
    return TRI_TRUE;
}