#include "ssl.hpp"


INIT_EARLY TPool<ssl_t> ssl_pool; // make sure these are constructed first (and thus destructed last? - after other, more complex classes that might use them)
static SSL_CTX* ssl_client_ctx = NULL; ///< single ctx that gets used for deriving new ssl for communicating with upstream
static SSL_CTX* ssl_server_ctx = NULL; ///< single ctx that gets used for deriving new ssl for communicating with clients
static ssl_t* ssl_server_ssl = NULL; ///< for dummy client handshake for SNI


HOOK(PRE_IO, HOOK_PRIO_EARLY) { // called 571193 times acc. to gprof?
    log(info, "pre-generating DH/RSA params");
    (void)ssl_dh_callback(NULL, 0, -1);
    (void)ssl_rsa_callback(NULL, 0, -1);

    ssl_client_ctx = ssl_ctx_get(false);
    ssl_server_ctx = ssl_ctx_get(true);
    ssl_server_ssl = ssl_get();

    if (ssl_server_ssl) {
        int prio = 0;
        const char* c = SSL_get_cipher_list(ssl_server_ssl->ssl, prio);
        unless (c) die("no ciphers available");
        do {
            log(debug, "supported cipher: %s", c);
            c = SSL_get_cipher_list(ssl_server_ssl->ssl, ++prio);
        } while (c);
    }
}


HOOK(POST_IO, HOOK_PRIO_LATE) {
    if (ssl_server_ctx) SSL_CTX_free(ssl_server_ctx);
    ssl_server_ctx = NULL;
    if (ssl_client_ctx) SSL_CTX_free(ssl_client_ctx);
    ssl_client_ctx = NULL;
    if (ssl_server_ssl) ssl_del(ssl_server_ssl, false);
    ssl_server_ssl = NULL;
    (void)ssl_dh_callback(NULL, 0, -2);
    (void)ssl_rsa_callback(NULL, 0, -2);
}


//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//


static INLINE long ssl_bio_data_len(BIO* b) {
    char* p;
    return BIO_get_mem_data(b, &p);
}


static INLINE ssize_t ssl_finish_handshake(ssl_t* ssl, bool reading) {
    if (likely(ssl->init_finished == TRI_TRUE)) {
        return 0;
    } else if (SSL_is_init_finished(ssl->ssl)) {
        ssl->init_finished = TRI_TRUE;
        return 0;
    }
    ssl->init_finished = TRI_FALSE;
    log(io, "ssl handshake needed..");

    int rv = SSL_do_handshake(ssl->ssl);
    if (rv == 1) {
        ssl->init_finished = TRI_TRUE;
        return 0;
    } else if (rv == 0) {
        // The TLS/SSL handshake was not successful but was shut down controlled and by the specifications of the TLS/SSL protocol. Call SSL_get_error() with the return value ret to find out the reason.
        ssl_log_err(notice); // TODO: SSL_get_error(ssl->ssl, rv)
        errno = ENOTCONN;
        return -1;
    } else {
        switch (SSL_get_error(ssl->ssl, rv)) {
            case SSL_ERROR_WANT_READ:
                errno = reading? EAGAIN: ERESTART;
                return -1;
            case SSL_ERROR_WANT_WRITE:
                errno = reading? ERESTART: EAGAIN;
                return -1;
            case SSL_ERROR_SYSCALL:
                if (!ERR_peek_error()) {
                    if (!rv) errno = ENOMSG; // set errno for EOF
                    return -1;
                }
                // fall thru
            default:
                ssl_log_err(notice);
                errno = EPROTO;
                return -1;
        }
    }
}


