Source code for gallia.log

# SPDX-FileCopyrightText: AISEC Pentesting Team
#
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import atexit
import datetime
import gzip
import io
import logging
import mmap
import os
import shutil
import socket
import sys
import tempfile
import traceback
from collections.abc import Iterator
from dataclasses import dataclass
from enum import Enum, IntEnum, unique
from logging.handlers import QueueHandler, QueueListener
from pathlib import Path
from queue import Queue
from types import TracebackType
from typing import TYPE_CHECKING, Any, BinaryIO, Self, TextIO, TypeAlias, cast

import msgspec
import zstandard

if TYPE_CHECKING:
    from logging import _ExcInfoType


tz = datetime.datetime.now(datetime.UTC).tzinfo


[docs] @unique class ColorMode(Enum): """ColorMode is used as an argument to :func:`set_color_mode`.""" #: Colors are always turned on. ALWAYS = "always" #: Colors are turned off if the target #: stream (e.g. stderr) is not a tty. AUTO = "auto" #: No colors are used. In other words, #: no ANSI escape codes are included. NEVER = "never"
[docs] def resolve_color_mode(mode: ColorMode, stream: TextIO = sys.stderr) -> bool: """Sets the color mode of the console log handler. :param mode: The available options are described in :class:`ColorMode`. :param stream: Used as a reference for :attr:`ColorMode.AUTO`. """ match mode: case ColorMode.ALWAYS: return True case ColorMode.AUTO: if os.getenv("NO_COLOR") is not None: return False else: return stream.isatty() case ColorMode.NEVER: return False
# https://stackoverflow.com/a/35804945 def _add_logging_level(level_name: str, level_num: int) -> None: method_name = level_name.lower() if hasattr(logging, level_name): raise AttributeError(f"{level_name} already defined in logging module") if hasattr(logging, method_name): raise AttributeError(f"{method_name} already defined in logging module") if hasattr(logging.getLoggerClass(), method_name): raise AttributeError(f"{method_name} already defined in logger class") # This method was inspired by the answers to Stack Overflow post # http://stackoverflow.com/q/2183233/2988730, especially # http://stackoverflow.com/a/13638084/2988730 def for_level(self, message, *args, **kwargs): # type: ignore if self.isEnabledFor(level_num): self._log( level_num, message, args, **kwargs, ) def to_root(message, *args, **kwargs): # type: ignore logging.log(level_num, message, *args, **kwargs) logging.addLevelName(level_num, level_name) setattr(logging, level_name, level_num) setattr(logging.getLoggerClass(), method_name, for_level) setattr(logging, method_name, to_root) _add_logging_level("TRACE", 5) _add_logging_level("NOTICE", 25)
[docs] @unique class Loglevel(IntEnum): """A wrapper around the constants exposed by python's ``logging`` module. Since gallia adds two additional loglevel's (``NOTICE`` and ``TRACE``), this class provides a type safe way to access the loglevels. The level ``NOTICE`` was added to conform better to RFC3164. Subsequently, ``TRACE`` was added to have a facility for optional debug messages. Loglevel describes python specific values for loglevels which are required to integrate with the python ecosystem. For generic priority values, see :class:`PenlogPriority`. """ CRITICAL = logging.CRITICAL ERROR = logging.ERROR WARNING = logging.WARNING NOTICE = logging.NOTICE # type: ignore INFO = logging.INFO DEBUG = logging.DEBUG TRACE = logging.TRACE # type: ignore
[docs] @unique class PenlogPriority(IntEnum): """PenlogPriority holds the values which are written to json log records. These values conform to RFC3164 with the addition of ``TRACE``. Since Python uses different int values for the loglevels, there are two enums in gallia describing loglevels. PenlogPriority describes generic priority values which are included in json log records. """ EMERGENCY = 0 ALERT = 1 CRITICAL = 2 ERROR = 3 WARNING = 4 NOTICE = 5 INFO = 6 DEBUG = 7 TRACE = 8
[docs] @classmethod def from_str(cls, string: str) -> PenlogPriority: """Converts a string to an instance of PenlogPriority. ``string`` can be a numeric value (0 to 8 inclusive) or a string with a case insensitive name of the level (e.g. ``debug``). """ if string.isnumeric(): return cls(int(string, 0)) match string.lower(): case "emergency": return cls.EMERGENCY case "alert": return cls.ALERT case "critical": return cls.CRITICAL case "error": return cls.ERROR case "warning": return cls.WARNING case "notice": return cls.NOTICE case "info": return cls.INFO case "debug": return cls.DEBUG case "trace": return cls.TRACE case _: raise ValueError(f"{string} not a valid priority")
[docs] @classmethod def from_level(cls, value: int) -> PenlogPriority: """Converts an int value (e.g. from python's logging module) to an instance of this class. """ match value: case Loglevel.TRACE: return cls.TRACE case Loglevel.DEBUG: return cls.DEBUG case Loglevel.INFO: return cls.INFO case Loglevel.NOTICE: return cls.NOTICE case Loglevel.WARNING: return cls.WARNING case Loglevel.ERROR: return cls.ERROR case Loglevel.CRITICAL: return cls.CRITICAL case _: raise ValueError("invalid value")
[docs] def to_level(self) -> Loglevel: """Converts an instance of PenlogPriority to :class:`Loglevel`.""" match self: case self.TRACE: return Loglevel.TRACE case self.DEBUG: return Loglevel.DEBUG case self.INFO: return Loglevel.INFO case self.NOTICE: return Loglevel.NOTICE case self.WARNING: return Loglevel.WARNING case self.ERROR: return Loglevel.ERROR case self.CRITICAL: return Loglevel.CRITICAL case _: raise ValueError("invalid value")
[docs] def setup_logging( level: Loglevel | None = None, color_mode: ColorMode = ColorMode.AUTO, no_volatile_info: bool = False, logger_name: str = "gallia", ) -> None: """Enable and configure gallia's logging system. If this fuction is not called as early as possible, the logging system is in an undefined state und might not behave as expected. Always use this function to initialize gallia's logging. For instance, ``setup_logging()`` initializes a QueueHandler to avoid blocking calls during logging. :param level: The loglevel to enable for the console handler. If this argument is None, the env variable ``GALLIA_LOGLEVEL`` (see :doc:`../env`) is read. :param file_level: The loglevel to enable for the file handler. :param path: The path to the logfile containing json records. :param color_mode: The color mode to use for the console. """ colored = resolve_color_mode(color_mode) if level is None: # FIXME why is this here and not in config? if (raw := os.getenv("GALLIA_LOGLEVEL")) is not None: level = PenlogPriority.from_str(raw).to_level() else: level = Loglevel.DEBUG # These are slow and not used by gallia. logging.logMultiprocessing = False logging.logThreads = False logging.logProcesses = False logger = logging.getLogger(logger_name) # LogLevel cannot be 0 (NOTSET), because only the root logger sends it to its handlers then logger.setLevel(1) # Clean up potentially existing handlers and create a new async QueueHandler for stderr output while len(logger.handlers) > 0: logger.handlers[0].close() logger.removeHandler(logger.handlers[0]) colored = resolve_color_mode(color_mode) add_stderr_log_handler(logger_name, level, no_volatile_info, colored)
def add_stderr_log_handler( logger_name: str, level: Loglevel, no_volatile_info: bool, colored: bool ) -> None: queue: Queue[Any] = Queue() logger = logging.getLogger(logger_name) logger.addHandler(QueueHandler(queue)) stderr_handler = logging.StreamHandler(sys.stderr) stderr_handler.setLevel(level) console_formatter = _ConsoleFormatter() console_formatter.colored = colored stderr_handler.terminator = "" # We manually handle the terminator while formatting if no_volatile_info is False: console_formatter.volatile_info = True stderr_handler.setFormatter(console_formatter) queue_listener = QueueListener( queue, *[stderr_handler], respect_handler_level=True, ) queue_listener.start() atexit.register(queue_listener.stop) def add_zst_log_handler( logger_name: str, filepath: Path, file_log_level: Loglevel ) -> logging.Handler: queue: Queue[Any] = Queue() logger = get_logger(logger_name) logger.addHandler(QueueHandler(queue)) zstd_handler = _ZstdFileHandler( filepath, level=file_log_level, ) zstd_handler.setLevel(file_log_level) zstd_handler.setFormatter(_JSONFormatter()) queue_listener = QueueListener( queue, *[zstd_handler], respect_handler_level=True, ) queue_listener.start() atexit.register(queue_listener.stop) return zstd_handler class _PenlogRecordV1(msgspec.Struct, omit_defaults=True): component: str host: str data: str timestamp: str priority: int type: str | None = None tags: list[str] | None = None line: str | None = None stacktrace: str | None = None class _PenlogRecordV2(msgspec.Struct, omit_defaults=True, tag=2, tag_field="version"): module: str host: str data: str datetime: str priority: int tags: list[str] | None = None line: str | None = None stacktrace: str | None = None _python_level_no: int | None = None _python_level_name: str | None = None _python_func_name: str | None = None _PenlogRecord: TypeAlias = _PenlogRecordV1 | _PenlogRecordV2 def _colorize_msg(data: str, levelno: int) -> tuple[str, int]: if not sys.stderr.isatty(): return data, 0 out = "" match levelno: case Loglevel.TRACE: style = _Color.GRAY.value case Loglevel.DEBUG: style = _Color.GRAY.value case Loglevel.INFO: style = _Color.NOP.value case Loglevel.NOTICE: style = _Color.BOLD.value case Loglevel.WARNING: style = _Color.YELLOW.value case Loglevel.ERROR: style = _Color.RED.value case Loglevel.CRITICAL: style = _Color.RED.value + _Color.BOLD.value case _: style = _Color.NOP.value out += style out += data out += _Color.RESET.value return out, len(style) def _format_record( # noqa: PLR0913 dt: datetime.datetime, name: str, data: str, levelno: int, tags: list[str] | None, stacktrace: str | None, colored: bool = False, volatile_info: bool = False, ) -> str: msg = "\33[2K" extra_len = 4 msg += dt.strftime("%b %d %H:%M:%S.%f")[:-3] msg += " " msg += name if tags is not None and len(tags) > 0: msg += f" [{', '.join(tags)}]" msg += ": " if colored: tmp_msg, extra_len_tmp = _colorize_msg(data, levelno) msg += tmp_msg extra_len += extra_len_tmp else: msg += data if volatile_info and levelno <= Loglevel.INFO: terminal_width, _ = shutil.get_terminal_size() msg = msg[: terminal_width + extra_len - 1] # Adapt length to invisible ANSI colors msg += _Color.RESET.value msg += "\r" else: msg += "\n" if stacktrace is not None: msg += "\n" msg += stacktrace return msg
[docs] @dataclass class PenlogRecord: module: str host: str data: str datetime: datetime.datetime # FIXME: Enums are slow. priority: PenlogPriority tags: list[str] | None = None colored: bool = False line: str | None = None stacktrace: str | None = None _python_level_no: int | None = None _python_level_name: str | None = None _python_func_name: str | None = None def __str__(self) -> str: return _format_record( dt=self.datetime, name=self.module, data=self.data, levelno=self._python_level_no if self._python_level_no is not None else self.priority.to_level(), tags=self.tags, stacktrace=self.stacktrace, colored=self.colored, ) @classmethod def parse_priority(cls, data: bytes) -> int | None: if not data.startswith(b"<"): return None prio_str = data[1 : data.index(b">")] return int(prio_str) @classmethod def parse_json(cls, data: bytes) -> Self: if data.startswith(b"<"): data = data[data.index(b">") + 1 :] # PenlogRecordV1 has no version field, thus the tagged # union based approach does not work. record: _PenlogRecord try: record = msgspec.json.decode(data, type=_PenlogRecordV2) except msgspec.ValidationError: record = msgspec.json.decode(data, type=_PenlogRecordV1) match record: case _PenlogRecordV1(): try: dt = datetime.datetime.fromisoformat(record.timestamp) except ValueError: # Workaround for broken ISO strings. Go produced broken strings. :) # We have some old logfiles with this shortcoming. datestr, _ = record.timestamp.split(".", 2) dt = datetime.datetime.strptime(datestr, "%Y-%m-%dT%H:%M:%S") if record.tags is not None: tags = record.tags else: tags = [] if record.type is not None: tags += [record.type] return cls( module=record.component, host=record.host, data=record.data, datetime=dt, priority=PenlogPriority(record.priority), tags=tags, line=record.line, stacktrace=record.stacktrace, ) case _PenlogRecordV2(): return cls( module=record.module, host=record.host, data=record.data, datetime=datetime.datetime.fromisoformat(record.datetime), priority=PenlogPriority(record.priority), tags=record.tags, line=record.line, stacktrace=record.stacktrace, _python_level_no=record._python_level_no, _python_level_name=record._python_level_name, _python_func_name=record._python_func_name, ) raise ValueError("unknown record version") def to_log_record(self) -> logging.LogRecord: level = self.priority.to_level() timestamp = self.datetime.timestamp() msecs = (timestamp - int(timestamp)) * 1000 lineno = 0 pathname = "" if (line := self.line) is not None: pathname, lineno_str = line.rsplit(":", 1) lineno = int(lineno_str) return logging.makeLogRecord( { "name": self.module, "priority": self.priority, "levelno": level, "levelname": logging.getLevelName(level), "msg": self.data, "pathname": pathname, "lineno": lineno, "created": timestamp, "msecs": msecs, "host": self.host, "tags": self.tags, } )
class PenlogReader: def __init__(self, path: Path) -> None: self.path = path if str(path) != "-" else Path("/dev/stdin") self.raw_file = self._prepare_for_mmap(self.path) self.file_mmap = mmap.mmap(self.raw_file.fileno(), 0, access=mmap.ACCESS_READ) self._current_line = b"" self._current_record: PenlogRecord | None = None self._current_record_index = 0 self._parsed = False self._record_offsets: list[int] = [] def _test_mmap(self, path: Path) -> bool: with path.open("rb") as f: try: mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) return True except ValueError: return False def _prepare_for_mmap(self, path: Path) -> BinaryIO: if path.is_file() and path.suffix in [".zst", ".gz"]: tmpfile = tempfile.TemporaryFile() match path.suffix: case ".zst": with self.path.open("rb") as f: decomp = zstandard.ZstdDecompressor() decomp.copy_stream(f, tmpfile) case ".gz": with gzip.open(self.path, "rb") as f: shutil.copyfileobj(f, tmpfile) tmpfile.flush() return cast(BinaryIO, tmpfile) if path.is_fifo() or self._test_mmap(path) is False: tmpfile = tempfile.TemporaryFile() with path.open("rb") as f: shutil.copyfileobj(f, tmpfile) tmpfile.flush() return cast(BinaryIO, tmpfile) return self.path.open("rb") def _parse_file_structure(self) -> None: old_offset = self.file_mmap.tell() while True: self._record_offsets.append(self.file_mmap.tell()) line = self.file_mmap.readline() if line == b"": # The last newline char is not relevant, since # no data is following. del self._record_offsets[-1] break self.file_mmap.seek(old_offset) self._parsed = True def _lookup_offset(self, index: int) -> int: if index == 0: return 0 if not self._parsed: self._parse_file_structure() return self._record_offsets[index] @property def file_size(self) -> int: old_offset = self.file_mmap.tell() self.file_mmap.seek(0, io.SEEK_END) size = self.file_mmap.tell() self.file_mmap.seek(old_offset) return size @property def current_record(self) -> PenlogRecord: if self._current_record is not None: return self._current_record return PenlogRecord.parse_json(self._current_line) @property def current_priority(self) -> int: prio = PenlogRecord.parse_priority(self._current_line) if prio is None: self._current_record = PenlogRecord.parse_json(self._current_line) prio = self._current_record.priority return prio def seek_to_record(self, n: int) -> None: self.file_mmap.seek(self._lookup_offset(n)) self._current_record_index = n def seek_to_current_record(self) -> None: self.file_mmap.seek(self._lookup_offset(self._current_record_index)) def seek_to_previous_record(self) -> None: self._current_record_index -= 1 self.seek_to_record(self._current_record_index) def seek_to_next_record(self) -> None: self._current_record_index += 1 self.seek_to_record(self._current_record_index) def records( self, priority: PenlogPriority = PenlogPriority.TRACE, offset: int = 0, reverse: bool = False, ) -> Iterator[PenlogRecord]: self.seek_to_record(offset) if reverse is False: while True: if self.readline() == b"": break if self.current_priority <= priority: yield self.current_record else: while True: self.readline() if self.current_priority <= priority: yield self.current_record try: self.seek_to_previous_record() except IndexError: break def readline(self) -> bytes: self._current_record = None self._current_line = self.file_mmap.readline() return self._current_line def close(self) -> None: self.file_mmap.close() self.raw_file.close() def __len__(self) -> int: if not self._parsed: self._parse_file_structure() return len(self._record_offsets) def __enter__(self) -> PenlogReader: return self def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, tb: TracebackType | None, ) -> None: if exc_type is not None: self.close() @unique class _Color(Enum): NOP = "" RESET = "\033[0m" BOLD = "\033[1m" RED = "\033[31m" GREEN = "\033[32m" YELLOW = "\033[33m" BLUE = "\033[34m" PURPLE = "\033[35m" CYAN = "\033[36m" WHITE = "\033[37m" GRAY = "\033[0;38;5;245m" class _JSONFormatter(logging.Formatter): def __init__(self) -> None: super().__init__() self.hostname = socket.gethostname() def format(self, record: logging.LogRecord) -> str: tags = record.__dict__["tags"] if "tags" in record.__dict__ else None stacktrace = self.formatException(record.exc_info) if record.exc_info else None penlog_record = _PenlogRecordV2( module=record.name, host=self.hostname, data=record.getMessage(), priority=PenlogPriority.from_level(record.levelno).value, datetime=datetime.datetime.fromtimestamp(record.created, tz=tz).isoformat(), line=f"{record.pathname}:{record.lineno}", stacktrace=stacktrace, tags=tags, _python_level_no=record.levelno, _python_level_name=record.levelname, _python_func_name=record.funcName, ) return msgspec.json.encode(penlog_record).decode() class _ConsoleFormatter(logging.Formatter): colored: bool = False volatile_info: bool = False def format( self, record: logging.LogRecord, ) -> str: stacktrace = None if record.exc_info: exc_type, exc_value, exc_traceback = record.exc_info assert exc_type assert exc_value assert exc_traceback stacktrace = "\n" stacktrace += "".join(traceback.format_exception(exc_type, exc_value, exc_traceback)) return _format_record( dt=datetime.datetime.fromtimestamp(record.created), name=record.name, data=record.getMessage(), levelno=record.levelno, tags=record.__dict__["tags"] if "tags" in record.__dict__ else None, stacktrace=stacktrace, colored=self.colored, volatile_info=self.volatile_info, ) class _ZstdFileHandler(logging.Handler): def __init__(self, path: Path, level: int | str = logging.NOTSET) -> None: super().__init__(level) self.file = zstandard.open( filename=path, mode="wb", cctx=zstandard.ZstdCompressor( write_checksum=True, write_content_size=True, threads=-1, ), ) def close(self) -> None: self.file.flush() self.file.close() def emit(self, record: logging.LogRecord) -> None: prio = PenlogPriority.from_level(record.levelno).value data = f"<{prio}>{self.format(record)}" if not data.endswith("\n"): data += "\n" self.file.write(data.encode())
[docs] class Logger(logging.Logger): def trace( self, msg: Any, *args: Any, exc_info: _ExcInfoType = None, stack_info: bool = False, extra: dict[str, Any] | None = None, **kwargs: Any, ) -> None: if self.isEnabledFor(Loglevel.TRACE): self._log( Loglevel.TRACE, msg, args, exc_info=exc_info, extra=extra, stack_info=stack_info, **kwargs, ) def notice( self, msg: Any, *args: Any, exc_info: _ExcInfoType = None, stack_info: bool = False, extra: dict[str, Any] | None = None, **kwargs: Any, ) -> None: if self.isEnabledFor(Loglevel.NOTICE): self._log( Loglevel.NOTICE, msg, args, exc_info=exc_info, extra=extra, stack_info=stack_info, **kwargs, ) def result( self, msg: Any, *args: Any, exc_info: _ExcInfoType = None, stack_info: bool = False, extra: dict[str, Any] | None = None, **kwargs: Any, ) -> None: extra = extra if extra is not None else {} extra["tags"] = ["result"] if self.isEnabledFor(Loglevel.NOTICE): self._log( Loglevel.NOTICE, msg, args, exc_info=exc_info, extra=extra, stack_info=stack_info, **kwargs, )
logging.setLoggerClass(Logger) def get_logger(name: str) -> Logger: return cast(Logger, logging.getLogger(name))