Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ changes in the squashed commit, which is also fine.

So, in short:

1. Run `pre-commit -a` locally.
1. Run `pre-commit run -a` locally.
2. Run `pytest` locally.
3. Check your local commit messages before pushing.
4. `git push`
Expand Down
98 changes: 84 additions & 14 deletions tests/schemas/test_thermal_model_profile_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: MPL-2.0

from contextlib import nullcontext as does_not_raise
from datetime import datetime
from datetime import UTC, datetime

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -72,7 +72,13 @@
None,
),
(
np.array([datetime(2021, 1, 1, 0, 0, 0), datetime(2021, 1, 1, 0, 15, 0), datetime(2021, 1, 1, 0, 30, 0)]),
np.array(
[
datetime(2021, 1, 1, 0, 0, 0),
datetime(2021, 1, 1, 0, 15, 0),
datetime(2021, 1, 1, 0, 30, 0),
]
),
[1, 2, 3],
[1, 2, 3],
None,
Expand Down Expand Up @@ -116,43 +122,58 @@
[1, 2, 3],
[1, 2, 3],
None,
pytest.raises(ValueError),
"Could not convert object to NumPy datetime",
pytest.raises(TypeError),
"Input must be a pandas DatetimeIndex or an iterable of datetime objects.",
),
(
np.array(["2021-01-01 00:00:00", "2021-01-01 00:15:00", "2021-01-01 00:25:00"], dtype="datetime64[s]"),
np.array(
["2021-01-01 00:00:00", "2021-01-01 00:15:00", "2021-01-01 00:25:00"],
dtype="datetime64[ns]",
),
[2, 4, 5],
{"a": 1, "b": 3, "c": 3},
None,
pytest.raises(TypeError),
None,
),
(
np.array(["2021-01-01 00:00:00", "2021-01-01 00:15:00", "2021-01-01 00:25:00"], dtype="datetime64[s]"),
np.array(
["2021-01-01 00:00:00", "2021-01-01 00:15:00", "2021-01-01 00:25:00"],
dtype="datetime64[ns]",
),
[2, 4, 5],
[2, 4, 5],
{"a": 1, "b": 3, "c": 3},
pytest.raises(TypeError),
None,
),
(
np.array(["2021-01-01 00:00:00", "2021-01-01 00:15:00", "2021-01-01 00:25:00"], dtype="datetime64[s]"),
np.array(
["2021-01-01 00:00:00", "2021-01-01 00:15:00", "2021-01-01 00:25:00"],
dtype="datetime64[ns]",
),
[[2, 4, 5], [2, 4, 5]],
[2, 3, 4],
None,
pytest.raises(ValueError),
None,
),
(
np.array(["2021-01-01 00:00:00", "2021-01-01 00:15:00", "2021-01-01 00:25:00"], dtype="datetime64[s]"),
np.array(
["2021-01-01 00:00:00", "2021-01-01 00:15:00", "2021-01-01 00:25:00"],
dtype="datetime64[ns]",
),
(2, 4, 5),
pd.DataFrame([2, 3, 4], [2, 3, 4]),
None,
pytest.raises(ValueError),
"array must be one-dimensional",
),
(
np.array(["2021-01-01 00:00:00", "2021-01-01 00:15:00", "2021-01-01 00:25:00"], dtype="datetime64[s]"),
np.array(
["2021-01-01 00:00:00", "2021-01-01 00:15:00", "2021-01-01 00:25:00"],
dtype="datetime64[ns]",
),
(2, 4, 5),
[2, 3, 4],
pd.DataFrame([2, 3, 4], [2, 3, 4]),
Expand All @@ -164,11 +185,14 @@
[1, 2, 3],
[1, 2, 3],
None,
pytest.raises(ValueError),
"Converting an integer to a NumPy datetime requires a specified unit",
does_not_raise(),
None,
),
(
np.array(["2021-01-01 00:00:00", "2021-01-01 00:15:00", "2021-01-01 00:25:00"], dtype="datetime64[s]"),
np.array(
["2021-01-01 00:00:00", "2021-01-01 00:15:00", "2021-01-01 00:25:00"],
dtype="datetime64[ns]",
),
(2, 4, 5),
[2, 3, 4],
[2, 3, 4, 2, 3, 4],
Expand All @@ -178,7 +202,12 @@
],
)
def test_that_the_input_data_for_thermal_model_is_validated_properly(
datetime_index, load_profile, ambient_temperature_profile, top_oil_temperature_profile, expectation, message
datetime_index,
load_profile,
ambient_temperature_profile,
top_oil_temperature_profile,
expectation,
message,
):
"""Test that the InputProfile can be created from two Series."""
with expectation as e:
Expand Down Expand Up @@ -235,6 +264,47 @@ def test_from_dataframe_missing_columns():
)

