#include "bump.hpp"


static bool bump_op_parser(char* p, bump_op_t& rv) {
    if (!strcmp(p, "none")) {
        rv = BUMP_NONE;
        return true;
    } else if (!strcmp(p, "parent")) {
        rv = BUMP_PARENT;
        return true;
    } else if (!strcmp(p, "bump")) {
        rv = BUMP_BUMP;
        return true;
#if (SPLICE)
    } else if (!strcmp(p, "splice")) {
        rv = BUMP_SPLICE;
        return true;
#endif
    } else {
        return false;
    }
}
static bool bump_acl_config(void*& ctx, char* value) {
    if (!ctx) {
        ctx = new Acl<bump_op_t>(&bump_op_parser);
    }
    return ((Acl<bump_op_t>*)ctx)->parse_line(value);
}
static void bump_acl_unconfig(void*& ctx) {
    if (ctx) {
        delete (Acl<bump_op_t>*)ctx;
        ctx = NULL;
    }
}

CONF_DEF(config) {
    ConfigKey wallet_dir;
    ConfigKey ssl_cache;
    ConfigKey bump_acl;
};
CONF_INIT(config) {
    CONF_KEY_INST(wallet_dir, true, Wallet, char*);
    CONF_KEY_INST(ssl_cache, true, SslCtx, size_t); // TODO: cache clean?
    CONF_KEY_INIT(bump_acl, true, true, &bump_acl_config, &bump_acl_unconfig);
}


//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//


static INLINE bool is_ip(const char* cn) {
    static sockaddr_t tmp;
    return str2addr(cn, tmp) && strchr(cn, ':') == NULL;
}


bump_op_t bump_ruleset_lookup(const char* host, const sockaddr_t& src, const sockaddr_t& dst) {
    bump_op_t rv;
    Acl<bump_op_t>* acls = config? config->bump_acl.ctx_as<Acl<bump_op_t> >(): NULL;
    if (acls && acls->match(host, src, dst, rv) == TRI_TRUE) {
        return rv;
    } else {
        return BUMP_NONE;
    }
}


Bump* Bump::getInst(int fd, const char* cn, size_t len) {
    // as wildcards match only one level of subdomains, we replace the first part with a '*' (but not for 1st level), e.g: foo.bar.baz -> *.bar.baz
    #if (WILDCARD_CERTS)
        if (!is_ip(cn)) {
            char wc[MAX_SNI_LEN+1];
            assert(len < sizeof(wc));
            const char* dot = strchr(cn, '.');
            if (dot && dot != cn && strchr(dot+1, '.') != NULL) {
                wc[0] = '*';
                strncpy(wc+1, dot, len-(dot-cn));
                len = len-(dot-cn)+1;
                wc[len] = '\0';
                log(debug, "wildcard: '%s' -> '%s'", cn, wc);
                cn = wc;
            }
        }
    #endif

    // immediate cache-hit?
    stats_inc(certs);
    ssl_ctx_t* ssl = NULL;
    SslCtx* sslctx = config? config->ssl_cache.ctx_as<SslCtx>(): NULL;
    if (sslctx && (ssl = sslctx->get(cn)) != NULL) {
        stats_inc(cert_hits);
        return new Bump(fd, ssl);
    }

    // do we have predefined certs in the wallet?
    Wallet* wallet = config? config->wallet_dir.ctx_as<Wallet>(): NULL;
    if (wallet) {
        switch (wallet->get(cn, sslctx, ssl)) {
            case TRI_TRUE:
                return new Bump(fd, ssl); // great success
            case TRI_FALSE:
                return NULL; // permanent or temporary error
            case TRI_NONE:
                break; // not responsible for this host, can go on
        }
    }

    // have to create inst before enqueing, for preventing races and for passing as ctx
    Bump* bump = new Bump(fd);

    // enqueue or revert
    if (!ipc_enqueue(fd, &handler, bump, cn, len)) {
        bump->pending = false;
        delete bump;
        return NULL;
    }

    // return handle
    return bump;
}


Bump::Bump(int _fd): fd(_fd), ssl_ctx(NULL), pending(true) {
    log(debug, "dispatching cert generation");
}


Bump::Bump(int _fd, ssl_ctx_t* sslctx): fd(_fd), ssl_ctx(sslctx), pending(false) {
    log(debug, "dispatching cert generation: cache hit");
}


Bump::~Bump() {
    if (pending) {
        assert(!ssl_ctx);
        log(debug, "aborting cert generation");
        ipc_dequeue(fd);
    } else if (ssl_ctx) {
        log(debug, "discarding cert result");
        ssl_ctx_del(ssl_ctx);
    }
}


tristate_t Bump::ready() const {
    if (pending) {
        return TRI_NONE;
    } else if (ssl_ctx) {
        return TRI_TRUE;
    } else {
        return TRI_FALSE;
    }
}


ssl_ctx_t* Bump::get() {
    assert(!pending);
    ssl_ctx_t* sslctx = ssl_ctx;
    ssl_ctx = NULL;
    return sslctx;
}


void Bump::handler(int fd, void* inst, ssl_ctx_t* sslctx=NULL) {
    if (!sslctx) {
        log(error, "no certd result for %d", fd);
    }
    Bump* bump = (Bump*)inst;
    assert(bump && bump->fd == fd && bump->pending);
    bump->pending = false;
    bump->ssl_ctx = sslctx;
    Poll::getInst()->wakeup(fd);
}


void Bump::handler(int fd, void* inst, char* buf, size_t len) {
    if (!len) {
        handler(fd, inst);
        return;
    }

    SslCtx* sslctx = config? config->ssl_cache.ctx_as<SslCtx>(): NULL;
    ssl_ctx_t* rv;

    char* cn = buf;
    char* cert = strchrnul(cn, '~'); // allow parsing multiple times
    if (cert >= buf+len) {
        handler(fd, inst);
        return;
    } else {
        *cert = '\0';
        ++cert;
    }

    if (sslctx) {
        rv = sslctx->get(cn); // if multiple ipc requests had been forwarded, there is a good chance for a cache-hit now w/o need for further parsing or cert generation
        if (rv) {
            log(debug, "late cert hit for '%s'", cn);
            handler(fd, inst, rv);
            return;
        }
    }

    char* key = strchrnul(cert, '~');
    if (key >= buf+len) {
        handler(fd, inst);
        return;
    } else {
        *key = '\0';
        ++key;
    }
    size_t certlen = key - cert - 1;
    size_t keylen = buf + len - key;

    rv = sslctx?
        sslctx->get(cn, cert, certlen, key, keylen):
        ssl_ctx_get(cert, certlen, key, keylen);
    handler(fd, inst, rv);
}