import argparse
import sys
from collections import defaultdict
from pathlib import Path

from .features import Features
from .log import run, Status
from .utils import add_args, CSVDictReader, CSVWriter, M3UWriter


def run_main(status: Status, *,
             out_dir: Path,
             extract_features: Path,
             cluster_features: Path,
             kmeans_clusters: Path,
             **kwargs) -> None:
    status.start_progress("Reading features")

    with CSVDictReader(out_dir / extract_features) as reader:
        rows: list[Features] = list(Features.from_dict(row) for row in status.task_pending_counter(reader.read()))
        file_map: dict[int, Path] = {_.file.index: _.file.filename for _ in rows}

    with CSVDictReader(out_dir / cluster_features) as reader:
        clusters: dict[int, dict[str, float | list[float]]] = {int(row["index"]): {
            k: list(map(float, v.split(","))) if "," in v else float(v)
            for k, v in row.items()
        } for row in status.task_pending_counter(reader.read())}

    with CSVDictReader(out_dir / kmeans_clusters) as reader:
        kmeans_cluster_names: list[str] = [_ for _ in reader.fieldnames() if _ != "index"]
        kmeans_cluster_index: dict[int, dict[str, int]] = {int(row["index"]): {k: int(v) for k, v in row.items()}
                                                           for row in status.task_pending_counter(reader.read())}

    status.finish_progress()
    status.start_progress("Writing playlists", 2 + len(kmeans_cluster_names) + 2)

    with M3UWriter(out_dir / "pca-1d.m3u") as writer:
        writer.write_all(f.file.filename for f in sorted(rows, key=lambda _: (clusters[_.file.index]["pca10"], _.file.filename)))
    status.task_done()

    with M3UWriter(out_dir / "enc-1d.m3u") as writer:
        writer.write_all(f.file.filename for f in sorted(rows, key=lambda _: (clusters[_.file.index]["enc10"], _.file.filename)))
    status.task_done()

    for cluster_name in kmeans_cluster_names:
        with M3UWriter(out_dir / f"{cluster_name}.m3u") as writer:
            last_cluster: int = -1
            curr_cluster_stats = defaultdict(int)
            if "_pca_" in cluster_name:
                srt = lambda _: (kmeans_cluster_index[_.file.index][cluster_name], clusters[_.file.index]["pca10"], _.file.filename)
            elif "_enc_" in cluster_name:
                srt = lambda _: (kmeans_cluster_index[_.file.index][cluster_name], clusters[_.file.index]["enc10"], _.file.filename)
            else:
                srt = lambda _: (kmeans_cluster_index[_.file.index][cluster_name], _.file.filename)

            for f in sorted(rows, key=srt):
                curr_cluster: int = kmeans_cluster_index[f.file.index][cluster_name]
                if curr_cluster != last_cluster:
                    if curr_cluster_stats:
                        writer.write("-> " + ", ".join(f"{g or 'Unknown'}: {c}" for g, c in sorted(curr_cluster_stats.items(), key=lambda _: _[1], reverse=True)))
                    writer.write(str(curr_cluster + 1))
                    last_cluster = curr_cluster
                    curr_cluster_stats.clear()
                writer.write(f.file.filename)
                curr_cluster_stats[f.tags.genre] += 1
            if curr_cluster_stats:
                writer.write("-> " + ", ".join(f"{g or 'Unknown'}: {c}" for g, c in sorted(curr_cluster_stats.items(), key=lambda _: _[1], reverse=True)))
        status.task_done()

    for coordinates in ["pca05", "enc05"]:
        with CSVWriter(out_dir / f"coordinates-{coordinates}.m3u", delimiter=",") as w:
            for i, c in clusters.items():
                w.write([file_map[i]] + c[coordinates])
        status.task_done()


def main() -> int:
    parser = argparse.ArgumentParser()
    args = add_args(parser).parse_args()
    run(run_main, **vars(args))
    return 0


if __name__ == "__main__":
    sys.exit(main())