#!/usr/bin/python3

"""
Parse a log stream and track per-IP-address 'anomaly scores'.
An ipset is maintained for blocking subsequent access according to the configured ruleset and threshold.
"""

import argparse
import logging.config
import os
import re
import socket
import subprocess
import sys
import time
from collections import OrderedDict, defaultdict
from configparser import ConfigParser, Error as ConfigParserError
from dataclasses import dataclass
from ipaddress import IPv4Address, AddressValueError
from pathlib import Path
from string import Template
from threading import Thread, Lock, Event
from typing import TextIO, List, Tuple, Dict, Optional, Set, Iterator, Literal, DefaultDict, ContextManager


class ScoreCounter:
    """
    Sliding window of per-IP score values.
    """

    def __init__(self, timeout: float, score_min: int, maxelem: int, exclude_local: bool) -> None:
        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__)
        self._debug: bool = self._logger.getEffectiveLevel() <= logging.DEBUG
        self._timeout: float = max(0.0, timeout)
        self._score_min: int = score_min
        self._maxelem: int = max(0, maxelem)
        self._exclude_local: bool = exclude_local
        self._scores: DefaultDict[int, int] = defaultdict(int)
        self._last_seen: OrderedDict[int, float] = OrderedDict()
        self._dirty: Set[int] = set()
        self._lock: Lock = Lock()

    @classmethod
    def _now(cls) -> float:
        return time.monotonic()

    def _maintain(self) -> None:
        """Periodically purge too old or too many entries."""

        limit: float = self._now() - self._timeout
        while len(self._last_seen) > 0:
            ip, last_seen = next(iter(self._last_seen.items()))
            if last_seen < limit:
                self._del(ip)
            else:
                break

        while len(self._last_seen) > self._maxelem:
            ip = next(iter(self._last_seen.keys()))
            self._del(ip)

    def _del(self, ip: int) -> None:
        self._logger.debug(f"Removing {IPv4Address(ip)} with score {self._scores[ip]}")
        del self._scores[ip]
        del self._last_seen[ip]
        self._dirty.discard(ip)

    def add(self, address: IPv4Address, score: int) -> None:
        """Add the given score value to an IP address, leading to an update on next occasion."""

        if address.is_unspecified or (self._exclude_local and not address.is_global):
            self._logger.info(f"Ignoring local address {address}")
            return

        with self._lock:
            ip: int = int(address)
            old_score, self._scores[ip] = self._scores[ip], max(self._score_min, self._scores[ip] + score)
            if old_score != self._scores[ip] or score == 0:
                self._logger.debug(f"Score for {address} set to {self._scores[ip]}")
                self._dirty.add(ip)
                self._last_seen[ip] = self._now()
                self._last_seen.move_to_end(ip)

    def get_updates(self) -> Iterator[Tuple[IPv4Address, int]]:
        """Yield all 'dirty' entries."""

        with self._lock:
            self._maintain()
            updates: List[Tuple[IPv4Address, int]] = [(IPv4Address(ip), self._scores[ip]) for ip in self._dirty]
            self._dirty.clear()
        for address, score in updates:
            self._logger.info(f"Score for {address} now {score}")
        if updates:
            self._logger.info(f"Flushing {len(updates)} updates")
            yield from updates


