Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from qemu.qmp.protocol import ConnectError, Runstate

from jumpstarter.driver import Driver, export
from jumpstarter.streams.encoding import AutoDecompressIterator


def _vsock_available():
Expand All @@ -42,9 +43,15 @@ class QemuFlasher(FlasherInterface, Driver):

@export
async def flash(self, source, partition: str | None = None):
"""Flash an image to the specified partition.

Supports transparent decompression of gzip, xz, bz2, and zstd compressed images.
Compression format is auto-detected from file signature.
"""
async with await FileWriteStream.from_path(self.parent.validate_partition(partition)) as stream:
async with self.resource(source) as res:
async for chunk in res:
# Wrap with auto-decompression to handle .gz, .xz, .bz2, .zstd files
async for chunk in AutoDecompressIterator(source=res):
await stream.send(chunk)

@export
Expand Down
140 changes: 139 additions & 1 deletion packages/jumpstarter/jumpstarter/streams/encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import lzma
import sys
import zlib
from dataclasses import dataclass
from collections.abc import AsyncIterator
from dataclasses import dataclass, field
from enum import StrEnum
from typing import Any, Callable, Mapping

Expand All @@ -22,6 +23,55 @@ class Compression(StrEnum):
ZSTD = "zstd"


@dataclass(frozen=True)
class FileSignature:
"""File signature (magic bytes) for a compression format."""

signature: bytes
compression: Compression


# File signatures for compression format detection
# Reference: https://file-extension.net/seeker/
COMPRESSION_SIGNATURES: tuple[FileSignature, ...] = (
FileSignature(b"\x1f\x8b\x08", Compression.GZIP),
FileSignature(b"\xfd\x37\x7a\x58\x5a\x00", Compression.XZ),
FileSignature(b"\x42\x5a\x68", Compression.BZ2),
FileSignature(b"\x28\xb5\x2f\xfd", Compression.ZSTD),
)

# Standard buffer size for file signature detection (covers most formats)
SIGNATURE_BUFFER_SIZE = 8


def detect_compression_from_signature(data: bytes) -> Compression | None:
"""Detect compression format from file signature bytes at the start of data.

Args:
data: The first few bytes of the file/stream (at least SIGNATURE_BUFFER_SIZE bytes recommended)

Returns:
The detected Compression type, or None if uncompressed/unknown
"""
for sig in COMPRESSION_SIGNATURES:
if data.startswith(sig.signature):
return sig.compression
return None


def create_decompressor(compression: Compression) -> Any:
"""Create a decompressor object for the given compression type."""
match compression:
case Compression.GZIP:
return zlib.decompressobj(wbits=47) # Auto-detect gzip/zlib
case Compression.XZ:
return lzma.LZMADecompressor()
case Compression.BZ2:
return bz2.BZ2Decompressor()
case Compression.ZSTD:
return zstd.ZstdDecompressor()


@dataclass(kw_only=True)
class CompressedStream(ObjectStream[bytes]):
stream: AnyByteStream
Expand Down Expand Up @@ -99,3 +149,91 @@ def compress_stream(stream: AnyByteStream, compression: Compression | None) -> A
compressor=zstd.ZstdCompressor(),
decompressor=zstd.ZstdDecompressor(),
)


@dataclass(kw_only=True)
class AutoDecompressIterator(AsyncIterator[bytes]):
"""An async iterator that auto-detects and decompresses compressed data.

This wraps an async iterator of bytes and transparently decompresses
gzip, xz, bz2, or zstd compressed data based on file signature detection.
Uncompressed data passes through unchanged.
"""

source: AsyncIterator[bytes]
_decompressor: Any = field(init=False, default=None)
_compression: Compression | None = field(init=False, default=None)
_detected: bool = field(init=False, default=False)
_buffer: bytes = field(init=False, default=b"")
_exhausted: bool = field(init=False, default=False)

def _call_decompressor(self, method_name: str, *args) -> bytes:
"""Call decompressor method with error handling.

Args:
method_name: decompressor method to call
*args: Arguments to the method
"""
try:
method = getattr(self._decompressor, method_name)
return method(*args)
except (zlib.error, lzma.LZMAError, OSError, zstd.ZstdError) as e:
raise RuntimeError(
f"Failed to decompress {self._compression}: {e}"
) from e

async def _detect_compression(self) -> None:
"""Read enough bytes to detect compression format."""
# Buffer data until we have enough for detection
while len(self._buffer) < SIGNATURE_BUFFER_SIZE and not self._exhausted:
try:
chunk = await self.source.__anext__()
self._buffer += chunk
except StopAsyncIteration:
self._exhausted = True
break

# Detect compression from buffered data
compression = detect_compression_from_signature(self._buffer)
if compression is not None:
self._compression = compression
self._decompressor = create_decompressor(compression)

self._detected = True

async def __anext__(self) -> bytes:
# First call: detect compression format
if not self._detected:
await self._detect_compression()

# Process buffered data first
if self._buffer:
data = self._buffer
self._buffer = b""
if self._decompressor is not None:
return self._call_decompressor("decompress", data)
return data

# Stream exhausted
if self._exhausted:
raise StopAsyncIteration

# Read and process next chunk
try:
chunk = await self.source.__anext__()
except StopAsyncIteration:
self._exhausted = True
# Flush any remaining data from decompressor (gzip needs this)
if self._decompressor is not None and hasattr(self._decompressor, "flush"):
remaining = self._call_decompressor("flush")
self._decompressor = None
if remaining:
return remaining
raise

if self._decompressor is not None:
return self._call_decompressor("decompress", chunk)
return chunk

