#include "Arduino.h"
#include "ir_nec_4.h"
#include <avr/io.h>
#include <avr/interrupt.h>

#define SRATE 50 // interrupt sampling rate in microseconds
#define FUZZYNESS 150 // allowed measurement deviation in microseconds
#define TICK_INTERVAL(t, us) ((t) >= ((us)-FUZZYNESS)/SRATE && (t) <= ((us)+FUZZYNESS)/SRATE)

#define SAMPLES_MAX 100 // number of saved measurements (100 should be sufficient for a single keypress)
static volatile int samples_num = 0;
static volatile unsigned samples[SAMPLES_MAX] = {};
static volatile int lastread_val = HIGH;
static volatile unsigned lastswitch_tick = 0;
static int PIN = -1;

ISR(TIMER4_COMPA_vect) {
    if (digitalRead(PIN) == lastread_val) { // no change: increase the current value's tick duration
        ++lastswitch_tick;
    } else { // edge triggered: pin state has changed in either direction
        if (samples_num < SAMPLES_MAX) {
            samples[samples_num++] = lastswitch_tick; // TODO: use ring-buffer instead?
        }
        lastread_val ^= 1;
        lastswitch_tick = 0; // start counting tick duration from scratch
    }
}

IR_NEC_4::IR_NEC_4(int p): lastkey(0) {
    PIN = p;
}

void IR_NEC_4::enable() {
    pinMode(PIN, INPUT);

    // reset some unused 16bit timer, i.e. 4 on arduino mega (controlling also PWM pins 6, 7, 8)
    // TODO: support multiple timers to choose from
    // TODO: support 8 bit timers (with prescale and/or additional counter)
    noInterrupts();
    TCCR4A = 0;
    TCCR4B = 0;

    // start at 0 in CTC mode (Clear Timer on Compare Match)
    TCNT4 = 0;
    TCCR4B |= (1 << WGM12);

    // 16MHz w/o prescale -> 1/16us -> 800 for 50us
    TCCR4B |= (1 << CS10); // no prescale (1)
    OCR4A = (F_CPU / 1000000UL) * SRATE; // Output Compare Register

    // start Output Compare Interrupt
    TIMSK4 |= (1 << OCIE4A);
    interrupts();

    #ifdef IR_NEC_4_DEBUG
        Serial.println("using timer 4 for IR interrupt");
    #endif
}


byte IR_NEC_4::decode() {
    // fetch & reset collected samples
    int n = samples_num;
    if (n < 3 + (32*2)) {
        return 0; // whole keypress not possible yet
    }
    unsigned s[SAMPLES_MAX];
    memcpy(s, (void*)samples, n*sizeof(unsigned));
    lastread_val = HIGH;
    samples_num = 0;
    #ifdef IR_NEC_4_DEBUG
        Serial.print("trying to decode "); Serial.print(n); Serial.println(" IR samples");
    #endif

    // find NEC start sequence: 9000, 4500
    int start = 0;
    again:
    if (start >= n-1-(32*2)) return 0;
    for (int i=start; i<n-1-(32*2); ++i) {
        if (TICK_INTERVAL(s[i], 9000)) {
            start = i;
            break;
        }
    }
    #ifdef IR_NEC_4_DEBUG
        Serial.print("found NEC starting sequence at "); Serial.println(start);
    #endif
    if (!start) {
        return 0; // cannot be the very first one
    } else if (TICK_INTERVAL(s[start+1], 4500)) {
        start += 2; // found
    } else if (TICK_INTERVAL(s[start+1], 2250)) {
        start++;
        goto again; // TODO: repeat codes skipped
    } else {
        return 0;
    }

    // binary decode
    union {
        unsigned long w;
        byte b[4];
    } code = {};
    unsigned code_bits = 0;
    for (int i=start; i<n; i+=2) {
        if (!TICK_INTERVAL(s[i], 550)) {
            break;
        }
        if (TICK_INTERVAL(s[i+1], 550)) { // 0: 562.5µs pulse burst followed by a 562.5µs space
            code.w = code.w << 1;
            code_bits++;
        } else if (TICK_INTERVAL(s[i+1], 1600)) { // 1: 562.5µs pulse burst followed by a 1.6875ms space
            code.w = (code.w << 1) | 1;
            code_bits++;
        } else {
            break;
        }
    }

    // sanity checks & done
    #ifdef IR_NEC_4_DEBUG
        Serial.print(code_bits, DEC); Serial.println(" code bits found");
    #endif
    if (code_bits != 32) return 0;
    #ifdef IR_NEC_4_DEBUG
        Serial.print("decoded to 0x"); Serial.println(code.w, HEX);
    #endif
    if (code.b[3] != 0x00 || code.b[2] != 0xff) return 0;
    if (code.b[1] & 0xff != ~code.b[0] & 0xff) return 0;
    #ifdef IR_NEC_4_DEBUG
        Serial.print("success - keycode: 0x"); Serial.println(code.b[1], HEX);
    #endif
    return code.b[1];
}

boolean IR_NEC_4::checkResults() {
    return ((lastkey ?: (lastkey = decode())) != 0);
}

byte IR_NEC_4::getResults() {
    byte rv = lastkey;
    if (rv) {
        lastkey = 0;
    } else {
        rv = decode();
    }
    return rv;
}