import argparse
import logging
import sys
from abc import ABC, abstractmethod
from colorsys import hsv_to_rgb
from dataclasses import dataclass
from difflib import SequenceMatcher
from functools import cache
from hashlib import md5
from html import escape as html_escape
from math import floor, ceil, sqrt
from pathlib import Path
from typing import Callable, Literal, Iterable, Any

import numpy as np
import pydot
from matplotlib import colormaps
from matplotlib import pyplot as plt
from matplotlib.figure import Figure
from matplotlib.gridspec import GridSpec
from python_tsp.heuristics.local_search import solve_tsp_local_search

from .features import Features
from .log import run, Status
from .utils import add_args, CSVReader, CSVDictReader


# https://matplotlib.org/stable/users/explain/fonts.html#fonts-in-svg
plt.rcParams["svg.fonttype"] = "none"
plt.rcParams["font.family"] = "sans-serif"
plt.rcParams["font.sans-serif"] = ["Arial"]
plt.rcParams["svg.hashsalt"] = "XXX"
plt.rcParams["figure.dpi"] = 96


class Colorizer(ABC):
    @abstractmethod
    def get_color(self, f: Features) -> tuple[float, float, float]:
        raise NotImplementedError

    @abstractmethod
    def color_by_value(self, s: str) -> tuple[float, float, float]:
        raise NotImplementedError

    @abstractmethod
    def sort_value(self, s: str) -> float:
        raise NotImplementedError

    @classmethod
    @cache
    def hash_float(cls, s: str) -> float:
        h: bytes = md5(s.encode(errors="replace"), usedforsecurity=False).digest()
        return int.from_bytes(h[:4], byteorder=sys.byteorder, signed=False) / 0xffffffff

    @classmethod
    def hue_to_rgb(cls, h: float) -> tuple[float, float, float]:
        # return hsv_to_rgb(min(max(0.0, h), 1.0) % 1.0, 0.8, 0.7)  # primarish
        return hsv_to_rgb(min(max(0.0, h), 1.0) % 1.0, 0.5, 0.7)  # pastelly


class RandomColorizer(Colorizer):
    def get_color(self, f: Features) -> tuple[float, float, float]:
        return self.color_by_value(f.tags.genre)

    @classmethod
    def color_by_value(cls, s: str) -> tuple[float, float, float]:
        return cls.hue_to_rgb(cls.hash_float(s))

    @classmethod
    def sort_value(cls, s: str) -> float:
        return cls.hash_float(s)


class LinearColorizer(Colorizer):
    def __init__(self, values: Iterable[str]) -> None:
        self._values: list[str] = list(values)

    def sort_value(self, s: str) -> float:
        return self._values.index(s)

    def color_by_value(self, s: str) -> tuple[float, float, float]:
        return self.hue_to_rgb((self._values.index(s) / len(self._values)) +
                               (1 / len(self._values) / 2))

    def get_color(self, f: Features) -> tuple[float, float, float]:
        return self.color_by_value(f.tags.genre)


class DiscreteColorizer(Colorizer):
    _colormap = colormaps["tab20"].colors + colormaps["tab20b"].colors + colormaps["tab20c"].colors

    def __init__(self, values: Iterable[Any]) -> None:
        self._values: list = list(values)

    def sort_value(self, s: str) -> float:
        return self._values.index(s)

    @classmethod
    def color_by_index(cls, i: int) -> tuple[float, float, float]:
        return cls._colormap[i % len(cls._colormap)]

    @classmethod
    def color_by_str(cls, s: str) -> tuple[float, float, float]:
        return cls._colormap[floor(cls.hash_float(s) * len(cls._colormap))]

    def color_by_value(self, s: Any) -> tuple[float, float, float]:
        return self.hue_to_rgb((self._values.index(s) / len(self._values)) +
                               (1 / len(self._values) / 2))

    def get_color(self, f: Features) -> tuple[float, float, float]:
        return self.color_by_str(f.tags.genre)


@dataclass(frozen=True)
class PlotStrategy(ABC):
    pass


@dataclass(frozen=True)
class HLegendPlotStrategy(PlotStrategy):
    getter: list[list[tuple[str, float]]]
    limits: tuple[float, float]
    color_by: list[tuple[float, float, float]]


@dataclass(frozen=True)
class ScatterPlotStrategy(PlotStrategy):
    getter: list[Iterable[tuple[float, float]]]


@dataclass(frozen=True)
class MultiScatterPlotStrategy(PlotStrategy):
    getter: list[list[Iterable[tuple[float, float]]]]
    legends: list[list[str | None]]
    titles: list[str | None]
    color_by: list[list[tuple[float, float, float]]]


