#include "certs.hpp"
#include <dirent.h>
#include <fcntl.h>
TEST(ssl_ctx) {
SslCtx sslctx(1000);
const char* cn = "foo.com";
char *key=NULL, *cert=NULL;
size_t keylen, certlen;
unless (mkcert(cn, cert, certlen, key, keylen)) return false;
ssl_ctx_t* ctx1 = sslctx.get(cn, cert, certlen, key, keylen);
free(key);
free(cert);
unless (ctx1) return false;
ssl_ctx_t* ctx2 = sslctx.get(cn);
unless (ctx2) return false;
unless (!sslctx.get("bar.com")) return false;
ssl_ctx_del(ctx1);
ssl_ctx_del(ctx2);
return true;
}
SslCtx::SslCtx(size_t size): cache(size) {
}
SslCtx::~SslCtx() {
for (node_t* it = cache.begin(); it != cache.end(); ++it) {
if (it->ctx) {
ssl_ctx_del(it->ctx);
}
}
}
ssl_ctx_t* SslCtx::get(const char* host, const char* cert, size_t certlen, const char* key, size_t keylen) {
ssl_ctx_t* ctx = ssl_ctx_get(cert, certlen, key, keylen);
log(debug, "sslctx: '%s': %p", host, ctx);
unless (ctx) return NULL;
node_t* node = cache.at(hash(host));
if (node->ctx) {
ssl_ctx_del(node->ctx);
}
node->ctx = ctx;
ssl_ctx_up_ref(ctx);
assert(strlen(host) < sizeof(node->host));
strcpy(node->host, host);
return ctx;
}
ssl_ctx_t* SslCtx::get(const char* host) {
ssl_ctx_t* ctx = NULL;
node_t* node = cache.at(hash(host));
if (node->ctx && strcmp(host, node->host) == 0) {
ctx = node->ctx;
ssl_ctx_up_ref(ctx);
}
log(debug, "sslctx: '%s': cache-%s", host, ctx? "hit": "miss");
return ctx;
}
Wallet::Wallet(const char* path): data(1000) {
dirfd = open(path, O_RDONLY|O_DIRECTORY);
if (dirfd == -1) {
log_errno(error, "open(%s, O_DIRECTORY)", path);
return;
}
int dirpfd = dup(dirfd);
DIR* dirp = fdopendir(dirpfd);
unless (dirp) {
log_errno(error, "opendir(%s)", path);
close(dirpfd);
close(dirfd);
dirfd = -1;
return;
}
struct dirent* dent;
while ((dent = readdir(dirp)) != NULL) {
if (*dent->d_name != '.') {
log(io, "found wallet file '%s'", dent->d_name);
load(dent->d_name);
}
}
closedir(dirp);
log(info, "crawled wallet dir '%s'", path);
}
Wallet::~Wallet() {
for (data_t* it = data.begin(); it != data.end(); ++it) {
if (it->aio_result) it->aio_result->delInst();
}
if (dirfd != -1) close(dirfd);
}
void Wallet::load(const char* name) {
data_t* node = data[hash(name)];
if (*node->host) {
log(notice, "wallet cache collision: '%s', '%s'", node->host, name);
return;
}
if (strlen(name) >= sizeof(node->host)) {
return;
}
strcpy(node->host, name);
assert(!node->aio_result);
node->aio_result = AioFileIn::enqueue(name, dirfd);
if (node->aio_result) {
log(debug, "loading wallet cert from '%s'", name);
}
}
bool Wallet::parse(data_t* node) {
assert(node->aio_result->len > 0);
assert(!node->cert);
node->cert = node->aio_result->buf;
char* sep = strnchr(node->aio_result->buf, '~', node->aio_result->len);
unless (sep) {
return false;
}
*sep = '\0';
node->certlen = sep-node->cert;
node->key = sep+1;
node->keylen = node->aio_result->len - node->certlen - 1;
return true;
}
tristate_t Wallet::get(const char* cn, const char*& cert, size_t& certlen, const char*& key, size_t& keylen) {
data_t* node = data[hash(cn)];
if (strcmp(cn, node->host) != 0) {
return TRI_NONE;
}
if (!node->aio_result) {
return TRI_FALSE;
}
if (node->aio_result->len == 0) {
return TRI_FALSE;
}
if (node->aio_result->len == -1) {
node->aio_result->delInst();
node->aio_result = NULL;
return TRI_FALSE;
}
if (!node->cert) {
if (!parse(node)) {
node->aio_result->delInst();
node->aio_result = NULL;
return TRI_FALSE;
}
assert(node->cert);
}
log(debug, "found wallet cert '%s'", cn);
cert = node->cert;
certlen = node->certlen;
key = node->key;
keylen = node->keylen;
return TRI_TRUE;
}
tristate_t Wallet::get(const char* cn, SslCtx* sslctx, ssl_ctx_t*& ctx) {
const char* cert;
size_t certlen;
const char* key;
size_t keylen;
tristate_t rv = get(cn, cert, certlen, key, keylen);
if (rv != TRI_TRUE) {
return rv;
}
ctx = sslctx?
sslctx->get(cn, cert, certlen, key, keylen):
ssl_ctx_get(cert, certlen, key, keylen);
unless (ctx) {
log(notice, "reset wallet cert for '%s'", cn);
data_t* node = data[hash(cn)];
assert(node->aio_result && strcmp(cn, node->host) == 0);
node->aio_result->delInst();
node->aio_result = NULL;
return TRI_FALSE;
}
return TRI_TRUE;
}