with pytest.raises(
ValueError, match="The dataframe is missing the following required columns: ambient_temperature_profile"
ValueError,
match="The dataframe is missing the following required columns: ambient_temperature_profile",
):
InputProfile.from_dataframe(df_missing_columns)


def test_timezone_gets_retained():
"""Test if timezone handling of input datetime index is properly retained."""
input_df = pd.DataFrame(
{
"datetime_index": pd.date_range("2021-01-01 00:00:00", periods=2, tz="UTC"),
"load_profile": [0.8, 0.9],
"ambient_temperature_profile": [10, 20],
}
)
assert input_df["datetime_index"].dtype == "datetime64[ns, UTC]"
input_profile_df = InputProfile.from_dataframe(input_df)
assert input_profile_df.datetime_index[0].tzinfo == UTC

input_list = [
datetime(2023, 1, 1, 1, 5, tzinfo=UTC),
datetime(2023, 1, 1, 1, 10, tzinfo=UTC),
]
assert input_list[0].tzinfo == UTC
input_list_out = InputProfile.create(
datetime_index=input_list,
load_profile=input_df["load_profile"],
ambient_temperature_profile=input_df["ambient_temperature_profile"],
)
assert input_list_out.datetime_index[0].tzinfo == UTC

input_naive = pd.Series(
[
datetime(2023, 1, 1, 1, 5),
datetime(2023, 1, 1, 1, 10),
],
)
assert input_naive.dtype == "<M8[ns]"
input_naive_out = InputProfile.create(
datetime_index=input_naive,
load_profile=input_df["load_profile"],
ambient_temperature_profile=input_df["ambient_temperature_profile"],
)
assert input_naive_out.datetime_index[0].tzinfo is None
67 changes: 52 additions & 15 deletions transformer_thermal_model/schemas/thermal_model/input_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,36 @@ def load_profile_array(self) -> np.typing.NDArray[np.float64]:
"""Require subclasses to define a load_profile property."""
raise NotImplementedError("Subclasses must define a load_profile field or property.")

@classmethod
def _to_datetime_array(cls, obj: pd.DatetimeIndex | pd.Series | list | tuple | np.ndarray) -> np.typing.NDArray:
"""Convert a pandas DatetimeIndex or iterable of datetimes to a NumPy array.

- Preserves timezone if present.
- Handles tz-naive efficiently.
- Falls back for lists or other iterables.
"""
if isinstance(obj, (list, tuple, np.ndarray, pd.Series)):
# Convert list-like to array, preserving tz if present
if all(hasattr(x, "tzinfo") for x in obj):
return np.array(obj, dtype=object)
else:
# Try except it to check if the obj in the iterable is something that
# can be interpreted as a datetime with preserving original type
# else raise error
try:
[pd.to_datetime(x) for x in obj]
except Exception as e:
raise ValueError("Provided array is not a datetime type") from e
return np.array(obj, dtype=type(obj))
elif isinstance(obj, pd.DatetimeIndex) and obj.tz is not None:
# tz-aware: preserve original timezone
return np.array(obj.to_pydatetime(), dtype=object)
elif isinstance(obj, pd.DatetimeIndex) and obj.tz is None:
# Not tz-aware: convert to normal datetime
return obj.to_numpy(dtype="datetime64[ns]")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For which types do you want this return to happen?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer the try-catch to be removed. Maybe it is better and more readable to do:
if..
elif
else
raise type error
Because by doing it like that you know exactly when a error is raised. now it could also be another problem being catched by the try-except construction

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again: Why convert to this specific dtype?
Maybe it would be better to use the original dtype?

else:
raise TypeError("Input must be a pandas DatetimeIndex or an iterable of datetime objects.")


