#!/usr/bin/env python3

"""
Quick prototype for translating PDF and ODT documents using an Ollama chat API.
"""

from abc import ABC, abstractmethod
from pathlib import Path
from typing import Callable, Iterator
import argparse
import colorlog
import html
import json
import logging
import odf
import odf.element
import pymupdf
import re
import requests
import sys


class Transformer(ABC):
    @abstractmethod
    def transform(self, callback: Callable[[str], str | None]) -> None:
        raise NotImplementedError

    @abstractmethod
    def save(self, outfile: Path) -> None:
        raise NotImplementedError


class PdfTransformer(Transformer):
    """
    https://github.com/pymupdf/PyMuPDF-Utilities/blob/tutorials/tutorials/language-translation/translator.py
    """

    def __init__(self, infile: Path) -> None:
        self._doc: pymupdf.Document = pymupdf.Document(infile)
        self._ocg = self._doc.add_ocg("Translation", on=True)  # named optional content group/layer

    def transform(self, callback: Callable[[str], str | None]) -> None:
        for page in self._doc:  # type: pymupdf.Page
            for block in page.get_text("blocks", flags=pymupdf.TEXT_DEHYPHENATE):
                if (translation := callback(block[4])) is not None:
                    bbox = block[:4]
                    page.draw_rect(bbox, color=None, fill=pymupdf.pdfcolor["white"], oc=self._ocg)
                    page.insert_htmlbox(bbox, html.escape(translation).replace("\n", "<br>"), oc=self._ocg)

    def save(self, outfile: Path) -> None:
        self._doc.subset_fonts()
        self._doc.save(outfile)


class OdfTransformer(Transformer):
    """
    https://github.com/eea/odfpy/wiki/ReplaceOneTextToAnother
    """

    def __init__(self, infile: Path) -> None:
        self._doc: odf.opendocument.OpenDocument = odf.opendocument.load(infile)

    @classmethod
    def _saxiter(cls, node: odf.element.Element) -> Iterator[odf.element.Element]:
        while node:
            yield node
            if node.hasChildNodes():
                yield from cls._saxiter(node.firstChild)
            node = node.nextSibling

    def transform(self, callback: Callable[[str], str | None]) -> None:
        for elem in self._saxiter(self._doc.topnode):
            if elem.__class__ is odf.element.Text:
                if isinstance(elem.data, str):
                    if (translation := callback(elem.data)) is not None:
                        elem.data = translation

    def save(self, outfile: Path) -> None:
        self._doc.save(outfile)


class TranslationHistory:
    """
    For chat-context, preserve a certain amount of past original and translation i/o as FIFO.
    """

    def __init__(self, chunk_size: int) -> None:
        self._chunk_size: int = chunk_size
        self._text: list[str] = []
        self._translation: list[str] = []

    def _len(self) -> int:
        return sum(len(_) for _ in self._text) + sum(len(_) for _ in self._translation)

    def _maintain(self) -> None:
        assert len(self._text) == len(self._translation)
        while self._len() > self._chunk_size and len(self._text) > 1:
            self._text.pop(0)
            self._translation.pop(0)

    def push(self, text: str, translation: str) -> None:
        self._text.append(text)
        self._translation.append(translation)

    def get(self) -> tuple[str, str]:
        self._maintain()
        return "\n".join(self._text), "\n".join(self._translation)


class OllamaClient:
    """
    Simple requests-based Ollama client for chat with 'translator' prompt and past history.
    """

    def __init__(self,
                 api_url: str, model_name: str,
                 source_lang: str, target_lang: str, context_title: str,
                 history_len: int):
        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)

        self._model_name: str = model_name
        self._api_url: str = api_url
        self._session: requests.Session = requests.Session()
        self._history: TranslationHistory = TranslationHistory(history_len)

        self._context_title: str = context_title
        self._source_lang: str = source_lang
        self._target_lang: str = target_lang

    def _prompt(self) -> str:
        prompt: str = f"You are a professional translator. Translate from {self._source_lang} to {self._target_lang}.\n"
        prompt += "Preserve formatting and line breaks.\n"
        prompt += "Return only the direct translation without comments.\n"
        if self._context_title:
            prompt += f"Document type: {self._context_title}\n"
        return prompt

    def _is_hallucination(self, text: str, translation: str) -> bool:
        return len(translation) > len(text) * 4

    def translate(self, text: str) -> str | None:
        messages: list[dict] = [{"role": "system", "content": self._prompt()}]

        context_text, context_trans = self._history.get()
        if context_text and context_trans:
            messages.append({"role": "user", "content": context_text})
            messages.append({"role": "assistant", "content": context_trans})

        messages.append({"role": "user", "content": text})
        self._logger.debug(messages)

        try:
            response: requests.Response = self._session.post(self._api_url, json={
                "model": self._model_name,
                "messages": messages,
                "stream": False,
            }, timeout=(30, 120))
            response.raise_for_status()

            result: dict = json.loads(response.text)
            if "message" not in result or result["message"]["role"] != "assistant":
                self._logger.warning("No Ollama assistant response")
                return None

            translation: str = result["message"]["content"].strip()
            if self._is_hallucination(text, translation):
                self._logger.warning("Hallucination")
                return None
            self._history.push(text, translation)
            return translation
        except requests.exceptions.Timeout:
            self._logger.error("Ollama timeout")
            return None
        except Exception as e:
            self._logger.error(f"Unexpected Ollama error: {e}", exc_info=e)
            return None


