from .api.message import ConfigValueError, SttMessageType
from .api.model import AudioSource, AudioSink, RecognizerModel, ProcessorModel, SynthesizerModel, Recognizer, \
    Synthesizer, SpeechSegmenter, SentenceSegmenter, OutputFilter
from .config import Config
from .utils.utils import StringUtils


class FactoryError(Exception):
    def __init__(self, ty: str, value: str, e: BaseException) -> None:
        super().__init__(f"Cannot create {ty} '{value}': {e.__class__.__name__}: {str(e)}")


class Factory:
    def __init__(self, config: Config) -> None:
        self._config: Config = config

    def create_recognizer(self) -> RecognizerModel:
        try:
            if self._config.recognizer == "sphinx":
                from .recognizer.sphinx import SphinxRecognizer
                return SphinxRecognizer(keywords=StringUtils.get_enum_map(SttMessageType, self._config.keywords),
                                        **self._config.sphinx_recognizer)
            elif self._config.recognizer == "vosk":
                from .recognizer.vosk import VoskRecognizer
                return VoskRecognizer(keywords=StringUtils.get_enum_map(SttMessageType, self._config.keywords),
                                      **self._config.vosk_recognizer)
            elif self._config.recognizer == "whisper":
                from .recognizer.whisper import WhisperRecognizer
                return WhisperRecognizer(keywords=StringUtils.get_enum_map(SttMessageType, self._config.keywords),
                                         **self._config.whisper_recognizer)
            else:
                raise ConfigValueError(self._config.recognizer)
        except (ImportError, OSError) as e:
            raise FactoryError(RecognizerModel.__name__, self._config.recognizer, e) from e

    def create_processor(self) -> ProcessorModel:
        try:
            if self._config.processor == "noop":
                from .processor.noop import NoopProcessorModel
                return NoopProcessorModel()
            elif self._config.processor == "ollama":
                from .processor.ollama import OllamaProcessorModel
                return OllamaProcessorModel(**self._config.ollama_processor)
            elif self._config.processor == "gpt4all":
                from .processor.gpt4all import Gpt4AllProcessorModel
                return Gpt4AllProcessorModel(**self._config.gpt4all_processor)
            else:
                raise ConfigValueError(self._config.processor)
        except (ImportError, OSError) as e:
            raise FactoryError(ProcessorModel.__name__, self._config.processor, e) from e

    def create_synthesizer(self) -> SynthesizerModel:
        try:
            if self._config.synthesizer == "espeak":
                from .synthesizer.espeak import EspeakSynthesizer
                return EspeakSynthesizer(**self._config.espeak_synthesizer)
            elif self._config.synthesizer == "coqui":
                from .synthesizer.coqui import CoquiSynthesizerModel
                return CoquiSynthesizerModel(**self._config.coqui_synthesizer)
            else:
                raise ConfigValueError(self._config.synthesizer)
        except (ImportError, OSError) as e:
            raise FactoryError(SynthesizerModel.__name__, self._config.synthesizer, e) from e

    def create_source_for(self, recognizer: Recognizer) -> AudioSource:
        try:
            if self._config.source == "alsa":
                from .audio.alsa import AlsaRecorder
                return AlsaRecorder(sample_rate=recognizer.sample_rate(), **self._config.alsa_source)
            elif self._config.source == "pulse":
                from .audio.pulse import PulseRecorder
                return PulseRecorder(sample_rate=recognizer.sample_rate(), **self._config.pulse_source)
            elif self._config.source == "pyaudio":
                from .audio.pyaudio import PyAudioRecorder
                return PyAudioRecorder(sample_rate=recognizer.sample_rate(), **self._config.pyaudio_source)
            elif self._config.source == "wave":
                from .audio.wave import WavePlayer
                return WavePlayer(sample_rate=recognizer.sample_rate(), **self._config.wave_source)
            else:
                raise ConfigValueError(self._config.source)
        except (ImportError, OSError) as e:
            raise FactoryError(AudioSource.__name__, self._config.source, e) from e

    def create_sink_for(self, synthesizer: Synthesizer) -> AudioSink:
        try:
            if self._config.sink == "alsa":
                from .audio.alsa import AlsaPlayer
                return AlsaPlayer(sample_rate=synthesizer.sample_rate(), **self._config.alsa_sink)
            elif self._config.sink == "pulse":
                from .audio.pulse import PulsePlayer
                return PulsePlayer(sample_rate=synthesizer.sample_rate(), **self._config.pulse_sink)
            elif self._config.sink == "pyaudio":
                from .audio.pyaudio import PyAudioPlayer
                return PyAudioPlayer(sample_rate=synthesizer.sample_rate(), **self._config.pyaudio_sink)
            elif self._config.sink == "wave":
                from .audio.wave import WaveRecorder
                return WaveRecorder(sample_rate=synthesizer.sample_rate(), **self._config.wave_sink)
            else:
                raise ConfigValueError(self._config.sink)
        except (ImportError, OSError) as e:
            raise FactoryError(AudioSink.__name__, self._config.sink, e) from e

    def create_output_filter_for(self, synthesizer: Synthesizer) -> OutputFilter:
        try:
            if self._config.feedback == "noop":
                from .synthesizer.feedback.noop import NoopFeedbackGenerator
                return NoopFeedbackGenerator()
            elif self._config.feedback == "speech":
                from .synthesizer.feedback.speech import SpeechFeedbackGenerator
                return SpeechFeedbackGenerator()
            elif self._config.feedback == "beep":
                from .synthesizer.feedback.beep import BeepFeedbackGenerator
                return BeepFeedbackGenerator(sample_rate=synthesizer.sample_rate(), **self._config.beep_feedback)
            else:
                raise ConfigValueError(self._config.feedback)
        except (ImportError, OSError) as e:
            raise FactoryError(OutputFilter.__name__, self._config.feedback, e) from e

    def create_speech_segmenter_for(self, recognizer: Recognizer) -> SpeechSegmenter:
        try:
            if self._config.speech_segmenter == "simple":
                from .recognizer.speech.median import SimpleSegmenter
                return SimpleSegmenter(sample_rate=recognizer.sample_rate(),
                                       buffer_limit=self._config.speech_buffer_limit,
                                       pause_limit=self._config.speech_pause_limit,
                                       **self._config.simple_speech_segmenter)
            elif self._config.speech_segmenter == "median":
                from .recognizer.speech.median import MedianSegmenter
                return MedianSegmenter(sample_rate=recognizer.sample_rate(),
                                       buffer_limit=self._config.speech_buffer_limit,
                                       pause_limit=self._config.speech_pause_limit,
                                       **self._config.median_speech_segmenter)
            elif self._config.speech_segmenter == "band":
                from .recognizer.speech.rosa import BandSegmenter
                return BandSegmenter(sample_rate=recognizer.sample_rate(),
                                     buffer_limit=self._config.speech_buffer_limit,
                                     pause_limit=self._config.speech_pause_limit,
                                     **self._config.band_speech_segmenter)
            elif self._config.speech_segmenter == "sphinx":
                from .recognizer.speech.sphinx import SphinxSegmenter
                return SphinxSegmenter(sample_rate=recognizer.sample_rate(),
                                       buffer_limit=self._config.speech_buffer_limit,
                                       pause_limit=self._config.speech_pause_limit,
                                       **self._config.sphinx_speech_segmenter)
            else:
                raise ConfigValueError(self._config.speech_segmenter)
        except (ImportError, OSError) as e:
            raise FactoryError(SpeechSegmenter.__name__, self._config.speech_segmenter, e) from e

    def create_sentence_segmenter_for(self, synthesizer: Synthesizer) -> SentenceSegmenter:
        try:
            if self._config.sentence_segmenter == "split":
                from .synthesizer.sentence.segmenter import SentenceSplitSegmenter
                return SentenceSplitSegmenter(**self._config.split_sentence_segmenter)
            elif self._config.sentence_segmenter == "sbd":
                from .synthesizer.sentence.sbd import SentenceBoundarySegmenter
                return SentenceBoundarySegmenter(**self._config.sbd_sentence_segmenter)
            else:
                raise ConfigValueError(self._config.sentence_segmenter)
        except (ImportError, OSError) as e:
            raise FactoryError(SentenceSegmenter.__name__, self._config.sentence_segmenter, e) from e