Skip to content

TorchHook: A PyTorch hooks manager, providing convenient interfaces to capture feature maps and debug models.

License

Notifications You must be signed in to change notification settings

zzaiyan/TorchHook

Repository files navigation

TorchHook Logo

TorchHook

PyPI version License: MIT Downloads Python Version

English Blog | 中文博客 | 中文文档

TorchHook is a lightweight, easy-to-use Python library designed to simplify the process of extracting intermediate features from PyTorch models. It provides a clean API to manage PyTorch hooks for capturing layer outputs without the boilerplate code.

Key Features

  • Easy Hook Registration: Quickly register hooks for desired model layers by name or object.
  • Flexible Feature Extraction: Retrieve captured features easily.
  • Highly Customizable: Define custom hook logic or output transformations.
  • Resource Management: Automatic cleanup of registered hooks.

Installation

pip install torchhook

Or install from the local source:

git clone https://github.com/zzaiyan/TorchHook.git
cd TorchHook
pip install .

Quick Start

import torch
import torchvision.models as models
from torchhook import HookManager

# 1. Load your model
model = models.resnet18()
model.eval()

# 2. Initialize HookManager
hook_manager = HookManager(model, max_size=1) # Keep only the latest feature per hook

# 3. Register layers
hook_manager.add(layer_name='conv1')
hook_manager.add(layer_name='layer4.1.relu')
hook_manager.add(layer_name='fully_connected', layer=model.fc) # Optional: pass layer object

# 4. Forward pass
dummy_input = torch.randn(1, 3, 224, 224)
with torch.no_grad():
    output = model(dummy_input)

# 5. Get features
features_conv1 = hook_manager.get('conv1')
features_relu = hook_manager.get('layer4.1.relu')
all_features = hook_manager.get_all() # Get all features as a dict

print(f"Conv1 feature shape: {features_conv1[0].shape}")
print(f"Layer 4.1 ReLU feature shape: {features_relu[0].shape}")

# 6. Summary (Optional)
hook_manager.summary()

# 7. Clean up hooks (Important!)
hook_manager.clear_hooks()

For advanced usage like custom hooks and output transformations, please refer to the blog posts: English | 中文

About

TorchHook: A PyTorch hooks manager, providing convenient interfaces to capture feature maps and debug models.

Resources

License

Stars

Watchers

Forks