#!/usr/bin/env python3

"""
Minimal ollama UI.
Featuring embedded web frontend, response streaming, and conversation history.
No external dependencies or files are required.
"""

import argparse
import json
import os
import signal
import socket
import sys
import threading
import logging
import logging.config

from urllib.parse import ParseResult, urlparse, parse_qs, urljoin
from urllib.request import Request, urlopen
from urllib.error import URLError

from html import escape
from http import HTTPStatus
from http.server import ThreadingHTTPServer, BaseHTTPRequestHandler
from http.client import HTTPResponse

from dataclasses import dataclass, field
from typing import Iterator, Dict, Callable, Iterable, ClassVar, Literal


class ConversationHistory:
    """
    Each frontend instance carries a unique token, to tell concurrent or new sessions apart.
    For each, the message history is tracked can be re-fed alongside new prompts.
    """

    def __init__(self, max_num: int = 24, max_len: int = 32) -> None:
        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
        self._hist: dict[str, list[dict]] = {}
        self._lock: threading.Lock = threading.Lock()
        self._max_num: int = max_num  # number of sessions
        self._max_len: int = max_len  # number of messages # TODO: limit by str/token length instead?

    def push(self, token: str | None, message: dict) -> list[dict]:
        if token is None:
            return [message]

        with self._lock:
            while len(self._hist) > self._max_num:
                old_token: str = next(iter(self._hist.keys()))
                del self._hist[old_token]
                self._logger.debug(f"Removing conversation history for '{token}'")

            if token not in self._hist:
                self._hist[token] = [message]
                self._logger.debug(f"Adding conversation '{token}'")
            elif self._hist[token][-1]["role"] == message["role"]:
                self._hist[token][-1]["content"] += message["content"]
            else:
                self._hist[token].append(message)

            self._hist[token] = self._hist[token][-self._max_len:]
            return self._hist[token]

    def clear(self, token: str | None) -> None:
        with self._lock:
            if token is not None and token in self._hist:
                self._logger.warning(f"Clearing conversation history for '{token}'")
                del self._hist[token]


class OllamaRequestError(Exception):
    pass


class OllamaClient:
    """
    Minimal Ollama HTTP JSON client with streaming support.
    """

    _decoder: ClassVar[json.JSONDecoder] = json.JSONDecoder()
    _encoder: ClassVar[json.JSONEncoder] = json.JSONEncoder(ensure_ascii=False, check_circular=False,
                                                            allow_nan=False, sort_keys=False)

    def __init__(self, base_url: str, model: str, history: ConversationHistory, timeout: float = 30.0) -> None:
        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
        self._base_url: str = base_url
        self._model: str = model
        self._timeout: float = timeout
        self._history: ConversationHistory = history

    def _request(self, method: str, url: str, expect_ct: str, data: bytes | None) -> Iterator[bytes]:
        full_url: str = urljoin(self._base_url, url)
        try:
            response: HTTPResponse
            with urlopen(Request(method=method, url=full_url, data=data), timeout=self._timeout) as response:  # nosec
                if response.status == 404 or response.status == 501:
                    raise OllamaRequestError(f"Response status {response.status}: {full_url}")
                elif response.status != 200:
                    raise OllamaRequestError(f"Response status {response.status}: {full_url}")
                elif (ct := response.headers.get("content-type", "")) != expect_ct:
                    raise OllamaRequestError(f"Unexpected response type '{ct}': {full_url}")
                else:
                    self._logger.info(f"{method} {url} HTTP {response.status}")
                    while (line := response.readline()):  # XXX: suboptimal implementation when .chunked
                        yield line
        except (OSError, URLError) as e:
            raise OllamaRequestError(f"Request failed with {str(e)}: {full_url}") from None

    def _request_json(self, method: str, url: str, data: dict | None = None) -> dict:
        return self._decoder.decode(b"".join(self._request(
            method, url, "application/json; charset=utf-8",
            data=self._encoder.encode(data).encode() if data is not None else None
        )).decode(encoding="utf-8", errors="surrogatepass"))

    def _request_ndjson(self, method: str, url: str, data: dict | None = None) -> Iterator[dict]:
        body: bytes | None = self._encoder.encode(data).encode() if data is not None else None
        for line in self._request(method, url, "application/x-ndjson", data=body):
            yield self._decoder.decode(line.decode(encoding="utf-8", errors="surrogatepass"))

    def check(self, token: str | None) -> str:
        """Start new sessions by checking connectivity and whether the model exists."""
        models: list[str] = [_["name"] for _ in self._request_json("GET", "/api/tags")["models"]]
        if self._model not in models:
            raise OllamaRequestError(f"Model '{self._model}' not found in: {', '.join(models)}")
        self._history.clear(token)
        return self._model

    def generate(self, token: str | None, prompt: str) -> Iterator[str]:
        query: dict = {
            "model": self._model,
            "messages": self._history.push(token, {
                "role": "user",
                "content": prompt,
            })
        }
        tokens: int = 0
        for chunk in self._request_ndjson("POST", "/api/chat", query):
            if "message" in chunk:
                tokens += 3
                self._history.push(token, chunk["message"])
                yield chunk["message"]["content"]
        self._logger.debug(f"Read {tokens} message chunks/tokens")


