onnxruntime: add ROCm support (#454399)
This commit is contained in:
@@ -14,6 +14,7 @@
|
|||||||
howard-hinnant-date,
|
howard-hinnant-date,
|
||||||
libpng,
|
libpng,
|
||||||
nlohmann_json,
|
nlohmann_json,
|
||||||
|
perl,
|
||||||
pkg-config,
|
pkg-config,
|
||||||
python3Packages,
|
python3Packages,
|
||||||
re2,
|
re2,
|
||||||
@@ -24,8 +25,10 @@
|
|||||||
pythonSupport ? true,
|
pythonSupport ? true,
|
||||||
cudaSupport ? config.cudaSupport,
|
cudaSupport ? config.cudaSupport,
|
||||||
ncclSupport ? cudaSupport && cudaPackages.nccl.meta.available,
|
ncclSupport ? cudaSupport && cudaPackages.nccl.meta.available,
|
||||||
|
rocmSupport ? config.rocmSupport,
|
||||||
withFullProtobuf ? false,
|
withFullProtobuf ? false,
|
||||||
cudaPackages ? { },
|
cudaPackages ? { },
|
||||||
|
rocmPackages,
|
||||||
}@inputs:
|
}@inputs:
|
||||||
|
|
||||||
let
|
let
|
||||||
@@ -121,6 +124,9 @@ effectiveStdenv.mkDerivation rec {
|
|||||||
]
|
]
|
||||||
++ lib.optionals isCudaJetson [
|
++ lib.optionals isCudaJetson [
|
||||||
cudaPackages.autoAddCudaCompatRunpath
|
cudaPackages.autoAddCudaCompatRunpath
|
||||||
|
]
|
||||||
|
++ lib.optionals rocmSupport [
|
||||||
|
perl # for tools/ci_build/hipify-perl
|
||||||
];
|
];
|
||||||
|
|
||||||
buildInputs = [
|
buildInputs = [
|
||||||
@@ -156,6 +162,22 @@ effectiveStdenv.mkDerivation rec {
|
|||||||
]
|
]
|
||||||
++ lib.optionals ncclSupport [ nccl ]
|
++ lib.optionals ncclSupport [ nccl ]
|
||||||
)
|
)
|
||||||
|
++ lib.optionals rocmSupport [
|
||||||
|
rocmPackages.clr
|
||||||
|
rocmPackages.hipblas
|
||||||
|
rocmPackages.hipcub
|
||||||
|
rocmPackages.hipfft
|
||||||
|
rocmPackages.hiprand
|
||||||
|
rocmPackages.hipsparse
|
||||||
|
rocmPackages.rocblas
|
||||||
|
rocmPackages.rocprim
|
||||||
|
rocmPackages.rocrand
|
||||||
|
rocmPackages.rocthrust
|
||||||
|
rocmPackages.miopen
|
||||||
|
rocmPackages.rccl
|
||||||
|
rocmPackages.rocm-smi
|
||||||
|
rocmPackages.roctracer
|
||||||
|
]
|
||||||
++ lib.optionals effectiveStdenv.hostPlatform.isDarwin [
|
++ lib.optionals effectiveStdenv.hostPlatform.isDarwin [
|
||||||
(darwinMinVersionHook "13.3")
|
(darwinMinVersionHook "13.3")
|
||||||
];
|
];
|
||||||
@@ -203,6 +225,7 @@ effectiveStdenv.mkDerivation rec {
|
|||||||
(lib.cmakeBool "onnxruntime_USE_FULL_PROTOBUF" withFullProtobuf)
|
(lib.cmakeBool "onnxruntime_USE_FULL_PROTOBUF" withFullProtobuf)
|
||||||
(lib.cmakeBool "onnxruntime_USE_CUDA" cudaSupport)
|
(lib.cmakeBool "onnxruntime_USE_CUDA" cudaSupport)
|
||||||
(lib.cmakeBool "onnxruntime_USE_NCCL" (cudaSupport && ncclSupport))
|
(lib.cmakeBool "onnxruntime_USE_NCCL" (cudaSupport && ncclSupport))
|
||||||
|
(lib.cmakeBool "onnxruntime_USE_ROCM" rocmSupport)
|
||||||
(lib.cmakeBool "onnxruntime_ENABLE_LTO" (!cudaSupport || cudaPackages.cudaOlder "12.8"))
|
(lib.cmakeBool "onnxruntime_ENABLE_LTO" (!cudaSupport || cudaPackages.cudaOlder "12.8"))
|
||||||
]
|
]
|
||||||
++ lib.optionals pythonSupport [
|
++ lib.optionals pythonSupport [
|
||||||
@@ -213,15 +236,43 @@ effectiveStdenv.mkDerivation rec {
|
|||||||
(lib.cmakeFeature "onnxruntime_CUDNN_HOME" "${cudaPackages.cudnn}")
|
(lib.cmakeFeature "onnxruntime_CUDNN_HOME" "${cudaPackages.cudnn}")
|
||||||
(lib.cmakeFeature "CMAKE_CUDA_ARCHITECTURES" cudaArchitecturesString)
|
(lib.cmakeFeature "CMAKE_CUDA_ARCHITECTURES" cudaArchitecturesString)
|
||||||
(lib.cmakeFeature "onnxruntime_NVCC_THREADS" "1")
|
(lib.cmakeFeature "onnxruntime_NVCC_THREADS" "1")
|
||||||
|
]
|
||||||
|
++ lib.optionals rocmSupport [
|
||||||
|
# Werror combines with rocprim header issues to cause errors (warp size const deprecation)
|
||||||
|
"--compile-no-warning-as-error"
|
||||||
|
(lib.cmakeFeature "CMAKE_HIP_ARCHITECTURES" (
|
||||||
|
builtins.concatStringsSep ";" rocmPackages.clr.localGpuTargets or rocmPackages.clr.gpuTargets
|
||||||
|
))
|
||||||
|
(lib.cmakeFeature "onnxruntime_ROCM_HOME" "${rocmPackages.clr}")
|
||||||
|
# Incompatible with packaged version, far too slow to build vendored version
|
||||||
|
(lib.cmakeBool "onnxruntime_USE_COMPOSABLE_KERNEL" false)
|
||||||
|
(lib.cmakeBool "onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE" false)
|
||||||
];
|
];
|
||||||
|
|
||||||
env = lib.optionalAttrs effectiveStdenv.cc.isClang {
|
env =
|
||||||
|
lib.optionalAttrs effectiveStdenv.cc.isClang {
|
||||||
NIX_CFLAGS_COMPILE = "-Wno-error";
|
NIX_CFLAGS_COMPILE = "-Wno-error";
|
||||||
|
}
|
||||||
|
// lib.optionalAttrs rocmSupport {
|
||||||
|
MIOPEN_PATH = rocmPackages.miopen;
|
||||||
|
# HIP steps fail to find ROCm libs when not in HIPFLAGS, causing
|
||||||
|
# fatal error: 'rocrand/rocrand.h' file not found
|
||||||
|
HIPFLAGS = lib.concatMapStringsSep " " (pkg: "-I${lib.getInclude pkg}/include") [
|
||||||
|
rocmPackages.hipblas
|
||||||
|
rocmPackages.hipcub
|
||||||
|
rocmPackages.hiprand
|
||||||
|
rocmPackages.hipsparse
|
||||||
|
rocmPackages.rocblas
|
||||||
|
rocmPackages.rocprim
|
||||||
|
rocmPackages.rocrand
|
||||||
|
rocmPackages.rocthrust
|
||||||
|
];
|
||||||
};
|
};
|
||||||
|
|
||||||
doCheck =
|
doCheck =
|
||||||
!(
|
!(
|
||||||
cudaSupport
|
cudaSupport
|
||||||
|
|| rocmSupport
|
||||||
|| builtins.elem effectiveStdenv.buildPlatform.system [
|
|| builtins.elem effectiveStdenv.buildPlatform.system [
|
||||||
# aarch64-linux fails cpuinfo test, because /sys/devices/system/cpu/ does not exist in the sandbox
|
# aarch64-linux fails cpuinfo test, because /sys/devices/system/cpu/ does not exist in the sandbox
|
||||||
"aarch64-linux"
|
"aarch64-linux"
|
||||||
@@ -231,7 +282,7 @@ effectiveStdenv.mkDerivation rec {
|
|||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
requiredSystemFeatures = lib.optionals cudaSupport [ "big-parallel" ];
|
requiredSystemFeatures = lib.optionals (cudaSupport || rocmSupport) [ "big-parallel" ];
|
||||||
|
|
||||||
hardeningEnable = lib.optionals (effectiveStdenv.hostPlatform.system == "loongarch64-linux") [
|
hardeningEnable = lib.optionals (effectiveStdenv.hostPlatform.system == "loongarch64-linux") [
|
||||||
"nostrictaliasing"
|
"nostrictaliasing"
|
||||||
@@ -247,6 +298,9 @@ effectiveStdenv.mkDerivation rec {
|
|||||||
"GetRuntimePath() const { return PathString(); }" \
|
"GetRuntimePath() const { return PathString(); }" \
|
||||||
"GetRuntimePath() const { return PathString(\"$out/lib/\"); }"
|
"GetRuntimePath() const { return PathString(\"$out/lib/\"); }"
|
||||||
''
|
''
|
||||||
|
+ lib.optionalString rocmSupport ''
|
||||||
|
patchShebangs tools/ci_build/hipify-perl
|
||||||
|
''
|
||||||
+ lib.optionalString (effectiveStdenv.hostPlatform.system == "aarch64-linux") ''
|
+ lib.optionalString (effectiveStdenv.hostPlatform.system == "aarch64-linux") ''
|
||||||
# https://github.com/NixOS/nixpkgs/pull/226734#issuecomment-1663028691
|
# https://github.com/NixOS/nixpkgs/pull/226734#issuecomment-1663028691
|
||||||
rm -v onnxruntime/test/optimizer/nhwc_transformer_test.cc
|
rm -v onnxruntime/test/optimizer/nhwc_transformer_test.cc
|
||||||
|
|||||||
Reference in New Issue
Block a user