Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion packages/jumpstarter-cli/jumpstarter_cli/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@ def shell(config, lease_name, selector, duration):
with config.lease(selector=selector, lease_name=lease_name, duration=duration) as lease:
with lease.serve_unix() as path:
with lease.monitor():
exit_code = launch_shell(path, "remote", config.drivers.allow, config.drivers.unsafe)
exit_code = launch_shell(
path,
"remote",
config.drivers.allow,
config.drivers.unsafe,
use_alternative_endpoints=config.use_alternative_endpoints,
)

sys.exit(exit_code)

Expand Down
37 changes: 32 additions & 5 deletions packages/jumpstarter/jumpstarter/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,26 @@
import grpc
from anyio.from_thread import BlockingPortal
from google.protobuf import empty_pb2
from jumpstarter_protocol import jumpstarter_pb2_grpc

from .grpc import SmartExporterStub
from jumpstarter.client import DriverClient
from jumpstarter.common.importlib import import_class
from jumpstarter.exporter.tls import SAN


@asynccontextmanager
async def client_from_path(path: str, portal: BlockingPortal, stack: ExitStack, allow: list[str], unsafe: bool):
async def client_from_path(
path: str,
portal: BlockingPortal,
stack: ExitStack,
allow: list[str],
unsafe: bool,
use_alternative_endpoints: bool = False,
):
async with grpc.aio.secure_channel(
f"unix://{path}", grpc.local_channel_credentials(grpc.LocalConnectionType.UDS)
) as channel:
yield await client_from_channel(channel, portal, stack, allow, unsafe)
yield await client_from_channel(channel, portal, stack, allow, unsafe, use_alternative_endpoints)


