#!/usr/bin/env python

import sys, re


def parse(l):
    """parse capture output into object representation."""
    rv = {}
    pat = re.compile("^[ >]+([^:])+:([0-9]+) -> ([^:]+):([0-9]+) ([0-9]+) ([ch])t: *([0-9]+), *len: *([0-9]+)(,(.*))?$")
    mat = pat.match(l)
    if mat:
        rv['src'] = mat.group(1)
        rv['src_port'] = int(mat.group(2))
        rv['dst'] = mat.group(3)
        rv['dst_port'] = int(mat.group(4))
        rv['ts'] = int(mat.group(5))
        rv['is_handshake'] = (mat.group(6) == 'h')
        rv['type'] = int(mat.group(7))
        rv['len'] = int(mat.group(8))
        if mat.group(10):
            for va in mat.group(10).split(","):
                va = va.split(":")
                rv[va[0].strip()] = va[1].strip()
    return rv


def post(o):
    """validate and postprocess captured message objects."""
    if not o:
        return False

    if o['src_port'] != 443 and o['dst_port'] != 443:
        return False
    if o['src_port'] == o['dst_port']:
        return False
    o['id'] = o['src_port'] if o['src_port'] != 443 else o['dst_port']
    o['is_req'] = (o['dst_port'] == 443)

    if 'cs' in o:
        o['cs'] = int(o['cs'], 16)
    if 'comp' in o:
        o['comp'] = int(o['comp'], 16)

    return True


def est_len(cipher, len):
    """estimate the underlying plaintext length given a cipher and ciphertext length."""
    if cipher in [0xc02c, 0xc030, 0xc02b, 0xc02f]:
        # ECDHE-ECDSA-AES256-GCM-SHA384, 256 Bit, Key exchange: ECDH, encryption: AES, MAC: SHA384
        # ECDHE-RSA-AES256-GCM-SHA384,   256 Bit, Key exchange: ECDH, encryption: AES, MAC: SHA384
        # ECDHE-ECDSA-AES128-GCM-SHA256, 128 Bit, Key exchange: ECDH, encryption: AES, MAC: SHA256
        # ECDHE-RSA-AES128-GCM-SHA256,   128 Bit, Key exchange: ECDH, encryption: AES, MAC: SHA256
        # 8 counter + 16, no padding, no per-record sha mac
        # https://tools.ietf.org/html/rfc5246#section-6.2
        return (len - 24) if (len > 24) else 0
    else:
        sys.stderr.write("cipher %04x not supported\n" % cipher)
        return len


# parse captured input messages into per-src-port sessions
sessions = {};
for line in sys.stdin:
    o = parse(line)
    if not post(o):
        sys.stderr.write("cannot parse '%s'\n" % line.strip())
        continue
    if o['id'] in sessions:
        assert o['ts'] >= sessions[o['id']]['list'][-1]['ts'] # ascending but not unique
        sessions[o['id']]['list'].append(o)
    else:
        sessions[o['id']] = {
            'list': [o],
            'ts': o['ts'] # first one, assume its sorted by time
        }


# check sessions for handshakes indicating sni and cipher suite
for id, session in sessions.items():
    for m in session['list']:
        if not m['is_handshake']:
            continue
        if m['type'] == 1: # client helo
            if 'sni' in m:
                session['sni'] = m['sni']
        elif m['type'] == 2: # server helo
            assert 'cs' in m and 'comp' in m
            session['cs'] = m['cs']
            session['comp'] = m['comp']
            break

    if not 'cs' in session:
        sys.stderr.write("no handshake/cipher found\n")
        continue
    if 'comp' in session and session['comp']:
        sys.stderr.write("compression enabled\n")
        continue
    session['valid'] = True


# collect request<->response pairs
requests = []
for id, session in sessions.items():
    if not 'valid' in session or not session['valid']:
        continue
    for m in session['list']:
        if m['is_handshake'] or m['type'] != 23:
            continue # application data only
        m['est_len'] = est_len(session['cs'], m['len'])
        if m['is_req']: # assume theres only one
            requests.append({
                'session': id,
                'req': m,
                'res': [],
                'res_len': 0,
                'res_est_len': 0
            })
        elif len(requests):
            requests[-1]['res'].append(m)
            requests[-1]['res_len'] += m['len'];
            requests[-1]['res_est_len'] += m['est_len'];
    session['list'] = []


# sort by request timestamp
if not len(requests):
    sys.stderr.write("no requests found\n")
    sys.exit(1)
requests.sort(key= lambda o: o['req']['ts'])
start_ts = requests[0]['req']['ts']


# done: print out results
print("#%9s | %10s | %10s | %10s | %10s | %s" % ("req_t", "res_t", "end_t", "req_len", "res_len", "sni"))
for r in requests:
    s = sessions[r['session']]
    print("%10d | %10d | %10d | %10d | %10d | %s" % (
        r['req']['ts'] - start_ts,
        (r['res'][0]['ts'] - start_ts) if len(r['res']) else 0,
        (r['res'][-1]['ts'] - start_ts) if len(r['res']) else 0,
        r['req']['est_len'],
        r['res_est_len'],
        s['sni'] if ('sni' in s and len(s['sni'])) else "-"
    ))