#include "ssl_common.hpp"
#include <openssl/ssl.h>
#include <openssl/err.h>


HOOK(INIT, HOOK_PRIO_EARLY) {
    SSL_library_init();
    SSL_load_error_strings();
    ERR_load_BIO_strings();
    OpenSSL_add_all_algorithms();
    RAND_poll(); // so we make sure this gets done before possible chrooting
}


HOOK(DEINIT, HOOK_PRIO_LATE) {
    ssl_clear_err();
    ERR_remove_thread_state(NULL);
    ERR_remove_state(0);
    //ENGINE_cleanup();
    //CONF_modules_unload(1);
    ERR_free_strings();
    EVP_cleanup();
    sk_SSL_COMP_free(SSL_COMP_get_compression_methods());
    CRYPTO_cleanup_all_ex_data();
}


CONF_DEF(config) {
    ConfigKey dh_params;
};
CONF_INIT(config) {
    CONF_KEY_INIT(dh_params, false, ""); // not reconfigurable as we generate them at startup if not found, enabled if non-empty string
}


void _ssl_log_err(loglevel_t lvl, const char* f, int l) {
    if (&lvl < loglevel) {
        ssl_clear_err();
        return;
    }
    unsigned long e = ERR_get_error();
    unless (e) {
        log(error, "SSL: %s:%d no error?", f, l);
        return;
    }
    do {
        log_(&lvl, "SSL: %s:%d %s [%lu]", f, l, ERR_error_string(e, NULL), e);
    } while ((e = ERR_get_error()) != 0);
}


void _ssl_log_err(loglevel_t lvl, unsigned long e, const char* f, int l) {
    unless (e) {
        log_(&lvl, "SSL: %s:%d no error?", f?:"-", l);
    } else {
        log_(&lvl, "SSL: %s:%d %s [%lu]", f?:"-", l, ERR_error_string(e, NULL), e);
    }
}


