onnxruntime: add ROCm support (#454399)
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
howard-hinnant-date,
|
||||
libpng,
|
||||
nlohmann_json,
|
||||
perl,
|
||||
pkg-config,
|
||||
python3Packages,
|
||||
re2,
|
||||
@@ -24,8 +25,10 @@
|
||||
pythonSupport ? true,
|
||||
cudaSupport ? config.cudaSupport,
|
||||
ncclSupport ? cudaSupport && cudaPackages.nccl.meta.available,
|
||||
rocmSupport ? config.rocmSupport,
|
||||
withFullProtobuf ? false,
|
||||
cudaPackages ? { },
|
||||
rocmPackages,
|
||||
}@inputs:
|
||||
|
||||
let
|
||||
@@ -121,6 +124,9 @@ effectiveStdenv.mkDerivation rec {
|
||||
]
|
||||
++ lib.optionals isCudaJetson [
|
||||
cudaPackages.autoAddCudaCompatRunpath
|
||||
]
|
||||
++ lib.optionals rocmSupport [
|
||||
perl # for tools/ci_build/hipify-perl
|
||||
];
|
||||
|
||||
buildInputs = [
|
||||
@@ -156,6 +162,22 @@ effectiveStdenv.mkDerivation rec {
|
||||
]
|
||||
++ lib.optionals ncclSupport [ nccl ]
|
||||
)
|
||||
++ lib.optionals rocmSupport [
|
||||
rocmPackages.clr
|
||||
rocmPackages.hipblas
|
||||
rocmPackages.hipcub
|
||||
rocmPackages.hipfft
|
||||
rocmPackages.hiprand
|
||||
rocmPackages.hipsparse
|
||||
rocmPackages.rocblas
|
||||
rocmPackages.rocprim
|
||||
rocmPackages.rocrand
|
||||
rocmPackages.rocthrust
|
||||
rocmPackages.miopen
|
||||
rocmPackages.rccl
|
||||
rocmPackages.rocm-smi
|
||||
rocmPackages.roctracer
|
||||
]
|
||||
++ lib.optionals effectiveStdenv.hostPlatform.isDarwin [
|
||||
(darwinMinVersionHook "13.3")
|
||||
];
|
||||
@@ -203,6 +225,7 @@ effectiveStdenv.mkDerivation rec {
|
||||
(lib.cmakeBool "onnxruntime_USE_FULL_PROTOBUF" withFullProtobuf)
|
||||
(lib.cmakeBool "onnxruntime_USE_CUDA" cudaSupport)
|
||||
(lib.cmakeBool "onnxruntime_USE_NCCL" (cudaSupport && ncclSupport))
|
||||
(lib.cmakeBool "onnxruntime_USE_ROCM" rocmSupport)
|
||||
(lib.cmakeBool "onnxruntime_ENABLE_LTO" (!cudaSupport || cudaPackages.cudaOlder "12.8"))
|
||||
]
|
||||
++ lib.optionals pythonSupport [
|
||||
@@ -213,15 +236,43 @@ effectiveStdenv.mkDerivation rec {
|
||||
(lib.cmakeFeature "onnxruntime_CUDNN_HOME" "${cudaPackages.cudnn}")
|
||||
(lib.cmakeFeature "CMAKE_CUDA_ARCHITECTURES" cudaArchitecturesString)
|
||||
(lib.cmakeFeature "onnxruntime_NVCC_THREADS" "1")
|
||||
]
|
||||
++ lib.optionals rocmSupport [
|
||||
# Werror combines with rocprim header issues to cause errors (warp size const deprecation)
|
||||
"--compile-no-warning-as-error"
|
||||
(lib.cmakeFeature "CMAKE_HIP_ARCHITECTURES" (
|
||||
builtins.concatStringsSep ";" rocmPackages.clr.localGpuTargets or rocmPackages.clr.gpuTargets
|
||||
))
|
||||
(lib.cmakeFeature "onnxruntime_ROCM_HOME" "${rocmPackages.clr}")
|
||||
# Incompatible with packaged version, far too slow to build vendored version
|
||||
(lib.cmakeBool "onnxruntime_USE_COMPOSABLE_KERNEL" false)
|
||||
(lib.cmakeBool "onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE" false)
|
||||
];
|
||||
|
||||
env = lib.optionalAttrs effectiveStdenv.cc.isClang {
|
||||
NIX_CFLAGS_COMPILE = "-Wno-error";
|
||||
};
|
||||
env =
|
||||
lib.optionalAttrs effectiveStdenv.cc.isClang {
|
||||
NIX_CFLAGS_COMPILE = "-Wno-error";
|
||||
}
|
||||
// lib.optionalAttrs rocmSupport {
|
||||
MIOPEN_PATH = rocmPackages.miopen;
|
||||
# HIP steps fail to find ROCm libs when not in HIPFLAGS, causing
|
||||
# fatal error: 'rocrand/rocrand.h' file not found
|
||||
HIPFLAGS = lib.concatMapStringsSep " " (pkg: "-I${lib.getInclude pkg}/include") [
|
||||
rocmPackages.hipblas
|
||||
rocmPackages.hipcub
|
||||
rocmPackages.hiprand
|
||||
rocmPackages.hipsparse
|
||||
rocmPackages.rocblas
|
||||
rocmPackages.rocprim
|
||||
rocmPackages.rocrand
|
||||
rocmPackages.rocthrust
|
||||
];
|
||||
};
|
||||
|
||||
doCheck =
|
||||
!(
|
||||
cudaSupport
|
||||
|| rocmSupport
|
||||
|| builtins.elem effectiveStdenv.buildPlatform.system [
|
||||
# aarch64-linux fails cpuinfo test, because /sys/devices/system/cpu/ does not exist in the sandbox
|
||||
"aarch64-linux"
|
||||
@@ -231,7 +282,7 @@ effectiveStdenv.mkDerivation rec {
|
||||
]
|
||||
);
|
||||
|
||||
requiredSystemFeatures = lib.optionals cudaSupport [ "big-parallel" ];
|
||||
requiredSystemFeatures = lib.optionals (cudaSupport || rocmSupport) [ "big-parallel" ];
|
||||
|
||||
hardeningEnable = lib.optionals (effectiveStdenv.hostPlatform.system == "loongarch64-linux") [
|
||||
"nostrictaliasing"
|
||||
@@ -247,6 +298,9 @@ effectiveStdenv.mkDerivation rec {
|
||||
"GetRuntimePath() const { return PathString(); }" \
|
||||
"GetRuntimePath() const { return PathString(\"$out/lib/\"); }"
|
||||
''
|
||||
+ lib.optionalString rocmSupport ''
|
||||
patchShebangs tools/ci_build/hipify-perl
|
||||
''
|
||||
+ lib.optionalString (effectiveStdenv.hostPlatform.system == "aarch64-linux") ''
|
||||
# https://github.com/NixOS/nixpkgs/pull/226734#issuecomment-1663028691
|
||||
rm -v onnxruntime/test/optimizer/nhwc_transformer_test.cc
|
||||
|
||||
Reference in New Issue
Block a user