From f2cea90f2f100618fed786d9ed30358e2532ef57 Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Wed, 19 Mar 2025 14:56:13 -0400 Subject: [PATCH 1/9] Allow serving alternative endpoints on exporter --- .../jumpstarter/jumpstarter/client/client.py | 24 ++++- .../jumpstarter/jumpstarter/client/core.py | 17 ++-- .../jumpstarter/jumpstarter/client/grpc.py | 30 +++++- .../jumpstarter/config/exporter.py | 5 +- .../jumpstarter/exporter/exporter.py | 3 +- .../jumpstarter/exporter/session.py | 16 ++- .../jumpstarter/jumpstarter/exporter/tls.py | 98 +++++++++++++++++++ packages/jumpstarter/pyproject.toml | 2 +- uv.lock | 4 +- 9 files changed, 175 insertions(+), 24 deletions(-) create mode 100644 packages/jumpstarter/jumpstarter/exporter/tls.py diff --git a/packages/jumpstarter/jumpstarter/client/client.py b/packages/jumpstarter/jumpstarter/client/client.py index b994cc725..8658a9c33 100644 --- a/packages/jumpstarter/jumpstarter/client/client.py +++ b/packages/jumpstarter/jumpstarter/client/client.py @@ -6,8 +6,8 @@ import grpc from anyio.from_thread import BlockingPortal from google.protobuf import empty_pb2 -from jumpstarter_protocol import jumpstarter_pb2_grpc +from .grpc import SmartExporterStub from jumpstarter.client import DriverClient from jumpstarter.common.importlib import import_class @@ -26,13 +26,31 @@ async def client_from_channel( stack: ExitStack, allow: list[str], unsafe: bool, + use_alternative_endpoints: bool = True, ) -> DriverClient: topo = defaultdict(list) last_seen = {} reports = {} clients = OrderedDict() - response = await jumpstarter_pb2_grpc.ExporterServiceStub(channel).GetReport(empty_pb2.Empty()) + response = await SmartExporterStub([channel]).GetReport(empty_pb2.Empty()) + + channels = [channel] + if use_alternative_endpoints: + for endpoint in response.alternative_endpoints: + if endpoint.certificate: + channels.append( + grpc.aio.secure_channel( + endpoint.endpoint, + grpc.ssl_channel_credentials( + root_certificates=endpoint.certificate.encode(), + private_key=endpoint.client_private_key.encode(), + certificate_chain=endpoint.client_certificate.encode(), + ), + ) + ) + + stub = SmartExporterStub(list(reversed(channels))) for index, report in enumerate(response.reports): topo[index] = [] @@ -52,7 +70,7 @@ async def client_from_channel( client = client_class( uuid=UUID(report.uuid), labels=report.labels, - channel=channel, + stub=stub, portal=portal, stack=stack.enter_context(ExitStack()), children={reports[k].labels["jumpstarter.dev/name"]: clients[k] for k in topo[index]}, diff --git a/packages/jumpstarter/jumpstarter/client/core.py b/packages/jumpstarter/jumpstarter/client/core.py index 91a125233..248c7125e 100644 --- a/packages/jumpstarter/jumpstarter/client/core.py +++ b/packages/jumpstarter/jumpstarter/client/core.py @@ -5,11 +5,12 @@ import logging from contextlib import asynccontextmanager from dataclasses import dataclass, field +from typing import Any from anyio import create_task_group from google.protobuf import empty_pb2 from grpc import StatusCode -from grpc.aio import AioRpcError, Channel +from grpc.aio import AioRpcError from jumpstarter_protocol import jumpstarter_pb2, jumpstarter_pb2_grpc, router_pb2_grpc from jumpstarter.common import Metadata @@ -60,7 +61,7 @@ class AsyncDriverClient( Backing implementation of blocking driver client. """ - channel: Channel + stub: Any log_level: str = "INFO" logger: logging.Logger = field(init=False) @@ -68,8 +69,6 @@ class AsyncDriverClient( def __post_init__(self): if hasattr(super(), "__post_init__"): super().__post_init__() - jumpstarter_pb2_grpc.ExporterServiceStub.__init__(self, self.channel) - router_pb2_grpc.RouterServiceStub.__init__(self, self.channel) self.logger = logging.getLogger(self.__class__.__name__) self.logger.setLevel(self.log_level) @@ -89,7 +88,7 @@ async def call_async(self, method, *args): ) try: - response = await self.DriverCall(request) + response = await self.stub.DriverCall(request) except AioRpcError as e: match e.code(): case StatusCode.UNIMPLEMENTED: @@ -113,7 +112,7 @@ async def streamingcall_async(self, method, *args): ) try: - async for response in self.StreamingDriverCall(request): + async for response in self.stub.StreamingDriverCall(request): yield decode_value(response.result) except AioRpcError as e: match e.code(): @@ -128,7 +127,7 @@ async def streamingcall_async(self, method, *args): @asynccontextmanager async def stream_async(self, method): - context = self.Stream( + context = self.stub.Stream( metadata=StreamRequestMetadata.model_construct(request=DriverStreamRequest(uuid=self.uuid, method=method)) .model_dump(mode="json", round_trip=True) .items(), @@ -142,7 +141,7 @@ async def resource_async( self, stream, ): - context = self.Stream( + context = self.stub.Stream( metadata=StreamRequestMetadata.model_construct(request=ResourceStreamRequest(uuid=self.uuid)) .model_dump(mode="json", round_trip=True) .items(), @@ -160,7 +159,7 @@ def __log(self, level: int, msg: str): @asynccontextmanager async def log_stream_async(self): async def log_stream(): - async for response in self.LogStream(empty_pb2.Empty()): + async for response in self.stub.LogStream(empty_pb2.Empty()): self.__log(logging.getLevelName(response.severity), response.message) async with create_task_group() as tg: diff --git a/packages/jumpstarter/jumpstarter/client/grpc.py b/packages/jumpstarter/jumpstarter/client/grpc.py index 3ea50856c..bf87d375f 100644 --- a/packages/jumpstarter/jumpstarter/client/grpc.py +++ b/packages/jumpstarter/jumpstarter/client/grpc.py @@ -1,12 +1,16 @@ from __future__ import annotations -from dataclasses import dataclass, field +from collections import OrderedDict +from dataclasses import InitVar, dataclass, field from datetime import datetime, timedelta +from types import SimpleNamespace +from typing import Any import yaml from google.protobuf import duration_pb2, field_mask_pb2, json_format +from grpc import ChannelConnectivity from grpc.aio import Channel -from jumpstarter_protocol import client_pb2, client_pb2_grpc, kubernetes_pb2 +from jumpstarter_protocol import client_pb2, client_pb2_grpc, jumpstarter_pb2_grpc, kubernetes_pb2, router_pb2_grpc from pydantic import BaseModel, ConfigDict, Field, field_serializer from jumpstarter.common.grpc import translate_grpc_exceptions @@ -250,3 +254,25 @@ async def DeleteLease(self, *, name: str): name="namespaces/{}/leases/{}".format(self.namespace, name), ) ) + + +@dataclass(frozen=True, slots=True) +class SmartExporterStub: + channels: InitVar[list[Channel]] + + __stubs: dict[Channel, Any] = field(init=False, default_factory=OrderedDict) + + def __post_init__(self, channels): + for channel in channels: + stub = SimpleNamespace() + jumpstarter_pb2_grpc.ExporterServiceStub.__init__(stub, channel) + router_pb2_grpc.RouterServiceStub.__init__(stub, channel) + self.__stubs[channel] = stub + + def __getattr__(self, name): + for channel, stub in self.__stubs.items(): + # find the first channel that's ready + if channel.get_state(try_to_connect=True) == ChannelConnectivity.READY: + return getattr(stub, name) + # or fallback to the last channel (via router) + return getattr(next(reversed(self.__stubs.values())), name) diff --git a/packages/jumpstarter/jumpstarter/config/exporter.py b/packages/jumpstarter/jumpstarter/config/exporter.py index f4240bc67..60213fd83 100644 --- a/packages/jumpstarter/jumpstarter/config/exporter.py +++ b/packages/jumpstarter/jumpstarter/config/exporter.py @@ -2,7 +2,7 @@ from contextlib import asynccontextmanager, contextmanager, suppress from pathlib import Path -from typing import Any, ClassVar, Literal, Optional, Self +from typing import Any, ClassVar, List, Literal, Optional, Self import grpc import yaml @@ -83,6 +83,8 @@ class ExporterConfigV1Alpha1(BaseModel): token: str grpcOptions: dict[str, str | int] | None = Field(default_factory=dict) + alternative_endpoints: List[str] = Field(default_factory=list) + export: dict[str, ExporterConfigV1Alpha1DriverInstance] = Field(default_factory=dict) path: Path | None = Field(default=None) @@ -171,6 +173,7 @@ def channel_factory(): device_factory=ExporterConfigV1Alpha1DriverInstance(children=self.export).instantiate, tls=self.tls, grpc_options=self.grpcOptions, + alternative_endpoints=self.alternative_endpoints, ) as exporter: await exporter.serve() diff --git a/packages/jumpstarter/jumpstarter/exporter/exporter.py b/packages/jumpstarter/jumpstarter/exporter/exporter.py index 428292f89..af37c8c66 100644 --- a/packages/jumpstarter/jumpstarter/exporter/exporter.py +++ b/packages/jumpstarter/jumpstarter/exporter/exporter.py @@ -25,6 +25,7 @@ class Exporter(AbstractAsyncContextManager, Metadata): channel_factory: Callable[[], grpc.aio.Channel] device_factory: Callable[[], Driver] lease_name: str = field(init=False, default="") + alternative_endpoints: list[str] = field(default_factory=list) tls: TLSConfigV1Alpha1 = field(default_factory=TLSConfigV1Alpha1) grpc_options: dict[str, str] = field(default_factory=dict) @@ -50,7 +51,7 @@ async def session(self): labels=self.labels, root_device=self.device_factory(), ) as session: - async with session.serve_unix_async() as path: + async with session.serve_unix_async(alternative_endpoints=self.alternative_endpoints) as path: async with grpc.aio.secure_channel( f"unix://{path}", grpc.local_channel_credentials(grpc.LocalConnectionType.UDS) ) as channel: diff --git a/packages/jumpstarter/jumpstarter/exporter/session.py b/packages/jumpstarter/jumpstarter/exporter/session.py index 55d52f471..0f7503262 100644 --- a/packages/jumpstarter/jumpstarter/exporter/session.py +++ b/packages/jumpstarter/jumpstarter/exporter/session.py @@ -15,6 +15,7 @@ ) from .logging import LogHandler +from .tls import with_alternative_endpoints from jumpstarter.common import Metadata, TemporarySocket from jumpstarter.common.streams import StreamRequestMetadata from jumpstarter.driver import Driver @@ -53,12 +54,16 @@ def __init__(self, *args, root_device, **kwargs): self._logging_queue = deque(maxlen=32) self._logging_handler = LogHandler(self._logging_queue) + self._alternative_endpoints = [] @asynccontextmanager - async def serve_port_async(self, port): + async def serve_ports_async(self, port, alternative_endpoints: list[str] | None = None): server = grpc.aio.server() server.add_insecure_port(port) + if alternative_endpoints is not None: + self._alternative_endpoints = with_alternative_endpoints(server, alternative_endpoints) + jumpstarter_pb2_grpc.add_ExporterServiceServicer_to_server(self, server) router_pb2_grpc.add_RouterServiceServicer_to_server(self, server) @@ -69,15 +74,15 @@ async def serve_port_async(self, port): await server.stop(grace=None) @asynccontextmanager - async def serve_unix_async(self): + async def serve_unix_async(self, alternative_endpoints: list[str] | None = None): with TemporarySocket() as path: - async with self.serve_port_async(f"unix://{path}"): + async with self.serve_ports_async(f"unix://{path}", alternative_endpoints): yield path @contextmanager - def serve_unix(self): + def serve_unix(self, alternative_endpoints: list[str] | None = None): with start_blocking_portal() as portal: - with portal.wrap_async_context_manager(self.serve_unix_async()) as path: + with portal.wrap_async_context_manager(self.serve_unix_async(alternative_endpoints)) as path: yield path def __getitem__(self, key: UUID): @@ -92,6 +97,7 @@ async def GetReport(self, request, context): instance.report(parent=parent, name=name) for (_, parent, name, instance) in self.root_device.enumerate() ], + alternative_endpoints=self._alternative_endpoints, ) async def DriverCall(self, request, context): diff --git a/packages/jumpstarter/jumpstarter/exporter/tls.py b/packages/jumpstarter/jumpstarter/exporter/tls.py new file mode 100644 index 000000000..a6bc512ca --- /dev/null +++ b/packages/jumpstarter/jumpstarter/exporter/tls.py @@ -0,0 +1,98 @@ +from datetime import datetime, timedelta +from ipaddress import IPv4Address, IPv6Address, ip_address + +import grpc +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from jumpstarter_protocol import jumpstarter_pb2 + + +def parse_endpoint(endpoint): + host, sep, port = endpoint.rpartition(":") + + if sep == "": + raise ValueError("port not specified in endpoint {}".format(endpoint)) + + host = host.strip("[]") # strip brackets from ipv6 addresses + + try: + port = int(port) + if port < 0 or port > 65535: + raise ValueError("port number {} out of range".format(port)) + except ValueError as e: + raise ValueError("invalid port {} in endpoint {}".format(port, endpoint)) from e + + try: + return ip_address(host), port + except ValueError: + return host, port + + +def with_alternative_endpoints(server, endpoints: list[str]): + sans = [] + for endpoint in endpoints: + host, port = parse_endpoint(endpoint) + match host: + case str(): + sans.append(x509.DNSName(host)) + case IPv4Address() | IPv6Address(): + sans.append(x509.IPAddress(host)) + + key = rsa.generate_private_key(public_exponent=65537, key_size=2048, backend=default_backend()) + client_key = rsa.generate_private_key(public_exponent=65537, key_size=2048, backend=default_backend()) + + crt = ( + x509.CertificateBuilder() + .subject_name(x509.Name([])) + .issuer_name(x509.Name([])) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now()) + .not_valid_after(datetime.now() + timedelta(days=365)) + .add_extension(x509.SubjectAlternativeName(sans), critical=False) + .sign(private_key=key, algorithm=hashes.SHA256(), backend=default_backend()) + ) + client_crt = ( + x509.CertificateBuilder() + .subject_name(x509.Name([])) + .issuer_name(x509.Name([])) + .public_key(client_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now()) + .not_valid_after(datetime.now() + timedelta(days=365)) + .sign(private_key=client_key, algorithm=hashes.SHA256(), backend=default_backend()) + ) + + pem_crt = crt.public_bytes(serialization.Encoding.PEM) + pem_key = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + + pem_client_crt = client_crt.public_bytes(serialization.Encoding.PEM) + pem_client_key = client_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + + server_credentials = grpc.ssl_server_credentials( + [(pem_key, pem_crt)], root_certificates=pem_client_crt, require_client_auth=True + ) + + endpoints_pb = [] + for endpoint in endpoints: + server.add_secure_port(endpoint, server_credentials) + endpoints_pb.append( + jumpstarter_pb2.Endpoint( + endpoint=endpoint, + certificate=pem_crt, + client_certificate=pem_client_crt, + client_private_key=pem_client_key, + ), + ) + + return endpoints_pb diff --git a/packages/jumpstarter/pyproject.toml b/packages/jumpstarter/pyproject.toml index 36d995918..308466533 100644 --- a/packages/jumpstarter/pyproject.toml +++ b/packages/jumpstarter/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "anyio>=4.4.0,!=4.6.2", "aiohttp>=3.10.5", "tqdm>=4.66.5", + "cryptography>=43.0.3", "pydantic>=2.8.2" ] @@ -25,7 +26,6 @@ dev = [ "pytest-cov>=6.0.0", "pytest-anyio>=0.0.0", "pytest-asyncio>=0.0.0", - "cryptography>=43.0.3", "jumpstarter-driver-power", "jumpstarter-driver-network", "jumpstarter-driver-composite" diff --git a/uv.lock b/uv.lock index 156912ff1..ccc7ae9e6 100644 --- a/uv.lock +++ b/uv.lock @@ -922,6 +922,7 @@ source = { editable = "packages/jumpstarter" } dependencies = [ { name = "aiohttp" }, { name = "anyio" }, + { name = "cryptography" }, { name = "jumpstarter-protocol" }, { name = "pydantic" }, { name = "pyyaml" }, @@ -930,7 +931,6 @@ dependencies = [ [package.dev-dependencies] dev = [ - { name = "cryptography" }, { name = "jumpstarter-driver-composite" }, { name = "jumpstarter-driver-network" }, { name = "jumpstarter-driver-power" }, @@ -944,6 +944,7 @@ dev = [ requires-dist = [ { name = "aiohttp", specifier = ">=3.10.5" }, { name = "anyio", specifier = ">=4.4.0,!=4.6.2" }, + { name = "cryptography", specifier = ">=43.0.3" }, { name = "jumpstarter-protocol", editable = "packages/jumpstarter-protocol" }, { name = "pydantic", specifier = ">=2.8.2" }, { name = "pyyaml", specifier = ">=6.0.2" }, @@ -952,7 +953,6 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ - { name = "cryptography", specifier = ">=43.0.3" }, { name = "jumpstarter-driver-composite", editable = "packages/jumpstarter-driver-composite" }, { name = "jumpstarter-driver-network", editable = "packages/jumpstarter-driver-network" }, { name = "jumpstarter-driver-power", editable = "packages/jumpstarter-driver-power" }, From 7cc466bde4cd4e8b007d8b0afb0e4ba0c5c46ca2 Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Fri, 21 Mar 2025 11:41:31 -0400 Subject: [PATCH 2/9] Make useAlternativeEndpoints configurable --- packages/jumpstarter/jumpstarter/client/client.py | 2 +- packages/jumpstarter/jumpstarter/client/lease.py | 10 +++++++++- packages/jumpstarter/jumpstarter/config/client.py | 3 +++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/packages/jumpstarter/jumpstarter/client/client.py b/packages/jumpstarter/jumpstarter/client/client.py index 8658a9c33..e4732c146 100644 --- a/packages/jumpstarter/jumpstarter/client/client.py +++ b/packages/jumpstarter/jumpstarter/client/client.py @@ -26,7 +26,7 @@ async def client_from_channel( stack: ExitStack, allow: list[str], unsafe: bool, - use_alternative_endpoints: bool = True, + use_alternative_endpoints: bool = False, ) -> DriverClient: topo = defaultdict(list) last_seen = {} diff --git a/packages/jumpstarter/jumpstarter/client/lease.py b/packages/jumpstarter/jumpstarter/client/lease.py index 2902447e5..a306ade54 100644 --- a/packages/jumpstarter/jumpstarter/client/lease.py +++ b/packages/jumpstarter/jumpstarter/client/lease.py @@ -41,6 +41,7 @@ class Lease(AbstractContextManager, AbstractAsyncContextManager): controller: jumpstarter_pb2_grpc.ControllerServiceStub = field(init=False) tls_config: TLSConfigV1Alpha1 = field(default_factory=TLSConfigV1Alpha1) grpc_options: dict[str, Any] = field(default_factory=dict) + use_alternative_endpoints: bool = False def __post_init__(self): if hasattr(super(), "__post_init__"): @@ -184,7 +185,14 @@ async def _monitor(): @asynccontextmanager async def connect_async(self, stack): async with self.serve_unix_async() as path: - async with client_from_path(path, self.portal, stack, allow=self.allow, unsafe=self.unsafe) as client: + async with client_from_path( + path, + self.portal, + stack, + allow=self.allow, + unsafe=self.unsafe, + use_alternative_endpoints=self.use_alternative_endpoints, + ) as client: yield client @contextmanager diff --git a/packages/jumpstarter/jumpstarter/config/client.py b/packages/jumpstarter/jumpstarter/config/client.py index 7fba240fe..e278fc0fd 100644 --- a/packages/jumpstarter/jumpstarter/config/client.py +++ b/packages/jumpstarter/jumpstarter/config/client.py @@ -50,6 +50,8 @@ class ClientConfigV1Alpha1(BaseModel): token: str grpcOptions: dict[str, str | int] | None = Field(default_factory=dict) + useAlternativeEndpoints: bool = False + drivers: ClientConfigV1Alpha1Drivers async def channel(self): @@ -172,6 +174,7 @@ async def lease_async( release=release_lease, tls_config=self.tls, grpc_options=self.grpcOptions, + use_alternative_endpoints=self.useAlternativeEndpoints, ) as lease: yield lease From 2665c8ab6bc74ed523468a8c42355641668eb84c Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Fri, 21 Mar 2025 11:55:58 -0400 Subject: [PATCH 3/9] Pass use_alternative_endpoints to shell --- .../jumpstarter-cli/jumpstarter_cli/shell.py | 8 ++++++- .../jumpstarter/jumpstarter/client/client.py | 11 ++++++++-- .../jumpstarter/jumpstarter/common/utils.py | 22 ++++++++++++++++--- .../jumpstarter/jumpstarter/config/env.py | 1 + 4 files changed, 36 insertions(+), 6 deletions(-) diff --git a/packages/jumpstarter-cli/jumpstarter_cli/shell.py b/packages/jumpstarter-cli/jumpstarter_cli/shell.py index c37e2a765..f8e9e3dc7 100644 --- a/packages/jumpstarter-cli/jumpstarter_cli/shell.py +++ b/packages/jumpstarter-cli/jumpstarter_cli/shell.py @@ -31,7 +31,13 @@ def shell(config, lease_name, selector, duration): with config.lease(selector=selector, lease_name=lease_name, duration=duration) as lease: with lease.serve_unix() as path: with lease.monitor(): - exit_code = launch_shell(path, "remote", config.drivers.allow, config.drivers.unsafe) + exit_code = launch_shell( + path, + "remote", + config.drivers.allow, + config.drivers.unsafe, + use_alternative_endpoints=config.use_alternative_endpoints, + ) sys.exit(exit_code) diff --git a/packages/jumpstarter/jumpstarter/client/client.py b/packages/jumpstarter/jumpstarter/client/client.py index e4732c146..5d39bd878 100644 --- a/packages/jumpstarter/jumpstarter/client/client.py +++ b/packages/jumpstarter/jumpstarter/client/client.py @@ -13,11 +13,18 @@ @asynccontextmanager -async def client_from_path(path: str, portal: BlockingPortal, stack: ExitStack, allow: list[str], unsafe: bool): +async def client_from_path( + path: str, + portal: BlockingPortal, + stack: ExitStack, + allow: list[str], + unsafe: bool, + use_alternative_endpoints: bool = False, +): async with grpc.aio.secure_channel( f"unix://{path}", grpc.local_channel_credentials(grpc.LocalConnectionType.UDS) ) as channel: - yield await client_from_channel(channel, portal, stack, allow, unsafe) + yield await client_from_channel(channel, portal, stack, allow, unsafe, use_alternative_endpoints) async def client_from_channel( diff --git a/packages/jumpstarter/jumpstarter/common/utils.py b/packages/jumpstarter/jumpstarter/common/utils.py index 832dce4e3..f989661b4 100644 --- a/packages/jumpstarter/jumpstarter/common/utils.py +++ b/packages/jumpstarter/jumpstarter/common/utils.py @@ -7,7 +7,7 @@ from jumpstarter.client import client_from_path from jumpstarter.config.client import _allow_from_env -from jumpstarter.config.env import JMP_DRIVERS_ALLOW, JUMPSTARTER_HOST +from jumpstarter.config.env import JMP_DRIVERS_ALLOW, JMP_USE_ALTERNATIVE_ENDPOINTS, JUMPSTARTER_HOST from jumpstarter.driver import Driver from jumpstarter.exporter import Session @@ -52,7 +52,16 @@ async def env_async(portal, stack): allow, unsafe = _allow_from_env() - async with client_from_path(host, portal, stack, allow=allow, unsafe=unsafe) as client: + use_alternative_endpoints = os.environ.get(JMP_USE_ALTERNATIVE_ENDPOINTS, "0") == "1" + + async with client_from_path( + host, + portal, + stack, + allow=allow, + unsafe=unsafe, + use_alternative_endpoints=use_alternative_endpoints, + ) as client: try: yield client finally: @@ -80,7 +89,13 @@ def env(): PROMPT_CWD = "\\W" -def launch_shell(host: str, context: str, allow: list[str], unsafe: bool) -> int: +def launch_shell( + host: str, + context: str, + allow: list[str], + unsafe: bool, + use_alternative_endpoints: bool, +) -> int: """Launch a shell with a custom prompt indicating the exporter type. Args: @@ -103,6 +118,7 @@ def launch_shell(host: str, context: str, allow: list[str], unsafe: bool) -> int | { JUMPSTARTER_HOST: host, JMP_DRIVERS_ALLOW: "UNSAFE" if unsafe else ",".join(allow), + JMP_USE_ALTERNATIVE_ENDPOINTS: "1" if use_alternative_endpoints else "0", "PS1": f"{ANSI_GRAY}{PROMPT_CWD} {ANSI_YELLOW}⚡{ANSI_WHITE}{context} {ANSI_YELLOW}➤{ANSI_RESET} ", }, ) diff --git a/packages/jumpstarter/jumpstarter/config/env.py b/packages/jumpstarter/jumpstarter/config/env.py index 8a23678de..3bb625b63 100644 --- a/packages/jumpstarter/jumpstarter/config/env.py +++ b/packages/jumpstarter/jumpstarter/config/env.py @@ -7,3 +7,4 @@ JMP_DRIVERS_ALLOW = "JMP_DRIVERS_ALLOW" JUMPSTARTER_HOST = "JUMPSTARTER_HOST" JMP_LEASE = "JMP_LEASE" +JMP_USE_ALTERNATIVE_ENDPOINTS = "JMP_USE_ALTERNATIVE_ENDPOINTS" From 3a5eddf9777b523add07e04ef12cafbc19551252 Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Fri, 21 Mar 2025 14:23:05 -0400 Subject: [PATCH 4/9] Fix tests --- .../jumpstarter/jumpstarter/common/utils_test.py | 16 ++++++++++++++-- .../jumpstarter/config/client_config_test.py | 3 +++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/packages/jumpstarter/jumpstarter/common/utils_test.py b/packages/jumpstarter/jumpstarter/common/utils_test.py index ece4c3239..fa270d73a 100644 --- a/packages/jumpstarter/jumpstarter/common/utils_test.py +++ b/packages/jumpstarter/jumpstarter/common/utils_test.py @@ -5,9 +5,21 @@ def test_launch_shell(tmp_path, monkeypatch): monkeypatch.setenv("SHELL", shutil.which("true")) - exit_code = launch_shell(host=str(tmp_path / "test.sock"), context="remote", allow=["*"], unsafe=False) + exit_code = launch_shell( + host=str(tmp_path / "test.sock"), + context="remote", + allow=["*"], + unsafe=False, + use_alternative_endpoints=False, + ) assert exit_code == 0 monkeypatch.setenv("SHELL", shutil.which("false")) - exit_code = launch_shell(host=str(tmp_path / "test.sock"), context="remote", allow=["*"], unsafe=False) + exit_code = launch_shell( + host=str(tmp_path / "test.sock"), + context="remote", + allow=["*"], + unsafe=False, + use_alternative_endpoints=False, + ) assert exit_code == 1 diff --git a/packages/jumpstarter/jumpstarter/config/client_config_test.py b/packages/jumpstarter/jumpstarter/config/client_config_test.py index a83d0d89e..bd11ce81f 100644 --- a/packages/jumpstarter/jumpstarter/config/client_config_test.py +++ b/packages/jumpstarter/jumpstarter/config/client_config_test.py @@ -207,6 +207,7 @@ def test_client_config_save(monkeypatch: pytest.MonkeyPatch): insecure: false token: dGhpc2lzYXRva2VuLTEyMzQxMjM0MTIzNEyMzQtc2Rxd3Jxd2VycXdlcnF3ZXJxd2VyLTEyMzQxMjM0MTIz grpcOptions: {} +useAlternativeEndpoints: false drivers: allow: - jumpstarter.drivers.* @@ -243,6 +244,7 @@ def test_client_config_save_explicit_path(): insecure: false token: dGhpc2lzYXRva2VuLTEyMzQxMjM0MTIzNEyMzQtc2Rxd3Jxd2VycXdlcnF3ZXJxd2VyLTEyMzQxMjM0MTIz grpcOptions: {} +useAlternativeEndpoints: false drivers: allow: - jumpstarter.drivers.* @@ -277,6 +279,7 @@ def test_client_config_save_unsafe_drivers(): insecure: false token: dGhpc2lzYXRva2VuLTEyMzQxMjM0MTIzNEyMzQtc2Rxd3Jxd2VycXdlcnF3ZXJxd2VyLTEyMzQxMjM0MTIz grpcOptions: {} +useAlternativeEndpoints: false drivers: allow: [] unsafe: true From ae31977dfb9344d701b629902b30f0cc4e939ea1 Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Fri, 21 Mar 2025 15:04:53 -0400 Subject: [PATCH 5/9] Set default for launch_shell parameter use_alternative_endpoints --- packages/jumpstarter/jumpstarter/common/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/jumpstarter/jumpstarter/common/utils.py b/packages/jumpstarter/jumpstarter/common/utils.py index f989661b4..c17eab574 100644 --- a/packages/jumpstarter/jumpstarter/common/utils.py +++ b/packages/jumpstarter/jumpstarter/common/utils.py @@ -94,7 +94,8 @@ def launch_shell( context: str, allow: list[str], unsafe: bool, - use_alternative_endpoints: bool, + *, + use_alternative_endpoints: bool = False, ) -> int: """Launch a shell with a custom prompt indicating the exporter type. From dfdb875e653eaa17afe6430e4e46fb38e2416e89 Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Fri, 21 Mar 2025 15:06:19 -0400 Subject: [PATCH 6/9] Shorted JMP_USE_ALTERNATIVE_ENDPOINTS to JMP_USE_ALT_ENDPOINTS --- packages/jumpstarter/jumpstarter/common/utils.py | 6 +++--- packages/jumpstarter/jumpstarter/config/env.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/jumpstarter/jumpstarter/common/utils.py b/packages/jumpstarter/jumpstarter/common/utils.py index c17eab574..523fb106e 100644 --- a/packages/jumpstarter/jumpstarter/common/utils.py +++ b/packages/jumpstarter/jumpstarter/common/utils.py @@ -7,7 +7,7 @@ from jumpstarter.client import client_from_path from jumpstarter.config.client import _allow_from_env -from jumpstarter.config.env import JMP_DRIVERS_ALLOW, JMP_USE_ALTERNATIVE_ENDPOINTS, JUMPSTARTER_HOST +from jumpstarter.config.env import JMP_DRIVERS_ALLOW, JMP_USE_ALT_ENDPOINTS, JUMPSTARTER_HOST from jumpstarter.driver import Driver from jumpstarter.exporter import Session @@ -52,7 +52,7 @@ async def env_async(portal, stack): allow, unsafe = _allow_from_env() - use_alternative_endpoints = os.environ.get(JMP_USE_ALTERNATIVE_ENDPOINTS, "0") == "1" + use_alternative_endpoints = os.environ.get(JMP_USE_ALT_ENDPOINTS, "0") == "1" async with client_from_path( host, @@ -119,7 +119,7 @@ def launch_shell( | { JUMPSTARTER_HOST: host, JMP_DRIVERS_ALLOW: "UNSAFE" if unsafe else ",".join(allow), - JMP_USE_ALTERNATIVE_ENDPOINTS: "1" if use_alternative_endpoints else "0", + JMP_USE_ALT_ENDPOINTS: "1" if use_alternative_endpoints else "0", "PS1": f"{ANSI_GRAY}{PROMPT_CWD} {ANSI_YELLOW}⚡{ANSI_WHITE}{context} {ANSI_YELLOW}➤{ANSI_RESET} ", }, ) diff --git a/packages/jumpstarter/jumpstarter/config/env.py b/packages/jumpstarter/jumpstarter/config/env.py index 3bb625b63..674884bf8 100644 --- a/packages/jumpstarter/jumpstarter/config/env.py +++ b/packages/jumpstarter/jumpstarter/config/env.py @@ -7,4 +7,4 @@ JMP_DRIVERS_ALLOW = "JMP_DRIVERS_ALLOW" JUMPSTARTER_HOST = "JUMPSTARTER_HOST" JMP_LEASE = "JMP_LEASE" -JMP_USE_ALTERNATIVE_ENDPOINTS = "JMP_USE_ALTERNATIVE_ENDPOINTS" +JMP_USE_ALT_ENDPOINTS = "JMP_USE_ALT_ENDPOINTS" From f88c708603c93614a92e21c4a8c1bc55a6393c92 Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Fri, 21 Mar 2025 15:08:59 -0400 Subject: [PATCH 7/9] Keep useAlternativeEndpoints in snake case --- .../jumpstarter/jumpstarter/config/client.py | 23 +++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/packages/jumpstarter/jumpstarter/config/client.py b/packages/jumpstarter/jumpstarter/config/client.py index e278fc0fd..36b493a3d 100644 --- a/packages/jumpstarter/jumpstarter/config/client.py +++ b/packages/jumpstarter/jumpstarter/config/client.py @@ -50,7 +50,7 @@ class ClientConfigV1Alpha1(BaseModel): token: str grpcOptions: dict[str, str | int] | None = Field(default_factory=dict) - useAlternativeEndpoints: bool = False + use_alternative_endpoints: bool = Field(alias="useAlternativeEndpoints", default=False) drivers: ClientConfigV1Alpha1Drivers @@ -174,7 +174,7 @@ async def lease_async( release=release_lease, tls_config=self.tls, grpc_options=self.grpcOptions, - use_alternative_endpoints=self.useAlternativeEndpoints, + use_alternative_endpoints=self.use_alternative_endpoints, ) as lease: yield lease @@ -238,12 +238,27 @@ def save(cls, config: Self, path: Optional[os.PathLike] = None) -> Path: else: config.path = Path(path) with config.path.open(mode="w") as f: - yaml.safe_dump(config.model_dump(mode="json", exclude={"path", "alias"}), f, sort_keys=False) + yaml.safe_dump( + config.model_dump( + mode="json", + exclude={"path", "alias"}, + by_alias=True, + ), + f, + sort_keys=False, + ) return config.path @classmethod def dump_yaml(cls, config: Self) -> str: - return yaml.safe_dump(config.model_dump(mode="json", exclude={"path", "alias"}), sort_keys=False) + return yaml.safe_dump( + config.model_dump( + mode="json", + exclude={"path", "alias"}, + by_alias=True, + ), + sort_keys=False, + ) @classmethod def exists(cls, alias: str) -> bool: From 011248da9ba40ec4994c493015405e35e6b18020 Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Fri, 21 Mar 2025 15:14:20 -0400 Subject: [PATCH 8/9] Use dummy san for certificates --- .../jumpstarter/jumpstarter/client/client.py | 2 ++ .../jumpstarter/jumpstarter/exporter/tls.py | 33 ++----------------- 2 files changed, 4 insertions(+), 31 deletions(-) diff --git a/packages/jumpstarter/jumpstarter/client/client.py b/packages/jumpstarter/jumpstarter/client/client.py index 5d39bd878..9ea5080da 100644 --- a/packages/jumpstarter/jumpstarter/client/client.py +++ b/packages/jumpstarter/jumpstarter/client/client.py @@ -10,6 +10,7 @@ from .grpc import SmartExporterStub from jumpstarter.client import DriverClient from jumpstarter.common.importlib import import_class +from jumpstarter.exporter.tls import SAN @asynccontextmanager @@ -54,6 +55,7 @@ async def client_from_channel( private_key=endpoint.client_private_key.encode(), certificate_chain=endpoint.client_certificate.encode(), ), + options=(("grpc.ssl_target_name_override", SAN),), ) ) diff --git a/packages/jumpstarter/jumpstarter/exporter/tls.py b/packages/jumpstarter/jumpstarter/exporter/tls.py index a6bc512ca..2411ee03d 100644 --- a/packages/jumpstarter/jumpstarter/exporter/tls.py +++ b/packages/jumpstarter/jumpstarter/exporter/tls.py @@ -1,5 +1,4 @@ from datetime import datetime, timedelta -from ipaddress import IPv4Address, IPv6Address, ip_address import grpc from cryptography import x509 @@ -8,38 +7,10 @@ from cryptography.hazmat.primitives.asymmetric import rsa from jumpstarter_protocol import jumpstarter_pb2 - -def parse_endpoint(endpoint): - host, sep, port = endpoint.rpartition(":") - - if sep == "": - raise ValueError("port not specified in endpoint {}".format(endpoint)) - - host = host.strip("[]") # strip brackets from ipv6 addresses - - try: - port = int(port) - if port < 0 or port > 65535: - raise ValueError("port number {} out of range".format(port)) - except ValueError as e: - raise ValueError("invalid port {} in endpoint {}".format(port, endpoint)) from e - - try: - return ip_address(host), port - except ValueError: - return host, port +SAN = "localhost" def with_alternative_endpoints(server, endpoints: list[str]): - sans = [] - for endpoint in endpoints: - host, port = parse_endpoint(endpoint) - match host: - case str(): - sans.append(x509.DNSName(host)) - case IPv4Address() | IPv6Address(): - sans.append(x509.IPAddress(host)) - key = rsa.generate_private_key(public_exponent=65537, key_size=2048, backend=default_backend()) client_key = rsa.generate_private_key(public_exponent=65537, key_size=2048, backend=default_backend()) @@ -51,7 +22,7 @@ def with_alternative_endpoints(server, endpoints: list[str]): .serial_number(x509.random_serial_number()) .not_valid_before(datetime.now()) .not_valid_after(datetime.now() + timedelta(days=365)) - .add_extension(x509.SubjectAlternativeName(sans), critical=False) + .add_extension(x509.SubjectAlternativeName([x509.DNSName(SAN)]), critical=False) .sign(private_key=key, algorithm=hashes.SHA256(), backend=default_backend()) ) client_crt = ( From b8dd7580449fad2778463972f9d8dd600bd44ed5 Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Fri, 21 Mar 2025 15:26:51 -0400 Subject: [PATCH 9/9] Document the design of alternative endpoints --- packages/jumpstarter/jumpstarter/exporter/tls.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/packages/jumpstarter/jumpstarter/exporter/tls.py b/packages/jumpstarter/jumpstarter/exporter/tls.py index 2411ee03d..cce5fc236 100644 --- a/packages/jumpstarter/jumpstarter/exporter/tls.py +++ b/packages/jumpstarter/jumpstarter/exporter/tls.py @@ -11,6 +11,21 @@ def with_alternative_endpoints(server, endpoints: list[str]): + """ + Listen on alternative endpoints directly without going through the router + + Useful when the network bandwidth/latency between the clients/exporters and the router is suboptimal, + yet direct connectivity between the clients and exporters can be established, e.g. the exporters have + public ip addresses, or they are in the same subnet. + + Since the direct traffic can transit through untrusted networks, it's encrypted and authenticated with + mTLS. The client would attempt the first connection through the router, a trusted channel, on which the + exporter would provide the client with its own certificate, and a client certificate/key pair for client + authentication. All certificates are selfsigned as they are only ever explicitly trusted by the client + and the exporter for the duration of a single lease. Future connections would be attempted on alternative + endpoints first and fallback to the router if none works. + """ + key = rsa.generate_private_key(public_exponent=65537, key_size=2048, backend=default_backend()) client_key = rsa.generate_private_key(public_exponent=65537, key_size=2048, backend=default_backend())