@dataclass
class HttpRequest:
    method: str
    path: str
    query: dict[str, str]
    body: bytes = b""


@dataclass
class HttpResponse:
    code: int
    headers: Dict[str, str] = field(default_factory=dict)
    body: bytes | Iterable[bytes] = b""


class OllamaUiHTTPServer(ThreadingHTTPServer):
    """
    HTTP server with systemd socket support that delegates to external Handler class.
    """

    class RequestHandler(BaseHTTPRequestHandler):
        """
        Read and stream back bodies for accepted requests. Log to Logger.
        """

        server: 'OllamaUiHTTPServer'
        protocol_version = "HTTP/1.1"
        _response_mode: Literal['chunked', 'stream', 'full'] = "stream"

        def do_GET(self) -> None:
            self._handle("GET")

        def do_POST(self) -> None:
            self._handle("POST")

        def _handle(self, method: str) -> None:
            try:
                try:
                    content_length: int = int(self.headers.get("Content-Length", "0"))
                    post_data: bytes = self.rfile.read(content_length)
                except (ValueError, OSError) as e:
                    self.server.logger.warning(f"Cannot read request body: {str(e)}")
                    response = HttpResponse(400, {"X-Exception": e.__class__.__name__})
                else:
                    response = self.server.handle(method, self.path, post_data)

                self.send_response(response.code)
                self.send_header("Cache-Control", "no-cache")
                for header, value in response.headers.items():
                    self.send_header(header.title(), value)

                if isinstance(response.body, bytes):
                    self.send_header("Content-Length", str(len(response.body)))
                    self.end_headers()
                    self.wfile.write(response.body)
                elif self._response_mode == "chunked":  # exclicit transfer encoding
                    self.send_header("Transfer-Encoding", "chunked")
                    self.end_headers()
                    for chunk in response.body:
                        if chunk:
                            self.wfile.write(b"".join(("{:X}\r\n".format(len(chunk)).encode(), chunk, b"\r\n")))
                    self.wfile.write(b"0\r\n\r\n")
                elif self._response_mode == "stream":  # should be sufficient for local connections
                    self.send_header("Connection", "close")
                    self.end_headers()
                    for chunk in response.body:
                        self.wfile.write(chunk)
                else:  # wait for complete response for max compatibility
                    self.end_headers()
                    self.wfile.write(b"".join(response.body))
            except OSError:
                pass

        def log_request(self, code='-', size='-') -> None:
            if isinstance(code, HTTPStatus):
                code = code.value
            self._log(logging.DEBUG if code == 200 else logging.INFO,
                      '"%s" %s %s',
                      self.requestline, str(code), str(size))

        def log_message(self, format, *args) -> None:
            self._log(logging.WARNING, format, *args)

        def _log(self, level: int, format, *args) -> None:
            self.server.logger.log(level, "%s - - [%s] %s",
                                   self.address_string(), self.log_date_time_string(), format % args)

    def __init__(self,
                 server_address: tuple[str, int], systemd: bool,
                 handler: Callable[[HttpRequest], HttpResponse]) -> None:
        self.logger: logging.Logger = logging.getLogger(self.__class__.__name__)
        self._systemd: bool = systemd and os.getenv("LISTEN_FDS", "") == "1"
        self._handler: Callable[[HttpRequest], HttpResponse] = handler
        super().__init__(server_address, self.RequestHandler, bind_and_activate=True)

    def server_bind(self) -> None:
        if not self._systemd:
            super().server_bind()
        else:
            self.socket.close()
            self.socket = socket.fromfd(3, self.address_family, self.socket_type)  # SD_LISTEN_FDS_START
            self.server_address = self.socket.getsockname()
            self.logger.info("Obtained systemd socket")

    def server_activate(self) -> None:
        if not self._systemd:
            super().server_activate()
        else:
            self._sd_notify(self._sd_notify_path(), b"READY=1\n")

    def _sd_notify_path(self) -> str | None:
        notify_path: str | None = os.getenv("NOTIFY_SOCKET") or None
        if notify_path is not None and notify_path.startswith("@"):  # abstract namespace socket
            notify_path = "\0" + notify_path[1:]
        if notify_path is None:
            self.logger.warning("Cannot get NOTIFY_SOCKET")
        return notify_path

    def _sd_notify(self, notify_path: str | None, msg: bytes) -> None:
        if notify_path is None:
            return
        else:
            sock: socket.socket = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
            sock.setblocking(False)
        try:
            sock.connect(notify_path)
            sock.sendall(msg)
        except (OSError, UnicodeError) as e:
            self.logger.warning(f"Cannot write to NOTIFY_SOCKET: {str(e)}")
        else:
            self.logger.debug("Sent READY status")
        finally:
            sock.close()

    def handle(self, method: str, path: str, body: bytes) -> HttpResponse:
        """RequestHandler callback, calling external Handler class."""
        try:
            parsed: ParseResult = urlparse(path)
            path = parsed.path
            query: dict[str, str] = {k: v[0] for k, v in parse_qs(parsed.query).items()}
        except ValueError as e:
            self.logger.warning(f"Cannot parse request: {str(e)}")
            return HttpResponse(400, {"X-Exception": e.__class__.__name__})
        try:
            return self._handler(HttpRequest(method, path, query, body))
        except Exception as e:
            self.logger.error(f"Cannot handle request: {str(e)}")
            return HttpResponse(500, {"X-Exception": e.__class__.__name__})

    def serve(self) -> bool:
        """Run until signal."""
        shutdown_requested: threading.Event = threading.Event()

        def _handler(signum: int, frame) -> None:
            shutdown_requested.set()

        signal.signal(signal.SIGINT, _handler)
        signal.signal(signal.SIGTERM, _handler)

        thread: threading.Thread = threading.Thread(target=self.serve_forever)
        thread.start()

        self.logger.info("Serving on {}:{}".format(*self.server_address))
        shutdown_requested.wait()
        self.logger.info("Shutting down")
        self.shutdown()
        self.server_close()
        thread.join()
        return True

    @classmethod
    def run(cls, localhost: bool, port: int, systemd: bool, handler: Callable[[HttpRequest], HttpResponse]) -> bool:
        try:
            httpd: OllamaUiHTTPServer = OllamaUiHTTPServer(
                server_address=("127.0.0.1" if localhost else "0.0.0.0", port),  # nosec
                systemd=systemd,
                handler=handler,
            )
        except Exception as e:
            raise RuntimeError(str(e)) from None
        else:
            return httpd.serve()


