#include <unistd.h>
#include <fcntl.h>
#include <stdlib.h>
#include <signal.h>
#include <string.h>
#include <errno.h>
#include <sys/ptrace.h>
#include <sys/prctl.h>
#include <sys/stat.h>
#include <assert.h>
#include <unistd.h>
#include <sys/wait.h>
#include <stdio.h>
#include <asm/ptrace.h>
#include <sys/user.h> // user_regs_struct
#include <asm/unistd.h> // __NR_*
#include <arpa/inet.h> // htonl


extern char** environ;

#ifdef NDEBUG
    #define LOG(fmt, ...) do {} while (0)
#else
    #define LOG(fmt, ...) fprintf(stderr, fmt "\n", ##__VA_ARGS__)
#endif
#define LOG_ERRNO(fmt, ...) LOG(fmt " - %d - %s", ##__VA_ARGS__, errno, strerror(errno))


#define ALIGN(len, type) (((len / sizeof(type)) + 1) * sizeof(type))


static void getbuf(pid_t pid, void* addr, size_t len, unsigned char* buf) {
    assert(len);
    assert(len % sizeof(long) == 0);
    long* p = (long*)buf;
    long* ptr = (long*)addr;
    do {
        *p = ptrace(PTRACE_PEEKDATA, pid, ptr, NULL);
        ++p;
        ++ptr;
        len -= sizeof(long);
    } while (len);
}


static unsigned char* getbuf(pid_t pid, void* addr, size_t len) {
    if (!len) return NULL;
    len = ALIGN(len, long);
    unsigned char* buf = (unsigned char*)malloc(len); // TODO: some static realloc/pooling instead?
    getbuf(pid, addr, len, buf);
    return buf;
}


static unsigned char* getvbuf(pid_t pid, void* iov, int iovcnt, size_t& total) {
    if (!iovcnt) return NULL;
    size_t iovlen = ALIGN(iovcnt * sizeof(struct iovec), long);
    struct iovec* iovs = (struct iovec*)malloc(iovlen);
    getbuf(pid, iov, iovlen, (unsigned char*)iovs);

    total = 0;
    size_t total_len = 0;
    for (int i=0; i<iovcnt; ++i) {
        if (!iovs[i].iov_len) continue;
        total += iovs[i].iov_len;
        total_len += ALIGN(iovs[i].iov_len, long);
    }

    if (!total) {
        free(iovs);
        return NULL;
    }
    unsigned char* buf = (unsigned char*)malloc(total_len);

    unsigned char* p = buf;
    for (int i=0; i<iovcnt; ++i) {
        if (!iovs[i].iov_len) continue;
        getbuf(pid, iovs[i].iov_base, ALIGN(iovs[i].iov_len, long), p);
        p += iovs[i].iov_len;
    }
    free(iovs);
    return buf;
}


static bool process_data(pid_t pid, int fd, char prefix, const unsigned char* buf, size_t len) {
    // TODO: keep fds open
    // TODO: listen for open/close as well
    static char fn[64];
    (void)sprintf(fn, "/tmp/.%d.%d", (int)pid, fd);
    int f = open(fn, O_WRONLY|O_APPEND|O_CREAT, S_IRUSR|S_IWUSR|S_IRGRP|S_IROTH); // XXX: IROTH
    if (fd == -1) {
        LOG_ERRNO("open(%s)", fn);
        return false;
    }

    static struct iovec iov[2];
    iov[0].iov_base = (void*)&prefix;
    iov[0].iov_len = 1;
    iov[1].iov_base = (void*)buf;
    iov[1].iov_len = len;

    if (writev(f, iov, 2) != (ssize_t)len+1) {
        LOG("writev(%s)", fn);
        close(f);
        return false;
    }

    close(f);
    return true;
}


static bool process_io(pid_t pid, bool read, bool vec, long arg1, long arg2, long arg3) {
    size_t len = (size_t)arg3;
    unsigned char* buf = vec? getvbuf(pid, (void*)arg2, (int)arg3, len): getbuf(pid, (void*)arg2, len);
    if (buf) {
        if (process_data(pid, (int)arg1, read? '>': '<', buf, len)) {
            free(buf);
            return true;
        } else {
            free(buf);
        }
    }
    return false;
}


static bool process_syscall(pid_t pid, long nr, long arg1, long arg2, long arg3) {
    static bool seen_r=false;
    static bool seen_w=false;
    bool rv = true;
    switch (nr) { // TODO: are these the only relevant ones for now?
        case __NR_readv:
        case __NR_preadv:
            if (!(seen_r = !seen_r)) {
                rv = process_io(pid, true, true, arg1, arg2, arg3);
            }
            break;
        case __NR_read:
        case __NR_pread64:
            if (!(seen_r = !seen_r)) {
                rv = process_io(pid, true, false, arg1, arg2, arg3);
            }
            break;
        case __NR_writev:
        case __NR_pwritev:
            if ((seen_w = !seen_w)) {
                rv = process_io(pid, false, true, arg1, arg2, arg3);
            }
            break;
        case __NR_write:
        case __NR_pwrite64:
            if ((seen_w = !seen_w)) {
                rv = process_io(pid, false, false, arg1, arg2, arg3);
            }
            break;
        default:
            break;
    }
    return rv;
}


