Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def __init__(
reset_ops_id: bool = True,
track_meta: bool = False,
weights_only: bool = True,
in_memory: bool = False,
) -> None:
"""
Args:
Expand Down Expand Up @@ -273,6 +274,10 @@ def __init__(
other safe objects. Setting this to `False` is required for loading `MetaTensor`
objects saved with `track_meta=True`, however this creates the possibility of remote
code execution through `torch.load` so be aware of the security implications of doing so.
in_memory: if `True`, keep the pre-processed data in an in-memory dictionary after first access.
This combines the benefits of persistent storage (data survives restarts) with faster RAM access.
When data is accessed, it is first loaded from disk cache and then stored in memory.
Default to `False`.

Raises:
ValueError: When both `track_meta=True` and `weights_only=True`, since this combination
Expand All @@ -299,6 +304,13 @@ def __init__(
)
self.track_meta = track_meta
self.weights_only = weights_only
self.in_memory = in_memory
self._memory_cache: dict[str, Any] = {}

@property
def memory_cache_size(self) -> int:
"""Return the number of items currently stored in the in-memory cache."""
return len(self._memory_cache)

def set_transform_hash(self, hash_xform_func: Callable[..., bytes]):
"""Get hashable transforms, and then hash them. Hashable transforms
Expand Down Expand Up @@ -326,6 +338,7 @@ def set_data(self, data: Sequence):

"""
self.data = data
self._memory_cache = {}
if self.cache_dir is not None and self.cache_dir.exists():
shutil.rmtree(self.cache_dir, ignore_errors=True)
self.cache_dir.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -389,14 +402,24 @@ def _cachecheck(self, item_transformed):

"""
hashfile = None
# compute cache key once for both disk and memory caching
data_item_md5 = self.hash_func(item_transformed).decode("utf-8")
data_item_md5 += self.transform_hash
cache_key = f"{data_item_md5}.pt"

if self.cache_dir is not None:
data_item_md5 = self.hash_func(item_transformed).decode("utf-8")
data_item_md5 += self.transform_hash
hashfile = self.cache_dir / f"{data_item_md5}.pt"
hashfile = self.cache_dir / cache_key

# check in-memory cache first
if self.in_memory and cache_key in self._memory_cache:
return self._memory_cache[cache_key]

if hashfile is not None and hashfile.is_file(): # cache hit
try:
return torch.load(hashfile, weights_only=self.weights_only)
_item_transformed = torch.load(hashfile, weights_only=self.weights_only)
if self.in_memory:
self._memory_cache[cache_key] = _item_transformed
return _item_transformed
except PermissionError as e:
if sys.platform != "win32":
raise e
Expand All @@ -409,6 +432,8 @@ def _cachecheck(self, item_transformed):

_item_transformed = self._pre_transform(deepcopy(item_transformed)) # keep the original hashed
if hashfile is None:
if self.in_memory:
self._memory_cache[cache_key] = _item_transformed
return _item_transformed
try:
# NOTE: Writing to a temporary directory and then using a nearly atomic rename operation
Expand All @@ -431,6 +456,8 @@ def _cachecheck(self, item_transformed):
pass
except PermissionError: # project-monai/monai issue #3613
pass
if self.in_memory:
self._memory_cache[cache_key] = _item_transformed
return _item_transformed

def _transform(self, index: int):
Expand Down
106 changes: 106 additions & 0 deletions tests/data/test_persistentdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os
import tempfile
import unittest
from pathlib import Path

import nibabel as nib
import numpy as np
Expand Down Expand Up @@ -200,6 +201,111 @@ def test_track_meta_and_weights_only(self, track_meta, weights_only, expected_er
im = test_dataset[0]["image"]
self.assertIsInstance(im, expected_type)

def test_in_memory_cache(self):
"""Test in_memory caching feature that combines persistent storage with RAM caching."""
items = [[list(range(i))] for i in range(5)]

with tempfile.TemporaryDirectory() as tempdir:
# First, create the persistent cache
ds1 = PersistentDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, in_memory=False)
# Access all items to populate disk cache
_ = list(ds1)

# Now create a new dataset with in_memory=True
ds2 = PersistentDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, in_memory=True)

# Memory cache should be empty initially
self.assertEqual(ds2.memory_cache_size, 0)

# Access items - they should be loaded from disk and cached in memory
_ = ds2[0]
self.assertEqual(ds2.memory_cache_size, 1)

_ = ds2[1]
self.assertEqual(ds2.memory_cache_size, 2)

