Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 37 additions & 35 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ requires = [
"ninja>=1.11.0",
"pyyaml>=6.0",
"cffi>=1.15.1",
"torch>=2.10.0.dev,<2.11.0; platform_machine != 'aarch64' or (platform_machine == 'aarch64' and 'tegra' not in platform_release)",
"torch>=2.8.0,<2.9.0; platform_machine == 'aarch64' and 'tegra' in platform_release",
"torch>=2.10.0.dev,<2.11.0",
"pybind11==2.6.2",
]
build-backend = "setuptools.build_meta"
Expand All @@ -33,7 +32,7 @@ classifiers = [
"Topic :: Software Development :: Libraries",
]
readme = { file = "README.md", content-type = "text/markdown" }
requires-python = ">=3.9"
requires-python = ">=3.10"
keywords = [
"pytorch",
"torch",
Expand Down Expand Up @@ -100,12 +99,10 @@ index-strategy = "unsafe-best-match"

[tool.uv.sources]
torch = [
{ index = "pytorch-nightly-cu130", marker = "platform_machine != 'aarch64' or (platform_machine == 'aarch64' and 'tegra' not in platform_release)" },
{ index = "jetson-containers", marker = "platform_machine == 'aarch64' and 'tegra' in platform_release" },
{ index = "pytorch-nightly-cu130" },
]
torchvision = [
{ index = "pytorch-nightly-cu130", marker = "platform_machine != 'aarch64' or (platform_machine == 'aarch64' and 'tegra' not in platform_release)" },
{ index = "jetson-containers", marker = "platform_machine == 'aarch64' and 'tegra' in platform_release" },
{ index = "pytorch-nightly-cu130" },
]

[[tool.uv.index]]
Expand All @@ -114,50 +111,55 @@ url = "https://download.pytorch.org/whl/nightly/cu130"
explicit = false

[[tool.uv.index]]
name = "pytorch-nightly-cu129"
url = "https://download.pytorch.org/whl/nightly/cu129"
name = "pytorch-nightly-cu128"
url = "https://download.pytorch.org/whl/nightly/cu128"
explicit = false

[[tool.uv.index]]
name = "jetson-containers"
url = "https://pypi.jetson-ai-lab.io/jp6/cu126"
name = "pytorch-nightly-cu126"
url = "https://download.pytorch.org/whl/nightly/cu126"
explicit = false

[[tool.uv.index]]
name = "nvidia"
url = "https://pypi.nvidia.com"
name = "pytorch-test-cu130"
url = "https://download.pytorch.org/whl/test/cu130"
explicit = false

# [[tool.uv.index]]
# name = "pytorch-nightly-cu124"
# url = "https://download.pytorch.org/whl/nightly/cu124"
# explicit = true
[[tool.uv.index]]
name = "pytorch-test-cu128"
url = "https://download.pytorch.org/whl/test/cu128"
explicit = false

# [[tool.uv.index]]
# name = "pytorch-nightly-cu118"
# url = "https://download.pytorch.org/whl/nightly/cu118"
# explicit = true
[[tool.uv.index]]
name = "pytorch-test-cu126"
url = "https://download.pytorch.org/whl/test/cu126"
explicit = false

# [[tool.uv.index]]
# name = "pytorch-test-cu124"
# url = "https://download.pytorch.org/whl/test/cu124"
# explicit = false
[[tool.uv.index]]
name = "pytorch-release-cu130"
url = "https://download.pytorch.org/whl/release/cu130"
explicit = false

# [[tool.uv.index]]
# name = "pytorch-test-cu118"
# url = "https://download.pytorch.org/whl/test/cu118"
# explicit = false
[[tool.uv.index]]
name = "pytorch-release-cu128"
url = "https://download.pytorch.org/whl/release/cu128"
explicit = false

# [[tool.uv.index]]
# name = "pytorch-release-cu124"
# url = "https://download.pytorch.org/whl/cu124"
# explicit = false
[[tool.uv.index]]
name = "pytorch-release-cu126"
url = "https://download.pytorch.org/whl/release/cu126"
explicit = false

# [[tool.uv.index]]
# name = "pytorch-release-cu118"
# url = "https://download.pytorch.org/whl/cu118"
# name = "jetson-containers"
# url = "https://pypi.jetson-ai-lab.io/jp6/cu126"
# explicit = false

[[tool.uv.index]]
name = "nvidia"
url = "https://pypi.nvidia.com"
explicit = false


[tool.ruff]
# NOTE: Synchoronize the ignores with .flake8
Expand Down
122 changes: 68 additions & 54 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,71 +725,85 @@ def run(self):
with open(os.path.join(get_root_dir(), "README.md"), "r", encoding="utf-8") as fh:
long_description = fh.read()

base_requirements = [
"packaging>=23",
"typing-extensions>=4.7.0",
"dllist",
"psutil",
# dummy package as a WAR for the tensorrt dependency on nvidia-cuda-runtime-cu13
"nvidia-cuda-runtime-cu13==0.0.0a0",
]

def get_jetpack_requirements(base_requirements):
requirements = base_requirements + ["numpy<2.0.0"]
if IS_DLFW_CI:
return requirements
else:
return requirements + ["torch>=2.8.0,<2.9.0", "tensorrt>=10.3.0,<10.4.0"]

def get_requirements():
if IS_JETPACK:
requirements = get_jetpack_requirements()
elif IS_SBSA:
requirements = get_sbsa_requirements()

def get_sbsa_requirements(base_requirements):
requirements = base_requirements + ["numpy"]
if IS_DLFW_CI:
return requirements
else:
# standard linux and windows requirements
requirements = base_requirements + ["numpy"]
if not IS_DLFW_CI:
requirements = requirements + ["torch>=2.10.0.dev,<2.11.0"]
if USE_TRT_RTX:
# TensorRT does not currently build wheels for Tegra, so we need to use the local tensorrt install from the tarball for thor
# also due to we use sbsa torch_tensorrt wheel for thor, so when we build sbsa wheel, we need to only include tensorrt dependency.
return requirements + [
"torch>=2.10.0.dev,<2.11.0",
"tensorrt>=10.14.1,<10.15.0",
]


def get_x86_64_requirements(base_requirements):
requirements = base_requirements + ["numpy"]

if IS_DLFW_CI:
return requirements
else:
requirements = requirements + ["torch>=2.10.0.dev,<2.11.0"]
if USE_TRT_RTX:
return requirements + [
"tensorrt_rtx>=1.2.0.54",
]
else:
requirements = requirements + [
"tensorrt>=10.14.1,<10.15.0",
]
cuda_version = torch.version.cuda
if cuda_version.startswith("12"):
# directly use tensorrt>=10.14.1,<10.15.0 in cu12* env, it will pull both tensorrt_cu12 and tensorrt_cu13
# which will cause the conflict due to cuda-toolkit 13 is also pulled in, so we need to specify tensorrt_cu12 here
tensorrt_prefix = "tensorrt-cu12"
requirements = requirements + [
f"{tensorrt_prefix}>=10.14.1,<10.15.0",
f"{tensorrt_prefix}-bindings>=10.14.1,<10.15.0",
f"{tensorrt_prefix}-libs>=10.14.1,<10.15.0",
]
elif cuda_version.startswith("13"):
tensorrt_prefix = "tensorrt-cu13"
requirements = requirements + [
"tensorrt_rtx>=1.2.0.54",
f"{tensorrt_prefix}>=10.14.1,<10.15.0,!=10.14.1.48",
f"{tensorrt_prefix}-bindings>=10.14.1,<10.15.0,!=10.14.1.48",
f"{tensorrt_prefix}-libs>=10.14.1,<10.15.0,!=10.14.1.48",
]
else:
cuda_version = torch.version.cuda
if cuda_version.startswith("12"):
# directly use tensorrt>=10.14.1,<10.15.0 in cu12* env, it will pull both tensorrt_cu12 and tensorrt_cu13
# which will cause the conflict due to cuda-toolkit 13 is also pulled in, so we need to specify tensorrt_cu12 here
tensorrt_prefix = "tensorrt-cu12"
requirements = requirements + [
f"{tensorrt_prefix}>=10.14.1,<10.15.0",
f"{tensorrt_prefix}-bindings>=10.14.1,<10.15.0",
f"{tensorrt_prefix}-libs>=10.14.1,<10.15.0",
]
elif cuda_version.startswith("13"):
tensorrt_prefix = "tensorrt-cu13"
requirements = requirements + [
f"{tensorrt_prefix}>=10.14.1,<10.15.0,!=10.14.1.48",
f"{tensorrt_prefix}-bindings>=10.14.1,<10.15.0,!=10.14.1.48",
f"{tensorrt_prefix}-libs>=10.14.1,<10.15.0,!=10.14.1.48",
]
else:
raise ValueError(f"Unsupported CUDA version: {cuda_version}")
return requirements
raise ValueError(f"Unsupported CUDA version: {cuda_version}")


def get_jetpack_requirements():
jetpack_requirements = base_requirements + ["numpy<2.0.0"]
if IS_DLFW_CI:
return jetpack_requirements
return jetpack_requirements + ["torch>=2.8.0,<2.9.0", "tensorrt>=10.3.0,<10.4.0"]
return requirements


def get_sbsa_requirements():
sbsa_requirements = base_requirements + ["numpy"]
if IS_DLFW_CI:
return sbsa_requirements
# TensorRT does not currently build wheels for Tegra, so we need to use the local tensorrt install from the tarball for thor
# also due to we use sbsa torch_tensorrt wheel for thor, so when we build sbsa wheel, we need to only include tensorrt dependency.
return sbsa_requirements + [
"torch>=2.10.0.dev,<2.11.0",
"tensorrt>=10.14.1,<10.15.0",
def get_requirements():
base_requirements = [
"packaging>=23",
"typing-extensions>=4.7.0",
"dllist",
"psutil",
# dummy package as a WAR for the tensorrt dependency on nvidia-cuda-runtime-cu13
"nvidia-cuda-runtime-cu13==0.0.0a0",
]

if IS_JETPACK:
requirements = get_jetpack_requirements(base_requirements)
elif IS_SBSA:
requirements = get_sbsa_requirements(base_requirements)
else:
# standard linux and windows requirements
requirements = get_x86_64_requirements(base_requirements)
return requirements


setup(
name="torch_tensorrt",
Expand Down
Loading
Loading