import argparse
import sys
from pathlib import Path

import numpy as np
import sklearn

from .features import Features
from .log import run, Status
from .utils import add_args, CSVReader, CSVDictReader, CSVDictWriter


class Kmeans:
    def __init__(self, dim: int, epochs: int) -> None:
        self._encoder = sklearn.cluster.KMeans(n_clusters=dim, max_iter=epochs, random_state=0)

    def fit(self, data: np.ndarray) -> tuple[float, list[int]]:
        self._encoder.fit(data)
        return self._encoder.inertia_, self._encoder.labels_.tolist()


class Agglomerative:
    def __init__(self, dim: int) -> None:
        # euclidean, as distance is how we build up and interpret the dimensional clustering
        self._encoder = sklearn.cluster.AgglomerativeClustering(n_clusters=dim, metric="euclidean", linkage="ward")

    def fit(self, data: np.ndarray) -> tuple[int, list[int]]:
        self._encoder.fit(data)
        return self._encoder.n_connected_components_, self._encoder.labels_.tolist()


class Supervised:
    def __init__(self, dim: int, labels: list[int]) -> None:
        self._labels: list[int] = labels
        self._encoder = sklearn.cluster.AgglomerativeClustering(n_clusters=dim, metric="precomputed", linkage="complete")

    @classmethod
    def _dist(cls, x: np.ndarray, y: np.ndarray, data: np.ndarray, labels: list[int], bias: float) -> np.floating:
        dist = np.linalg.norm(data[x.item()] - data[y.item()])  # np.sqrt(np.sum(np.square()))
        lbl: bool = labels[x.item()] == labels[y.item()]
        return dist * bias if lbl else dist

    @classmethod
    def compute_weights(cls, data: np.ndarray, labels: list[int], bias: float) -> np.ndarray:
        # NB: is symmetric, so half the matrix would be sufficient
        assert 0.0 <= bias <= 1.0
        indices = np.asarray([[_] for _ in range(len(data))])
        return sklearn.metrics.pairwise_distances(indices, metric=cls._dist, data=data, labels=labels, bias=1.0 - bias)

    def fit_weights(self, dist: np.ndarray):
        self._encoder.fit(dist)
        return self._encoder.n_connected_components_, self._encoder.labels_.tolist()

    def fit(self, data: np.ndarray, bias: float) -> tuple[int, list[int]]:
        self._encoder.fit(self.compute_weights(data, self._labels, bias))
        return self._encoder.n_connected_components_, self._encoder.labels_.tolist()


class Scoring:
    def __init__(self, reference: dict[int, int]) -> None:
        self._reference: dict[int, int] = reference

    def score(self, prediction: dict[int, int]) -> tuple[float, float, float, float, float]:
        assert len(prediction) == len(self._reference)
        reference = [self._reference[i] for i in self._reference.keys()]
        labels = [prediction[i] for i in self._reference.keys()]
        return (sklearn.metrics.adjusted_rand_score(reference, labels),
                sklearn.metrics.adjusted_mutual_info_score(reference, labels),
                sklearn.metrics.homogeneity_score(reference, labels),
                sklearn.metrics.completeness_score(reference, labels),
                sklearn.metrics.fowlkes_mallows_score(reference, labels))


