#include "dns.hpp"
#include "poll.hpp"
#if (HAVE_ARES_H)
#include <ares.h> // XXX: might want to use getaddrinfo_a instead? (collect all hostnames in one poll iteration, dispatch, and let poll wakeup by signal thread)
#include <netdb.h>
static ares_channel channel = NULL;
#else
MESSAGE(DNS resolver disabled)
#endif


typedef HashCache<DNSResult::cache_key_t, sockaddr_t> DNSCache;

static bool dns_server_config(void*& ctx, char* value) {
    #if (HAVE_ARES_H)
        if (!channel) {
            log(error, "could not set DNS server(s) to '%s', not enabled", value);
            return false;
        }
        #if ARES_VERSION_MAJOR > 1 || (ARES_VERSION_MAJOR == 1 && ARES_VERSION_MINOR >= 11)
        if (ares_set_servers_ports_csv(channel, value) != ARES_SUCCESS) { // need to dup this?
        #else
        log(notice, "DNS: ignoring server ports, if any");
        if (ares_set_servers_csv(channel, value) != ARES_SUCCESS) {
        #endif
            log(error, "cannot set DNS server(s) to '%s'", value);
            return false;
        }
        ctx = channel; // as marker for the current config value
        return true;
    #else
        log(error, "cannot set DNS servers - disabled");
        return false;
    #endif
}

static void dns_server_unconfig(void*& ctx) {
    #if (HAVE_ARES_H)
        if (!ctx || ctx != channel) return;
        if (ares_set_servers_csv(channel, "") != ARES_SUCCESS) {
            log(error, "cannot clean up DNS server(s)");
        }
        ctx = NULL;
    #endif
}

CONF_DEF(config) {
    ConfigKey dns_server;
    ConfigKey dns_cache;
};
CONF_INIT(config) {
    CONF_KEY_INIT(dns_server, true, false, dns_server_config, dns_server_unconfig);
    CONF_KEY_INST(dns_cache, true, DNSCache, size_t);
}
HOOK(CACHE_CLEAN, HOOK_PRIO_MID) {
    DNSCache* cache = config? config->dns_cache.ctx_as<DNSCache>(): NULL;
    if (cache) {
        cache->invalidate();
    }
}


//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//


DNSResult::DNSResult(): fd(-1), state(TRI_FALSE) {
}


DNSResult::~DNSResult() {
    assert(state != TRI_NONE);
}


void DNSResult::delInst() {
    if (state == TRI_NONE) {
        fd = -1;
    } else {
        delete this;
    }
}


tristate_t DNSResult::get(sockaddr_t& rv) const {
    if (state == TRI_TRUE) {
        rv.family = result.family;
        if (result.family == AF_INET) {
            memcpy(&rv.addr4.sin_addr, &result.addr4.sin_addr, sizeof(result.addr4.sin_addr));
        } else {
            memcpy(&rv.addr6.sin6_addr, &result.addr6.sin6_addr, sizeof(result.addr6.sin6_addr));
        }
    }
    return state;
}


#if !(HAVE_ARES_H)

DNSResult* DNSResult::getInst(const char*, sa_family_t, int) {
    return new DNSResult();
}

#else

DNSResult* DNSResult::getInst(const char* name, sa_family_t af, int fd) {
    if (!channel) return new DNSResult();
    if (strlen(name) >= sizeof(cache_key_t::host)) return new DNSResult();
    DNSResult* rv = new DNSResult(name, af, fd);

    DNSCache* cache = config? config->dns_cache.ctx_as<DNSCache>(): NULL;
    if (cache) {
        const sockaddr_t* hit = cache->at(rv->key);
        if (hit) {
            log(debug, "%s DNS cache hit for '%s'", hit->family? "pos": "neg", name);
            if (hit->family) {
                assert(hit->family == af);
                rv->state = TRI_TRUE;
                rv->result = *hit;
            } else {
                rv->state = TRI_FALSE;
            }
            return rv;
        }
    }

    log(debug, "dispatching DNS lookup for '%s'", name);
    ares_gethostbyname(channel, name/*internally copied*/, AF_INET, &cb, (void*)rv);
    rv->wakeup = true; // prevent parallel wakeup in case cb gets called immediately
    return rv;
}




DNSResult::DNSResult(const char* name, sa_family_t af, int fd): fd(fd), wakeup(false), state(TRI_NONE) {
    assert(strlen(name) < sizeof(key.host));
    key.family = af;
    strncpy(key.host, name, sizeof(key.host));
    bzero(&result, sizeof(result));
}


void DNSResult::cb(void* i, int status, int timeouts, hostent* host) {
    DNSResult* inst = (DNSResult*)i;
    assert(inst->state == TRI_NONE);
    DNSCache* cache = config? config->dns_cache.ctx_as<DNSCache>(): NULL;

    if (status != ARES_SUCCESS) {
        log(notice, "DNS: %s", ares_strerror(status));
        inst->state = TRI_FALSE;
        if (cache) {
            log(debug, "caching negative DNS result for '%s'", inst->key.host);
            cache->get(inst->key)->family = 0;
        }
    } else if (!host || !host->h_addr || host->h_addrtype != inst->key.family || host->h_length != (int)(inst->key.family == AF_INET? sizeof(in_addr): sizeof(in6_addr))){
        log(notice, "DNS: invalid response");
        inst->state = TRI_FALSE;
    } else {
        inst->result.family = inst->key.family;
        memcpy(inst->key.family == AF_INET? (void*)&inst->result.addr4.sin_addr: (void*)&inst->result.addr6.sin6_addr, host->h_addr, host->h_length);
        log(debug, "DNS: resolved to %s", addr2str(inst->result, false));
        inst->state = TRI_TRUE;
        if (cache) {
            log(debug, "caching DNS result for '%s'", inst->key.host);
            *(cache->get(inst->key)) = inst->result;
        }
    }

    if (inst->fd == -1) { // aborted
        delete inst;
    } else if (inst->wakeup) {
        Poll::getInst()->wakeup(inst->fd);
    }
}


//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//\\//


static void dns_event_cb(int fd, event_t ev, unsigned, void*) {
    if (!event_isset(ev, EVENT_IN|EVENT_OUT|EVENT_CLOSE)) return;
    ares_process_fd(channel, event_isset(ev, EVENT_IN|EVENT_CLOSE)? fd: -1, event_isset(ev, EVENT_OUT)? fd: -1);
}


static void dns_socket_cb(void*, int fd, int read, int write) {
    log(io, "DNS: change state fd %d read:%d write:%d", fd, read, write);
    static const int sentinel = 42; // s,t. ctx_get != NULL if already added

    if (!read && !write) {
        (void)Poll::getInst()->del(fd);
    } else {
        event_t ev = EVENT_CLOSE;
        if (read) ev |= EVENT_IN;
        if (write) ev |= EVENT_OUT;

        if (Poll::getInst()->ctx_get(fd)) { // assume its polled yet
            Poll::getInst()->mod(fd, &dns_event_cb, ev, false, (void*)&sentinel);
        } else {
            Poll::getInst()->add(fd, &dns_event_cb, ev, false, (void*)&sentinel);
        }
    }
}


HOOK(INIT, HOOK_PRIO_LATE) {
    int status = ares_library_init(ARES_LIB_INIT_ALL);
    if (status != ARES_SUCCESS){
        log(error, "ares_library_init: %s", ares_strerror(status)); // XXX: too early for proper logging
        return;
    }

    struct ares_options options = {};
    int flags = 0;
    options.sock_state_cb_data = NULL;
    options.sock_state_cb = &dns_socket_cb;
    flags |= ARES_OPT_SOCK_STATE_CB;
    options.timeout = (MAX_SHORT_IDLE * 1000) / 2;
    flags |= ARES_OPT_TIMEOUTMS;
    options.tries = 1;
    flags |= ARES_OPT_TRIES;
    options.flags = ARES_FLAG_PRIMARY|ARES_FLAG_IGNTC; // ARES_FLAG_NOSEARCH?
    flags |= ARES_OPT_FLAGS;

    status = ares_init_options(&channel, &options, flags);
    if (status != ARES_SUCCESS) {
        assert(!channel);
        log(error, "ares_init_options: %s", ares_strerror(status));
        return;
    }
    assert(channel);
}


HOOK(DEINIT, HOOK_PRIO_EARLY) {
    if (channel) {
        ares_cancel(channel); // needed when destroying? TODO: check for side-effects such as woken up callbacks
        ares_destroy(channel);
        channel = NULL;
    }
    ares_library_cleanup();
}


#endif