# SPDX-FileCopyrightText: AISEC Pentesting Team
#
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import asyncio
import binascii
import io
from abc import ABC, abstractmethod
from typing import Any, Self
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
from typing_extensions import Protocol
from gallia.log import get_logger
from gallia.utils import join_host_port
logger = get_logger("gallia.transport.base")
[docs]
class TargetURI:
"""TargetURI represents a target to which gallia can connect.
The target string must conform to a URI is specified by RFC3986.
Basically, this is a wrapper around Python's ``urlparse()`` and
``parse_qs()`` methods. TargetURI provides frequently used properties
for a more userfriendly usage. Instances are meant to be passed to
:meth:`BaseTransport.connect()` of transport implementations.
"""
def __init__(self, raw: str) -> None:
self.raw = raw
self.url = urlparse(raw)
self.qs = parse_qs(self.url.query)
[docs]
@classmethod
def from_parts(
cls,
scheme: str,
host: str,
port: int | None,
args: dict[str, Any],
) -> TargetURI:
"""Constructs a instance of TargetURI with the given arguments.
The ``args`` dict is used for the query string.
"""
netloc = host if port is None else join_host_port(host, port)
return TargetURI(urlunparse((scheme, netloc, "", "", urlencode(args), "")))
@property
def scheme(self) -> str:
"""The URI scheme"""
return self.url.scheme
@property
def hostname(self) -> str | None:
"""The hostname (without port)"""
return self.url.hostname
@property
def port(self) -> int | None:
"""The port number"""
return self.url.port
@property
def netloc(self) -> str:
"""The hostname and the portnumber, separated by a colon."""
return self.url.netloc
@property
def path(self) -> str:
"""The path property of the url."""
return self.url.path
@property
def location(self) -> str:
"""A URI string which only consists of the relevant scheme,
the host and the port.
"""
return f"{self.scheme}://{self.url.netloc}"
@property
def qs_flat(self) -> dict[str, str]:
"""A dict which contains the query string's key/value pairs.
In case a key appears multiple times, this variant only
contains the first found key/value pair. In contrast to
:attr:`qs`, this variant avoids lists and might be easier
to use for some cases.
"""
d = {}
for k, v in self.qs.items():
d[k] = v[0]
return d
def __str__(self) -> str:
return self.raw
class TransportProtocol(Protocol):
mutex: asyncio.Lock
target: TargetURI
is_closed: bool
def get_writer(self) -> asyncio.StreamWriter:
raise NotImplementedError
def get_reader(self) -> asyncio.StreamReader:
raise NotImplementedError
[docs]
class BaseTransport(ABC):
"""BaseTransport is the base class providing the required
interface for all transports used by gallia.
A transport usually is some kind of network protocol which
carries an application level protocol. A good example is
DoIP carrying UDS requests which acts as a minimal middleware
on top of TCP.
This class is to be used as a subclass with all abstractmethods
implemented and the SCHEME property filled.
A few methods provide a ``tags`` argument. The debug logs of these
calls include these tags in the ``tags`` property of the relevant
:class:`gallia.log.PenlogRecord`.
"""
#: The scheme for the implemented protocol, e.g. "doip".
SCHEME: str = ""
#: The buffersize of the transport. Might be used in read() calls.
#: Defaults to :const:`io.DEFAULT_BUFFER_SIZE`.
BUFSIZE: int = io.DEFAULT_BUFFER_SIZE
def __init__(self, target: TargetURI) -> None:
self.mutex = asyncio.Lock()
self.target = target
self.is_closed = False
def __init_subclass__(
cls,
/,
scheme: str,
bufsize: int = io.DEFAULT_BUFFER_SIZE,
**kwargs: Any,
) -> None:
super().__init_subclass__(**kwargs)
cls.SCHEME = scheme
cls.BUFSIZE = bufsize
[docs]
@classmethod
def check_scheme(cls, target: TargetURI) -> None:
"""Checks if the provided URI has the correct scheme."""
if target.scheme != cls.SCHEME:
raise ValueError(f"invalid scheme: {target.scheme}; expected: {cls.SCHEME}")
[docs]
@classmethod
@abstractmethod
async def connect(
cls,
target: str | TargetURI,
timeout: float | None = None,
) -> Self:
"""Classmethod to connect the transport to a relevant target.
The target argument is a URI, such as `doip://192.0.2.2:13400?src_addr=0xf4&dst_addr=0x1d"`
An instance of the relevant transport class is returned.
"""
[docs]
@abstractmethod
async def close(self) -> None:
"""Terminates the connection and clean up all allocated ressources."""
[docs]
async def reconnect(self, timeout: float | None = None) -> Self:
"""Closes the connection to the target and reconnects. A new
instance of this class is returned rendering the old one
obsolete. This method is safe for concurrent use.
"""
async with self.mutex:
try:
await self.close()
except ConnectionError as e:
logger.warning(f"close() failed during reconnect ({e}); ignoring")
return await self.connect(self.target)
[docs]
@abstractmethod
async def read(
self,
timeout: float | None = None,
tags: list[str] | None = None,
) -> bytes:
"""Reads one message and returns its raw byte representation.
An example for one message is 'one line, terminated by newline'
for a TCP transport yielding lines.
"""
[docs]
@abstractmethod
async def write(
self,
data: bytes,
timeout: float | None = None,
tags: list[str] | None = None,
) -> int:
"""Writes one message and return the number of written bytes."""
[docs]
async def request(
self,
data: bytes,
timeout: float | None = None,
tags: list[str] | None = None,
) -> bytes:
"""Chains a :meth:`write()` call with a :meth:`read()` call.
The call is protected by a mutex and is thus safe for concurrent
use.
"""
async with self.mutex:
return await self.request_unsafe(data, timeout, tags)
[docs]
async def request_unsafe(
self,
data: bytes,
timeout: float | None = None,
tags: list[str] | None = None,
) -> bytes:
"""Chains a :meth:`write()` call with a :meth:`read()` call.
The call is **not** protected by a mutex. Only use this method
when you know what you are doing.
"""
await self.write(data, timeout, tags)
return await self.read(timeout, tags)
class LinesTransportMixin:
async def write(
self: TransportProtocol,
data: bytes,
timeout: float | None = None,
tags: list[str] | None = None,
) -> int:
t = tags + ["write"] if tags is not None else ["write"]
logger.trace(data.hex() + "0a", extra={"tags": t})
writer = self.get_writer()
writer.write(binascii.hexlify(data) + b"\n")
await asyncio.wait_for(writer.drain(), timeout)
return len(data)
async def read(
self: TransportProtocol,
timeout: float | None = None,
tags: list[str] | None = None,
) -> bytes:
data = await asyncio.wait_for(self.get_reader().readline(), timeout)
d = data.decode().strip()
t = tags + ["read"] if tags is not None else ["read"]
logger.trace(d + "0a", extra={"tags": t})
return binascii.unhexlify(d)