class TokenBuffer:
    """
    Transform input tokens to HTTP output stream.
    With more buffering involved than 'words', could already try to interpret markdown for innerHTML.
    """

    @classmethod
    def translate(cls, slop: Iterator[str]) -> Iterator[bytes]:
        buffer: str = ""
        for token in slop:
            buffer += token
            bound: int = max(buffer.rfind(" "), buffer.rfind("\n"))
            if bound > 3:
                token, buffer = buffer[:bound], buffer[bound:]
                yield token.encode()
        yield buffer.rstrip().encode()


class Handlers:
    """Routing of requests, handling actual functionality."""

    def __init__(self, client: OllamaClient) -> None:
        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
        self._client: OllamaClient = client

    def handle(self, r: HttpRequest) -> HttpResponse:
        try:
            if r.path == "/":
                return HttpResponse(
                    code=200,
                    headers={"Content-Type": "text/html; charset=utf-8"},
                    body=_INDEX_HTML,
                )
            elif r.path == "/check":
                return HttpResponse(
                    code=200,
                    headers={"Content-Type": "text/plain; charset=utf-8"},
                    body=b"`" + escape(self._client.check(r.query.get("t"))).encode() + b"`",
                )
            elif r.path == "/q":  # disregard /?q= as index
                return HttpResponse(
                    code=200,
                    headers={"Content-Type": "text/plain; charset=utf-8"},
                    body=TokenBuffer.translate(self._client.generate(r.query.get("t"), r.body.decode())),
                )
            else:
                return HttpResponse(404)
        except OllamaRequestError as e:
            self._logger.error(str(e))
            return HttpResponse(502, {"X-Exception": e.__class__.__name__})


