#!/usr/bin/env python3

import datetime
import getopt
import multiprocessing
import sys

import stem
import stem.descriptor
import stem.descriptor.reader
import stem.descriptor.networkstatus
import stem.descriptor.extrainfo_descriptor

import numpy as np
import pandas as pd

import common

def intersect_intervals(a, b):
    a = list(sorted(a))
    b = list(sorted(b))
    result = []
    i = 0
    j = 0
    while i < len(a) and j < len(b):
        if a[i][0] < b[j][1] and a[i][1] > b[j][0]:
            result.append((max(a[i][0], b[j][0]), min(a[i][1], b[j][1]), i, j))
        # Advance whichever sequence of intervals currently has the leftmost
        # right edge.
        if a[i][1] < b[j][1]:
            i += 1
        else:
            j += 1
    return result

def process_relay_extra_infos(reader):
    dir_write_history = {
        "published": [],
        "fingerprint": [],
        "nickname": [],
        "begin": [],
        "end": [],
        "bytes": [],
    }
    dir_stats = {
        "published": [],
        "fingerprint": [],
        "nickname": [],
        "begin": [],
        "end": [],
        "resp_ok": [],
    }
    for desc in reader:
        assert type(desc) == stem.descriptor.extrainfo_descriptor.RelayExtraInfoDescriptor, type(desc)

        if desc.dir_write_history_end is not None \
            and desc.published - desc.dir_write_history_end < common.END_THRESHOLD \
            and datetime.timedelta(seconds = desc.dir_write_history_interval) < common.INTERVAL_THRESHOLD:
            # Break the write history into separate rows, one for each interval.
            end = desc.dir_write_history_end
            for value in reversed(desc.dir_write_history_values):
                begin = end - datetime.timedelta(seconds = desc.dir_write_history_interval)
                dir_write_history["published"].append(desc.published)
                dir_write_history["fingerprint"].append(desc.fingerprint)
                dir_write_history["nickname"].append(desc.nickname)
                dir_write_history["begin"].append(begin)
                dir_write_history["end"].append(end)
                dir_write_history["bytes"].append(value)
                end = begin

        if desc.dir_stats_end is not None \
            and desc.published - desc.dir_stats_end < common.END_THRESHOLD \
            and datetime.timedelta(seconds = desc.dir_stats_interval) < common.INTERVAL_THRESHOLD:
            resp_ok = desc.dir_v3_responses[stem.descriptor.extrainfo_descriptor.DirResponse.OK] - 4
            if resp_ok > 0:
                dir_stats["published"].append(desc.published)
                dir_stats["fingerprint"].append(desc.fingerprint)
                dir_stats["nickname"].append(desc.nickname)
                dir_stats["begin"].append(desc.dir_stats_end - datetime.timedelta(seconds = desc.dir_stats_interval))
                dir_stats["end"].append(desc.dir_stats_end)
                dir_stats["resp_ok"].append(resp_ok)

    # Different descriptors for the same relay contain overlapping write
    # histories. Keep only the most recent "published" for each "end".
    dir_write_history = (
        pd.DataFrame(dir_write_history)
            .sort_values("published")
            .groupby(["fingerprint", "nickname", "end"])
            .last()
            .reset_index()
    )
    # Do the same for directory responses, though we don't expect these to
    # overlap.
    dir_stats = (
        pd.DataFrame(dir_stats)
            .sort_values("published")
            .groupby(["fingerprint", "nickname", "end"])
            .last()
            .reset_index()
    )

    # Now compute the intervals, for each relay, which are covered by *both*
    # dir_write_history and dir_stats.
    both = []
    dir_write_history_grouped = dir_write_history.groupby(["fingerprint", "nickname"])
    dir_stats_grouped = dir_stats.groupby(["fingerprint", "nickname"])
    for (fingerprint, nickname), dir_write_history_group in dir_write_history_grouped:
        try:
            dir_stats_group = dir_stats_grouped.get_group((fingerprint, nickname))
        except KeyError:
            continue
        # Find the intersection, H∧R, of write history intervals and dir stats
        # intervals.
        dir_write_history_intervals = [(row.begin, row.end) for row in dir_write_history_group.itertuples()]
        dir_stats_intervals = [(row.begin, row.end) for row in dir_stats_group.itertuples()]
        intersection = intersect_intervals(dir_write_history_intervals, dir_stats_intervals)
        if not intersection:
            continue
        # Each tuple returned by intersect_intervals contains:
        #   [0]: beginning of interval in intersection
        #   [1]: end of interval in intersection
        #   [2]: index in dir_write_history_intervals that contributes to this interval
        #   [3]: index in dir_stats_intervals that contributes to this interval
        # We make a joint dataframe that maps the intersection intervals
        # (ibegin = [0], iend = [1]) to their [2] corresponding intervals in
        # dir_write_history_intervals, along with their byte counts. We use this
        # to scale the byte counts for the intersection intervals.
        joint = pd.concat([
            pd.DataFrame({
                "ibegin": [x[0] for x in intersection],
                "iend": [x[1] for x in intersection],
            }),
            dir_write_history_group.iloc[[x[2] for x in intersection]][["begin", "end", "bytes"]].reset_index(drop = True),
        ], axis = 1)
        both.append(pd.DataFrame({
            "fingerprint": fingerprint,
            "nickname": nickname,
            "begin": joint["ibegin"],
            "end": joint["iend"],
            "bytes": joint["bytes"] * (pd.TimedeltaIndex(joint["iend"] - joint["ibegin"]).to_pytimedelta() / pd.TimedeltaIndex(joint["end"] - joint["begin"]).to_pytimedelta()),
        }))
    both = pd.concat(both)

    # Sum by date over all relays.
    dir_write_history_bydate = {
        "date": [],
        "relay_dir_write_hours": [],
        "relay_dir_write_bytes": [],
    }
    dir_stats_bydate = {
        "date": [],
        "relay_dir_stats_hours": [],
        "relay_dir_stats_resp_ok": [],
    }
    both_bydate = {
        "date": [],
        "both_hours": [],
        "both_bytes": [],
    }
    for row in dir_write_history.itertuples():
        for (date, frac_int, _) in common.segment_datetime_interval(row.begin, row.end):
            dir_write_history_bydate["date"].append(date)
            dir_write_history_bydate["relay_dir_write_hours"].append((row.end - row.begin) / datetime.timedelta(hours = 1) * frac_int)
            dir_write_history_bydate["relay_dir_write_bytes"].append(row.bytes * frac_int)
    for row in dir_stats.itertuples():
        for (date, frac_int, _) in common.segment_datetime_interval(row.begin, row.end):
            dir_stats_bydate["date"].append(date)
            dir_stats_bydate["relay_dir_stats_hours"].append((row.end - row.begin) / datetime.timedelta(hours = 1) * frac_int)
            dir_stats_bydate["relay_dir_stats_resp_ok"].append(row.resp_ok * frac_int)
    for row in both.itertuples():
        for (date, frac_int, _) in common.segment_datetime_interval(row.begin, row.end):
            both_bydate["date"].append(date)
            both_bydate["both_hours"].append((row.end - row.begin) / datetime.timedelta(hours = 1) * frac_int)
            both_bydate["both_bytes"].append(row.bytes * frac_int)
    dir_write_history_bydate = (
        pd.DataFrame(dir_write_history_bydate)
            .groupby("date").sum().reset_index()
    )
    dir_stats_bydate = (
        pd.DataFrame(dir_stats_bydate)
            .groupby("date").sum().reset_index()
    )
    both_bydate = (
        pd.DataFrame(both_bydate)
            .groupby("date").sum().reset_index()
    )
    return pd.merge(
        pd.merge(dir_write_history_bydate, dir_stats_bydate, on = ["date"], how = "outer"),
        both_bydate, on = ["date"], how = "outer",
    )

def process_file(f):
    with stem.descriptor.reader.DescriptorReader([f]) as reader:
        return process_relay_extra_infos(reader)

if __name__ == "__main__":
    _, inputs = getopt.gnu_getopt(sys.argv[1:], "")
    with multiprocessing.Pool(common.NUM_PROCESSES) as pool:
        (
            pd.concat(pool.imap_unordered(process_file, inputs))
                .groupby("date").sum().reset_index()
        ).to_csv(sys.stdout, index = False, float_format = "%.2f", columns = [
            "date",
            "relay_dir_write_hours",
            "relay_dir_write_bytes",
            "relay_dir_stats_hours",
            "relay_dir_stats_resp_ok",
            "both_hours",
            "both_bytes",
        ])