class IpSet:
    """
    Wrap calls to the 'ipset' binary.
    The used set can be created or is assumed to be present and with a timeout configured, for example:
    `ipset -exist create $NAME hash:ip family inet timeout $TIMEOUT [maxelem $MAXELEM]`
    """

    def __init__(self,
                 ipset_exe: str, name: str, maxelem: int,
                 timeout: int, create: bool, dry_run: bool = False) -> None:
        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__).getChild(name)
        self._debug: bool = self._logger.getEffectiveLevel() <= logging.DEBUG
        self._name: str = name
        self._timeout: int = max(0, timeout)
        self._maxelem: int = max(1, maxelem)
        self._create: bool = create
        self._dry_run: bool = dry_run
        self._ipset_exe: str = ipset_exe
        self._check()

    def _run(self, stdin: Optional[str], *cmd: str) -> bool:
        if self._debug:
            self._logger.debug(" ".join(cmd))
        if self._debug and stdin is not None:
            self._logger.debug(stdin.replace("\n", "␤"))
        if self._dry_run:
            return True
        try:
            subprocess.run(cmd, check=True, shell=False, close_fds=True, env={},
                           input=stdin.encode("ascii", errors="replace") if stdin is not None else None,
                           stdin=subprocess.DEVNULL if stdin is None else None,
                           stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
        except (subprocess.SubprocessError, OSError) as e:
            self._logger.error(f"Command failed: {' '.join(cmd)} - {str(e)}")
            return False
        else:
            return True

    def _ipset(self, *args: str) -> bool:
        return self._run(None, self._ipset_exe, *args)

    def _check(self) -> None:
        """Early check to verify the set is present or can be created."""
        if self._create:
            if not self._run(None, self._ipset_exe, "-exist", "create", self._name, "hash:ip", "family", "inet",
                             "maxelem", str(self._maxelem), "timeout", str(self._timeout)):
                raise RuntimeError(f"Cannot create ipset {self._name}")
        else:
            if not self._run(None, self._ipset_exe, "-terse", "list", self._name):
                raise RuntimeError(f"Cannot find ipset '{self._name}'")

    def update(self, *args: IPv4Address) -> None:
        """Add or read the given IP, resetting its timeout."""
        for address in args:
            self._logger.info(f"Update {address}")
            self._ipset("-exist", "add", self._name, str(address))

    def remove(self, *args: IPv4Address) -> None:
        """Remove the given IP, if present."""
        for address in args:
            self._logger.info(f"Remove {address}")
            self._ipset("-exist", "del", self._name, str(address))

    def apply(self) -> None:
        """Flush pending changes."""
        pass


class BulkIpSet(IpSet):
    """
    Collect multiple add/del calls and apply them in batch via 'restore'.
    """

    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self._calls: List[Tuple[str, ...]] = []

    def _ipset(self, *args: str) -> bool:
        self._calls.append(args)
        return True

    def apply(self) -> None:
        calls, self._calls = self._calls, []
        if not calls:
            pass
        elif len(calls) == 1:
            self._run(None, self._ipset_exe, *calls[0])
        else:
            self._run("\n".join(" ".join(_) for _ in calls) + "\n", self._ipset_exe, "restore")


class Match:
    """
    Regular expression match as configured to be applied to the given log line and context.
    """

    def __init__(self,
                 name: str, target: Optional[str],
                 match: str, ignore_case: bool, negated: bool, defaults: Dict[str, str],
                 score: Optional[int],
                 enabled: bool,
                 last: bool) -> None:
        """
        :param name: Arbitrary but unique identifier as defined by the config section.
        :param target: Template to run against according to the current variable context, otherwise match the log line.
        :param match: Regular expression, named capture groups will be added to the current line's context.
        :param ignore_case: Perform case-insensitive matching.
        :param negated: Invert the outcome of the match, i.e., successfully match if the pattern does not apply.
        :param defaults: Variables to set even if the pattern as a whole or a capture group does not apply.
        :param score: Upon successful match, add this value to the IP address score.
        :param enabled: Whether this match should be evaluated.
        :param last: Stop processing of the current ruleset after successful match.
        """
        try:
            self._re: re.Pattern[str] = re.compile(match, flags=re.IGNORECASE if ignore_case else 0)
            self._target: Optional[Template] = Template(target) if target is not None else None
        except re.error as e:
            raise ValueError(f"Invalid pattern '{match}': {str(e)}") from None
        else:
            self._name: str = name
            self._defaults: Dict[str, str] = defaults
            self._score: Optional[int] = score
            self._negated: bool = negated
            self._enabled: bool = enabled
            self._last: bool = last

    @property
    def name(self) -> str:
        return self._name

    @property
    def enabled(self) -> bool:
        return self._enabled

    @property
    def negated(self) -> bool:
        return self._negated

    @property
    def target(self) -> Optional[str]:
        return self._target.template if self._target is not None else None

    @property
    def last(self) -> bool:
        return self._last

    @property
    def score(self) -> Optional[int]:
        return self._score

    @property
    def groups(self) -> Set[str]:
        return set(self._re.groupindex.keys()) | set(self._defaults.keys())

    def _get_target(self, line: str, context: Dict[str, str]) -> str:
        if self._target is None:
            return line
        try:
            return self._target.safe_substitute(context)
        except (ValueError, re.error):
            return ""

    def match(self, line: str, context: Dict[str, str]) -> bool:
        target: str = self._get_target(line, context)
        context.update(self._defaults)
        m: Optional[re.Match[str]] = self._re.search(target)
        if m is not None:
            context.update({k: v for k, v in m.groupdict(default=None).items() if v is not None})
            return not self._negated
        return self._negated

    def __str__(self) -> str:
        defaults: str = ','.join(f'{k}={v}' for k, v in self._defaults.items())
        flags: str = "{}{}{}".format("" if self._enabled else "X",
                                     "I" if self._re.flags & re.IGNORECASE else "",
                                     "." if self._last else "",)
        return "{}: {} {}~ '{}' -> {}{}{}".format(
            self._name, "line" if self._target is None else self._target.template,
            "!" if self._negated else "=", self._re.pattern, self._score,
            f" [{flags}]" if flags else "", f" ({defaults})" if defaults else "",
        )


class Ruleset:
    """
    Match ruleset consisting of an initial match and a list of subsequent matches.
    """

    @dataclass
    class Result:
        address: Optional[IPv4Address]
        score: Optional[int]

    def __init__(self, pre_match: Match) -> None:
        self._logger: logging.Logger = logging.getLogger(self.__class__.__name__).getChild(pre_match.name)
        self._debug: bool = self._logger.getEffectiveLevel() <= logging.DEBUG
        self._pre_match: Match = pre_match
        self._rules: List[Match] = []

    @property
    def enabled(self) -> bool:
        return self._pre_match.enabled

    @property
    def last(self) -> bool:
        return self._pre_match.last

    def append(self, rule: Match) -> None:
        if rule.enabled:
            self._rules.append(rule)

    def validate(self) -> None:
        self._logger.debug(f"{self._pre_match}")
        for rule in self._rules:
            self._logger.debug(f"-> {rule}")

        if self._pre_match.target is not None:
            self._logger.warning(f"Initial match on '{self._pre_match.target}'")
        elif not any(_.score is not None for _ in (self._pre_match, *self._rules)):
            self._logger.warning("No match assigns a score")
        elif not any("address" in _.groups for _ in (self._pre_match, *self._rules)):
            self._logger.warning("No match provides an 'address' group")
        elif self._pre_match.negated and self._pre_match.groups:
            self._logger.warning("Needless groups in negated match")

    def match(self, line: str) -> Optional[Result]:
        """
        If the initial match succeeds, run all other matches (if any).
        Thereby count the overall score, maintain a variable context, and extract the IP address.
        """

        context: Dict[str, str] = {}
        if not self._pre_match.match(line, context):
            self._logger.debug("Skipped")
            return None

        score: Optional[int] = self._pre_match.score
        for rule in self._rules:
            if rule.match(line, context):
                if score is None:
                    score = rule.score
                elif rule.score is not None:
                    score += rule.score
                if rule.last:
                    break
        if self._debug and context:
            self._logger.debug(str(context))

        if score is None:
            self._logger.debug("Match, no score")
            return Ruleset.Result(address=None, score=None)
        try:
            address: IPv4Address = IPv4Address(context["address"])
        except (KeyError, AddressValueError) as e:
            self._logger.warning(f"Cannot parse 'address' group for match: {repr(e)}")
            return Ruleset.Result(address=None, score=score)
        else:
            self._logger.debug(f"Match for {address} with score {score}")
            return Ruleset.Result(address=address, score=score)


class RulesetParser:
    """
    Read and parse the ruleset configuration(s) from the given INI file.
    """

    @classmethod
    def _parse_match(cls, parser: ConfigParser, section: str, name: str) -> Match:
        return Match(name=name,
                     target=parser.get(section, "target", fallback=None),
                     match=parser.get(section, "match", fallback=""),
                     ignore_case=parser.getboolean(section, "ignore_case", fallback=False),
                     negated=parser.getboolean(section, "negated", fallback=False),
                     score=parser.getint(section, "score", fallback=None),
                     enabled=parser.getboolean(section, "enabled", fallback=True),
                     last=parser.getboolean(section, "last", fallback=False),
                     defaults={k.strip(): v.strip() for k, v in
                               (line.split("=", maxsplit=1)
                               for line in parser.get(section, "defaults", fallback="").splitlines(keepends=False)
                               if "=" in line)})

    @classmethod
    def _parse(cls, parser: ConfigParser) -> Iterator[Ruleset]:
        rulesets: Dict[str, Ruleset] = {}

        for section in parser.sections():
            parts: List[str] = section.split(".")
            if len(parts) == 2 and parts[0] == "ruleset" and parts[1] not in rulesets:
                rulesets[parts[1]] = Ruleset(cls._parse_match(parser, section, parts[1]))
            elif len(parts) == 3 and parts[0] == "ruleset" and parts[1] in rulesets:
                rulesets[parts[1]].append(cls._parse_match(parser, section, parts[2]))
            else:
                raise ValueError(f"Unrecognized section '{section}'")

        for ruleset in rulesets.values():
            if ruleset.enabled:
                ruleset.validate()
                yield ruleset

    @classmethod
    def from_ini(cls, filename: Path) -> Iterator[Ruleset]:
        try:
            parser: ConfigParser = ConfigParser()
            with filename.open("r") as fp:
                parser.read_file(fp)
            yield from cls._parse(parser)
        except (OSError, ConfigParserError, ValueError) as e:
            raise RuntimeError(f"Cannot parse config '{filename}': {str(e)}") from e


class Worker(ContextManager[None]):
    """
    Background thread to periodically check for relevant changes and update the ipset accordingly.
    """

    def __init__(self, counters: ScoreCounter, ipset: IpSet, threshold: int, update_interval: float) -> None:
        self._counters: ScoreCounter = counters
        self._ipset: IpSet = ipset
        self._threshold: int = threshold
        self._update_interval: float = max(0.0, update_interval)
        self._shutdown: Event = Event()
        self._thread: Optional[Thread] = None

    def _worker(self) -> None:
        self._update()
        while not self._shutdown.wait(self._update_interval):
            self._update()
        self._update()

    def _update(self) -> None:
        for address, score in self._counters.get_updates():
            if score <= self._threshold:
                self._ipset.remove(address)
            else:
                self._ipset.update(address)
        self._ipset.apply()

    def __enter__(self) -> None:
        assert self._thread is None
        assert not self._shutdown.is_set()
        self._thread = Thread(target=self._worker, name=self.__class__.__name__)
        self._thread.start()

    def __exit__(self, *args) -> Literal[False]:
        assert self._thread is not None
        assert not self._shutdown.is_set()
        self._shutdown.set()
        self._thread.join()
        self._thread = None
        return False


def _sd_notify(*commands: bytes) -> bool:
    """
    Send the given commands to the systemd socket, if present.
    """

    notify_path: Optional[str] = os.getenv("NOTIFY_SOCKET")
    if not notify_path:
        return False
    if notify_path.startswith("@"):  # abstract namespace socket
        notify_path = "\0" + notify_path[1:]
    if not commands:
        return False

    sock: socket.socket = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
    try:
        sock.connect(notify_path)
        sock.sendall(b"\n".join(commands) + b"\n")
    except OSError:
        return False
    else:
        return True
    finally:
        sock.close()


def _main(logger: logging.Logger, stream: TextIO, systemd: bool,
          ruleset: List[Ruleset], counters: ScoreCounter, ipset: IpSet,
          threshold: int, update_interval: float) -> bool:
    """
    Run the ruleset on every line read, until interrupt or EOF.
    """

    if not systemd:
        logger.info("Starting up")
    elif _sd_notify(b"READY=1"):
        logger.info("Starting up, systemd READY")
    else:
        logger.warning("Starting up, systemd notify failed")

    line_logger: logging.Logger = logger.getChild("input")
    total_lines: int = 0

    try:
        with Worker(counters=counters, ipset=ipset, threshold=threshold, update_interval=update_interval):
            while True:
                line: str = stream.readline()
                if not line:
                    break

                line_logger.debug(line.rstrip())
                total_lines += 1
                for rule in ruleset:
                    result = rule.match(line)
                    if result is not None:
                        if result.address is not None and result.score is not None:
                            counters.add(result.address, result.score)
                        if rule.last:
                            break
    except KeyboardInterrupt:
        logger.info(f"Interrupt after {total_lines} lines")
        return True
    else:
        logger.info(f"Shutdown after {total_lines} lines")
        return True
    finally:
        stream.close()


def _open_log_stream(stream: str, create: bool) -> TextIO:
    """Open the pipe to read from, creating it on demand."""
    if stream == "-":
        return sys.stdin
    else:
        sys.stdin.close()
    try:
        try:
            return os.fdopen(os.open(stream, os.O_RDWR), "r")
        except FileNotFoundError:
            if not create:
                raise
            os.mkfifo(stream)
            return os.fdopen(os.open(stream, os.O_RDWR), "r")
    except OSError as e:
        raise RuntimeError(f"Cannot open '{stream}': {str(e)}") from e


def _setup_logging(level: str, log_time: bool) -> logging.Logger:
    """Set up log level and format."""
    logging.basicConfig(level=logging.getLevelName(level.upper()),
                        stream=sys.stderr,
                        format="%(asctime)s %(levelname)s %(name)s: %(message)s"
                        if log_time else "%(levelname)s %(name)s: %(message)s")
    return logging.getLogger(__name__.strip("_"))


def main() -> int:
    parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    agroup = parser.add_argument_group('input options')
    agroup.add_argument('--input', type=str, required=True,
                        metavar="STREAM", help="pipe to read from or '-' for stdin")
    agroup.add_argument('--input-create', action="store_true",
                        help="create the fifo with current umask if needed")
    agroup.add_argument('--ruleset', type=Path, required=True,
                        metavar="INI", help="ruleset configuration file")
    agroup = parser.add_argument_group('ipset options')
    agroup.add_argument('--ipset', required=True, type=str,
                        metavar="NAME", help="name of the ipset to maintain")
    agroup.add_argument('--ipset-create', action="store_true",
                        help="create the ipset if needed")
    agroup.add_argument('--ipset-create-maxelem', type=int, default=65536,
                        metavar="NUM", help="max ipset size when creating")
    agroup.add_argument('--ipset-create-timeout', type=int, default=300,
                        metavar="SEC", help="ipset timeout when creating")
    agroup.add_argument("--ipset-exe", type=str, default="/usr/sbin/ipset",
                        metavar="CMD", help="path to ipset command binary")
    agroup = parser.add_argument_group('update options')
    agroup.add_argument('--update-interval', type=float, default=5.0,
                        metavar="SEC", help="how often to apply ipset changes")
    agroup.add_argument('--score-threshold', type=int, default=0,
                        metavar="NUM", help="add address when above, remove otherwise")
    agroup.add_argument('--score-min', type=int, default=0,
                        metavar="NUM", help="minimum value, if negative scores are involved")
    agroup = parser.add_argument_group('tracking options')
    agroup.add_argument('--score-timeout', type=float, default=300.0,
                        metavar="SEC", help="how long to locally keep address scores")
    agroup.add_argument('--score-limit', type=int, default=65536,
                        metavar="LIMIT", help="number of addresses to track locally")
    agroup.add_argument('--exclude-local', action="store_true",
                        help="ignore results for reserved local addresses")
    agroup = parser.add_argument_group('log options')
    agroup.add_argument('--systemd', action="store_true",
                        help="try to signal systemd readiness")
    agroup.add_argument('--log-level', type=str, default="info",
                        metavar="LVL", choices=["debug", "info", "warning", "error"], help="log level")
    agroup.add_argument('--log-time', action="store_true",
                        help="prefix log with timestamp")
    parser.add_argument('--dry-run', action="store_true",
                        help="do not actually perform ipset operations")
    args = parser.parse_args()

    logger: logging.Logger = _setup_logging(args.log_level, args.log_time)
    try:
        return 0 if _main(
            logger=logger,
            systemd=args.systemd,
            stream=_open_log_stream(stream=args.input, create=args.input_create),
            ruleset=list(RulesetParser.from_ini(filename=args.ruleset)),
            counters=ScoreCounter(timeout=args.score_timeout, score_min=min(args.score_min, args.score_threshold),
                                  maxelem=args.score_limit, exclude_local=args.exclude_local),
            ipset=BulkIpSet(ipset_exe=args.ipset_exe, name=args.ipset, dry_run=args.dry_run, create=args.ipset_create,
                            maxelem=args.ipset_create_maxelem, timeout=args.ipset_create_timeout),
            threshold=args.score_threshold, update_interval=args.update_interval,
        ) else 1
    except RuntimeError as e:
        logger.error(str(e), exc_info=e if logger.getEffectiveLevel() <= logging.DEBUG else None)
        return 1


if __name__ == "__main__":
    sys.exit(main())