def _main(localhost: bool, port: int, systemd: bool, ollama_url: str, model: str) -> int:
    try:
        handlers: Handlers = Handlers(OllamaClient(ollama_url, model, ConversationHistory()))
        return 0 if OllamaUiHTTPServer.run(localhost, port, systemd, handlers.handle) else 1
    except RuntimeError as e:
        logging.getLogger(None).error(str(e))
        return 1


def _setup_logging(debug: bool) -> None:
    logging.raiseExceptions = False
    logging.logThreads = True
    logging.logMultiprocessing = False
    logging.logProcesses = False
    logging.config.dictConfig({
        'version': 1,
        'formatters': {'standard': {
            'format': '%(levelname)s %(name)s: %(message)s',
        }},
        'handlers': {'default': {
            'formatter': 'standard',
            'class': 'logging.StreamHandler',
            'stream': 'ext://sys.stderr',
        }},
        'loggers': {'': {
            'handlers': ['default'],
            'level': 'DEBUG' if debug else 'INFO',
            'propagate': False,
        }},
    })


def main() -> int:
    parser = argparse.ArgumentParser(description=__doc__.strip(),
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--verbose', action='store_const', const=True, default=False,
                        help='enable debug logging')
    parser.add_argument('--ollama-url', default='http://127.0.0.1:11434/',
                        help='ollama API base url', metavar="URL")
    parser.add_argument('--model', default='llama3.1:latest',
                        help='ollama model to use', metavar="NAME")
    parser.add_argument('--localhost', action='store_const', const=True, default=False,
                        help='bind to localhost only')
    parser.add_argument('--port', type=int, default=8080,
                        help='port to bind to')
    parser.add_argument('--systemd', action='store_const', const=True, default=False,
                        help='use inherited socket for systemd activation')
    args = parser.parse_args()

    _setup_logging(args.verbose)
    return _main(args.localhost, args.port, args.systemd, args.ollama_url, args.model)


