Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion src/power_grid_model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
This module contains functions that may be useful when working with the power-grid-model library.
"""

import io
import json
import math
import tempfile
import warnings
from pathlib import Path
from typing import cast as cast_type
from typing import IO, Any, cast as cast_type

import numpy as np

Expand Down Expand Up @@ -471,3 +472,58 @@ def _make_test_case( # noqa: PLR0913
params_json = json.dumps(params, indent=2)
(output_path / "params.json").write_text(data=params_json, encoding="utf-8")
(output_path / "params.json.license").write_text(data=LICENSE_TEXT, encoding="utf-8")


def msgpack_deserialize_from_stream(
stream: IO[bytes],
data_filter: ComponentAttributeMapping = None,
) -> Dataset:
"""
Load and deserialize a msgpack file to a new dataset.

Args:
stream: the Binary IO stream to the file to load and deserialize.

Raises:
ValueError: if the data is inconsistent with the rest of the dataset or a component is unknown.
PowerGridError: if there was an internal error.

Returns:
The deserialized dataset in Power grid model input format.
"""
if stream is IO[Any]:
raise TypeError("Expected a stream.")
if isinstance(stream, io.TextIOBase):
raise TypeError("Expected a binary stream.")
if not stream.readable():
raise io.UnsupportedOperation("Stream is not readable.")
return msgpack_deserialize(stream.read(), data_filter=data_filter)


def msgpack_serialize_to_stream(
stream: IO[bytes],
data: Dataset,
dataset_type: DatasetType | None = None,
use_compact_list: bool = False,
):
"""
Export msgpack data in most recent format.

Args:
stream: the Binary IO stream to the file to load and deserialize.
data: a single or batch dataset for power-grid-model.
use_compact_list: write components on a single line.

Returns:
Save to BytesIO file.
"""
if stream is IO[Any]:
raise TypeError("Expected a stream.")
if isinstance(stream, io.TextIOBase):
raise TypeError("Expected a binary stream.")
if not stream.writable():
raise io.UnsupportedOperation("Stream is not writable.")

data = _map_to_component_types(data)
result = msgpack_serialize(data=data, dataset_type=dataset_type, use_compact_list=use_compact_list)
stream.write(result)
73 changes: 71 additions & 2 deletions tests/unit/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# SPDX-FileCopyrightText: 2025 Contributors to the Power Grid Model project <powergridmodel@lfenergy.org>
# SPDX-FileCopyrightText: Contributors to the Power Grid Model project <powergridmodel@lfenergy.org>
#
# SPDX-License-Identifier: MPL-2.0

import io
import json
import re
from collections.abc import Mapping
from io import BytesIO, TextIOBase, UnsupportedOperation
from typing import Any

import msgpack
Expand All @@ -15,7 +18,32 @@
from power_grid_model._core.utils import get_dataset_type, is_columnar, is_sparse
from power_grid_model.data_types import BatchDataset, Dataset, DenseBatchData, SingleComponentData, SingleDataset
from power_grid_model.enum import ComponentAttributeFilterOptions
from power_grid_model.utils import json_deserialize, json_serialize, msgpack_deserialize, msgpack_serialize
from power_grid_model.utils import (
json_deserialize,
json_serialize,
msgpack_deserialize,
msgpack_deserialize_from_stream,
msgpack_serialize,
msgpack_serialize_to_stream,
)


class FakeRawIO(io.RawIOBase):
def __init__(self, initial_bytes: bytes | bytearray = b"2"):
super().__init__()
self._buf = bytearray(initial_bytes)
self._pos = 0
self._closed = False

# --- Capability flags ---
def readable(self) -> bool:
return False

def writable(self) -> bool:
return False

def seekable(self) -> bool:
return True


def to_json(data, raw_buffer: bool = False, indent: int | None = None):
Expand Down Expand Up @@ -790,3 +818,44 @@ def test_serialize_deserialize_double_round_trip(deserialize, serialize, seriali

np.testing.assert_array_equal(nan_a, nan_b)
np.testing.assert_allclose(field_result_a[~nan_a], field_result_b[~nan_b], rtol=1e-15)


def test_messagepack_round_trip_with_stream(serialized_data):
data = to_msgpack(serialized_data)
input_data: Dataset = msgpack_deserialize(data)

io_buffer_data = BytesIO()
msgpack_serialize_to_stream(io_buffer_data, input_data, dataset_type=serialized_data["type"])
io_buffer_data.seek(0)
output_data = msgpack_deserialize_from_stream(io_buffer_data)
assert str(output_data) == str(input_data)


def test_messagepack_to_stream_text_type_error(serialized_data):
data = to_msgpack(serialized_data)
input_data: Dataset = msgpack_deserialize(data)

io_buffer_data = TextIOBase()
with pytest.raises(TypeError, match=re.escape("Expected a binary stream.")):
msgpack_serialize_to_stream(io_buffer_data, input_data, dataset_type=serialized_data["type"])


def test_messagepack_from_stream_text_type_error():
io_buffer_data = TextIOBase()
with pytest.raises(TypeError, match=re.escape("Expected a binary stream.")):
_ = msgpack_deserialize_from_stream(io_buffer_data)


def test_messagepack_from_stream_readable_error():
io_buffer_data = FakeRawIO(initial_bytes=b"bla")
with pytest.raises(UnsupportedOperation, match=re.escape("Stream is not readable.")):
_ = msgpack_deserialize_from_stream(io_buffer_data)


def test_messagepack_to_stream_writable_error(serialized_data):
data = to_msgpack(serialized_data)
input_data: Dataset = msgpack_deserialize(data)

io_buffer_data = FakeRawIO(initial_bytes=b"bla")
with pytest.raises(UnsupportedOperation, match=re.escape("Stream is not writable.")):
msgpack_serialize_to_stream(io_buffer_data, input_data, dataset_type=serialized_data["type"])
Loading