class InputProfile(BaseInputProfile):
"""Class containing the temperature and load profiles of two winding transformers for the thermal model `Model()`.
Expand Down Expand Up @@ -138,9 +168,10 @@ def create(
... ambient_temperature_profile=ambient_temperature_profile,
... )
>>> input_profile
InputProfile(datetime_index=array(['2023-01-01T00:00:00.000000',
'2023-01-01T01:00:00.000000', '2023-01-01T02:00:00.000000'],
dtype='datetime64[us]'), ambient_temperature_profile=array([25. , 24.5, 24. ]),
InputProfile(datetime_index=array([datetime.datetime(2023, 1, 1, 0, 0),
datetime.datetime(2023, 1, 1, 1, 0),
datetime.datetime(2023, 1, 1, 2, 0)],
dtype=object), ambient_temperature_profile=array([25. , 24.5, 24. ]),
top_oil_temperature_profile=None, load_profile=array([0.8, 0.9, 1. ]))

```
Expand Down Expand Up @@ -189,15 +220,17 @@ def create(
... top_oil_temperature_profile=top_oil_temperature,
... )
>>> input_profile
InputProfile(datetime_index=array(['2023-01-01T00:00:00.000000', '2023-01-01T01:00:00.000000',
'2023-01-01T02:00:00.000000'], dtype='datetime64[us]'),
InputProfile(datetime_index=array([datetime.datetime(2023, 1, 1, 0, 0),
datetime.datetime(2023, 1, 1, 1, 0),
datetime.datetime(2023, 1, 1, 2, 0)],
dtype=object),
ambient_temperature_profile=array([25. , 24.5, 24. ]),
top_oil_temperature_profile=array([37. , 36.5, 36. ]), load_profile=array([0.8, 0.9, 1. ]))

```
"""
return cls(
datetime_index=np.array(datetime_index, dtype=np.datetime64),
datetime_index=cls._to_datetime_array(datetime_index),
load_profile=np.array(load_profile, dtype=float),
ambient_temperature_profile=np.array(ambient_temperature_profile, dtype=float),
top_oil_temperature_profile=(
Expand Down Expand Up @@ -243,7 +276,11 @@ def from_dataframe(cls, df: pd.DataFrame) -> Self:
An InputProfile object.

"""
required_columns = {"datetime_index", "load_profile", "ambient_temperature_profile"}
required_columns = {
"datetime_index",
"load_profile",
"ambient_temperature_profile",
}
missing_columns = required_columns - set(df.columns)
if missing_columns:
raise ValueError(f"The dataframe is missing the following required columns: {', '.join(missing_columns)}")
Expand All @@ -252,9 +289,9 @@ def from_dataframe(cls, df: pd.DataFrame) -> Self:
datetime_index=df["datetime_index"].to_numpy(),
load_profile=df["load_profile"].to_numpy(),
ambient_temperature_profile=df["ambient_temperature_profile"].to_numpy(),
top_oil_temperature_profile=df["top_oil_temperature_profile"].to_numpy()
if "top_oil_temperature_profile" in df.columns
else None,
top_oil_temperature_profile=(
df["top_oil_temperature_profile"].to_numpy() if "top_oil_temperature_profile" in df.columns else None
),
)

model_config = ConfigDict(arbitrary_types_allowed=True)
Expand Down Expand Up @@ -341,7 +378,7 @@ def create(
... )
>>> input_profile
ThreeWindingInputProfile(datetime_index=array(['2023-01-01T00:00:00.000000', '2023-01-01T01:00:00.000000',
'2023-01-01T02:00:00.000000'], dtype='datetime64[us]'),
'2023-01-01T02:00:00.000000'], dtype='datetime64[ns]'),
ambient_temperature_profile=array([25. , 24.5, 24. ]),
top_oil_temperature_profile=None,
load_profile_high_voltage_side=array([0.8, 0.9, 1. ]),
Expand All @@ -351,14 +388,14 @@ def create(
```
"""
return cls(
datetime_index=np.array(datetime_index, dtype=np.datetime64),
datetime_index=cls._to_datetime_array(datetime_index),
ambient_temperature_profile=np.array(ambient_temperature_profile, dtype=float),
load_profile_high_voltage_side=np.array(load_profile_high_voltage_side, dtype=float),
load_profile_middle_voltage_side=np.array(load_profile_middle_voltage_side, dtype=float),
load_profile_low_voltage_side=np.array(load_profile_low_voltage_side, dtype=float),
top_oil_temperature_profile=np.array(top_oil_temperature_profile, dtype=float)
if top_oil_temperature_profile is not None
else None,
top_oil_temperature_profile=(
np.array(top_oil_temperature_profile, dtype=float) if top_oil_temperature_profile is not None else None
),
)

@model_validator(mode="after")
Expand Down
Loading