# language=JS
_INDEX_JS: str = r"""
"use strict";

const conversation_token = (
  Math.random().toString(16).substring(2) +
  Math.random().toString(16).substring(2)
).substring(0, 12);

function sloppy_markdown(text) {
  text = text
    .replace(/\r\n/g, "\n").replace(/\r/g, "\n")
    .replace(/&/g, "&")
    .replace(/</g, "&lt;").replace(/>/g, "&gt;")
    .replace(/"/g, "&quot;").replace(/'/g, "&#039;");
  text = text.replace(/^---+$/gm, "<hr>");
  text = text.replace(/^##+([^#\n]+)#*$/gm, "<h3>$1</h3>");
  text = text.replace(/^[ \t]*[*+-]+ +(.*)$/gm, "<ul><li>$1</li></ul>");
  text = text.replace(/^[ \t]*([0-9]+[.)] +.*)$/gm, "\n<ul><li>$1</li></ul>\n");
  text = text.replace(/\n+```+([^\n]+)?\n*([^`]+)\n+```\n+/g, "\n\n<pre><code title=\"$1\">$2</code></pre>\n\n");
  text = text.replace(/\*\*+([^*\n]+)\*\*+/g, "<strong>$1</strong>");
  text = text.replace(/`+([^`\n]+)`+/g, "<code>$1</code>");
  text = text.replace(/([^>\n])\n([^<\n])/g, "$1<br>$2");
  text = text.replace(/([^>\n])\n+([^<\n])/g, "$1<br><br>$2");
  return text;
}

function question_cleanup(text) {
  return sloppy_markdown(text)
    .replace(/&lt;([a-zA-Z-]+)&gt;/g, "<em title=\"$1\">")
    .replace(/&lt;\/[a-zA-Z-]+&gt;/g, "</em>");
}

function add_bubble(cls, html) {
  const msg = document.createElement("div");
  msg.classList.add(cls);
  msg.innerHTML = html;
  document.getElementById("convo").appendChild(msg);
  msg.scrollIntoView();
  return msg;
}

function add_question(text) {
  return add_bubble("question", question_cleanup(text));
}

function add_answer(text) {
  return add_bubble("answer", sloppy_markdown(text));
}

async function add_answer_stream(readable_stream) {
  const msg = add_answer("");
  const decoder = new TextDecoder("utf-8", {fatal: true});
  let buffer = "";
  for await (const chunk of readable_stream) {
    try {
      buffer += decoder.decode(chunk, {stream: false});
    } catch (TypeError) {
      buffer += decoder.decode(chunk, {stream: true});
    }
    msg.innerHTML = sloppy_markdown(buffer);
  }
  buffer += decoder.decode(new Uint8Array(), {stream: false});
  msg.innerHTML = sloppy_markdown(buffer);
  msg.scrollIntoView();
  return msg;
}

function set_busy(is_busy) {
  document.getElementsByTagName("fieldset")[0].disabled = !!is_busy;
  document.getElementById("spinner").style.opacity = is_busy? 1: 0;
}

function send_query(form, query=null) {
  if (query === null) {
    if (!form.reportValidity()) return false;
    query = form.getElementsByTagName("textarea")[0].value.trim();
  }
  set_busy(true);
  add_question(query);

  fetch(form.getAttribute("action") + "?t=" + conversation_token, {
    method: form.getAttribute("method"),
    headers: {"Content-Type": "text/plain"},
    body: query.trim()
  }).then(response => {
    if (!response.ok) throw new Error(`Response status: ${response.status}`);
    add_answer_stream(response.body).then(msg => {
      form.reset();
    }).catch(error => {
      add_answer(String(error));
    }).finally(() => {
      set_busy(false);
    })
  }).catch(error => {
    add_answer(String(error));
    set_busy(false);
  });

  return false;
}

function send_check() {
  const page_params = new URLSearchParams(window.location.search);
  const page_query = page_params.get("q");

  set_busy(true);
  fetch("check" + "?t=" + conversation_token, {
    method: "POST",
    headers: {"Content-Type": "text/plain"},
  }).then(response => {
    if (!response.ok) throw new Error(`Response status: ${response.status}`);
    return response.text();
  }).then(text => {
    add_answer(text);
    if (page_query) {
      send_query(document.getElementById("form"), page_query);
    } else {
      set_busy(false);
    }
  }).catch(error => {
    add_answer(String(error));
    set_busy(false);
  });
}

document.getElementById("q").addEventListener("keypress", e => {
  if (e.key === "Enter" && e.ctrlKey) {
    e.preventDefault();
    send_query(document.getElementById("form"));
  }
});
send_check();
"""