def __aiter__(self) -> AsyncIterator[bytes]:
return self
148 changes: 147 additions & 1 deletion packages/jumpstarter/jumpstarter/streams/encoding_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,26 @@
import bz2
import gzip
import lzma
import os
import sys
from io import BytesIO

import pytest
from anyio import EndOfStream, create_memory_object_stream
from anyio.streams.stapled import StapledObjectStream

from .encoding import compress_stream
if sys.version_info >= (3, 14):
from compression import zstd
else:
from backports import zstd

from .encoding import (
COMPRESSION_SIGNATURES,
AutoDecompressIterator,
Compression,
compress_stream,
detect_compression_from_signature,
)

pytestmark = pytest.mark.anyio

Expand All @@ -28,3 +44,133 @@ async def test_compress_stream(compression):
except EndOfStream:
break
assert result.getvalue() == b"hello"


def _get_signature(compression: Compression) -> bytes:
"""Helper to get signature bytes for a compression type."""
for sig in COMPRESSION_SIGNATURES:
if sig.compression == compression:
return sig.signature
raise ValueError(f"No signature found for {compression}")


class TestDetectCompressionFromSignature:
"""Tests for file signature detection."""

@pytest.mark.parametrize(
"compression",
[Compression.GZIP, Compression.XZ, Compression.BZ2, Compression.ZSTD],
)
def test_detect_from_signature(self, compression):
"""Each compression format should be detected from its signature."""
signature = _get_signature(compression)
# Pad with random bytes to simulate real file content
data = signature + os.urandom(4)
assert detect_compression_from_signature(data) == compression

def test_detect_uncompressed(self):
# Random data that doesn't match any compression format
assert detect_compression_from_signature(b"hello world") is None

def test_detect_empty(self):
assert detect_compression_from_signature(b"") is None

def test_detect_too_short(self):
# Truncated signatures should not match
assert detect_compression_from_signature(b"\x1f") is None # gzip partial
assert detect_compression_from_signature(b"\xfd\x37\x7a") is None # xz partial

def test_detect_from_real_gzip_data(self):
compressed = gzip.compress(b"test data")
assert detect_compression_from_signature(compressed) == Compression.GZIP

def test_detect_from_real_xz_data(self):
compressed = lzma.compress(b"test data", format=lzma.FORMAT_XZ)
assert detect_compression_from_signature(compressed) == Compression.XZ

def test_detect_from_real_bz2_data(self):
compressed = bz2.compress(b"test data")
assert detect_compression_from_signature(compressed) == Compression.BZ2

def test_detect_from_real_zstd_data(self):
compressed = zstd.compress(b"test data")
assert detect_compression_from_signature(compressed) == Compression.ZSTD


class TestAutoDecompressIterator:
"""Tests for auto-decompressing async iterator."""

async def _async_iter_from_bytes(self, data: bytes, chunk_size: int):
"""Helper to create an async iterator from bytes."""
for i in range(0, len(data), chunk_size):
yield data[i : i + chunk_size]

async def _decompress_and_check(self, compressed: bytes, expected: bytes, chunk_size: int = 16):
"""Helper to decompress data and verify it matches expected output."""
chunks = []
async for chunk in AutoDecompressIterator(source=self._async_iter_from_bytes(compressed, chunk_size)):
chunks.append(chunk)
assert b"".join(chunks) == expected

async def test_passthrough_uncompressed(self):
"""Uncompressed data should pass through unchanged."""
original = b"hello world, this is uncompressed data"
await self._decompress_and_check(original, original)

async def test_decompress_gzip(self):
"""Gzip compressed data should be decompressed."""
original = b"hello world, this is gzip compressed data"
compressed = gzip.compress(original)
await self._decompress_and_check(compressed, original)

async def test_decompress_xz(self):
"""XZ compressed data should be decompressed."""
original = b"hello world, this is xz compressed data"
compressed = lzma.compress(original, format=lzma.FORMAT_XZ)
await self._decompress_and_check(compressed, original)

async def test_decompress_bz2(self):
"""BZ2 compressed data should be decompressed."""
original = b"hello world, this is bz2 compressed data"
compressed = bz2.compress(original)
await self._decompress_and_check(compressed, original)

async def test_decompress_zstd(self):
"""Zstd compressed data should be decompressed."""
original = b"hello world, this is zstd compressed data"
compressed = zstd.compress(original)
await self._decompress_and_check(compressed, original)

async def test_small_chunks(self):
"""Should work with very small chunks."""
original = b"hello world"
compressed = gzip.compress(original)
await self._decompress_and_check(compressed, original, chunk_size=1)

async def test_empty_input(self):
"""Empty input should produce no output."""

async def empty_iter():
if False:
yield

chunks = []
async for chunk in AutoDecompressIterator(source=empty_iter()):
chunks.append(chunk)
assert chunks == []

async def test_large_data(self):
"""Should handle large data correctly."""
original = b"x" * 1024 * 1024 # 1MB of data
compressed = gzip.compress(original)
await self._decompress_and_check(compressed, original, chunk_size=65536)

async def test_corrupted_gzip(self):
"""Corrupted gzip data should raise RuntimeError with clear message."""
# Create fake gzip data: valid signature but corrupted payload
corrupted = b"\x1f\x8b\x08" + b"corrupted data here"

with pytest.raises(RuntimeError, match=r"Failed to decompress gzip:.*"):
chunks = []
async for chunk in AutoDecompressIterator(source=self._async_iter_from_bytes(corrupted, 16)):
chunks.append(chunk)
Loading