Source code for gallia.transports.unix

# SPDX-FileCopyrightText: AISEC Pentesting Team
#
# SPDX-License-Identifier: Apache-2.0

import asyncio
import sys
from typing import Self

assert sys.platform.startswith("linux"), "unsupported platform"

from gallia.log import get_logger
from gallia.transports.base import BaseTransport, LinesTransportMixin, TargetURI

logger = get_logger(__name__)


[docs] class UnixTransport(BaseTransport, scheme="unix"): def __init__( self, target: TargetURI, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, ) -> None: super().__init__(target) self.reader = reader self.writer = writer
[docs] @classmethod async def connect(cls, target: str | TargetURI, timeout: float | None = None) -> Self: t = target if isinstance(target, TargetURI) else TargetURI(target) cls.check_scheme(t) reader, writer = await asyncio.wait_for(asyncio.open_unix_connection(t.path), timeout) return cls(t, reader, writer)
[docs] async def close(self) -> None: if self.is_closed: return self.writer.close() await self.writer.wait_closed()
[docs] async def write( self, data: bytes, timeout: float | None = None, tags: list[str] | None = None, ) -> int: t = tags + ["write"] if tags is not None else ["write"] logger.trace(data.hex(), extra={"tags": t}) self.writer.write(data) await asyncio.wait_for(self.writer.drain(), timeout) return len(data)
[docs] async def read( self, timeout: float | None = None, tags: list[str] | None = None, ) -> bytes: data = await self.reader.read() t = tags + ["read"] if tags is not None else ["read"] logger.trace(data.hex(), extra={"tags": t}) return data
[docs] class UnixLinesTransport(LinesTransportMixin, UnixTransport, scheme="unix-lines"): def get_reader(self) -> asyncio.StreamReader: return self.reader def get_writer(self) -> asyncio.StreamWriter: return self.writer