From 2f2e78a57dfc9b522e5ca448d6753c6cc1a4e698 Mon Sep 17 00:00:00 2001 From: original-doc Date: Sat, 17 Jan 2026 20:59:22 -0800 Subject: [PATCH 1/4] win11 adapted --- csrc/mlp_cuda.cu | 2 + csrc/multi_tensor_axpby_kernel.cu | 20 +++-- csrc/multi_tensor_scale_kernel.cu | 7 +- setup.py | 17 ++-- windows_install.md | 134 ++++++++++++++++++++++++++++++ 5 files changed, 165 insertions(+), 15 deletions(-) create mode 100644 windows_install.md diff --git a/csrc/mlp_cuda.cu b/csrc/mlp_cuda.cu index 4a870da4d..07c29aa62 100644 --- a/csrc/mlp_cuda.cu +++ b/csrc/mlp_cuda.cu @@ -10,6 +10,8 @@ #include #include +typedef unsigned int uint; + #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000 // includes cublaslt #include diff --git a/csrc/multi_tensor_axpby_kernel.cu b/csrc/multi_tensor_axpby_kernel.cu index 7d56488a7..d97c48e3e 100644 --- a/csrc/multi_tensor_axpby_kernel.cu +++ b/csrc/multi_tensor_axpby_kernel.cu @@ -9,7 +9,7 @@ #include "multi_tensor_apply.cuh" #include "type_shim.h" - +#include #define BLOCK_SIZE 512 #define ILP 4 @@ -61,9 +61,12 @@ struct AxpbyFunctor { #pragma unroll for (int ii = 0; ii < ILP; ii++) { r_out[ii] = a * static_cast(r_x[ii]) + b * static_cast(r_y[ii]); - if (arg_to_check == -1) finite = finite && (isfinite(r_x[ii]) && isfinite(r_y[ii])); - if (arg_to_check == 0) finite = finite && isfinite(r_x[ii]); - if (arg_to_check == 1) finite = finite && isfinite(r_y[ii]); + if (arg_to_check == -1) finite = finite && ((fabsf((float)r_x[ii]) <= 3.40282e+38f) && (fabsf((float)r_y[ii]) <= 3.40282e+38f)); + // if (arg_to_check == -1) finite = finite && (std::isfinite(static_cast(r_x[ii])) && std::isfinite(static_cast(r_y[ii]))); + if (arg_to_check == 0) finite = finite && (fabsf((float)r_x[ii]) <= 3.40282e+38f); + // if (arg_to_check == 0) finite = finite && std::isfinite(static_cast(r_x[ii])); + if (arg_to_check == 1) finite = finite && (fabsf((float)r_y[ii]) <= 3.40282e+38f); + // if (arg_to_check == 1) finite = finite && std::isfinite(static_cast(r_y[ii])); } // store load_store(out, r_out, i_start, 0); @@ -84,9 +87,12 @@ struct AxpbyFunctor { #pragma unroll for (int ii = 0; ii < ILP; ii++) { r_out[ii] = a * static_cast(r_x[ii]) + b * static_cast(r_y[ii]); - if (arg_to_check == -1) finite = finite && (isfinite(r_x[ii]) && isfinite(r_y[ii])); - if (arg_to_check == 0) finite = finite && isfinite(r_x[ii]); - if (arg_to_check == 1) finite = finite && isfinite(r_y[ii]); + if (arg_to_check == -1) finite = finite && ((fabsf((float)r_x[ii]) <= 3.40282e+38f) && (fabsf((float)r_y[ii]) <= 3.40282e+38f)); + // if (arg_to_check == -1) finite = finite && (std::isfinite(static_cast(r_x[ii])) && std::isfinite(static_cast(r_y[ii]))); + if (arg_to_check == 0) finite = finite && (fabsf((float)r_x[ii]) <= 3.40282e+38f); + // if (arg_to_check == 0) finite = finite && std::isfinite(static_cast(r_x[ii])); + if (arg_to_check == 1) finite = finite && (fabsf((float)r_y[ii]) <= 3.40282e+38f); + // if (arg_to_check == 1) finite = finite && std::isfinite(static_cast(r_y[ii])); } // see note in multi_tensor_scale_kernel.cu #pragma unroll diff --git a/csrc/multi_tensor_scale_kernel.cu b/csrc/multi_tensor_scale_kernel.cu index dc25be105..f10b264b8 100644 --- a/csrc/multi_tensor_scale_kernel.cu +++ b/csrc/multi_tensor_scale_kernel.cu @@ -11,6 +11,7 @@ #include "multi_tensor_apply.cuh" #include "type_shim.h" +#include #define BLOCK_SIZE 512 #define ILP 4 @@ -58,7 +59,8 @@ struct ScaleFunctor { #pragma unroll for (int ii = 0; ii < ILP; ii++) { r_out[ii] = static_cast(r_in[ii]) * scale; - finite = finite && isfinite(r_in[ii]); + finite = finite && (fabsf((float)r_in[ii]) <= 3.40282e+38f); + // finite = finite && std::isfinite(static_cast(r_in[ii])); } // store load_store(out, r_out, i_start, 0); @@ -80,7 +82,8 @@ struct ScaleFunctor { #pragma unroll for (int ii = 0; ii < ILP; ii++) { r_out[ii] = static_cast(r_in[ii]) * scale; - finite = finite && isfinite(r_in[ii]); + finite = finite && (fabsf((float)r_in[ii]) <= 3.40282e+38f); + // finite = finite && std::isfinite(static_cast(r_in[ii])); } #pragma unroll for (int ii = 0; ii < ILP; ii++) { diff --git a/setup.py b/setup.py index 32b218d0d..4978e372b 100644 --- a/setup.py +++ b/setup.py @@ -238,12 +238,13 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int "csrc/update_scale_hysteresis.cu", ], extra_compile_args={ - "cxx": ["-O3"], + "cxx": ["-O3", "-D_DISABLE_EXTENDED_ALIGNED_STORAGE"], "nvcc": [ "-lineinfo", "-O3", # '--resource-usage', "--use_fast_math", + "-D_DISABLE_EXTENDED_ALIGNED_STORAGE", ], }, ) @@ -274,9 +275,10 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int CUDAExtension( name="mlp_cuda", sources=["csrc/mlp.cpp", "csrc/mlp_cuda.cu"], + libraries=["cublas", "cublasLt"], extra_compile_args={ - "cxx": ["-O3"], - "nvcc": ["-O3"], + "cxx": ["-O3", "-D_DISABLE_EXTENDED_ALIGNED_STORAGE"], + "nvcc": ["-O3", "-D_DISABLE_EXTENDED_ALIGNED_STORAGE"], }, ) ) @@ -284,9 +286,10 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int CUDAExtension( name="fused_dense_cuda", sources=["csrc/fused_dense.cpp", "csrc/fused_dense_cuda.cu"], + libraries=["cublas", "cublasLt"], extra_compile_args={ - "cxx": ["-O3"], - "nvcc": ["-O3"], + "cxx": ["-O3", "-D_DISABLE_EXTENDED_ALIGNED_STORAGE"], + "nvcc": ["-O3", "-D_DISABLE_EXTENDED_ALIGNED_STORAGE"], }, ) ) @@ -405,8 +408,9 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int "csrc/megatron/fused_weight_gradient_dense_cuda.cu", "csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu", ], + libraries=["cublas", "cublasLt"], extra_compile_args={ - "cxx": ["-O3"], + "cxx": ["-O3", "-D_DISABLE_EXTENDED_ALIGNED_STORAGE"], "nvcc": [ "-O3", "-U__CUDA_NO_HALF_OPERATORS__", @@ -414,6 +418,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int "--expt-relaxed-constexpr", "--expt-extended-lambda", "--use_fast_math", + "-D_DISABLE_EXTENDED_ALIGNED_STORAGE", ], }, ) diff --git a/windows_install.md b/windows_install.md new file mode 100644 index 000000000..eab5d3355 --- /dev/null +++ b/windows_install.md @@ -0,0 +1,134 @@ +# Build Fixes for NVIDIA Apex on Windows 11 (CUDA 12.8 / MSVC 2022) + +## Installation Command + +Make sure you run below commands in **x64 Native Tools Command Prompt for VS 2022** (use search in the win11 to find it). Before install it, make sure your environment has the necessary dependencies like `Pytorch` and `ninja`. + +```bash +git clone https://github.com/NVIDIA/apex.git +cd apex +set APEX_CPP_EXT=1 +set APEX_CUDA_EXT=1 +set DISTUTILS_USE_SDK=1 +pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation ./ +``` + +--- + +> **Note:** Building NVIDIA Apex on Windows is challenging and may find different errors on different devices. This guide documents a successful build on Win11 RTX5070 (sm_120) with CUDA 12.8. + +--- + +## Build Environment + +| Component | Version | +|-----------|---------| +| **OS** | Windows 11 | +| **CUDA Toolkit** | 12.8 (Blackwell / SM_100 / SM_120) | +| **CUDA Path** | `E:\CUDA128` | +| **Compiler** | MSVC 2022 (Visual Studio Build Tools) | +| **Python** | 3.10 | +| **PyTorch** | 2.9.1+cu128 | +| **Build Flags** | `APEX_CPP_EXT=1`, `APEX_CUDA_EXT=1` | + +### NVCC Version Info + +``` +nvcc: NVIDIA (R) Cuda compiler driver +Copyright (c) 2005-2025 NVIDIA Corporation +Built on Wed_Jan_15_19:38:46_Pacific_Standard_Time_2025 +Cuda compilation tools, release 12.8, V12.8.61 +Build cuda_12.8.r12.8/compiler.35404655_0 +``` + +--- + +## Summary of Changes + +This patch addresses **three primary categories** of build failures encountered on Windows: + +1. Standard type definitions +2. MSVC-specific compiler flags for memory alignment +3. Explicit library linking for cuBLAS + +--- + +## 1. `setup.py` Configuration + +### Changes + +Added `libraries=["cublas", "cublasLt"]` and `extra_compile_args` with `-D_DISABLE_EXTENDED_ALIGNED_STORAGE` to several CUDA extensions. + +### Affected Extensions + +- `mlp_cuda` +- `fused_dense_cuda` +- `fused_weight_gradient_mlp_cuda` +- *(And potentially others using cuBLAS or aligned storage)* + +### Code Diff + +```python +ext_modules.append( + CUDAExtension( + name="module_name", + sources=["..."], + # Fix 1: Explicitly link cuBLAS for Windows + libraries=["cublas", "cublasLt"], + extra_compile_args={ + # Fix 2: Disable extended aligned storage to fix VS2019+ static assertion errors + "cxx": ["-O3", "-D_DISABLE_EXTENDED_ALIGNED_STORAGE"], + "nvcc": ["-O3", "-D_DISABLE_EXTENDED_ALIGNED_STORAGE", ...], + }, + ) +) +``` + +### Reasoning + +| Issue | Explanation | +|-------|-------------| +| **Linker Errors (`LNK2001`)** | Unlike Linux, the Windows build environment does not automatically link `cublas.lib` and `cublasLt.lib` when these headers are used. Explicit linking resolves unresolved external symbols for `cublasGemmEx`, `cublasLtMatmul`, etc. | +| **Alignment Errors** | Visual Studio 2017 (15.8 update) and later changed how `std::aligned_storage` works, causing compliance standard errors with older CUDA headers. The flag `_DISABLE_EXTENDED_ALIGNED_STORAGE` restores the necessary behavior for compilation to succeed. | + +--- + +## 2. Source Code Fixes (`csrc/`) + +### A. Type Definition Fix (`uint`) + +**File:** `csrc/mlp_cuda.cu` + +**Change:** Replaced `uint` with `unsigned int`. + +**Reasoning:** The type alias `uint` is standard in Linux system headers but is **not defined** by default in the MSVC (Windows) environment. Using the standard C++ type `unsigned int` ensures cross-platform compatibility. + +--- + +### B. Device Function Compatibility (`isfinite`) + +**Files:** +- `csrc/multi_tensor_scale_kernel.cu` +- `csrc/multi_tensor_axpby_kernel.cu` + +**Change:** Replaced the `isfinite()` check with a robust floating-point check using `fabsf`. Affected variables including `r_in[ii]`, `r_x[ii]` and `r_y[ii]`. + +```cpp +// Before +finite = finite && (isfinite(r_in[ii])); ... + +// After +finite = finite && (fabsf((float)r_in[ii]) <= 3.40282e+38f); ... +// Checks if value is within finite float range +``` + +**Reasoning:** On Windows NVCC, `isfinite` often resolves to the host-only C++ standard library function (`std::isfinite`) rather than the device intrinsic, causing a *"calling a host function from a device function"* error. Replacing it with `fabsf` (which is correctly mapped to a device intrinsic) bypasses this restriction while maintaining logical correctness. + +--- + + + + +## License + +Follow the original [NVIDIA Apex License](https://github.com/NVIDIA/apex/blob/master/LICENSE). \ No newline at end of file From 800bf1a0aff79b23769fa097b626b7b1a8ef3f55 Mon Sep 17 00:00:00 2001 From: original-doc Date: Sat, 17 Jan 2026 23:19:06 -0800 Subject: [PATCH 2/4] win11 adapted --- csrc/mlp_cuda.cu | 1 + setup.py | 7 ++++++- windows_install.md | 12 ++++++++++++ 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/csrc/mlp_cuda.cu b/csrc/mlp_cuda.cu index 07c29aa62..c2cc0db05 100644 --- a/csrc/mlp_cuda.cu +++ b/csrc/mlp_cuda.cu @@ -6,6 +6,7 @@ #include #include + /* Includes, cuda */ #include #include diff --git a/setup.py b/setup.py index 4978e372b..91011a22e 100644 --- a/setup.py +++ b/setup.py @@ -113,7 +113,11 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int return False return True - +import sys +nvcc_args = [] +if sys.platform == 'win32': + nvcc_args.append("-D_WIN32=1") + if not torch.cuda.is_available(): # https://github.com/NVIDIA/apex/issues/486 # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), @@ -1052,6 +1056,7 @@ def compile_new(*args, **kwargs): return objects + setup( name="apex", version="0.1", diff --git a/windows_install.md b/windows_install.md index eab5d3355..17e0c4d92 100644 --- a/windows_install.md +++ b/windows_install.md @@ -13,6 +13,18 @@ set DISTUTILS_USE_SDK=1 pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation ./ ``` +## Trouble shooting +If you encounter trouble with `compiled_autograd.h(1134 / 1108 / 1181)`, based on the [Pytorch issue #148317](https://github.com/pytorch/pytorch/issues/148317#issuecomment-3344732754), you may need to navigate to `\anaconda\envs\basic\lib\site-packages\torch\include\torch\csrc\dynamo\compiled_autograd.h`to Line 1134, and change it from: +```python +} else if constexpr (::std::is_same_v) { + return at::StringType::get(); +``` +to +```python +// } else if constexpr (::std::is_same_v) { +// return at::StringType::get(); +``` + --- > **Note:** Building NVIDIA Apex on Windows is challenging and may find different errors on different devices. This guide documents a successful build on Win11 RTX5070 (sm_120) with CUDA 12.8. From ce883ae9256967e522a927cf299afd99090bb757 Mon Sep 17 00:00:00 2001 From: original-doc Date: Sat, 17 Jan 2026 23:24:41 -0800 Subject: [PATCH 3/4] win11 adapted --- setup.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/setup.py b/setup.py index 91011a22e..a8e3ace03 100644 --- a/setup.py +++ b/setup.py @@ -113,11 +113,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int return False return True -import sys -nvcc_args = [] -if sys.platform == 'win32': - nvcc_args.append("-D_WIN32=1") - + if not torch.cuda.is_available(): # https://github.com/NVIDIA/apex/issues/486 # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), From 4f5f7b01e98ae40fe8ec8b7d04d83ce121a8f80c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 18 Jan 2026 07:50:46 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- csrc/mlp_cuda.cu | 1 - csrc/multi_tensor_axpby_kernel.cu | 15 ++++++++++----- csrc/multi_tensor_scale_kernel.cu | 2 +- setup.py | 1 - 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/csrc/mlp_cuda.cu b/csrc/mlp_cuda.cu index c2cc0db05..07c29aa62 100644 --- a/csrc/mlp_cuda.cu +++ b/csrc/mlp_cuda.cu @@ -6,7 +6,6 @@ #include #include - /* Includes, cuda */ #include #include diff --git a/csrc/multi_tensor_axpby_kernel.cu b/csrc/multi_tensor_axpby_kernel.cu index d97c48e3e..2ef56e1fe 100644 --- a/csrc/multi_tensor_axpby_kernel.cu +++ b/csrc/multi_tensor_axpby_kernel.cu @@ -7,9 +7,10 @@ #include +#include + #include "multi_tensor_apply.cuh" #include "type_shim.h" -#include #define BLOCK_SIZE 512 #define ILP 4 @@ -61,8 +62,10 @@ struct AxpbyFunctor { #pragma unroll for (int ii = 0; ii < ILP; ii++) { r_out[ii] = a * static_cast(r_x[ii]) + b * static_cast(r_y[ii]); - if (arg_to_check == -1) finite = finite && ((fabsf((float)r_x[ii]) <= 3.40282e+38f) && (fabsf((float)r_y[ii]) <= 3.40282e+38f)); - // if (arg_to_check == -1) finite = finite && (std::isfinite(static_cast(r_x[ii])) && std::isfinite(static_cast(r_y[ii]))); + if (arg_to_check == -1) + finite = finite && ((fabsf((float)r_x[ii]) <= 3.40282e+38f) && (fabsf((float)r_y[ii]) <= 3.40282e+38f)); + // if (arg_to_check == -1) finite = finite && (std::isfinite(static_cast(r_x[ii])) && + // std::isfinite(static_cast(r_y[ii]))); if (arg_to_check == 0) finite = finite && (fabsf((float)r_x[ii]) <= 3.40282e+38f); // if (arg_to_check == 0) finite = finite && std::isfinite(static_cast(r_x[ii])); if (arg_to_check == 1) finite = finite && (fabsf((float)r_y[ii]) <= 3.40282e+38f); @@ -87,8 +90,10 @@ struct AxpbyFunctor { #pragma unroll for (int ii = 0; ii < ILP; ii++) { r_out[ii] = a * static_cast(r_x[ii]) + b * static_cast(r_y[ii]); - if (arg_to_check == -1) finite = finite && ((fabsf((float)r_x[ii]) <= 3.40282e+38f) && (fabsf((float)r_y[ii]) <= 3.40282e+38f)); - // if (arg_to_check == -1) finite = finite && (std::isfinite(static_cast(r_x[ii])) && std::isfinite(static_cast(r_y[ii]))); + if (arg_to_check == -1) + finite = finite && ((fabsf((float)r_x[ii]) <= 3.40282e+38f) && (fabsf((float)r_y[ii]) <= 3.40282e+38f)); + // if (arg_to_check == -1) finite = finite && (std::isfinite(static_cast(r_x[ii])) && + // std::isfinite(static_cast(r_y[ii]))); if (arg_to_check == 0) finite = finite && (fabsf((float)r_x[ii]) <= 3.40282e+38f); // if (arg_to_check == 0) finite = finite && std::isfinite(static_cast(r_x[ii])); if (arg_to_check == 1) finite = finite && (fabsf((float)r_y[ii]) <= 3.40282e+38f); diff --git a/csrc/multi_tensor_scale_kernel.cu b/csrc/multi_tensor_scale_kernel.cu index f10b264b8..9b395be09 100644 --- a/csrc/multi_tensor_scale_kernel.cu +++ b/csrc/multi_tensor_scale_kernel.cu @@ -7,11 +7,11 @@ #include // Stringstream is a big hammer, but I want to rely on operator<< for dtype. +#include #include #include "multi_tensor_apply.cuh" #include "type_shim.h" -#include #define BLOCK_SIZE 512 #define ILP 4 diff --git a/setup.py b/setup.py index a8e3ace03..4978e372b 100644 --- a/setup.py +++ b/setup.py @@ -1052,7 +1052,6 @@ def compile_new(*args, **kwargs): return objects - setup( name="apex", version="0.1",