diff --git a/pkgs/by-name/on/onnxruntime/package.nix b/pkgs/by-name/on/onnxruntime/package.nix index 05ecb2453419..bd91c2d2584f 100644 --- a/pkgs/by-name/on/onnxruntime/package.nix +++ b/pkgs/by-name/on/onnxruntime/package.nix @@ -14,6 +14,7 @@ howard-hinnant-date, libpng, nlohmann_json, + perl, pkg-config, python3Packages, re2, @@ -24,8 +25,10 @@ pythonSupport ? true, cudaSupport ? config.cudaSupport, ncclSupport ? cudaSupport && cudaPackages.nccl.meta.available, + rocmSupport ? config.rocmSupport, withFullProtobuf ? false, cudaPackages ? { }, + rocmPackages, }@inputs: let @@ -121,6 +124,9 @@ effectiveStdenv.mkDerivation rec { ] ++ lib.optionals isCudaJetson [ cudaPackages.autoAddCudaCompatRunpath + ] + ++ lib.optionals rocmSupport [ + perl # for tools/ci_build/hipify-perl ]; buildInputs = [ @@ -156,6 +162,22 @@ effectiveStdenv.mkDerivation rec { ] ++ lib.optionals ncclSupport [ nccl ] ) + ++ lib.optionals rocmSupport [ + rocmPackages.clr + rocmPackages.hipblas + rocmPackages.hipcub + rocmPackages.hipfft + rocmPackages.hiprand + rocmPackages.hipsparse + rocmPackages.rocblas + rocmPackages.rocprim + rocmPackages.rocrand + rocmPackages.rocthrust + rocmPackages.miopen + rocmPackages.rccl + rocmPackages.rocm-smi + rocmPackages.roctracer + ] ++ lib.optionals effectiveStdenv.hostPlatform.isDarwin [ (darwinMinVersionHook "13.3") ]; @@ -203,6 +225,7 @@ effectiveStdenv.mkDerivation rec { (lib.cmakeBool "onnxruntime_USE_FULL_PROTOBUF" withFullProtobuf) (lib.cmakeBool "onnxruntime_USE_CUDA" cudaSupport) (lib.cmakeBool "onnxruntime_USE_NCCL" (cudaSupport && ncclSupport)) + (lib.cmakeBool "onnxruntime_USE_ROCM" rocmSupport) (lib.cmakeBool "onnxruntime_ENABLE_LTO" (!cudaSupport || cudaPackages.cudaOlder "12.8")) ] ++ lib.optionals pythonSupport [ @@ -213,15 +236,43 @@ effectiveStdenv.mkDerivation rec { (lib.cmakeFeature "onnxruntime_CUDNN_HOME" "${cudaPackages.cudnn}") (lib.cmakeFeature "CMAKE_CUDA_ARCHITECTURES" cudaArchitecturesString) (lib.cmakeFeature "onnxruntime_NVCC_THREADS" "1") + ] + ++ lib.optionals rocmSupport [ + # Werror combines with rocprim header issues to cause errors (warp size const deprecation) + "--compile-no-warning-as-error" + (lib.cmakeFeature "CMAKE_HIP_ARCHITECTURES" ( + builtins.concatStringsSep ";" rocmPackages.clr.localGpuTargets or rocmPackages.clr.gpuTargets + )) + (lib.cmakeFeature "onnxruntime_ROCM_HOME" "${rocmPackages.clr}") + # Incompatible with packaged version, far too slow to build vendored version + (lib.cmakeBool "onnxruntime_USE_COMPOSABLE_KERNEL" false) + (lib.cmakeBool "onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE" false) ]; - env = lib.optionalAttrs effectiveStdenv.cc.isClang { - NIX_CFLAGS_COMPILE = "-Wno-error"; - }; + env = + lib.optionalAttrs effectiveStdenv.cc.isClang { + NIX_CFLAGS_COMPILE = "-Wno-error"; + } + // lib.optionalAttrs rocmSupport { + MIOPEN_PATH = rocmPackages.miopen; + # HIP steps fail to find ROCm libs when not in HIPFLAGS, causing + # fatal error: 'rocrand/rocrand.h' file not found + HIPFLAGS = lib.concatMapStringsSep " " (pkg: "-I${lib.getInclude pkg}/include") [ + rocmPackages.hipblas + rocmPackages.hipcub + rocmPackages.hiprand + rocmPackages.hipsparse + rocmPackages.rocblas + rocmPackages.rocprim + rocmPackages.rocrand + rocmPackages.rocthrust + ]; + }; doCheck = !( cudaSupport + || rocmSupport || builtins.elem effectiveStdenv.buildPlatform.system [ # aarch64-linux fails cpuinfo test, because /sys/devices/system/cpu/ does not exist in the sandbox "aarch64-linux" @@ -231,7 +282,7 @@ effectiveStdenv.mkDerivation rec { ] ); - requiredSystemFeatures = lib.optionals cudaSupport [ "big-parallel" ]; + requiredSystemFeatures = lib.optionals (cudaSupport || rocmSupport) [ "big-parallel" ]; hardeningEnable = lib.optionals (effectiveStdenv.hostPlatform.system == "loongarch64-linux") [ "nostrictaliasing" @@ -247,6 +298,9 @@ effectiveStdenv.mkDerivation rec { "GetRuntimePath() const { return PathString(); }" \ "GetRuntimePath() const { return PathString(\"$out/lib/\"); }" '' + + lib.optionalString rocmSupport '' + patchShebangs tools/ci_build/hipify-perl + '' + lib.optionalString (effectiveStdenv.hostPlatform.system == "aarch64-linux") '' # https://github.com/NixOS/nixpkgs/pull/226734#issuecomment-1663028691 rm -v onnxruntime/test/optimizer/nhwc_transformer_test.cc