class Chunker:
    """
    Split/join sentences to be translated wrt given chunk size.
    """

    def __init__(self, cb: Callable[[str], str | None], chunk_len: int) -> None:
        self._cb: Callable[[str], str | None] = cb
        self._chunk_len: int = chunk_len

    def _chunk_sentence(self, text: str) -> Iterator[str]:
        """If too long, try to recursively split at sentence boundaries."""
        if len(text) > self._chunk_len:
            delim: list[int] = [_.start() for _ in re.finditer("[\n.¡!!¿??::]", text)]
            if len(delim) > 1:
                pos: int = delim[len(delim) // 2]
                yield from self._chunk_sentence(text[:pos])
                yield from self._chunk_sentence(text[pos:])
                return
        yield text

    def _pad_whitespace_strip(self, text: str) -> str:
        """Strip possibly splitted input, but re-add whitespace again."""
        prefix: str = re.search("^\\s*", text)[0]  # type: ignore
        suffix: str = re.search("\\s*$", text[len(prefix):])[0]  # type: ignore
        return "".join((prefix, self._translate(text[len(prefix):-len(suffix)]).strip(), suffix))

    def _translate(self, text: str) -> str:
        translated: str | None = self._cb(text)
        return translated if translated is not None else text  # XXX:

    def run_iter(self, text: str) -> str | None:
        return "".join(self._pad_whitespace_strip(_) for _ in self._chunk_sentence(text))


class Translator:
    """
    Main interface providing the string translation callable.
    """

    def __init__(self, client: OllamaClient, chunk_len: int) -> None:
        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
        self._chunker: Chunker = Chunker(client.translate, chunk_len)

    def translate(self, orig: str) -> str | None:
        if not orig.strip() or orig.isnumeric():
            return None
        self._logger.info(f"> {json.dumps(orig.strip(), ensure_ascii=False)}")

        translated: str | None = self._chunker.run_iter(orig)
        if translated is None or not translated.strip() or translated == orig:
            return None
        self._logger.info(f"< {json.dumps(translated.strip(), ensure_ascii=False)}")
        return translated


def main(infile: Path, outfile: Path,
         api_url: str, model: str,
         source_lang: str, target_lang: str, title: str,
         history_len: int, chunk_len: int) -> int:

    logger: logging.Logger = logging.getLogger(main.__name__)
    translator: Translator = Translator(OllamaClient(
        api_url=api_url, model_name=model,
        source_lang=source_lang, target_lang=target_lang, context_title=title,
        history_len=history_len
    ), chunk_len=chunk_len)

    if infile.suffix == ".pdf":
        doc: Transformer = PdfTransformer(infile)
    elif infile.suffix == ".odt":
        doc = OdfTransformer(infile)
    else:
        logger.error(f"Input file extension not supported: {infile.suffix}")
        return 1

    try:
        doc.transform(translator.translate)
    except KeyboardInterrupt:
        logger.warning("Interrupt")
        doc.save(outfile)
        return 1
    else:
        doc.save(outfile)
        return 0


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description=__doc__.strip(),
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument("--api", metavar="URL", type=str, default="http://localhost:11434/api/chat",
                        help="ollama endpoint to use")
    parser.add_argument("--model", type=str, required=True, default="llama3.1:latest",
                        help="ollama model to use")

    parser.add_argument("--source-lang", metavar="LANG", type=str, required=True, default="English",
                        help="original input document language")
    parser.add_argument("--target-lang", metavar="LANG", type=str, required=True, default="German",
                        help="desired translated output language")
    parser.add_argument("--title-context", metavar="TITLE", type=str, required=False,
                        help="document title/context as hinted by prompt")

    parser.add_argument("--history-len", metavar="LEN", type=int, default=5000,
                        help="length of past original/translation to replay as context")
    parser.add_argument("--chunk-len", metavar="LEN", type=int, default=5000,
                        help="try to split by sentence boundaries when input length exceeded")

    parser.add_argument("--debug", action="store_const", const=True, default=False,
                        help="enable debug logging")

    parser.add_argument("infile", type=Path)
    parser.add_argument("outfile", type=Path)

    args = parser.parse_args()
    colorlog.basicConfig(level=logging.DEBUG if args.debug else logging.INFO, stream=sys.stderr,
                         format="%(log_color)s%(levelname)-8s%(reset)s %(name)s: %(message)s")

    sys.exit(main(infile=args.infile, outfile=args.outfile,
                  api_url=args.api, model=args.model,
                  source_lang=args.source_lang, target_lang=args.target_lang, title=args.title_context,
                  history_len=args.history_len, chunk_len=args.chunk_len))