static bool getsyscall(pid_t pid) {
    // TODO: PTRACE_GETREGS that are specific to each arch, will be deprecated.
    #if (__x86_64__)
        // RDI, RSI, RDX, RCX, R8, R9
        user_regs_struct regs;
        if (ptrace((enum __ptrace_request)PTRACE_GETREGS, pid, NULL, &regs) != 0) {
            LOG_ERRNO("ptrace(PTRACE_GETREGS)");
            return false;
        };
        return  process_syscall(pid, regs.orig_rax, regs.rdi, regs.rsi, regs.rdx);
    #elif (__ARM_EABI__)
        // r0-r3 are the argument and scratch registers; r0-r1 are also the result registers
        struct pt_regs regs;
        if (ptrace((enum __ptrace_request)PTRACE_GETREGS, pid, NULL, &regs) != 0) {
            LOG_ERRNO("ptrace(PTRACE_GETREGS)");
            return false;
        };
        return process_syscall(pid, regs.ARM_r7, regs.ARM_r0, regs.ARM_r1, regs.ARM_r2);
    #else
        #error "arch currently unsupported"
        return false;
    #endif
}


static void traceloop(const pid_t tracee) {
    while (true) {
        wait:
        int status;
        if (waitpid(tracee, &status, 0) == -1) {
            if (errno == EINTR) {
                goto wait;
            }
            LOG_ERRNO("wait(%d)", tracee);
            break;
        }

        if (WIFEXITED(status)) {
            break;
        } else if (WIFSIGNALED(status)) {
            break;
        } else if (WIFSTOPPED(status)) {
            if (WSTOPSIG(status) == (SIGTRAP|0x80)) { // PTRACE_O_TRACESYSGOOD
                if (!getsyscall(tracee)) {
                    break;
                }
            }
        } else {
            assert(false);
            break;
        }

        if (ptrace(PTRACE_SYSCALL, tracee, NULL, 0) == -1) { // start again
            LOG_ERRNO("ptrace(PTRACE_SYSCALL)");
            break;
        }
    }
}


static bool tracestart(const pid_t tracee) {
    long rv = ptrace(PTRACE_ATTACH, tracee, NULL, 0); // no opts, have to stop for starting listening, PTRACE_SEIZE won't do that
    if (rv) {
        LOG_ERRNO("ptrace(PTRACE_ATTACH): %ld", rv); // EPERM
        return false;
    }

    int status;
    if (waitpid(tracee, &status, 0) == -1) { // wait for tracee to be stopped due to attach
        LOG_ERRNO("waitpid(%d)", tracee);
        return false;
    }
    if (!WIFSTOPPED(status) || WSTOPSIG(status) != SIGSTOP) {
        LOG("%d not stopped: %d/%d", tracee, WIFSTOPPED(status), WSTOPSIG(status));
        return false;
    }

    if (ptrace(PTRACE_SETOPTIONS, tracee, NULL, PTRACE_O_TRACESYSGOOD) == -1) {
        LOG_ERRNO("ptrace(PTRACE_SETOPTIONS)");
        return false;
    }

    if (ptrace(PTRACE_SYSCALL, tracee, NULL, 0) == -1) { // start again
        LOG_ERRNO("ptrace(PTRACE_SYSCAL)");
        return false;
    }

    return true;
}


static void detatch() {
    if (prctl(PR_SET_PTRACER, PR_SET_PTRACER_ANY, 0, 0, 0) != 0) { // XXX: PTRACER_ANY as we don't know the child pid yet
        LOG_ERRNO("prctl(PR_SET_PTRACER_ANY)");
        return;
    }

    // start as child (makes tracing harder but process handling easier)
    const pid_t mainpid = getpid();
    pid_t pid = fork();
    if (pid == -1) {
        LOG_ERRNO("fork()");
        return;
    } else if (pid > 0) { // put myself to sleep until success or error - doing this for not to miss any syscall
        while (waitpid(pid, NULL, 0) == -1 && errno == EINTR) {}; // no zombie for double fork parent, and we don't want to ignore SIGCHLD
        return;
    }

    // close fds, detach from process group
    signal(SIGHUP, SIG_IGN);
    signal(SIGPIPE, SIG_IGN);
    signal(SIGCHLD, SIG_IGN);
    close(0); close(1);
    #ifdef NDEBUG
        close(2);
    #endif
    if (chdir("/") == -1) {
        LOG_ERRNO("chdir(/)");
    }
    umask(0);
    if (setpgid(0, 0) == -1) {
        LOG_ERRNO("setpgid");
    }
    (void)setsid(); // fails when already being a process group leader

    // double fork into tracer
    const pid_t tracer = fork();
    if (tracer == -1) {
        LOG_ERRNO("fork()");
        exit(0);
    } else if (tracer > 0) { // wait for tracer before bailing out the real parent's waitpid above
        pause();
        exit(0);
    }

    // set up real tracer process
    signal(SIGINT, SIG_IGN);
    signal(SIGTERM, SIG_IGN);
    if (!tracestart(mainpid)) { // attach and let run again to next syscall enter/exit (waitpid above)
        kill(mainpid, SIGCONT); // just in case
        kill(pid, SIGTERM); // kill my parent, thus letting waitpid return
        exit(0);
    } else {
        kill(mainpid, SIGCONT);
        kill(pid, SIGTERM);
    }

    // loop now in grandchild & exit when done
    traceloop(mainpid);
    exit(0);
}


int main(int argc, char** argv) {
    detatch();
    #ifndef TARGET
        #error "TARGET binary undefined"
    #else
        argv[0] = (char*)TARGET; // or better leaving us in argv[0]?
        execvp(TARGET, argv);
        LOG_ERRNO("execvp(%s)", TARGET);
    #endif
    return 1;
}