import argparse
import csv
from pathlib import Path
from typing import ContextManager, Iterator, Iterable, TextIO, TypeVar, Any

T = TypeVar("T")


def add_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
    parser.add_argument("--extract-in-playlist", type=Path, required=True, metavar="M3U",
                        help="main input of files to process")
    parser.add_argument("--out-dir", type=Path, required=True, metavar="DIR",
                        help="output directory for various files")
    parser.add_argument("--debug", action='store_true', default=False,
                        help="enable debug logging, extra plotting, and additional checks")
    parser.add_argument("--concurrency", type=int, default=1, metavar="NUM",
                        help="number of threads for audio processing")
    parser.add_argument("--extract-cache", type=lambda _: bool(int(_)), default=None, metavar="0|1",
                        help="force re-usage of (lengthy) audio feature extraction")
    parser.add_argument("--extract-limit", type=int, default=None, metavar="NUM",
                        help="process at most this number of input audio files")
    parser.add_argument("--epoch-limit", type=int, default=2000, metavar="NUM",
                        help="maximum number of autoencoder epochs, might stop earlier")
    parser.add_argument("--genre-bias", type=float, default=0.0, metavar="NUM",
                        help="use genre embeddings for encoding and guided hierarchical clustering")
    parser.add_argument("--extract-features", type=Path, default="audio-features.csv", metavar="CSV")
    parser.add_argument("--scaled-features", type=Path, default="cluster-features-scaled.csv", metavar="CSV")
    parser.add_argument("--normed-features", type=Path, default="cluster-features-normed.csv", metavar="CSV")
    parser.add_argument("--cluster-features", type=Path, default="cluster-features.csv", metavar="CSV")
    parser.add_argument("--pca-embeddings", type=Path, default="pca-embeddings.csv", metavar="CSV")
    parser.add_argument("--pca-stats", type=Path, default="pca-stats.csv", metavar="CSV")
    parser.add_argument("--enc-stats", type=Path, default="enc-stats.csv", metavar="CSV")
    parser.add_argument("--enc-graph", type=Path, default="enc-graph.dot", metavar="DOT")
    parser.add_argument("--kmeans-stats", type=Path, default="kmeans-stats.csv", metavar="CSV")
    parser.add_argument("--kmeans-clusters", type=Path, default="kmeans-clusters.csv", metavar="CSV")
    return parser


class _FileContext(ContextManager):
    def __init__(self, filename: str | Path, mode: str) -> None:
        self._filename: Path = Path(filename)
        self._mode: str = mode
        self._fp: TextIO | None = None

    def __enter__(self: T) -> T:
        assert self._fp is None
        self._fp = self._filename.open(self._mode, newline="").__enter__()
        return self

    def __exit__(self, *args: Any) -> None:
        assert self._fp is not None
        self._fp.close()
        self._fp = None


class CSVReader(_FileContext):
    def __init__(self, filename: str | Path) -> None:
        super().__init__(filename, "r")
        self._reader: Iterable[str] | None = None

    def __enter__(self: T) -> T:
        super().__enter__()
        assert self._reader is None
        self._reader = csv.reader(self._fp, delimiter=";", quoting=0)
        return self

    def __exit__(self, *args) -> None:
        assert self._reader is not None
        self._reader = None
        super().__exit__(*args)

    def read(self) -> Iterator[list[str]]:
        assert self._reader is not None
        yield from self._reader


class CSVDictReader(_FileContext):
    def __init__(self, filename: str | Path) -> None:
        super().__init__(filename, "r")
        self._reader: csv.DictReader[str] | None = None

    def __enter__(self: T) -> T:
        super().__enter__()
        assert self._reader is None
        self._reader = csv.DictReader(self._fp, fieldnames=None, delimiter=";", quoting=0)
        return self

    def __exit__(self, *args) -> None:
        assert self._reader is not None
        self._reader = None
        super().__exit__(*args)

    def fieldnames(self) -> list[str]:
        assert self._reader is not None
        return list(self._reader.fieldnames) or []

    def read(self) -> Iterator[dict[str, Any]]:
        assert self._reader is not None
        yield from self._reader


class CSVWriter(_FileContext):
    def __init__(self, filename: str | Path, delimiter: str = ";") -> None:
        super().__init__(filename, "w")
        self._delimiter: str = delimiter
        self._writer = None

    def __enter__(self: T) -> T:
        super().__enter__()
        assert self._writer is None
        self._writer = csv.writer(self._fp, delimiter=self._delimiter, quoting=0)
        return self

    def __exit__(self, *args) -> None:
        assert self._writer is not None
        self._writer = None
        super().__exit__(*args)

    def write(self, data: Iterable[Any]) -> None:
        assert self._writer is not None
        self._writer.writerow(data)

    def write_all(self, data: Iterable[Iterable[Any]]) -> None:
        assert self._writer is not None
        self._writer.writerows(data)


class CSVDictWriter(_FileContext):
    def __init__(self, filename: str | Path, fieldnames: list[str]) -> None:
        super().__init__(filename, "w")
        self._fieldnames: list[str] = fieldnames
        self._writer: csv.DictWriter | None = None

    def __enter__(self: T) -> T:
        super().__enter__()
        assert self._writer is None
        self._writer = csv.DictWriter(self._fp, fieldnames=self._fieldnames, delimiter=";", quoting=0)
        self._writer.writeheader()
        return self

    def __exit__(self, *args) -> None:
        assert self._writer is not None
        self._writer = None
        super().__exit__(*args)

    def fieldnames(self) -> list[str]:
        return self._fieldnames

    def write(self, data: dict[str, Any]) -> None:
        assert self._writer is not None
        self._writer.writerow(data)

    def write_all(self, data: Iterable[dict[str, Any]]) -> None:
        assert self._writer is not None
        self._writer.writerows(data)


class M3UReader(_FileContext):
    def __init__(self, filename: str | Path) -> None:
        super().__init__(filename, "r")

    def read(self) -> Iterator[Path]:
        assert self._fp is not None
        for line in self._fp:
            line = line.strip()
            if line and not line.startswith("#"):
                yield Path(line)

    def read_limited(self, limit: int | None) -> Iterator[tuple[int, Path]]:
        idx: int = 0
        for line in self.read():
            if limit is not None and idx >= limit:
                break
            yield idx, line
            idx += 1


class M3UWriter(_FileContext):
    def __init__(self, filename: str | Path) -> None:
        super().__init__(filename, "w")

    def write(self, line: Path | str) -> None:
        assert self._fp is not None
        if isinstance(line, Path):
            self._fp.write(f"{line.as_posix()}\n")
        else:
            self._fp.write(f"# {line}\n")

    def write_all(self, lines: Iterable[Path | str]) -> None:
        for _ in lines:
            self.write(_)