import argparse
import itertools
import logging
import sys
from math import ceil
from pathlib import Path

import keras
import numpy as np
import pydot
import sklearn

from .features import Features
from .log import run, Status
from .utils import add_args, CSVDictReader, CSVWriter, CSVDictWriter

keras.utils.disable_interactive_logging()  # enable absl logging instead of stdout


def _scale_norm(features: list[list[float]]) -> np.ndarray:
    def _normalize(d: np.ndarray, percentile: int = 1) -> np.ndarray:
        d -= np.percentile(d, percentile)
        d /= np.percentile(d, 100 - percentile)
        return d.clip(0, 1)

    data: np.ndarray = np.asarray(features)
    normalized = np.asarray([_normalize(_) for _ in data.transpose()]).transpose()
    scaled = sklearn.preprocessing.MinMaxScaler().fit_transform(normalized)
    return scaled


def _scale_auto(features: list[list[float]]) -> np.ndarray:
    return sklearn.preprocessing.StandardScaler().fit_transform(_scale_norm(features))


class Autoencoder:
    def __init__(self, status: Status, epochs: int, width: int, dim: int) -> None:
        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
        self._status: Status = status
        self._dim: int = dim
        self._max_epochs: int = epochs
        self._batch_size: int | None = 32
        self._activation: str = "sigmoid"
        self._dense_layer: int = ceil(((width - dim) / 2) + dim)
        self._model: tuple[keras.Model, keras.Model] = self._create_model(width, dim)

    def _create_model(self, width: int, dim: int) -> tuple[keras.Model, keras.Model]:
        data_in = keras.Input(shape=(width,), name="input")
        encoded = keras.layers.Dense(units=min(width, self._dense_layer), activation=self._activation, name="_dense-in")(data_in)
        encoded = keras.layers.Dense(units=dim, activation=self._activation, name="encoded")(encoded)
        decoded = keras.layers.Dense(units=min(width, self._dense_layer), activation=self._activation, name="_dense-out")(encoded)
        decoded = keras.layers.Dense(units=width, activation="sigmoid", name="reconstructed")(decoded)
        autoencoder = keras.Model(data_in, decoded)
        encoder = keras.Model(data_in, encoded)
        autoencoder.summary(print_fn=self._logger.debug)
        autoencoder.compile(optimizer="adam", loss="binary_crossentropy")
        return encoder, autoencoder

    def to_graphviz(self) -> str:
        # keras.utils.model_to_dot does not include 'nodes' (units), graphviz-python seems abandoned, graphviz cannot
        # re-read and parse dot files, so using the pydot package here.

        dot = pydot.Dot(f"G{self._dim}")
        dot.set_graph_defaults(rankdir="TB", splines="line", outputorder="edgesfirst")
        dot.set_node_defaults(shape="circle", label="", style="filled", fillcolor="white", fixedsize="true", width="0.2", height="0.2")
        dot.set_edge_defaults(arrowsize="0.3", penwidth="0.5", color="#80808080")

        layer_widths: list[int] = [_.output.shape[1] for _ in self._model[1].layers]
        layer_names: list[str] = [_.name for _ in self._model[1].layers]
        first_nodes: list[str] = []
        last_nodes: list[str] = []

        for i, layer in enumerate(layer_widths):
            subgraph = pydot.Subgraph(f"cluster_{i}")
            dot.add_subgraph(subgraph)

            if not layer_names[i].startswith("_"):
                subgraph.set("label", layer_names[i])
                subgraph.set("labelloc", "t" if i < len(layer_widths) - 1 else "b")
            else:
                subgraph.set("color", "transparent")

            subgraph.set("rank", "same")
            new_nodes: list[str] = [f"n_{i}_{_}" for _ in range(layer)]
            for n in new_nodes:
                subgraph.add_node(pydot.Node(n))
                for f in last_nodes:
                    subgraph.add_edge(pydot.Edge(f, n))
            if not first_nodes:
                first_nodes = new_nodes
            last_nodes = new_nodes

        if len(first_nodes) == len(last_nodes):  # try to fix horizontal alignment
            dot.add_edge(pydot.Edge(first_nodes[0], last_nodes[0], style="invis"))

        return dot.to_string(indent=" " * 4)

    def fit(self, data: np.ndarray) -> tuple[float, np.ndarray]:
        status = self._status.get_substatus()
        status.start_progress("", self._max_epochs)
        status_callback = keras.callbacks.LambdaCallback(on_epoch_end=lambda epoch, logs=None: status.task_done())
        stop_callback = keras.callbacks.EarlyStopping(monitor="loss", patience=max(1, self._max_epochs // 100), verbose=1, restore_best_weights=True)

        history: keras.callbacks.History = self._model[1].fit(data, data,
                                                              verbose=0,
                                                              epochs=self._max_epochs,
                                                              batch_size=data.shape[0] if self._batch_size is None else min(self._batch_size, data.shape[0]),
                                                              shuffle=True,
                                                              validation_data=(data, data),
                                                              callbacks=[status_callback, stop_callback])
        self._logger.info(f"{len(history.history['loss'])} epochs, loss: {history.history['loss'][-1]:0.3f}")
        status.finish_progress()

        result = self._model[0].predict(data, verbose=0)
        # reconstructed = autoencoder.predict(data, verbose=0)

        if self._logger.getEffectiveLevel() <= logging.DEBUG:
            for i, loss in enumerate(history.history["val_loss"]):
                self._logger.debug(f"Epoch {i: 3}: {loss:0.3f} val loss")

        return stop_callback.best, result


class Pca:
    def __init__(self, dim: int) -> None:
        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
        self._dim: int = dim
        self._pca = sklearn.decomposition.PCA(n_components=dim, svd_solver="full")

    def analyze(self, data: np.ndarray) -> tuple[np.ndarray, np.ndarray, float, int]:
        self._pca.fit(data)
        return (self._pca.explained_variance_ratio_.cumsum(),
                self._pca.singular_values_,
                self._pca.score(data),
                sum(1 for _ in self._pca.components_[0] if _ < 0))

    def fit(self, data: np.ndarray) -> np.ndarray:
        result: np.ndarray = self._pca.fit_transform(data)
        assert result.shape == (len(data), self._dim)
        self._logger.info(f"Fitted {self._pca.n_samples_} samples with {self._pca.n_features_in_} features onto {self._pca.n_components_} components, variance ratio: {self._pca.explained_variance_ratio_[:self._pca.n_components_].sum():0.3f}")
        return result


class Embedding:
    """Use the 3D PCA result as somewhat relationship-aware embedding for the categorical genre labels."""

    def __init__(self, labels: dict[int, str]) -> None:
        self._label_map: dict[int, str] = labels
        self._genres: list[str] = sorted(set(self._label_map.values()))

    @classmethod
    def _scale(cls, min_: float | int, max_: float | int, val: np.floating | float | int) -> float:
        assert min_ <= val <= max_
        assert min_ < max_
        return float((val - min_) / (max_ - min_))

    def analyze(self, data: np.ndarray) -> dict[str, tuple[float, float, float]]:
        assert data.shape == (len(self._label_map), 3)

        values = {
            genre: np.asarray([data[i] for i in range(data.shape[0]) if self._label_map[i] == genre])
            for genre in self._genres
        }
        avg = {genre: np.mean(values[genre], axis=0) for genre in self._genres}

        return {genre: (
            self._scale(min(avg[_][0] for _ in self._genres), max(avg[_][0] for _ in self._genres), avg[genre][0]),
            self._scale(min(avg[_][1] for _ in self._genres), max(avg[_][1] for _ in self._genres), avg[genre][1]),
            self._scale(min(avg[_][2] for _ in self._genres), max(avg[_][2] for _ in self._genres), avg[genre][2]),
        ) for genre in self._genres}


def run_main(status: Status, *,
             debug: bool,
             out_dir: Path,
             epoch_limit: int,
             genre_bias: float,
             extract_features: Path,
             scaled_features: Path,
             normed_features: Path,
             pca_embeddings: Path,
             pca_stats: Path,
             enc_stats: Path,
             enc_graph: Path,
             cluster_features: Path,
             **kwargs) -> None:

    status.start_progress("Reading features", 1)
    with CSVDictReader(out_dir / extract_features) as reader:
        rows: list[Features] = [Features.from_dict(_) for _ in status.task_pending_counter(reader.read())]
        features: list[list[float]] = [_.to_vector() for _ in rows]
        genres: dict[int, str] = {_: rows[_].tags.genre for _ in range(len(features))}

    scaled: np.ndarray = _scale_auto(features)
    n_features: int = scaled.shape[1]
    dimensions: list[int] = [1, 2, 3, n_features // 2]
    status.finish_progress()

    if debug:
        # however, neither pca, encoder, nor kmeans would be sensitive to moving a dimension to another quadrant
        neg_axes: int = n_features + 1  # number of negative axes for the major component, indicating contradictions
        status.start_progress("Probing signs", 2 ** n_features)
        for signs in itertools.product([-1, 1], repeat=n_features):  # XXX: half of it would work, too
            flipped = scaled * np.tile(signs, (len(scaled), 1))
            num_neg = Pca(1).analyze(flipped)[3]
            if num_neg <= neg_axes:
                neg_axes = num_neg
                if len(set(signs)) != 1:
                    status.logger.info(f"Suggested sign flip for {neg_axes} contradictions: {' '.join('-' if _ < 0 else '+' for _ in signs)}")
            status.task_done()

    status.start_progress("Running PCA", 1 + len(dimensions))
    pca_dim_stats: tuple[np.ndarray, np.ndarray, float, int] = Pca(n_features).analyze(scaled)
    status.task_done()
    pca_results: list[np.ndarray] = list(status.task_done_counter(Pca(i).fit(scaled) for i in dimensions))

    status.start_progress("Getting embedding", 1)
    embeddings: dict[str, tuple[float, float, float]] = Embedding(genres).analyze(pca_results[2])
    if genre_bias > 0:
        status.logger.info("Using genre embeddings")
        for i in range(len(features)):
            features[i] += embeddings[rows[i].tags.genre]
    normed: np.ndarray = _scale_norm(features)
    status.task_done()

    status.start_progress("Running encoder", len(dimensions))
    enc_loss_stats: list[float] = []
    enc_graphs: list[str] = []
    enc_results: list[np.ndarray] = []
    for i, dim in enumerate(dimensions):
        encoder = Autoencoder(status, epoch_limit, normed.shape[1], dim)
        loss, weights = encoder.fit(normed)
        enc_loss_stats.append(loss)
        enc_results.append(weights)
        enc_graphs.append(encoder.to_graphviz())
        status.task_done()

    status.start_progress("Writing stats", 6)

    Path(out_dir / enc_graph).write_text("\n".join(enc_graphs))
    status.task_done()

    with CSVWriter(out_dir / scaled_features) as csv_writer:
        for i, data in enumerate(scaled):
            csv_writer.write([rows[i].file.index] + data.tolist())
        status.task_done()

    with CSVWriter(out_dir / normed_features) as csv_writer:
        for i, data in enumerate(normed):
            csv_writer.write([rows[i].file.index] + data.tolist())
        status.task_done()

    with CSVDictWriter(out_dir / pca_embeddings, ["genre", "emb1", "emb2", "emb3"]) as writer:
        writer.write_all({
            "genre": g,
            "emb1": v[0],
            "emb2": v[1],
            "emb3": v[2],
        } for g, v in embeddings.items())
        status.task_done()

    with CSVDictWriter(out_dir / pca_stats, ["dim", "variance", "singular"]) as writer:
        writer.write_all({
            "dim": i + 1,
            "variance": pca_dim_stats[0][i],
            "singular": pca_dim_stats[1][i],
        } for i in range(n_features))
        status.task_done()

    with CSVDictWriter(out_dir / enc_stats, ["dim", "loss"]) as writer:
        writer.write_all({
            "dim": d,
            "loss": enc_loss_stats[i],
        } for i, d in enumerate(dimensions))
        status.task_done()

    status.start_progress("Writing results", len(rows))
    with CSVDictWriter(out_dir / cluster_features, ["index",
                                                    "pca10", "pca20", "pca21", "pca30", "pca31", "pca32", "pca05",
                                                    "enc10", "enc20", "enc21", "enc30", "enc31", "enc32", "enc05"]) as writer:
        writer.write_all({
            "index": feature.file.index,
            "pca10": pca_results[0][i][0],
            "pca20": pca_results[1][i][0],
            "pca21": pca_results[1][i][1],
            "pca30": pca_results[2][i][0],
            "pca31": pca_results[2][i][1],
            "pca32": pca_results[2][i][2],
            "pca05": ",".join(map(str, pca_results[3][i])),
            "enc10": enc_results[0][i][0],
            "enc20": enc_results[1][i][0],
            "enc21": enc_results[1][i][1],
            "enc30": enc_results[2][i][0],
            "enc31": enc_results[2][i][1],
            "enc32": enc_results[2][i][2],
            "enc05": ",".join(map(str, enc_results[3][i])),
        } for i, feature in status.task_done_counter(enumerate(rows)))


def main() -> int:
    parser = argparse.ArgumentParser()
    args = add_args(parser).parse_args()
    run(run_main, **vars(args))
    return 0


if __name__ == "__main__":
    sys.exit(main())