import os
from dataclasses import dataclass, fields
from json import JSONEncoder, JSONDecoder
from pathlib import Path
from typing import TypeVar, ClassVar, Type

T = TypeVar('T')


@dataclass
class _FeatureSet:
    _encoder: ClassVar[JSONEncoder] = JSONEncoder(ensure_ascii=False, indent=None, separators=(',', ':'))
    _decoder: ClassVar[JSONDecoder] = JSONDecoder()

    def to_dict(self) -> dict[str, str | int | float]:
        return {f.name: getattr(self, f.name) if f.type in (str, int, float, Path) else self._encoder.encode(getattr(self, f.name))
                for f in fields(self)}

    @classmethod
    def from_dict(cls: Type[T], dct: dict[str, str | int | float]) -> T:
        return cls(**{f.name: f.type(dct[f.name] if f.type in (str, int, float, Path) else cls._decoder.decode(dct[f.name]))
                      for f in fields(cls)})


@dataclass
class FileFeatures(_FeatureSet):
    index: int
    filename: Path


@dataclass
class TagFeatures(_FeatureSet):
    artist: str
    genre: str


@dataclass
class AudioFeatures(_FeatureSet):
    rate: float
    length: float
    tempogram: tuple[float, float, float]
    centroid: tuple[float, float, float]
    bandwidth: tuple[float, float, float]
    flatness: tuple[float, float, float]
    crossing: tuple[float, float, float]
    rollon: tuple[float, float, float]
    rolloff: tuple[float, float, float]
    hpss: tuple[float, float, float]
    spectrum: tuple[list[float], list[float], list[float]]


@dataclass
class Features:
    file: FileFeatures
    tags: TagFeatures
    audio: AudioFeatures

    @classmethod
    def fieldnames(cls) -> list[str]:
        return [f.name for _ in [FileFeatures, TagFeatures, AudioFeatures] for f in fields(_)]

    @classmethod
    def from_dict(cls, dct: dict[str, str | int | float]) -> 'Features':
        return cls(file=FileFeatures.from_dict(dct),
                   tags=TagFeatures.from_dict(dct),
                   audio=AudioFeatures.from_dict(dct))

    def to_dict(self) -> dict[str, str | int | float]:
        return self.file.to_dict() | self.tags.to_dict() | self.audio.to_dict()

    def to_vector(self) -> list[float]:
        """Pick the values actually used for encoding."""
        return [
            self.audio.centroid[1],  # low -> hi
            self.audio.bandwidth[0],  # narrow -> wide
            -self.audio.bandwidth[2],  # changing -> constant/narrow/wide
            self.audio.flatness[0],  # narrow -> noisy
            self.audio.crossing[0],  # constant -> crossings
            self.audio.crossing[2],  # constant crossings -> varying
            self.audio.rollon[0],  # low -> hi
            self.audio.rolloff[1],  # low -> hi
            self.audio.hpss[1],  # harm -> perc
            -self.audio.tempogram[1],  # flat -> varying
            -self.audio.spectrum[1][0],  # low/loud -> flat
            -self.audio.spectrum[1][1],  # low/loud -> flat
            self.audio.spectrum[1][2],  # loud -> mid/flat
            self.audio.spectrum[1][3],  # loud -> mid/flat
            self.audio.spectrum[1][4],  # low -> flat
            self.audio.spectrum[1][5],  # low -> flat
        ]

    def __str__(self) -> str:
        return os.path.basename(self.file.filename)