#!/usr/bin/env python3

"""
Minimal SMTP forwarding relay daemon.
Accepts unsolicited mails via SMTP and sends them using a remote SMTP server as configured.
"""

import argparse
import asyncio
import configparser
import logging.config
import os
import signal
import sys
import threading
import time
import uuid

from locale import setlocale, LC_ALL

from email import message_from_bytes, message_from_string
from email.message import Message
from email.policy import default as EmailMessagePolicy

from aiosmtpd.controller import Controller, SMTP as Server
from aiosmtpd.smtp import Session, Envelope
from aiosmtplib import SMTP, SMTPException, SMTPRecipientsRefused, SMTPRecipientRefused, SMTPResponse

from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union, List, Tuple, Set, ContextManager


@dataclass(frozen=True)
class ServerConfig:
    dryrun: bool
    hostname: str
    port: int


@dataclass(frozen=True)
class ClientConfig:
    hostname: str
    port: int
    sender: str
    recipients: List[str]
    username: Optional[str]
    password: Optional[str]
    set_reply_to: bool
    use_tls: bool
    start_tls: bool


@dataclass(frozen=True)
class Config:
    server: ServerConfig
    client: ClientConfig

    @classmethod
    def _from_config(cls, config: configparser.ConfigParser) -> 'Config':
        return cls(
            server=ServerConfig(
                dryrun=config.getboolean("server", "dryrun", fallback=False),
                hostname=config.get("server", "hostname", fallback="localhost"),
                port=config.getint("server", "port", fallback=8025),
            ),
            client=ClientConfig(
                hostname=config.get("client", "hostname"),
                port=config.getint("client", "port"),
                sender=config.get("client", "sender"),
                recipients=[_.strip() for _ in config.get("client", "recipients", fallback="").split(",") if _],
                username=config.get("client", "username", fallback=config.get("client", "sender")),
                password=config.get("client", "password", fallback=""),
                set_reply_to=config.getboolean("client", "set_reply_to", fallback=False),
                use_tls=config.getboolean("client", "use_tls", fallback=True),
                start_tls=config.getboolean("client", "start_tls", fallback=False),
            ),
        )

    @classmethod
    def from_ini(cls, filename: str) -> 'Config':
        try:
            config_parser: configparser.ConfigParser = configparser.ConfigParser()
            with open(filename, "r") as fp:
                config_parser.read_file(fp)
            return cls._from_config(config_parser)
        except (OSError, configparser.Error, UnicodeDecodeError) as e:
            raise RuntimeError(f"Cannot parse config '{filename}': {str(e)}") from None


class SMTPClient(SMTP):
    def __init__(self, config: ClientConfig, enabled: bool) -> None:
        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
        self._config: ClientConfig = config
        self._enabled: bool = enabled
        self._lock: asyncio.Lock = asyncio.Lock()
        super().__init__(
            hostname=self._config.hostname, port=self._config.port,
            username=self._config.username, password=self._config.password,
            use_tls=self._config.use_tls, start_tls=self._config.start_tls,
        )

    async def _send_message(self, message: Message, sender: str, recipients: List[str]) -> None:
        failed_recipients: Set[str] = set((await self.send_message(
            message, sender=sender, recipients=recipients
        ))[0].keys())
        if len(failed_recipients):  # raise also when not all have been refused
            raise SMTPRecipientsRefused([SMTPRecipientRefused(0, "", _) for _ in failed_recipients])

    async def forward_message(self, message: Message, sender: Optional[str], recipients: List[str]) -> None:
        del message["From"]
        message["From"] = self._config.sender
        if sender is not None and sender != self._config.sender:
            message["Original-Sender"] = sender

        del message["Reply-To"]
        if self._config.set_reply_to and sender is not None and sender != self._config.sender:
            message["Reply-To"] = sender

        if self._config.recipients:
            message["Original-Recipient"] = ", ".join(recipients)
            recipients = self._config.recipients
        del message["To"]
        message["To"] = ", ".join(recipients)

        self._logger.info(f"{message['From']} ({message['Original-Sender']}) -> "
                          f"{message['To']} ({message['Original-Recipient']}): "
                          f"'{message.get('Subject', '')}'")

        if not self._enabled:
            return
        async with self._lock:  # TODO: consumer task from spool queue, reusing connections
            try:
                await self.connect()
                await self._send_message(message, self._config.sender, recipients)
            except SMTPRecipientsRefused as e:
                raise RuntimeError(f"Recipients refused: {', '.join(_.recipient for _ in e.recipients)}") from e
            except (SMTPException, OSError) as e:
                raise RuntimeError(str(e)) from e
            finally:
                self.close()