ssize_t ssl_read(ssl_t* ssl, char* buf, size_t num) {
    ssl_clear_err();

    // do we (still) have pre-read (client helo) data? then write it into the read bio.
    if (unlikely(ssl->rbuf)) {
        char* b;
        size_t l;

        // as the handshake needs to be completely available, read directly from fd here in case there's still s.t. missing.
        if (ssl->init_finished == TRI_TRUE) {
            log(debug, "not peeking even more data, handshake already finished");
        } else if (ssl->rbuf->has_space()) {
            ssl->rbuf->get_space(b, l);
            ssize_t rv = read(ssl->fd, b, l); // XXX: could read too much?
            log(io, "ssl_read read(%zu): %zd", l, rv);
            if (rv == -1) {
                if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) return rv;
            } else if (rv == 0) {
                // XXX: go on despite eof
            } else {
                ssl->rbuf->add_data((size_t)rv);
            }
        } else {
            log(error, "client helo peek buf full"); // so long handshake?
            errno = EPROTO;
            return -1; // or use chain?
        }

        // write new buf data (if any) into ssl read bio - TODO: can use const buf instead of copying?
        if (ssl->rbuf->has_data()) {
            ssl->rbuf->get_data(b, l);
            int rv = BIO_write(SSL_get_rbio(ssl->ssl), b, l);
            log(io, "ssl_read BIO_write(%zu): %d (now %ld)", l, rv, ssl_bio_data_len(SSL_get_rbio(ssl->ssl)));
            if (rv <= 0) {
                log(error, "cannot write to ssl buf");
                errno = EPROTO;
                return -1;
            } else {
                ssl->rbuf->del_data((size_t)rv); // TODO: this is recoverable in case we want to replay it lateron when connecting directly (in case we did not write)
                assert(!ssl->rbuf->has_data()); // XXX: can this happen?
            }
        }
    }

    int rv = ssl_finish_handshake(ssl, true);
    if (unlikely(rv != 0)) return rv;

    // handshake finished - clean up client helo buf (if any) and re-set ssl bios to fd
    if (unlikely(ssl->rbuf)) {
        if (ssl_bio_data_len(SSL_get_rbio(ssl->ssl))) { // TODO: how to prevent this?
            log(info, "still reading from buf after handshake");
        } else {
            log(debug, "switching to fd %d i/o", ssl->fd);
            buf_pool.push(ssl->rbuf);
            ssl->rbuf = NULL;
            if (SSL_set_rfd(ssl->ssl, ssl->fd) == 0) { // should call bio_free
                ssl_log_err(notice);
                errno = EPROTO;
                return -1;
            }
        }
    }

    rv = SSL_read(ssl->ssl, buf, num);
    if (rv > 0) {
        MAKE_MEM_DEFINED(buf, rv); // http://www.hardening-consulting.com/en/posts/20140512openssl-and-valgrind.html
        return rv; // NOTE: SSL_read might give us not everything
    }

    switch (SSL_get_error(ssl->ssl, rv)) {
        case SSL_ERROR_WANT_READ:
            errno = EAGAIN;
            return -1;
        case SSL_ERROR_WANT_WRITE: // XXX: When an SSL_read() operation has to be repeated because of SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, it must be repeated with the same arguments.
            errno = ERESTART;
            return -1;
        case SSL_ERROR_ZERO_RETURN:
            return 0; // EOF
        case SSL_ERROR_SYSCALL:
            if (!ERR_peek_error()) {
                return rv? -1: 0; // 0: EOF, -1: see errno
            }
            // fall thru
        default:
            ssl_log_err(notice);
            errno = EPROTO;
            return -1;
    }
}


ssize_t ssl_write(ssl_t* ssl, const char* buf, size_t num) {
    ssl_clear_err();

    int rv = ssl_finish_handshake(ssl, false);
    if (unlikely(rv != 0)) return rv;

    rv = SSL_write(ssl->ssl, buf, num);
    if (rv > 0) {
        MAKE_MEM_DEFINED(buf, rv);
        return rv;
    }

    rv = SSL_get_error(ssl->ssl, rv);
    switch (rv) {
        case SSL_ERROR_WANT_WRITE:
            errno = EAGAIN;
            return -1;
        case SSL_ERROR_WANT_READ: // XXX: When an SSL_write() operation has to be repeated because of SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, it must be repeated with the same arguments.
            errno = ERESTART;
            return -1;
        case SSL_ERROR_SYSCALL:
            return -1; // see errno
        default:
            ssl_log_err(notice);
            errno = EPROTO;
            return -1;
    }
}


//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//