@dataclass(frozen=True)
class Scatter1dPlotStrategy(PlotStrategy):
    sort_by: Callable[[Features], tuple]
    getter: list[Callable[[Features], float]]
    color_by: Callable[[Features], tuple[float, float, float]]


@dataclass(frozen=True)
class Scatter2dPlotStrategy(PlotStrategy):
    getter: list[Callable[[Features], tuple[float, float]]]
    color_by: Callable[[Features], tuple[float, float, float]]


@dataclass(frozen=True)
class Scatter3dPlotStrategy(PlotStrategy):
    getter: list[Callable[[Features], tuple[float, float, float]]]
    color_by: Callable[[Features], tuple[float, float, float]]


@dataclass(frozen=True)
class PolarBarPlotStrategy(PlotStrategy):
    getter: list[list[list[float]]]  # x -> list[h]
    titles: list[str | None]
    color_by: list[tuple[float, float, float]]  # y


@dataclass(frozen=True)
class SpectrumPlotStrategy(PlotStrategy):
    sort_by: Callable[[Features], tuple]
    getter: list[Callable[[Features], list[float]]]
    labels: list[list[str] | None]
    scale: list[Literal["linear", "log"]]


class Plotter:
    def __init__(self, debug: bool) -> None:
        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
        self._debug: bool = debug
        self._width: float = 16 if debug else 9.375  # 900px @ 96dpi
        self._height: float = 4 if debug else 3
        self._history: list[Path] = []

    def _createfig(self, nrows: int, d3: bool = False, polar: bool = False, side: bool = False, wide: bool = False) -> tuple[Figure, np.ndarray]:
        if d3 and side:
            fig = plt.figure(figsize=(self._width, self._height * nrows * 1.5), layout="constrained")
            gs = GridSpec(3 * nrows, 3, figure=fig)
            ax = []
            for r in range(nrows):
                ax.append([fig.add_subplot(gs[r * 2:r * 2 + 2, 0], projection="3d"),
                           fig.add_subplot(gs[r * 2:r * 2 + 2, 1], projection="3d"),
                           fig.add_subplot(gs[r * 2:r * 2 + 2, 2], projection="3d")])
                ax.append([fig.add_subplot(gs[r * 2 + 2, 0]),
                           fig.add_subplot(gs[r * 2 + 2, 1]),
                           fig.add_subplot(gs[r * 2 + 2, 2])])
            return fig, np.asarray(ax)
        elif d3:
            fig = plt.figure(figsize=(self._width, self._height * nrows), layout="constrained")
            ax = []
            for r in range(nrows):
                ax.append([fig.add_subplot(nrows, 3, r * 3 + 1, projection="3d"),
                           fig.add_subplot(nrows, 3, r * 3 + 2, projection="3d"),
                           fig.add_subplot(nrows, 3, r * 3 + 3, projection="3d")])
            return fig, np.asarray(ax)
        elif polar:
            nrows = ceil(nrows / 3)
            fig = plt.figure(figsize=(self._width, self._height * nrows), layout="constrained")
            ax = []
            for r in range(nrows):
                ax.extend([[fig.add_subplot(nrows, 3, r * 3 + 1, projection="polar")],
                           [fig.add_subplot(nrows, 3, r * 3 + 2, projection="polar")],
                           [fig.add_subplot(nrows, 3, r * 3 + 3, projection="polar")]])
            return fig, np.asarray(ax)
        elif side:
            return plt.subplots(nrows=nrows, ncols=2, squeeze=False, width_ratios=[4, 1], sharex="col", sharey="row",
                                figsize=(self._width, (ceil(self._height / 2) if wide else self._height) * nrows),
                                layout="constrained")
        else:
            return plt.subplots(nrows=nrows, squeeze=False, sharex=True,
                                figsize=(self._width, (ceil(self._height / 2) if wide else self._height) * nrows),
                                layout="constrained")

    def _savefig(self, fig: Figure, dst: Path) -> None:
        self._logger.info(f"Writing {dst.name}")
        self._history.append(dst)
        fig.savefig(dst, metadata={"Date": None})  # reproducible output, alongside with hashsalt

    def _marker_size(self, num: int, sq: bool) -> float:
        if num <= 100:
            return 36 if sq else 6
        elif num <= 2000:
            return 36 if sq else 6
        else:
            return 25 if sq else 5

    def write_index(self, out_file: Path) -> None:
        self._logger.info(f"Writing {out_file.name}")
        with out_file.open("w") as fp:
            fp.write("""<!DOCTYPE html><html lang="en"><head><meta charset="utf-8"><title>Plots</title><link rel="icon" href="data:,"><meta name="viewport" content="width=device-width,initial-scale=1">"""
                     """<style>figure {display: inline-block; max-width: 100%;} img {max-width: 100%;} figcaption {text-align: center;}</style>"""
                     """</head><body>\n""")
            fp.write("""<ul class="toc">\n""")
            for plot in self._history:
                fp.write(f"""<li><a href="#{html_escape(plot.stem)}">{html_escape(plot.stem)}</a></li>\n""")
            fp.write("""</ul>\n""")
            for plot in self._history:
                plot = plot.relative_to(out_file.parent)
                fp.write(f"""<figure><figcaption>{html_escape(plot.stem)}</figcaption><img src="{html_escape(plot.as_posix())}" title="{html_escape(plot.stem)}" alt="{html_escape(plot.stem)}" id="{html_escape(plot.stem)}" loading="lazy"></figure>\n""")
            fp.write("""</body></html>\n""")

    def plot(self, features: list[Features], strategy: PlotStrategy, out_file: Path) -> None:
        if isinstance(strategy, HLegendPlotStrategy):
            self._plot_hlegend(strategy, out_file)
        elif isinstance(strategy, ScatterPlotStrategy):
            self._plot_scatter(strategy, out_file)
        elif isinstance(strategy, MultiScatterPlotStrategy):
            self._plot_scatter_multi(strategy, out_file)
        elif isinstance(strategy, Scatter1dPlotStrategy):
            self._plot_scatter_1d(features, strategy, out_file)
        elif isinstance(strategy, Scatter2dPlotStrategy):
            self._plot_scatter_2d(features, strategy, out_file)
        elif isinstance(strategy, Scatter3dPlotStrategy):
            self._plot_scatter_3d(features, strategy, out_file)
        elif isinstance(strategy, PolarBarPlotStrategy):
            self._plot_polar_bar(strategy, out_file)
        elif isinstance(strategy, SpectrumPlotStrategy):
            self._plot_spectrum(features, strategy, out_file)
        else:
            raise TypeError(type(strategy))

    def plot_dot(self, source: str, out_prefix: Path) -> None:
        try:
            for dot in pydot.graph_from_dot_data(source):
                dot.set("size", f"{self._width},{self._width}!")  # 96dpi -> 675pt/720pt -> 900px or 960px
                dot.set("pad", "-0.05")
                dot.set("fontname", "'Arial', sans-serif")
                dot.set("fontsize", "8")  # pt vs px vs scale-transform

                name: str = dot.get_name() or "G"
                out_file: Path = out_prefix.with_stem(out_prefix.stem + f"_{name}")
                self._logger.info(f"Writing {out_file.name}")
                dot.write(out_file, "dot", "svg")
                self._history.append(out_file)
        except Exception as e:
            self._logger.error(f"Cannot plot to {out_prefix}: {str(e)}")

    def _plot_hlegend(self, strategy: HLegendPlotStrategy, out_file: Path) -> None:
        fig, ax = self._createfig(len(strategy.getter), wide=True)

        for i, getter in enumerate(strategy.getter):
            indices = sorted(range(len(getter)), key=lambda _: getter[_][1])
            labels = [getter[_][0] for _ in indices]
            starts = [strategy.limits[0] if _ == 0 else getter[indices[_ - 1]][1] for _ in range(len(indices))]
            widths = [getter[indices[_]][1] - starts[_] for _ in range(len(indices))]
            colors = [strategy.color_by[_] for _ in indices]

            for idx in indices:
                b = ax[i][0].barh([0], [widths[idx]], left=[starts[idx]], color=[colors[idx]], edgecolor='white', linewidth=0.1)
                ax[i][0].bar_label(b, [labels[idx]], label_type='center', rotation=90)

            ax[i][0].set_xticks([])
            ax[i][0].set_yticks([])
            ax[i][0].set_xlim(left=strategy.limits[0], right=strategy.limits[1])
            ax[i][0].spines['top'].set_visible(False)
            ax[i][0].spines['right'].set_visible(False)
            ax[i][0].spines['bottom'].set_visible(False)
            ax[i][0].spines['left'].set_visible(False)

        self._savefig(fig, out_file)
        plt.close()

    def _plot_scatter(self, strategy: ScatterPlotStrategy, out_file: Path) -> None:
        fig, ax = self._createfig(len(strategy.getter))

        for i, getter in enumerate(strategy.getter):
            x, y = zip(*list(getter))
            ax[i][0].plot(x, y, linestyle="dotted", marker=".", ms=self._marker_size(len(strategy.getter) * len(x), False))
            ax[i][0].set_ylim(bottom=0)

        self._savefig(fig, out_file)
        plt.close()

    def _plot_scatter_multi(self, strategy: MultiScatterPlotStrategy, out_file: Path) -> None:
        fig, ax = self._createfig(len(strategy.getter), wide=not self._debug)

        legends: bool = all(None not in _ for _ in strategy.legends)
        for i, getters in enumerate(strategy.getter):
            for j, getter in enumerate(getters):
                x, y = zip(*list(getter))
                ax[i][0].plot(x, y, linestyle="dotted", marker=".", ms=self._marker_size(len(getters) * len(x), False), label=strategy.legends[i][j], c=strategy.color_by[i][j])
                ax[i][0].set_xticks(x)
            if strategy.titles[i] is not None:
                ax[i][0].set_title(strategy.titles[i], loc="left", fontsize=10, y=0.9, pad=0, x=0.01)
            if legends:
                ax[i][0].legend(loc="upper left")
            if self._debug:
                ax[i][0].set_ylim(bottom=0)
            ax[i][0].set_yticks([])
        if not legends:
            fig.legend(loc="outside right upper", frameon=False)

        self._savefig(fig, out_file)
        plt.close()

    def _plot_scatter_1d(self, features: list[Features], strategy: Scatter1dPlotStrategy, out_file: Path) -> None:
        fig, ax = self._createfig(len(strategy.getter), side=True, wide=True)

        v = sorted(features, key=strategy.sort_by)
        x = list(range(len(v)))
        c = [strategy.color_by(_) for _ in v]

        for i, getter in enumerate(strategy.getter):
            values = [getter(_) for _ in v]
            by_color = {cc: [values[_] for _ in range(len(v)) if c[_] == cc] for cc in set(c)}
            ax[i][0].scatter(x=x, y=values, marker=".", s=self._marker_size(len(x), True), c=c)
            ax[i][1].hist(x=by_color.values(), facecolor=by_color.keys(), bins=max(10, min(50, floor(sqrt(len(v))))), orientation="horizontal", stacked=True)
            ax[i][0].set_xticks([])
            ax[i][0].set_yticks([])
            ax[i][1].set_xticks([])
            ax[i][1].set_yticks([])
            ax[i][1].spines['top'].set_visible(False)
            ax[i][1].spines['right'].set_visible(False)
            ax[i][1].spines['bottom'].set_visible(False)
            ax[i][1].spines['left'].set_visible(False)

        self._savefig(fig, out_file)
        plt.close()

    def _plot_scatter_2d(self, features: list[Features], strategy: Scatter2dPlotStrategy, out_file: Path) -> None:
        fig, ax = self._createfig(len(strategy.getter), side=True, wide=True)

        v = features
        c = [strategy.color_by(_) for _ in v]

        for i, getter in enumerate(strategy.getter):
            x, y = zip(*[getter(_) for _ in v])
            by_color = {cc: [y[_] for _ in range(len(v)) if c[_] == cc] for cc in set(c)}
            ax[i][0].scatter(x=x, y=y, marker=".", s=self._marker_size(len(v), True), c=c)
            ax[i][1].hist(x=by_color.values(), facecolor=by_color.keys(), bins=max(10, min(50, floor(sqrt(len(v))))), orientation="horizontal", stacked=True)
            ax[i][0].set_xticks([])
            ax[i][0].set_yticks([])
            ax[i][1].set_xticks([])
            ax[i][1].set_yticks([])
            ax[i][1].spines['top'].set_visible(False)
            ax[i][1].spines['right'].set_visible(False)
            ax[i][1].spines['bottom'].set_visible(False)
            ax[i][1].spines['left'].set_visible(False)

        self._savefig(fig, out_file)
        plt.close()

    def _plot_scatter_3d(self, features: list[Features], strategy: Scatter3dPlotStrategy, out_file: Path) -> None:
        fig, ax = self._createfig(len(strategy.getter), d3=True, side=True)

        v = features
        c = [strategy.color_by(_) for _ in v]

        for i, getter in enumerate(strategy.getter):
            x, y, z = zip(*[getter(_) for _ in v])
            ax[i * 2][0].scatter(x, y, z, marker=".", s=self._marker_size(len(v), True), c=c)
            ax[i * 2][2].scatter(y, z, x, marker=".", s=self._marker_size(len(v), True), c=c)
            ax[i * 2][1].scatter(z, x, y, marker=".", s=self._marker_size(len(v), True), c=c)
            ax[i * 2][0].azim = -70
            ax[i * 2][1].azim = -60  # default
            ax[i * 2][2].azim = -20

            by_color_i = {cc: [_ for _ in range(len(v)) if c[_] == cc] for cc in set(c)}
            by_color = [{cc: [x[_] for _ in i] for cc, i in by_color_i.items()},
                        {cc: [y[_] for _ in i] for cc, i in by_color_i.items()},
                        {cc: [z[_] for _ in i] for cc, i in by_color_i.items()}]
            ax[i * 2 + 1][0].hist(by_color[0].values(), facecolor=by_color[0].keys(), bins=max(10, min(50, floor(sqrt(len(v))))), stacked=True)
            ax[i * 2 + 1][1].hist(by_color[2].values(), facecolor=by_color[2].keys(), bins=max(10, min(50, floor(sqrt(len(v))))), stacked=True)
            ax[i * 2 + 1][2].hist(by_color[1].values(), facecolor=by_color[1].keys(), bins=max(10, min(50, floor(sqrt(len(v))))), stacked=True)

        for i in range(len(strategy.getter)):
            for j in range(3):
                ax[i * 2][j].tick_params(axis="both", which="both",
                                         bottom=False, top=False, left=False, right=False,
                                         labelbottom=False, labeltop=False, labelleft=False, labelright=False)
                ax[i * 2 + 1][j].yaxis.set_inverted(True)
                ax[i * 2 + 1][j].set_xticks([])
                ax[i * 2 + 1][j].set_yticks([])
                ax[i * 2 + 1][j].spines['right'].set_visible(False)
                ax[i * 2 + 1][j].spines['left'].set_visible(False)
                ax[i * 2 + 1][j].spines['bottom'].set_visible(False)

        self._savefig(fig, out_file)
        plt.close()

    def _plot_polar_bar(self, strategy: PolarBarPlotStrategy, out_file: Path) -> None:
        fig, ax = self._createfig(len(strategy.getter), polar=True)

        for i, getter in enumerate(strategy.getter):
            data = np.asarray(getter).transpose()
            carry = np.zeros(data.shape[1])
            for y, c in enumerate(strategy.color_by):
                x = [_ * 2 * np.pi / data.shape[1] for _ in range(data.shape[1])]
                ax[i][0].bar(x, data[y], width=2 * np.pi / data.shape[1], bottom=carry, facecolor=c)
                carry += data[y]
            ax[i][0].set_rorigin(carry.max() / -5)
            ax[i][0].set_xticks([])
            ax[i][0].set_yticks([])
            if strategy.titles[i] is not None:
                ax[i][0].set_title(strategy.titles[i], loc="center", fontsize=10, y=0.5, pad=0, verticalalignment='center')

        self._savefig(fig, out_file)
        plt.close()

    def _plot_spectrum(self, features: list[Features], strategy: SpectrumPlotStrategy, out_file: Path) -> None:
        fig, ax = self._createfig(len(strategy.getter))

        v = sorted(features, key=strategy.sort_by)
        for i, getter in enumerate(strategy.getter):
            im = ax[i][0].imshow(np.transpose([getter(_) for _ in v]), norm=strategy.scale[i],
                                 origin="lower", aspect="auto", interpolation="none", resample=False, cmap="viridis")
            ax[i][0].set_xticks([])
            if self._debug:
                fig.colorbar(im, ax=ax[i][0], fraction=0.05, pad=0.01)
            if strategy.labels[i] is not None:
                ax[i][0].set_yticks(range(len(strategy.labels[i])))
                ax[i][0].set_yticklabels(strategy.labels[i])

        self._savefig(fig, out_file)
        plt.close()