# Access all items
_ = list(ds2)
self.assertEqual(ds2.memory_cache_size, 5)

# Accessing same item again should use memory cache (same result)
result1 = ds2[0]
result2 = ds2[0]
self.assertEqual(result1, result2)

# Test set_data clears in-memory cache
ds2.set_data(items[:3])
self.assertEqual(ds2.memory_cache_size, 0)

def test_in_memory_without_cache_dir(self):
"""Test in_memory caching works even without a cache_dir (pure RAM cache)."""
items = [[list(range(i))] for i in range(3)]

ds = PersistentDataset(data=items, transform=_InplaceXform(), cache_dir=None, in_memory=True)

# Memory cache should be empty initially
self.assertEqual(ds.memory_cache_size, 0)

# Access items - they should be cached in memory
_ = ds[0]
self.assertEqual(ds.memory_cache_size, 1)

_ = list(ds)
self.assertEqual(ds.memory_cache_size, 3)

def test_automatic_hybrid_caching(self):
"""
Test that in_memory=True provides automatic hybrid caching:
- ALL samples automatically persist to disk
- ALL samples automatically cache to RAM after first access
- No manual specification of which samples go where (unlike torchdatasets)
- Simulates restart scenario: disk cache survives, RAM cache rebuilds automatically
"""
items = [[list(range(i))] for i in range(5)]

with tempfile.TemporaryDirectory() as tempdir:
# === First "session": populate both disk and RAM cache ===
ds1 = PersistentDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, in_memory=True)

# Access all items - should automatically cache to BOTH disk AND RAM
for i in range(len(items)):
_ = ds1[i]

# Verify: ALL samples are in RAM (automatic, no manual specification)
self.assertEqual(ds1.memory_cache_size, 5)

# Verify: ALL samples are on disk (count .pt files)
cache_files = list(Path(tempdir).glob("*.pt"))
self.assertEqual(len(cache_files), 5)

# === Simulate "restart": new dataset instance, same cache_dir ===
# This is the key benefit over CacheDataset - disk cache survives restart
ds2 = PersistentDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, in_memory=True)

# RAM cache starts empty (simulating fresh process)
self.assertEqual(ds2.memory_cache_size, 0)

# Access all items - should load from disk and automatically cache to RAM
results = [ds2[i] for i in range(len(items))]

# Verify: ALL samples now in RAM again (automatic rebuild from disk)
self.assertEqual(ds2.memory_cache_size, 5)

# Verify: Results are correct
for i, result in enumerate(results):
self.assertEqual(result, [list(range(i))])
Comment on lines +294 to +296
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Assertion may fail due to transform behavior.

The _InplaceXform modifies data by adding np.pi to data[0] (line 60). The assertion self.assertEqual(result, [list(range(i))]) expects untransformed data, but cached results will have np.pi added.

For i=2, cached result would be [[np.pi, 1]], not [[0, 1]].

Proposed fix
-            for i, result in enumerate(results):
-                self.assertEqual(result, [list(range(i))])
+            for i, result in enumerate(results):
+                expected = [list(range(i))]
+                if expected[0]:
+                    expected[0][0] = expected[0][0] + np.pi
+                self.assertEqual(result, expected)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Verify: Results are correct
for i, result in enumerate(results):
self.assertEqual(result, [list(range(i))])
# Verify: Results are correct
for i, result in enumerate(results):
expected = [list(range(i))]
if expected[0]:
expected[0][0] = expected[0][0] + np.pi
self.assertEqual(result, expected)
🤖 Prompt for AI Agents
In @tests/data/test_persistentdataset.py around lines 294 - 296, The test
expects untransformed data but the used transform class _InplaceXform mutates
data by adding np.pi to data[0], so update the assertion in the results loop to
account for that transform: compute the expected row as list(range(i)) then add
np.pi to its first element (or use numpy to build expected and add np.pi to
expected[0,0]) and compare with results using a numeric comparison (e.g.,
np.testing.assert_allclose) instead of plain equality so cached, transformed
values like [[np.pi, 1]] match the expectation.


# === Verify RAM cache provides fast repeated access ===
# Accessing same items again should hit RAM cache (same objects)
for i in range(len(items)):
result1 = ds2[i]
result2 = ds2[i]
# Should return equivalent results
self.assertEqual(result1, result2)

# RAM cache size unchanged (no duplicate entries)
self.assertEqual(ds2.memory_cache_size, 5)


if __name__ == "__main__":
unittest.main()
Loading