def run_main(status: Status, *,
             out_dir: Path,
             epoch_limit: int,
             genre_bias: float,
             extract_features: Path,
             normed_features: Path,
             scaled_features: Path,
             cluster_features: Path,
             kmeans_stats: Path,
             kmeans_clusters: Path,
             **kwargs) -> None:
    status.start_progress("Reading features")

    with CSVDictReader(out_dir / extract_features) as reader:
        rows: list[Features] = list(Features.from_dict(row) for row in status.task_pending_counter(reader.read()))
        genres: list[str] = sorted(set(_.tags.genre for _ in rows))
        labels: dict[int, int] = {_.file.index: genres.index(_.tags.genre) for _ in rows}

    with CSVDictReader(out_dir / cluster_features) as reader:
        _clusters: dict[int, dict[str, float | list[float]]] = {int(row["index"]): {
            k: list(map(float, v.split(","))) if "," in v else float(v)
            for k, v in row.items()
        } for row in status.task_pending_counter(reader.read())}
        indices: list[int] = list(_clusters.keys())
        pca_coords: np.ndarray = np.asarray([[_["pca30"], _["pca31"], _["pca32"]] for _ in _clusters.values()])
        enc_coords: np.ndarray = np.asarray([[_["enc30"], _["enc31"], _["enc32"]] for _ in _clusters.values()])
        pca_values: np.ndarray = np.asarray([_["pca05"] for _ in _clusters.values()])
        enc_values: np.ndarray = np.asarray([_["enc05"] for _ in _clusters.values()])
        encoded_genres: list[int] = [labels[_] for _ in _clusters.keys()]

    with CSVReader(out_dir / scaled_features) as csv_reader:
        scaled_coords: dict[int, np.ndarray] = {int(row[0]): np.asarray([float(_) for _ in row[1:]]) for row in status.task_pending_counter(csv_reader.read())}
        scaled_values: np.ndarray = np.asarray(list(scaled_coords.values()))
        scaled_genres: list[int] = [labels[_] for _ in scaled_coords.keys()]

    with CSVReader(out_dir / normed_features) as csv_reader:
        normed_coords: dict[int, np.ndarray] = {int(row[0]): np.asarray([float(_) for _ in row[1:]]) for row in status.task_pending_counter(csv_reader.read())}
        normed_values: np.ndarray = np.asarray(list(normed_coords.values()))
        normed_genres: list[int] = [labels[_] for _ in normed_coords.keys()]

    status.finish_progress()
    dimensions: list[int] = list(range(2, 21))

    status.start_progress("Running kmeans", 6 * len(dimensions))
    km_results_pca_3d: dict[int, tuple[float, list[int]]] = {dim: status.task_done_wrap(1, Kmeans(dim, epoch_limit).fit(pca_coords)) for dim in dimensions}
    km_results_enc_3d: dict[int, tuple[float, list[int]]] = {dim: status.task_done_wrap(1, Kmeans(dim, epoch_limit).fit(enc_coords)) for dim in dimensions}
    km_results_pca_values: dict[int, tuple[float, list[int]]] = {dim: status.task_done_wrap(1, Kmeans(dim, epoch_limit).fit(pca_values)) for dim in dimensions}
    km_results_enc_values: dict[int, tuple[float, list[int]]] = {dim: status.task_done_wrap(1, Kmeans(dim, epoch_limit).fit(enc_values)) for dim in dimensions}
    km_results_scaled: dict[int, tuple[float, list[int]]] = {dim: status.task_done_wrap(1, Kmeans(dim, epoch_limit).fit(scaled_values)) for dim in dimensions}
    km_results_normed: dict[int, tuple[float, list[int]]] = {dim: status.task_done_wrap(1, Kmeans(dim, epoch_limit).fit(normed_values)) for dim in dimensions}

    status.start_progress("Running agglomerative", 6 * len(dimensions))
    agg_results_pca_3d: dict[int, tuple[int, list[int]]] = {dim: status.task_done_wrap(1, Agglomerative(dim).fit(pca_coords)) for dim in dimensions}
    agg_results_enc_3d: dict[int, tuple[int, list[int]]] = {dim: status.task_done_wrap(1, Agglomerative(dim).fit(enc_coords)) for dim in dimensions}
    agg_results_pca_values: dict[int, tuple[int, list[int]]] = {dim: status.task_done_wrap(1, Agglomerative(dim).fit(pca_values)) for dim in dimensions}
    agg_results_enc_values: dict[int, tuple[int, list[int]]] = {dim: status.task_done_wrap(1, Agglomerative(dim).fit(enc_values)) for dim in dimensions}
    agg_results_scaled: dict[int, tuple[int, list[int]]] = {dim: status.task_done_wrap(1, Agglomerative(dim).fit(scaled_values)) for dim in dimensions}
    agg_results_normed: dict[int, tuple[int, list[int]]] = {dim: status.task_done_wrap(1, Agglomerative(dim).fit(normed_values)) for dim in dimensions}

    status.start_progress("Running hierarchical", 6 * len(dimensions))
    dist = Supervised.compute_weights(pca_coords, encoded_genres, genre_bias)
    lbl_results_pca_3d: dict[int, tuple[int, list[int]]] = {dim: status.task_done_wrap(1, Supervised(dim, encoded_genres).fit_weights(dist)) for dim in dimensions}
    dist = Supervised.compute_weights(enc_coords, encoded_genres, genre_bias)
    lbl_results_enc_3d: dict[int, tuple[int, list[int]]] = {dim: status.task_done_wrap(1, Supervised(dim, encoded_genres).fit_weights(dist)) for dim in dimensions}
    dist = Supervised.compute_weights(pca_values, encoded_genres, genre_bias)
    lbl_results_pca_values: dict[int, tuple[int, list[int]]] = {dim: status.task_done_wrap(1, Supervised(dim, encoded_genres).fit_weights(dist)) for dim in dimensions}
    dist = Supervised.compute_weights(enc_values, encoded_genres, genre_bias)
    lbl_results_enc_values: dict[int, tuple[int, list[int]]] = {dim: status.task_done_wrap(1, Supervised(dim, encoded_genres).fit_weights(dist)) for dim in dimensions}
    dist = Supervised.compute_weights(scaled_values, scaled_genres, genre_bias)
    lbl_results_scaled: dict[int, tuple[int, list[int]]] = {dim: status.task_done_wrap(1, Supervised(dim, scaled_genres).fit_weights(dist)) for dim in dimensions}
    dist = Supervised.compute_weights(normed_values, normed_genres, genre_bias)
    lbl_results_normed: dict[int, tuple[int, list[int]]] = {dim: status.task_done_wrap(1, Supervised(dim, normed_genres).fit_weights(dist)) for dim in dimensions}

    status.start_progress("Getting scores", len(dimensions) * 3)
    scoring = Scoring(labels)

    km_scores: dict[int, dict[str, tuple[float, float, float, float, float]]] = {}
    for dim in dimensions:
        km_scores[dim] = {
            "pca_3d": scoring.score({idx: km_results_pca_3d[dim][1][i] for i, idx in enumerate(indices)}),
            "enc_3d": scoring.score({idx: km_results_enc_3d[dim][1][i] for i, idx in enumerate(indices)}),
            "pca_values": scoring.score({idx: km_results_pca_values[dim][1][i] for i, idx in enumerate(indices)}),
            "enc_values": scoring.score({idx: km_results_enc_values[dim][1][i] for i, idx in enumerate(indices)}),
            "scaled": scoring.score({idx: km_results_scaled[dim][1][i] for i, idx in enumerate(indices)}),
            "normed": scoring.score({idx: km_results_normed[dim][1][i] for i, idx in enumerate(indices)}),
        }
        status.task_done()

    agg_scores: dict[int, dict[str, tuple[float, float, float, float, float]]] = {}
    for dim in dimensions:
        agg_scores[dim] = {
            "pca_3d": scoring.score({idx: agg_results_pca_3d[dim][1][i] for i, idx in enumerate(indices)}),
            "enc_3d": scoring.score({idx: agg_results_enc_3d[dim][1][i] for i, idx in enumerate(indices)}),
            "pca_values": scoring.score({idx: agg_results_pca_values[dim][1][i] for i, idx in enumerate(indices)}),
            "enc_values": scoring.score({idx: agg_results_enc_values[dim][1][i] for i, idx in enumerate(indices)}),
            "scaled": scoring.score({idx: agg_results_scaled[dim][1][i] for i, idx in enumerate(indices)}),
            "normed": scoring.score({idx: agg_results_normed[dim][1][i] for i, idx in enumerate(indices)}),
        }
        status.task_done()

    lbl_scores: dict[int, dict[str, tuple[float, float, float, float, float]]] = {}
    for dim in dimensions:
        lbl_scores[dim] = {
            "pca_3d": scoring.score({idx: lbl_results_pca_3d[dim][1][i] for i, idx in enumerate(indices)}),
            "enc_3d": scoring.score({idx: lbl_results_enc_3d[dim][1][i] for i, idx in enumerate(indices)}),
            "pca_values": scoring.score({idx: lbl_results_pca_values[dim][1][i] for i, idx in enumerate(indices)}),
            "enc_values": scoring.score({idx: lbl_results_enc_values[dim][1][i] for i, idx in enumerate(indices)}),
            "scaled": scoring.score({idx: lbl_results_scaled[dim][1][i] for i, idx in enumerate(indices)}),
            "normed": scoring.score({idx: lbl_results_normed[dim][1][i] for i, idx in enumerate(indices)}),
        }
        status.task_done()

    status.start_progress("Writing clusters", len(dimensions) + len(indices))
    with CSVDictWriter(out_dir / kmeans_stats, ["dim"] + [f"{a}_{b}_{c}"
                                                          for a in ["ars", "ami", "homo", "complete", "fowlkes"]
                                                          for b in ["km", "agg", "lbl"]
                                                          for c in ["pca_3d", "enc_3d", "pca_values", "enc_values", "scaled", "normed"]]) as writer:
        writer.write_all(
            {"dim": dim} |
            {f"ars_km_{t}": km_scores[dim][t][0] for t in ["pca_3d", "enc_3d", "pca_values", "enc_values", "scaled", "normed"]} |
            {f"ami_km_{t}": km_scores[dim][t][1] for t in ["pca_3d", "enc_3d", "pca_values", "enc_values", "scaled", "normed"]} |
            {f"homo_km_{t}": km_scores[dim][t][2] for t in ["pca_3d", "enc_3d", "pca_values", "enc_values", "scaled", "normed"]} |
            {f"complete_km_{t}": km_scores[dim][t][3] for t in ["pca_3d", "enc_3d", "pca_values", "enc_values", "scaled", "normed"]} |
            {f"fowlkes_km_{t}": km_scores[dim][t][4] for t in ["pca_3d", "enc_3d", "pca_values", "enc_values", "scaled", "normed"]} |
            {f"ars_agg_{t}": agg_scores[dim][t][0] for t in ["pca_3d", "enc_3d", "pca_values", "enc_values", "scaled", "normed"]} |
            {f"ami_agg_{t}": agg_scores[dim][t][1] for t in ["pca_3d", "enc_3d", "pca_values", "enc_values", "scaled", "normed"]} |
            {f"homo_agg_{t}": agg_scores[dim][t][2] for t in ["pca_3d", "enc_3d", "pca_values", "enc_values", "scaled", "normed"]} |
            {f"complete_agg_{t}": agg_scores[dim][t][3] for t in ["pca_3d", "enc_3d", "pca_values", "enc_values", "scaled", "normed"]} |
            {f"fowlkes_agg_{t}": agg_scores[dim][t][4] for t in ["pca_3d", "enc_3d", "pca_values", "enc_values", "scaled", "normed"]} |
            {f"ars_lbl_{t}": lbl_scores[dim][t][0] for t in ["pca_3d", "enc_3d", "pca_values", "enc_values", "scaled", "normed"]} |
            {f"ami_lbl_{t}": lbl_scores[dim][t][1] for t in ["pca_3d", "enc_3d", "pca_values", "enc_values", "scaled", "normed"]} |
            {f"homo_lbl_{t}": lbl_scores[dim][t][2] for t in ["pca_3d", "enc_3d", "pca_values", "enc_values", "scaled", "normed"]} |
            {f"complete_lbl_{t}": lbl_scores[dim][t][3] for t in ["pca_3d", "enc_3d", "pca_values", "enc_values", "scaled", "normed"]} |
            {f"fowlkes_lbl_{t}": lbl_scores[dim][t][4] for t in ["pca_3d", "enc_3d", "pca_values", "enc_values", "scaled", "normed"]}
            for i, dim in enumerate(dimensions)
        )
        status.task_done(len(dimensions))

    with CSVDictWriter(out_dir / kmeans_clusters, ["index"] + [f"cluster_{c}_{d}_{t}"
                                                               for c in ["km", "agg", "lbl"]
                                                               for t in ["pca_3d", "enc_3d", "pca_values", "enc_values", "scaled", "normed"]
                                                               for d in dimensions]) as writer:
        writer.write_all(
            {"index": idx} |
            {f"cluster_km_{d}_pca_3d": km_results_pca_3d[d][1][i] for d in dimensions} |
            {f"cluster_km_{d}_enc_3d": km_results_enc_3d[d][1][i] for d in dimensions} |
            {f"cluster_km_{d}_pca_values": km_results_pca_values[d][1][i] for d in dimensions} |
            {f"cluster_km_{d}_enc_values": km_results_enc_values[d][1][i] for d in dimensions} |
            {f"cluster_km_{d}_scaled": km_results_scaled[d][1][i] for d in dimensions} |
            {f"cluster_km_{d}_normed": km_results_normed[d][1][i] for d in dimensions} |
            {f"cluster_agg_{d}_pca_3d": agg_results_pca_3d[d][1][i] for d in dimensions} |
            {f"cluster_agg_{d}_enc_3d": agg_results_enc_3d[d][1][i] for d in dimensions} |
            {f"cluster_agg_{d}_pca_values": agg_results_pca_values[d][1][i] for d in dimensions} |
            {f"cluster_agg_{d}_enc_values": agg_results_enc_values[d][1][i] for d in dimensions} |
            {f"cluster_agg_{d}_scaled": agg_results_scaled[d][1][i] for d in dimensions} |
            {f"cluster_agg_{d}_normed": agg_results_normed[d][1][i] for d in dimensions} |
            {f"cluster_lbl_{d}_pca_3d": lbl_results_pca_3d[d][1][i] for d in dimensions} |
            {f"cluster_lbl_{d}_enc_3d": lbl_results_enc_3d[d][1][i] for d in dimensions} |
            {f"cluster_lbl_{d}_pca_values": lbl_results_pca_values[d][1][i] for d in dimensions} |
            {f"cluster_lbl_{d}_enc_values": lbl_results_enc_values[d][1][i] for d in dimensions} |
            {f"cluster_lbl_{d}_scaled": lbl_results_scaled[d][1][i] for d in dimensions} |
            {f"cluster_lbl_{d}_normed": lbl_results_normed[d][1][i] for d in dimensions}
            for i, idx in enumerate(indices)
        )
        status.task_done(len(indices))


def main() -> int:
    parser = argparse.ArgumentParser()
    args = add_args(parser).parse_args()
    run(run_main, **vars(args))
    return 0


if __name__ == "__main__":
    sys.exit(main())