# language=CSS
_INDEX_CSS: str = r"""
:root {
    --font-size: 100%;
    --text-color: #15141a;
    --main-bg-color: #ffffff;
    --highlight-color: #145ba6;
    --ui-bg-color: #f0f0f4;
    --box-bg-color-q: #eeeeef;
    --box-bg-color-a: #f8f8f9;
    --border-radius: 0.5rem;
    --pad: 1rem;
    --pad-s: 0.5rem;
    --input-height: 20vh;
}

@media (prefers-color-scheme: dark) {
    :root {
        --text-color: #f8f9f9;
        --main-bg-color: #1c1b22;
        --ui-bg-color: #33323a;
        --box-bg-color-q: #2c2b32;
        --box-bg-color-a: #222128;
    }
}

@media (max-width: 600px) {
    :root {
        --font-size: 0.9rem;
        --pad: 0.8rem;
        --pad-s: 0.4rem;
        --input-height: 10vh;
    }
}

* {
    box-sizing: border-box;
}

html, body {
    width: 100%;
    height: 100%;
    padding: 0;
    margin: 0;
    background-color: var(--main-bg-color);
    color: var(--text-color);
    font-size: var(--font-size);
    font-family: sans-serif;
}

body {
    display: grid;
    grid-template-rows: 1fr var(--input-height);
}

fieldset {
    padding: 0;
    margin: 0;
    border: none;
    height: 100%;
    display: grid;
    grid-template-columns: 1fr 3rem;
    justify-items: stretch;
    gap: var(--pad-s);
}

form {
    position: relative;
    padding: var(--pad);
    padding-top: var(--pad-s);
    margin: 0;
    border: none;
}

form textarea, form input {
    outline: none;
    border: none;
    margin: 0;
    transition: background-color 0.1s;
    background-color: var(--ui-bg-color);
    color: var(--text-color);
    border-radius: var(--border-radius);
    font-size: var(--font-size);
    font-family: sans-serif;
}

form textarea:disabled, form input:disabled {
    background-color: var(--box-bg-color-a);
}

form textarea {
    padding: var(--pad);
    resize: none;
    min-height: 0;
    min-width: 0;
}

form input {
    cursor: pointer;
}

div#convo {
    display: flex;
    flex-direction: column;
    gap: var(--pad);
    padding-top: var(--pad);
    padding-bottom: var(--pad-s);
    overflow-y: scroll;
    scroll-behavior: smooth;
}

div#convo > div:first-child {
    margin-top: auto;
}

div#convo .question, div#convo .answer {
    border-radius: var(--border-radius);
    padding: var(--pad);
    overflow-wrap: anywhere;
}

div#convo pre {
    white-space: pre-wrap;
}

div#convo .question {
    align-self: flex-end;
    margin: 0 var(--pad) 0 10vw;
    background-color: var(--box-bg-color-q);
}

div#convo .answer {
    align-self: flex-start;
    margin: 0 10vw 0 var(--pad);
    background-color: var(--box-bg-color-a);
}

#spinner {
    opacity: 0;
    display: block;
    position: absolute;
    top: -1px; left: 0;
    height: 2px; width: 100%;
    background-color: var(--highlight-color);
    background-repeat: no-repeat;
    background-image: linear-gradient(var(--main-bg-color) 0 0), linear-gradient(var(--main-bg-color) 0 0);
    background-size: 60% 100%;
    animation: css-loaders-progress-16 3s infinite;
    transition: opacity 0.1s;
}

@keyframes css-loaders-progress-16 {
    0%   {background-position:-150% 0,-150% 0}
    66%  {background-position: 250% 0,-150% 0}
    100% {background-position: 250% 0, 250% 0}
}
"""

# language=HTML
_INDEX_HTML: bytes = fr"""<!DOCTYPE html>
<html lang="en">
    <head>
        <meta charset="utf-8">
        <meta name="viewport" content="width=device-width,initial-scale=1">
        <meta name="robots" content="noindex,nofollow,nosnippet,notranslate,noarchive">
        <meta name="theme-color" content="#ffffff">
        <meta name="theme-color" content="#1c1b22" media="(prefers-color-scheme: dark)">
        <title>Ollama</title>
        <link rel="icon" href="">
        <style>{_INDEX_CSS}</style>
    </head>
    <body>
        <div id="convo"></div>
        <form method="POST" action="q" onsubmit="send_query(this); return false;" id="form">
            <div id="spinner"></div>
            <fieldset>
                <textarea id="q" placeholder="Ask AI" autocomplete="off" autofocus required></textarea>
                <input type="submit" value="➤" title="Send (Ctrl+Enter)">
            </fieldset>
        </form>
        <script>{_INDEX_JS}</script>
    </body>
</html>
""".encode()  # noqa

if __name__ == "__main__":
    sys.exit(main())