event_t getcert(ssl_t* ssl, X509*& x) {
    if (ssl_finish_handshake(ssl, true) == -1) {
        assert(ssl->init_finished != TRI_TRUE);
        switch (errno) {
            case EAGAIN:
                return EVENT_IN;
            case ERESTART:
                return EVENT_OUT;
            default:
                return EVENT_NONE;
        }
    } else {
        assert(ssl->init_finished == TRI_TRUE);
        x = SSL_get_peer_certificate(ssl->ssl);
        return EVENT_CLOSE;
    }
}


bool ssl_sni(const char* buf, size_t len, char*& sni) {
    // (re-)init our client handhake ssl*
    unless (ssl_server_ssl) return false;
    SSL_set_accept_state(ssl_server_ssl->ssl);
    BIO* bio = BIO_new_mem_buf((void*)buf, len); // will be r/o anyhow
    BIO_set_mem_eof_return(bio, -1);
    SSL_set_bio(ssl_server_ssl->ssl, bio, bio); // set r/o write mem bio to detect the error below

    bool rv;
    ssl_clear_err();
    const int hs = SSL_do_handshake(ssl_server_ssl->ssl);
    if (hs < 0) {
        switch (SSL_get_error(ssl_server_ssl->ssl, hs)) {
            case SSL_ERROR_WANT_READ:
                rv = false;
                break;
            case SSL_ERROR_WANT_WRITE:
                rv = true;
                break;
            case SSL_ERROR_SSL:
                // 2007507E:BIO routines:MEM_WRITE:write to read only BIO [537350270]
                rv = (ERR_peek_last_error() == 0x2007507Elu)? true: false;
                break;
            default:
                rv = false;
                break;
        }
    } else {
        rv = false; // a successful handshake should not happen in out case; a controlled handshake shutdown is an error as well.
    }

    const char* host = SSL_get_servername(ssl_server_ssl->ssl, TLSEXT_NAMETYPE_host_name);
    if (host) { // can happen even upon error above
        sni = strdup(host);
    }
    return rv;
}


TEST(client_helo_ssl_sni) {
    char* host;
    bool rv = ssl_sni("\x16\x03\x01\x01\x37\x01\x00\x01\x33\x03\x03\x32\x8a\x89\x1f\x16\x79\x27\xd0\xef\xc0\x71\xd9\x9f\xf8\xf8\x9e\xf9\x22\xcd\xf7\xc3\x45\xac\x84\x78\x35\x90\x48\x0c\xb8\x80\xef\x00\x00\x88\xc0\x30\xc0\x2c\xc0\x28\xc0\x24\xc0\x14\xc0\x0a\x00\xa3\x00\x9f\x00\x6b\x00\x6a\x00\x39\x00\x38\x00\x88\x00\x87\xc0\x32\xc0\x2e\xc0\x2a\xc0\x26\xc0\x0f\xc0\x05\x00\x9d\x00\x3d\x00\x35\x00\x84\xc0\x12\xc0\x08\x00\x16\x00\x13\xc0\x0d\xc0\x03\x00\x0a\xc0\x2f\xc0\x2b\xc0\x27\xc0\x23\xc0\x13\xc0\x09\x00\xa2\x00\x9e\x00\x67\x00\x40\x00\x33\x00\x32\x00\x9a\x00\x99\x00\x45\x00\x44\xc0\x31\xc0\x2d\xc0\x29\xc0\x25\xc0\x0e\xc0\x04\x00\x9c\x00\x3c\x00\x2f\x00\x96\x00\x41\xc0\x11\xc0\x07\xc0\x0c\xc0\x02\x00\x05\x00\x04\x00\x15\x00\x12\x00\x09\x00\xff\x01\x00\x00\x82\x00\x00\x00\x11\x00\x0f\x00\x00\x0c\x77\x77\x77\x2e\x68\x65\x69\x73\x65\x2e\x64\x65\x00\x0b\x00\x04\x03\x00\x01\x02\x00\x0a\x00\x34\x00\x32\x00\x0e\x00\x0d\x00\x19\x00\x0b\x00\x0c\x00\x18\x00\x09\x00\x0a\x00\x16\x00\x17\x00\x08\x00\x06\x00\x07\x00\x14\x00\x15\x00\x04\x00\x05\x00\x12\x00\x13\x00\x01\x00\x02\x00\x03\x00\x0f\x00\x10\x00\x11\x00\x23\x00\x00\x00\x0d\x00\x20\x00\x1e\x06\x01\x06\x02\x06\x03\x05\x01\x05\x02\x05\x03\x04\x01\x04\x02\x04\x03\x03\x01\x03\x02\x03\x03\x02\x01\x02\x02\x02\x03\x00\x0f\x00\x01\x01", 316, host);
    TEST_ASSERT(rv);
    TEST_ASSERT(host);
    TEST_ASSERT(!strcmp(host, "www.heise.de"));
    free(host);
    return true;
}


