Source code for gallia.transports.tcp

# SPDX-FileCopyrightText: AISEC Pentesting Team
#
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import asyncio

from gallia.log import get_logger
from gallia.transports.base import BaseTransport, LinesTransportMixin, TargetURI

logger = get_logger("gallia.transport.tcp")


[docs] class TCPTransport(BaseTransport, scheme="tcp"): def __init__( self, target: TargetURI, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, ) -> None: super().__init__(target) self.reader = reader self.writer = writer
[docs] @classmethod async def connect(cls, target: str | TargetURI, timeout: float | None = None) -> TCPTransport: t = target if isinstance(target, TargetURI) else TargetURI(target) cls.check_scheme(t) reader, writer = await asyncio.wait_for( asyncio.open_connection(t.hostname, t.port), timeout ) return cls(t, reader, writer)
[docs] async def close(self) -> None: if self.is_closed: return self.is_closed = True self.writer.close() await self.writer.wait_closed()
[docs] async def write( self, data: bytes, timeout: float | None = None, tags: list[str] | None = None, ) -> int: t = tags + ["write"] if tags is not None else ["write"] logger.trace(data.hex(), extra={"tags": t}) self.writer.write(data) await asyncio.wait_for(self.writer.drain(), timeout) return len(data)
[docs] async def read( self, timeout: float | None = None, tags: list[str] | None = None, ) -> bytes: data = await asyncio.wait_for(self.reader.read(self.BUFSIZE), timeout) t = tags + ["read"] if tags is not None else ["read"] logger.trace(data.hex(), extra={"tags": t}) return data
[docs] class TCPLinesTransport(LinesTransportMixin, TCPTransport, scheme="tcp-lines"): def get_reader(self) -> asyncio.StreamReader: return self.reader def get_writer(self) -> asyncio.StreamWriter: return self.writer