class Handler:
    def __init__(self, client: SMTPClient) -> None:
        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
        self._client: SMTPClient = client
        self._enabled: bool = True

    async def handle_DATA(self, server: Server, session: Session, envelope: Envelope) -> str:
        if not self._enabled:
            return "421 Service not available"
        try:
            await self._client.forward_message(self._prepare_message(server, envelope),
                                               sender=envelope.mail_from,
                                               recipients=envelope.rcpt_tos)
        except (RuntimeError, ValueError) as e:
            self._logger.error(f"Cannot forward: {str(e)}")
            return f"550 {e.__cause__.__class__.__name__ if e.__cause__ is not None else e.__class__.__name__}"
        else:
            return "250 OK"

    @classmethod
    def _parse_message(cls, data: Optional[Union[bytes, str]]) -> Message:
        if isinstance(data, bytes):
            return message_from_bytes(data, policy=EmailMessagePolicy)
        elif isinstance(data, str):
            return message_from_string(data, policy=EmailMessagePolicy)
        else:
            raise ValueError(str(type(data)))

    def _prepare_message(self, server: Server, envelope: Envelope) -> Message:
        message: Message = self._parse_message(envelope.content)

        if server.transport is None:
            raise RuntimeError("Cannot get transport socket")
        peer: Tuple[str, int] = server.transport.get_extra_info("peername")
        sock: Tuple[str, int] = server.transport.get_extra_info("sockname")
        host: str = server.hostname  # getfqdn
        now: str = time.strftime("%a, %d %b %Y %H:%M:%S %z")

        message["Received"] = f"from {peer[0]} ([{peer[0]}]) by {host} ([{sock[0]}]); {now}"
        message["X-Peer"] = f"[{peer[0]}]:{peer[1]}"
        if "Date" not in message:
            message["Date"] = now
        if "Message-ID" not in message:
            message["Message-ID"] = f"<{str(uuid.uuid4())}@{host}>"

        return message


class SMTPServer(Controller):
    def __init__(self, config: ServerConfig, handler: Handler) -> None:
        super().__init__(handler=handler, hostname=config.hostname, port=config.port, ready_timeout=10.0)
        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)

    def run(self) -> Optional[bool]:
        shutdown_signal: Optional[int] = None
        shutdown_requested: threading.Event = threading.Event()

        def _handler(signum: int, frame) -> None:
            nonlocal shutdown_signal
            shutdown_signal = signum
            shutdown_requested.set()

        signal.signal(signal.SIGHUP, _handler)
        signal.signal(signal.SIGINT, _handler)
        signal.signal(signal.SIGTERM, _handler)

        self._logger.info(f"Starting on {self.hostname}:{self.port}")
        self.start()
        shutdown_requested.wait()
        self._logger.info(f"Stopping due to signal {shutdown_signal}")
        self.stop()
        return None if shutdown_signal == signal.SIGHUP else True


class PidFile(ContextManager[None]):
    def __init__(self, filename: Optional[str]) -> None:
        self._filename: Optional[Path] = Path(filename) if filename is not None else None

    def __enter__(self) -> None:
        if self._filename is not None:
            self._filename.write_text(str(os.getpid()))

    def __exit__(self, *args, **kwargs) -> None:
        if self._filename is not None:
            self._filename.unlink(missing_ok=True)


def _configure_logging(config_file: Optional[str]) -> None:
    logging.raiseExceptions = False
    logging.logThreads = False
    logging.logMultiprocessing = False
    logging.logProcesses = True

    if config_file is not None:  # https://docs.python.org/3/library/logging.config.html
        logging.config.fileConfig(config_file)
    else:
        logging.config.dictConfig({
            "version": 1,
            "formatters": {"standard": {
                "format": "%(levelname)s %(name)s: %(message)s",
            }},
            "handlers": {"default": {
                "formatter": "standard",
                "class": "logging.StreamHandler",
                "stream": "ext://sys.stderr",
            }},
            "loggers": {"": {
                "handlers": ["default"],
                "level": "INFO",
            }},
        })


def main() -> int:
    parser = argparse.ArgumentParser(description=__doc__.strip(),
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--config", metavar="CONFIG.INI", type=str, default="./config.ini",
                        help="configuration file")
    parser.add_argument("--logging-config", metavar="LOGCONFIG.INI", type=str, default=None,
                        help="logging configuration file")
    parser.add_argument("--pid-file", metavar="SMTPRD.PID", type=str, default=None,
                        help="write pid to file")
    args = parser.parse_args()
    setlocale(LC_ALL, "C")  # for strftime

    _configure_logging(args.logging_config)
    exit_code: Optional[bool] = None
    while exit_code is None:
        try:
            config: Config = Config.from_ini(args.config)
            controller = SMTPServer(config=config.server, handler=Handler(
                SMTPClient(config.client, not config.server.dryrun),
            ))
        except RuntimeError as e:
            logging.getLogger().error(str(e))
            return 1
        with PidFile(args.pid_file):
            exit_code = controller.run()
    return 0 if exit_code else 1


if __name__ == "__main__":
    sys.exit(main())