ssl_t* ssl_get() {
    unless (ssl_server_ctx) return NULL;

    ssl_t* ssl = ssl_pool.pop();
    ssl->fd = -1;
    ssl->init_finished = TRI_NONE;
    ssl->rbuf = NULL;

    ssl->ssl = SSL_new(ssl_server_ctx);
    unless (ssl->ssl) {
        ssl_pool.push(ssl);
        return NULL;
    }
    SSL_set_accept_state(ssl->ssl);

    return ssl;
}


// http://www.roxlu.com/2014/042/using-openssl-with-memory-bios
// http://blog.davidwolinsky.com/2009/10/memory-bios-and-openssl.html
ssl_t* ssl_get(ssl_ctx_t* ctx, int fd, buf_t* rbuf) {
    assert(fd != -1);

    ssl_t* ssl = ssl_pool.pop();
    ssl->fd = fd;
    ssl->init_finished = TRI_NONE;
    ssl->rbuf = NULL;

    ssl->ssl = SSL_new(ctx);
    unless (ssl->ssl) {
        ssl_log_err(notice);
        ssl_pool.push(ssl);
        return NULL;
    }
    SSL_set_accept_state(ssl->ssl);

    if (rbuf) {
        BIO* sck = BIO_new_socket(fd, false);
        unless (sck) {
            SSL_free(ssl->ssl);
            ssl_pool.push(ssl);
            return NULL;
        }
        BIO* mem = BIO_new(BIO_s_mem());
        BIO_set_mem_eof_return(mem, -1);
        SSL_set_bio(ssl->ssl, mem, sck); // SSL_set_wfd?
        ssl->rbuf = rbuf;
        log(io, "created ssl read buf");
    } else {
        unless (SSL_set_fd(ssl->ssl, fd)) { // XXX: File descriptor BIOs should not be used for socket I/O. Use socket BIOs instead.
            ssl_log_err(notice);
            SSL_free(ssl->ssl);
            ssl_pool.push(ssl);
            return NULL;
        }
    }

    return ssl;
}


ssl_t* ssl_get(const char* host, int fd) {
    unless (ssl_client_ctx) return NULL;
    SSL* ssl = SSL_new(ssl_client_ctx);
    unless (ssl) {
        ssl_log_err(notice);
        return NULL;
    }

    if (host) {
        if (!SSL_set_tlsext_host_name(ssl, host)) { // TODO: is it ok if its freed somewhen?
            log(notice, "cannot set SNI");
        }
    }

    SSL_set_connect_state(ssl);
    assert(fd != -1);
    unless (SSL_set_fd(ssl, fd)) {
        ssl_log_err(notice);
        SSL_free(ssl);
        return NULL;
    }

    ssl_t* rv = ssl_pool.pop();
    rv->ssl = ssl;
    rv->fd = fd;
    rv->init_finished = TRI_NONE;
    rv->rbuf = NULL;
    return rv;
}


void ssl_del(ssl_t* ssl, bool shutdown) {
    if (shutdown) {
        int rv = SSL_shutdown(ssl->ssl); // TODO: check for previous errors or internal state in case this is not needed?
        if (rv == -1) {
            rv = SSL_get_error(ssl->ssl, rv);
            if (rv == SSL_ERROR_WANT_WRITE || SSL_ERROR_WANT_READ) {
                log(io, "unclean SSL_shutdown");
            } else {
                ssl_log_err(info);
            }
        } else if (rv == 0) {
            log(io, "unidirectional SSL_shutdown");
        } else {
            log(io, "SSL_shutdown");
        }
    }
    SSL_free(ssl->ssl); // frees ctx (if the last one), bios, ...
    if (ssl->rbuf) {
        buf_pool.push(ssl->rbuf);
    }
    ssl_pool.push(ssl);
}