bool validate_host(const char* b, size_t l) {
    if (!l || l > MAX_SNI_LEN) return false;
    for (size_t i=0; i<l; ++i) {
        const char c = b[i];
        if (!((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || (c == '.') || (c == '-') || (c == '_'))) {
            return false;
        }
    }
    return true;
}


static DH* ssl_dh_callback(int keylength) { // blocks for i/o and/or generation
    DH* rv = NULL;
    char* fn = NULL;
    if (config && *(config->dh_params.val.str.str)) {
        if (asprintf(&fn, "%sdh_%d.pem", config->dh_params.val.str.str, keylength) < 1) fn = NULL;
    }

    // try to read and validate from existing file, if configured so
    if (fn) {
        int tmp = 0;
        FILE* fp = fopen(fn, "r");
        if (fp) {
            rv = PEM_read_DHparams(fp, NULL, NULL, NULL);
            if (!rv) {
                ssl_log_err(error);
            } else if (DH_check(rv, &tmp) == 0 || tmp != 0) {
                log(error, "could not validate DH params from '%s'", fn);
                DH_free(rv);
                rv = NULL;
            } else {
                log(debug, "read DH params from '%s'", fn);
            }
            fclose(fp);
        }
    }

    // if not yet found, generate on the fly (very costly) and write back, if configured so
    if (!rv) {
        log(info, "generating %d bits DH params", keylength);
        rv = DH_new();
        unless (rv && DH_generate_parameters_ex(rv, keylength, 2, NULL)) {
            ssl_log_err(error);
            if (rv) DH_free(rv);
            rv = NULL;
        }

        if (rv && fn) {
            FILE* fp = fopen(fn, "w");
            if (!fp) {
                log_errno(error, "fopen(%s, w)", fn);
            } else {
                if (!PEM_write_DHparams(fp, rv)) {
                    ssl_log_err(error);
                } else {
                    log(debug, "wrote DH params to '%s'", fn);
                }
                fclose(fp);
            }
        }
    }

    if (fn) free(fn);
    return rv;
}


DH* ssl_dh_callback(SSL* ssl, int, int keylength) {
    static DH* dh_512 = NULL;
    static DH* dh_1024 = NULL;
    static DH* dh_2048 = NULL;

    DH* rv = NULL;
    switch (keylength) {
        case -1: // for initialization
            unless (dh_512) dh_512 = ssl_dh_callback(512);
            unless (dh_1024) dh_1024 = ssl_dh_callback(1024);
            unless (dh_2048) dh_2048 = ssl_dh_callback(2048);
            return NULL;
        case -2: // for cleanup
            if (dh_512) DH_free(dh_512);
            if (dh_1024) DH_free(dh_1024);
            if (dh_2048) DH_free(dh_2048);
            dh_512 = dh_1024 = dh_2048 = NULL;
            return NULL;
        case 512:
            rv = dh_512;
            break;
        case 1024:
            rv = dh_1024;
            break;
        case 2048:
            rv = dh_2048;
            break;
        default: // generating a key on the fly is very costly, so use what is there
            log(notice, "no precomputed %d bits DH params", keylength);
            rv = dh_2048 ?: dh_1024 ?: dh_512;
            break;
    }

    log(debug, "using %d bits DH params %p", keylength, rv);
    assert(rv);
    return rv;
}


static RSA* ssl_rsa_callback(int keylength) { // blocks for i/o and/or generation
    RSA* rv = NULL;
    char* fn = NULL;
    if (config && *(config->dh_params.val.str.str)) {
        if (asprintf(&fn, "%srsa_%d.pem", config->dh_params.val.str.str, keylength) < 1) fn = NULL;
    }

    // try to read and validate from existing file, if configured so
    if (fn) {
        FILE* fp = fopen(fn, "r");
        if (fp) {
            rv = PEM_read_RSAPublicKey(fp, NULL, NULL, NULL);
            if (!rv) {
                ssl_log_err(error);
            } else if (!PEM_read_RSAPrivateKey(fp, &rv, NULL, NULL) || !RSA_check_key(rv)) {
                ssl_log_err(error);
                RSA_free(rv);
                rv = NULL;
            } else {
                log(debug, "read RSA params from '%s'", fn);
            }
            fclose(fp);
        }
    }

    // if not yet found, generate on the fly (very costly) and write back, if configured so
    if (!rv) {
        log(info, "generating %d bits RSA params", keylength);
        rv = RSA_generate_key(keylength, RSA_F4, NULL, NULL);
        unless (rv) {
            ssl_log_err(error);
        }

        if (rv && fn) {
            FILE* fp = fopen(fn, "w");
            if (!fp) {
                log_errno(error, "fopen(%s, w)", fn);
            } else {
                if (!PEM_write_RSAPublicKey(fp, rv) || !PEM_write_RSAPrivateKey(fp, rv, NULL, NULL, 0, NULL, NULL)) {
                    ssl_log_err(error);
                } else {
                    log(debug, "wrote RSA params to '%s'", fn);
                }
                fclose(fp);
            }
        }
    }

    if (fn) free(fn);
    return rv;
}


RSA* ssl_rsa_callback(SSL* ssl, int, int keylength) {
    static RSA* rsa_512 = NULL;
    static RSA* rsa_1024 = NULL;
    static RSA* rsa_2048 = NULL;

    RSA* rv = NULL;
    switch (keylength) {
        case -1: // for initialization
            unless (rsa_512) rsa_512 = ssl_rsa_callback(512);
            unless (rsa_1024) rsa_1024 = ssl_rsa_callback(1024);
            unless (rsa_2048) rsa_2048 = ssl_rsa_callback(2048);
            return NULL;
        case -2: // for cleanup
            if (rsa_512) RSA_free(rsa_512);
            if (rsa_1024) RSA_free(rsa_1024);
            if (rsa_2048) RSA_free(rsa_2048);
            rsa_512 = rsa_1024 = rsa_2048 = NULL;
            return NULL;
        case 512:
            rv = rsa_512;
            break;
        case 1024:
            rv = rsa_1024;
            break;
        case 2048:
            rv = rsa_2048;
            break;
        default: // generating a key on the fly is very costly, so use what is there
            log(notice, "no precomputed %d bits RSA params", keylength);
            rv = rsa_1024 ?: rsa_512 ?: rsa_2048;
            break;
    }

    log(debug, "using %d bits RSA params %p", keylength, rv);
    assert(rv);
    return rv;
}