async def client_from_channel(
Expand All @@ -26,13 +34,32 @@ async def client_from_channel(
stack: ExitStack,
allow: list[str],
unsafe: bool,
use_alternative_endpoints: bool = False,
) -> DriverClient:
topo = defaultdict(list)
last_seen = {}
reports = {}
clients = OrderedDict()

response = await jumpstarter_pb2_grpc.ExporterServiceStub(channel).GetReport(empty_pb2.Empty())
response = await SmartExporterStub([channel]).GetReport(empty_pb2.Empty())

channels = [channel]
if use_alternative_endpoints:
for endpoint in response.alternative_endpoints:
if endpoint.certificate:
channels.append(
grpc.aio.secure_channel(
endpoint.endpoint,
grpc.ssl_channel_credentials(
root_certificates=endpoint.certificate.encode(),
private_key=endpoint.client_private_key.encode(),
certificate_chain=endpoint.client_certificate.encode(),
),
options=(("grpc.ssl_target_name_override", SAN),),
)
)

stub = SmartExporterStub(list(reversed(channels)))

for index, report in enumerate(response.reports):
topo[index] = []
Expand All @@ -52,7 +79,7 @@ async def client_from_channel(
client = client_class(
uuid=UUID(report.uuid),
labels=report.labels,
channel=channel,
stub=stub,
portal=portal,
stack=stack.enter_context(ExitStack()),
children={reports[k].labels["jumpstarter.dev/name"]: clients[k] for k in topo[index]},
Expand Down
17 changes: 8 additions & 9 deletions packages/jumpstarter/jumpstarter/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
import logging
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import Any

from anyio import create_task_group
from google.protobuf import empty_pb2
from grpc import StatusCode
from grpc.aio import AioRpcError, Channel
from grpc.aio import AioRpcError
from jumpstarter_protocol import jumpstarter_pb2, jumpstarter_pb2_grpc, router_pb2_grpc

from jumpstarter.common import Metadata
Expand Down Expand Up @@ -60,16 +61,14 @@ class AsyncDriverClient(
Backing implementation of blocking driver client.
"""

channel: Channel
stub: Any

log_level: str = "INFO"
logger: logging.Logger = field(init=False)

def __post_init__(self):
if hasattr(super(), "__post_init__"):
super().__post_init__()
jumpstarter_pb2_grpc.ExporterServiceStub.__init__(self, self.channel)
router_pb2_grpc.RouterServiceStub.__init__(self, self.channel)
self.logger = logging.getLogger(self.__class__.__name__)
self.logger.setLevel(self.log_level)

Expand All @@ -89,7 +88,7 @@ async def call_async(self, method, *args):
)

try:
response = await self.DriverCall(request)
response = await self.stub.DriverCall(request)
except AioRpcError as e:
match e.code():
case StatusCode.UNIMPLEMENTED:
Expand All @@ -113,7 +112,7 @@ async def streamingcall_async(self, method, *args):
)

try:
async for response in self.StreamingDriverCall(request):
async for response in self.stub.StreamingDriverCall(request):
yield decode_value(response.result)
except AioRpcError as e:
match e.code():
Expand All @@ -128,7 +127,7 @@ async def streamingcall_async(self, method, *args):

@asynccontextmanager
async def stream_async(self, method):
context = self.Stream(
context = self.stub.Stream(
metadata=StreamRequestMetadata.model_construct(request=DriverStreamRequest(uuid=self.uuid, method=method))
.model_dump(mode="json", round_trip=True)
.items(),
Expand All @@ -142,7 +141,7 @@ async def resource_async(
self,
stream,
):
context = self.Stream(
context = self.stub.Stream(
metadata=StreamRequestMetadata.model_construct(request=ResourceStreamRequest(uuid=self.uuid))
.model_dump(mode="json", round_trip=True)
.items(),
Expand All @@ -160,7 +159,7 @@ def __log(self, level: int, msg: str):
@asynccontextmanager
async def log_stream_async(self):
async def log_stream():
async for response in self.LogStream(empty_pb2.Empty()):
async for response in self.stub.LogStream(empty_pb2.Empty()):
self.__log(logging.getLevelName(response.severity), response.message)

async with create_task_group() as tg:
Expand Down
30 changes: 28 additions & 2 deletions packages/jumpstarter/jumpstarter/client/grpc.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from __future__ import annotations

from dataclasses import dataclass, field
from collections import OrderedDict
from dataclasses import InitVar, dataclass, field
from datetime import datetime, timedelta
from types import SimpleNamespace
from typing import Any

import yaml
from google.protobuf import duration_pb2, field_mask_pb2, json_format
from grpc import ChannelConnectivity
from grpc.aio import Channel
from jumpstarter_protocol import client_pb2, client_pb2_grpc, kubernetes_pb2
from jumpstarter_protocol import client_pb2, client_pb2_grpc, jumpstarter_pb2_grpc, kubernetes_pb2, router_pb2_grpc
from pydantic import BaseModel, ConfigDict, Field, field_serializer

from jumpstarter.common.grpc import translate_grpc_exceptions
Expand Down Expand Up @@ -250,3 +254,25 @@ async def DeleteLease(self, *, name: str):
name="namespaces/{}/leases/{}".format(self.namespace, name),
)
)


@dataclass(frozen=True, slots=True)
class SmartExporterStub:
channels: InitVar[list[Channel]]

__stubs: dict[Channel, Any] = field(init=False, default_factory=OrderedDict)

def __post_init__(self, channels):
for channel in channels:
stub = SimpleNamespace()
jumpstarter_pb2_grpc.ExporterServiceStub.__init__(stub, channel)
router_pb2_grpc.RouterServiceStub.__init__(stub, channel)
self.__stubs[channel] = stub

def __getattr__(self, name):
for channel, stub in self.__stubs.items():
# find the first channel that's ready
if channel.get_state(try_to_connect=True) == ChannelConnectivity.READY:
return getattr(stub, name)
# or fallback to the last channel (via router)
return getattr(next(reversed(self.__stubs.values())), name)
10 changes: 9 additions & 1 deletion packages/jumpstarter/jumpstarter/client/lease.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class Lease(AbstractContextManager, AbstractAsyncContextManager):
controller: jumpstarter_pb2_grpc.ControllerServiceStub = field(init=False)
tls_config: TLSConfigV1Alpha1 = field(default_factory=TLSConfigV1Alpha1)
grpc_options: dict[str, Any] = field(default_factory=dict)
use_alternative_endpoints: bool = False

def __post_init__(self):
if hasattr(super(), "__post_init__"):
Expand Down Expand Up @@ -184,7 +185,14 @@ async def _monitor():
@asynccontextmanager
async def connect_async(self, stack):
async with self.serve_unix_async() as path:
async with client_from_path(path, self.portal, stack, allow=self.allow, unsafe=self.unsafe) as client:
async with client_from_path(
path,
self.portal,
stack,
allow=self.allow,
unsafe=self.unsafe,
use_alternative_endpoints=self.use_alternative_endpoints,
) as client:
yield client

@contextmanager
Expand Down
23 changes: 20 additions & 3 deletions packages/jumpstarter/jumpstarter/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from jumpstarter.client import client_from_path
from jumpstarter.config.client import _allow_from_env
from jumpstarter.config.env import JMP_DRIVERS_ALLOW, JUMPSTARTER_HOST
from jumpstarter.config.env import JMP_DRIVERS_ALLOW, JMP_USE_ALT_ENDPOINTS, JUMPSTARTER_HOST
from jumpstarter.driver import Driver
from jumpstarter.exporter import Session

Expand Down Expand Up @@ -52,7 +52,16 @@ async def env_async(portal, stack):

allow, unsafe = _allow_from_env()

async with client_from_path(host, portal, stack, allow=allow, unsafe=unsafe) as client:
use_alternative_endpoints = os.environ.get(JMP_USE_ALT_ENDPOINTS, "0") == "1"

async with client_from_path(
host,
portal,
stack,
allow=allow,
unsafe=unsafe,
use_alternative_endpoints=use_alternative_endpoints,
) as client:
try:
yield client
finally:
Expand Down Expand Up @@ -80,7 +89,14 @@ def env():
PROMPT_CWD = "\\W"


def launch_shell(host: str, context: str, allow: list[str], unsafe: bool) -> int:
def launch_shell(
host: str,
context: str,
allow: list[str],
unsafe: bool,
*,
use_alternative_endpoints: bool = False,
) -> int:
"""Launch a shell with a custom prompt indicating the exporter type.

Args:
Expand All @@ -103,6 +119,7 @@ def launch_shell(host: str, context: str, allow: list[str], unsafe: bool) -> int
| {
JUMPSTARTER_HOST: host,
JMP_DRIVERS_ALLOW: "UNSAFE" if unsafe else ",".join(allow),
JMP_USE_ALT_ENDPOINTS: "1" if use_alternative_endpoints else "0",
"PS1": f"{ANSI_GRAY}{PROMPT_CWD} {ANSI_YELLOW}⚡{ANSI_WHITE}{context} {ANSI_YELLOW}➤{ANSI_RESET} ",
},
)
Expand Down
16 changes: 14 additions & 2 deletions packages/jumpstarter/jumpstarter/common/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,21 @@

def test_launch_shell(tmp_path, monkeypatch):
monkeypatch.setenv("SHELL", shutil.which("true"))
exit_code = launch_shell(host=str(tmp_path / "test.sock"), context="remote", allow=["*"], unsafe=False)
exit_code = launch_shell(
host=str(tmp_path / "test.sock"),
context="remote",
allow=["*"],
unsafe=False,
use_alternative_endpoints=False,
)
assert exit_code == 0

monkeypatch.setenv("SHELL", shutil.which("false"))
exit_code = launch_shell(host=str(tmp_path / "test.sock"), context="remote", allow=["*"], unsafe=False)
exit_code = launch_shell(
host=str(tmp_path / "test.sock"),
context="remote",
allow=["*"],
unsafe=False,
use_alternative_endpoints=False,
)
assert exit_code == 1
22 changes: 20 additions & 2 deletions packages/jumpstarter/jumpstarter/config/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class ClientConfigV1Alpha1(BaseModel):
token: str
grpcOptions: dict[str, str | int] | None = Field(default_factory=dict)

use_alternative_endpoints: bool = Field(alias="useAlternativeEndpoints", default=False)

drivers: ClientConfigV1Alpha1Drivers

async def channel(self):
Expand Down Expand Up @@ -172,6 +174,7 @@ async def lease_async(
release=release_lease,
tls_config=self.tls,
grpc_options=self.grpcOptions,
use_alternative_endpoints=self.use_alternative_endpoints,
) as lease:
yield lease

Expand Down Expand Up @@ -235,12 +238,27 @@ def save(cls, config: Self, path: Optional[os.PathLike] = None) -> Path:
else:
config.path = Path(path)
with config.path.open(mode="w") as f:
yaml.safe_dump(config.model_dump(mode="json", exclude={"path", "alias"}), f, sort_keys=False)
yaml.safe_dump(
config.model_dump(
mode="json",
exclude={"path", "alias"},
by_alias=True,
),
f,
sort_keys=False,
)
return config.path

@classmethod
def dump_yaml(cls, config: Self) -> str:
return yaml.safe_dump(config.model_dump(mode="json", exclude={"path", "alias"}), sort_keys=False)
return yaml.safe_dump(
config.model_dump(
mode="json",
exclude={"path", "alias"},
by_alias=True,
),
sort_keys=False,
)

@classmethod
def exists(cls, alias: str) -> bool:
Expand Down
3 changes: 3 additions & 0 deletions packages/jumpstarter/jumpstarter/config/client_config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def test_client_config_save(monkeypatch: pytest.MonkeyPatch):
insecure: false
token: dGhpc2lzYXRva2VuLTEyMzQxMjM0MTIzNEyMzQtc2Rxd3Jxd2VycXdlcnF3ZXJxd2VyLTEyMzQxMjM0MTIz
grpcOptions: {}
useAlternativeEndpoints: false
drivers:
allow:
- jumpstarter.drivers.*
Expand Down Expand Up @@ -243,6 +244,7 @@ def test_client_config_save_explicit_path():
insecure: false
token: dGhpc2lzYXRva2VuLTEyMzQxMjM0MTIzNEyMzQtc2Rxd3Jxd2VycXdlcnF3ZXJxd2VyLTEyMzQxMjM0MTIz
grpcOptions: {}
useAlternativeEndpoints: false
drivers:
allow:
- jumpstarter.drivers.*
Expand Down Expand Up @@ -277,6 +279,7 @@ def test_client_config_save_unsafe_drivers():
insecure: false
token: dGhpc2lzYXRva2VuLTEyMzQxMjM0MTIzNEyMzQtc2Rxd3Jxd2VycXdlcnF3ZXJxd2VyLTEyMzQxMjM0MTIz
grpcOptions: {}
useAlternativeEndpoints: false
drivers:
allow: []
unsafe: true
Expand Down
1 change: 1 addition & 0 deletions packages/jumpstarter/jumpstarter/config/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
JMP_DRIVERS_ALLOW = "JMP_DRIVERS_ALLOW"
JUMPSTARTER_HOST = "JUMPSTARTER_HOST"
JMP_LEASE = "JMP_LEASE"
JMP_USE_ALT_ENDPOINTS = "JMP_USE_ALT_ENDPOINTS"
Loading