def genre_sort(genres: list[str], embeddings: dict[str, tuple[float, ...]]) -> list[str]:
    """Genre sorting by quick Travelling Salesman Problem solver on the distance matrix."""
    # as first pass, seed the ordering by string similarity matrix

    genres = sorted(genres)
    distances: np.ndarray = np.asarray([[1.0 - SequenceMatcher(lambda junk: junk in " /-&',", _.lower(), g.lower()).ratio() for _ in genres] for g in genres])
    permutation, _ = solve_tsp_local_search(distances, x0=list(range(len(genres))))
    genres = [genres[_] for _ in permutation]

    distances = np.asarray([[np.linalg.norm(np.asarray(embeddings[_]) - np.asarray(embeddings[g])) for _ in genres] for g in genres])
    permutation, _ = solve_tsp_local_search(distances, x0=list(range(len(genres))))
    genres = [genres[_] for _ in permutation]

    return genres


def run_main(status: Status, *,
             debug: bool,
             genre_bias: float,
             out_dir: Path,
             extract_features: Path,
             scaled_features: Path,
             normed_features: Path,
             pca_embeddings: Path,
             pca_stats: Path,
             enc_stats: Path,
             enc_graph: Path,
             cluster_features: Path,
             kmeans_stats: Path,
             kmeans_clusters: Path,
             **kwargs) -> None:
    status.start_progress("Reading features")

    with CSVDictReader(out_dir / extract_features) as reader:
        rows: list[Features] = list(Features.from_dict(row) for row in status.task_pending_counter(reader.read()))
        genres: list[str] = sorted(set(_.tags.genre for _ in rows))

    with CSVDictReader(out_dir / cluster_features) as reader:
        clusters: dict[int, dict[str, float | list[float]]] = {int(row["index"]): {
            k: list(map(float, v.split(","))) if "," in v else float(v)
            for k, v in row.items()
        } for row in status.task_pending_counter(reader.read())}

    with CSVDictReader(out_dir / pca_embeddings) as reader:
        genre_embeddings: dict[str, tuple[float, float, float]] = {
            row["genre"]: (float(row["emb1"]), float(row["emb2"]), float(row["emb3"]))
            for row in status.task_pending_counter(reader.read())
        }

    with CSVDictReader(out_dir / pca_stats) as reader:
        pca_dim_stats: dict[int, tuple[float, float]] = {int(row["dim"]): (float(row["variance"]), float(row["singular"]))
                                                         for row in status.task_pending_counter(reader.read())}

    with CSVDictReader(out_dir / enc_stats) as reader:
        enc_dim_stats: dict[int, tuple[float]] = {int(row["dim"]): (float(row["loss"]),)
                                                  for row in status.task_pending_counter(reader.read())}

    with CSVReader(out_dir / scaled_features) as csv_reader:
        scaled_cluster_features: dict[int, list[float]] = {int(row[0]): [float(_) for _ in row[1:]]
                                                           for row in status.task_pending_counter(csv_reader.read())}

    with CSVReader(out_dir / normed_features) as csv_reader:
        normed_cluster_features: dict[int, list[float]] = {int(row[0]): [float(_) for _ in row[1:]]
                                                           for row in status.task_pending_counter(csv_reader.read())}

    with CSVDictReader(out_dir / kmeans_stats) as reader:
        kmeans_cluster_stats: dict[int, dict[str, float]] = {int(row["dim"]): {k: float(v) for k, v in row.items()}
                                                             for row in status.task_pending_counter(reader.read())}

    with CSVDictReader(out_dir / kmeans_clusters) as reader:
        kmeans_cluster_names: list[str] = [_ for _ in reader.fieldnames() if _ != "index"]
        kmeans_cluster_index: dict[int, dict[str, int]] = {int(row["index"]): {k: int(v) for k, v in row.items()}
                                                           for row in status.task_pending_counter(reader.read())}
    status.finish_progress()

    status.start_progress("Sorting genres")
    status.task_pending(len(genres))
    genres = genre_sort(genres, genre_embeddings)
    genre_colorizer: Colorizer = LinearColorizer(genres)
    status.logger.info(f"Sorted genres: {', '.join(genres)}")
    status.finish_progress()

    status.start_progress("Getting data")
    status.task_pending(1)

    encoder_types: list[str] = ["pca", "enc"]
    cluster_types: list[str] = ["km", "agg", "lbl"] if genre_bias > 0 else ["km", "agg"]
    plots: dict[str, PlotStrategy] = {}

    plots |= {"genre": HLegendPlotStrategy(
        getter=[[(g, i) for i, g in enumerate(genres)]],
        limits=(-1, len(genres) - 1),
        color_by=[genre_colorizer.color_by_value(g) for g in genres],
    )}

    plots |= {"genre-embeddings": HLegendPlotStrategy(
        getter=[[(g, genre_embeddings[g][embedding]) for g in genre_embeddings.keys()] for embedding in range(3)],
        limits=(-1.0 / len(genre_embeddings), 1.0),
        color_by=[genre_colorizer.color_by_value(g) for g in genre_embeddings.keys()],
    )}

    plots |= {f"{name}-{idx}": Scatter1dPlotStrategy(
        getter=[lambda _, f=name, i=o: getattr(_.audio, f)[i] for o in range(3)],
        sort_by=lambda _, f=name, i=idx: (getattr(_.audio, f)[i],),
        color_by=genre_colorizer.get_color,
    ) for name in ["centroid", "bandwidth", "flatness", "crossing", "rollon", "rolloff", "hpss", "tempogram"] for idx in range(3)}

    plots |= {f"{name}-genres": Scatter1dPlotStrategy(
        getter=[lambda _, f=name, i=o: getattr(_.audio, f)[i] for o in range(3)],
        sort_by=lambda _: (genre_colorizer.sort_value(_.tags.genre), _.file.filename),
        color_by=genre_colorizer.get_color,
    ) for name in ["centroid", "bandwidth", "flatness", "crossing", "rollon", "rolloff", "hpss", "tempogram"]}

    plots |= {f"{name}": SpectrumPlotStrategy(
        getter=[lambda _, f=name, i=o: getattr(_.audio, f)[i] for o in range(3)],
        sort_by=lambda _: (genre_colorizer.sort_value(_.tags.genre), _.file.filename),
        scale=["linear"] * 3, labels=[None] * 3,
    ) for name in ["spectrum"]}

    plots |= {
        "cluster-scaled-features": SpectrumPlotStrategy(
            sort_by=lambda _: (genre_colorizer.sort_value(_.tags.genre), _.file.filename),
            getter=[lambda _: scaled_cluster_features[_.file.index]],
            scale=["linear"],
            labels=[None],
        ),
        "cluster-normed-features": SpectrumPlotStrategy(
            sort_by=lambda _: (genre_colorizer.sort_value(_.tags.genre), _.file.filename),
            getter=[lambda _: normed_cluster_features[_.file.index]],
            scale=["linear"],
            labels=[["Centroid Med", "Bandwidth Avg", "Bandwidth -Std", "Flatness Avg", "Crossing Avg",
                     "Crossing Std", "Rollon Avg", "Rolloff Med", "HPSS Med", "Tempogram -Med",
                     "100Hz -Med", "510Hz -Med", "1080Hz Med", "2000Hz Med", "3700Hz Med", "15.5kHz Med",
                     "Genre1", "Genre2", "Genre3"]]
        ),
    }

    plots |= {
        "pca-variance-singular-loss": ScatterPlotStrategy(
            getter=[[(_, pca_dim_stats[_][0]) for _ in sorted(pca_dim_stats.keys())],
                    [(_, pca_dim_stats[_][1]) for _ in sorted(pca_dim_stats.keys())],
                    [(_, enc_dim_stats.get(_, (0.0,))[0]) for _ in sorted(pca_dim_stats.keys())]],  # TODO: None support
        ),
    }

    for cluster_type in encoder_types:
        plots |= {
            f"cluster-{cluster_type}-1d": Scatter1dPlotStrategy(
                getter=[lambda _, t=cluster_type: (clusters[_.file.index][f"{t}10"])],
                sort_by=lambda _, t=cluster_type: (clusters[_.file.index][f"{t}10"],),
                color_by=genre_colorizer.get_color,
            ),
            f"cluster-{cluster_type}-2d": Scatter2dPlotStrategy(
                getter=[lambda _, t=cluster_type: (clusters[_.file.index][f"{t}20"], clusters[_.file.index][f"{t}21"]),
                        lambda _, t=cluster_type: (clusters[_.file.index][f"{t}21"], clusters[_.file.index][f"{t}20"])],
                color_by=genre_colorizer.get_color,
            ),
            f"cluster-{cluster_type}-3d": Scatter3dPlotStrategy(
                getter=[lambda _, t=cluster_type: (clusters[_.file.index][f"{t}30"], clusters[_.file.index][f"{t}31"], clusters[_.file.index][f"{t}32"])],
                color_by=genre_colorizer.get_color,
            ),
        }

    c: Colorizer = LinearColorizer([f"{k}_{t}" for k in ["pca_3d", "enc_3d", "pca_values", "enc_values", "scaled", "normed"] for t in cluster_types])
    plots |= {
        "kmeans-scoring-full": MultiScatterPlotStrategy(
            getter=[[
                [(d, kmeans_cluster_stats[d][f"{scoring}_{t}_{k}"]) for d in kmeans_cluster_stats.keys()]
                for k in ["pca_3d", "enc_3d", "pca_values", "enc_values", "scaled", "normed"] for t in cluster_types
            ] for scoring in ["ars", "ami", "homo", "complete", "fowlkes"]],
            legends=[[
                f"{k} {t}" if scoring == "ars" else None
                for k in ["PCA 3D", "Enc. 3D", "PCA 8D", "Enc. 8D", "Scaled", "Normed"] for t in ["kMeans", "Agglom.", "Guided"][:len(cluster_types)]
            ] for scoring in ["ars", "ami", "homo", "complete", "fowlkes"]],
            titles=["Adjusted Rand Index", "Adjusted Mutual Information", "Homogeneity", "Completeness", "Fowlkes-Mallows Index"],
            color_by=[[
                c.color_by_value(f"{k}_{t}")
                for k in ["pca_3d", "enc_3d", "pca_values", "enc_values", "scaled", "normed"] for t in cluster_types
            ] for _ in ["ars", "ami", "homo", "complete", "fowlkes"]],
        ),
    }

    c = LinearColorizer(["pca_3d", "enc_3d", "pca_values", "enc_values", "scaled", "normed"])
    plots |= {
        "kmeans-scoring-encoding-log": MultiScatterPlotStrategy(
            getter=[[
                [(d, sqrt(max(0.0, kmeans_cluster_stats[d][f"{scoring}_{t}_{k}"]))) for d in kmeans_cluster_stats.keys()]
                for k in ["pca_3d", "enc_3d", "pca_values", "enc_values", "scaled", "normed"] for t in cluster_types
            ] for scoring in ["ars", "ami", "homo", "complete", "fowlkes"]],
            legends=[[
                f"{k} {t}" if scoring == "ars" else None
                for k in ["pca_3d", "enc_3d", "pca_values", "enc_values", "scaled", "normed"] for t in cluster_types
            ] for scoring in ["ars", "ami", "homo", "complete", "fowlkes"]],
            titles=[None] * 6,
            color_by=[[
                c.color_by_value(k)
                for k in ["pca_3d", "enc_3d", "pca_values", "enc_values", "scaled", "normed"] for t in cluster_types
            ] for _ in ["ars", "ami", "homo", "complete", "fowlkes"]],
        ),
    }

    c = LinearColorizer(cluster_types)
    plots |= {
        "kmeans-scoring-clustering": MultiScatterPlotStrategy(
            getter=[[
                [(d, kmeans_cluster_stats[d][f"{scoring}_{t}_{k}"]) for d in kmeans_cluster_stats.keys()]
                for k in ["pca_3d", "enc_3d", "pca_values", "enc_values", "scaled", "normed"] for t in cluster_types
            ] for scoring in ["ars", "ami", "homo", "complete", "fowlkes"]],
            legends=[[
                f"{k} {t}" if scoring == "ars" else None
                for k in ["pca_3d", "enc_3d", "pca_values", "enc_values", "scaled", "normed"] for t in cluster_types
            ] for scoring in ["ars", "ami", "homo", "complete", "fowlkes"]],
            titles=[None] * 6,
            color_by=[[
                c.color_by_value(t)
                for k in ["pca_3d", "enc_3d", "pca_values", "enc_values", "scaled", "normed"] for t in cluster_types
            ] for _ in ["ars", "ami", "homo", "complete", "fowlkes"]],
        ),
    }

    for cluster_type in encoder_types:
        for cluster_idx in kmeans_cluster_names:
            if any(cluster_idx.endswith(f"_{_}_{cluster_type}_3d") for _ in [2, 4, 8, 12, 16, 20]):
                colorizer: Colorizer = LinearColorizer(map(str, sorted(set(v for val in kmeans_cluster_index.values() for k, v in val.items() if k == cluster_idx))))
                plots |= {f"kmeans-{cluster_idx}": Scatter3dPlotStrategy(
                    getter=[lambda _, t=cluster_type: (clusters[_.file.index][f"{t}30"], clusters[_.file.index][f"{t}31"], clusters[_.file.index][f"{t}32"])],
                    color_by=lambda _, i=cluster_idx, c=colorizer: c.color_by_value(str(kmeans_cluster_index[_.file.index][i])),
                )}

    for cluster_name in cluster_types:
        for cluster_type in ["pca_3d", "enc_3d", "pca_values", "enc_values", "scaled", "normed"]:
            cluster_names = [f"cluster_{cluster_name}_{_}_{cluster_type}" for _ in [2, 4, 8, 12, 16, 20]]
            idxs = {name: sorted(set(_[name] for _ in kmeans_cluster_index.values())) for name in cluster_names}
            plots |= {f"kmeans-{cluster_name}-{cluster_type}": PolarBarPlotStrategy(
                getter=[[[sum([
                    1 for r in rows if r.tags.genre == genre and kmeans_cluster_index[r.file.index][name] == i
                ]) for genre in genres] for i in idxs[name]] for name in cluster_names],
                titles=list(map(str, [2, 4, 8, 12, 16, 20])),
                color_by=[genre_colorizer.color_by_value(genre) for genre in genres],
            )}

    status.finish_progress()
    status.start_progress("Writing plots", len(plots) + 2)
    plotter: Plotter = Plotter(debug)

    for name, strategy in plots.items():
        plotter.plot(rows, strategy, out_dir / f"{name}.svg")
        status.task_done()

    plotter.plot_dot((out_dir / enc_graph).read_text(), out_dir / enc_graph.with_suffix(".svg"))
    status.task_done()

    plotter.write_index(out_dir / "index.html")
    status.task_done()


def main() -> int:
    parser = argparse.ArgumentParser()
    args = add_args(parser).parse_args()
    run(run_main, **vars(args))
    return 0


if __name__ == "__main__":
    sys.exit(main())