diff --git a/pkgs/top-level/cuda-packages.nix b/pkgs/top-level/cuda-packages.nix index 6623c5b43b4f..3c90f71ed35e 100644 --- a/pkgs/top-level/cuda-packages.nix +++ b/pkgs/top-level/cuda-packages.nix @@ -5,10 +5,17 @@ lib, }: let - selectManifests = lib.mapAttrs (name: version: _cuda.manifests.${name}.${version}); + mkCudaPackages = + manifestVersions: + callPackage ../development/cuda-modules { + manifests = _cuda.lib.selectManifests manifestVersions; + }; - cudaPackages_12_6 = callPackage ../development/cuda-modules { - manifests = selectManifests { + cudaPackages_12_6 = + let + inherit (cudaPackages_12_6.backendStdenv) hasJetsonCudaCapability; + in + mkCudaPackages { cublasmp = "0.6.0"; cuda = "12.6.3"; cudnn = "9.13.0"; @@ -22,12 +29,14 @@ let nvjpeg2000 = "0.9.0"; nvpl = "25.5"; nvtiff = "0.5.1"; - tensorrt = if cudaPackages_12_6.backendStdenv.hasJetsonCudaCapability then "10.7.0" else "10.9.0"; + tensorrt = if hasJetsonCudaCapability then "10.7.0" else "10.9.0"; }; - }; - cudaPackages_12_8 = callPackage ../development/cuda-modules { - manifests = selectManifests { + cudaPackages_12_8 = + let + inherit (cudaPackages_12_8.backendStdenv) hasJetsonCudaCapability; + in + mkCudaPackages { cublasmp = "0.6.0"; cuda = "12.8.1"; cudnn = "9.13.0"; @@ -41,12 +50,14 @@ let nvjpeg2000 = "0.9.0"; nvpl = "25.5"; nvtiff = "0.5.1"; - tensorrt = if cudaPackages_12_8.backendStdenv.hasJetsonCudaCapability then "10.7.0" else "10.9.0"; + tensorrt = if hasJetsonCudaCapability then "10.7.0" else "10.9.0"; }; - }; - cudaPackages_12_9 = callPackage ../development/cuda-modules { - manifests = selectManifests { + cudaPackages_12_9 = + let + inherit (cudaPackages_12_9.backendStdenv) hasJetsonCudaCapability; + in + mkCudaPackages { cublasmp = "0.6.0"; cuda = "12.9.1"; cudnn = "9.13.0"; @@ -60,12 +71,14 @@ let nvjpeg2000 = "0.9.0"; nvpl = "25.5"; nvtiff = "0.5.1"; - tensorrt = if cudaPackages_12_9.backendStdenv.hasJetsonCudaCapability then "10.7.0" else "10.9.0"; + tensorrt = if hasJetsonCudaCapability then "10.7.0" else "10.9.0"; }; - }; - cudaPackages_13_0 = callPackage ../development/cuda-modules { - manifests = selectManifests { + cudaPackages_13_0 = + let + inherit (cudaPackages_13_0.backendStdenv) hasJetsonCudaCapability; + in + mkCudaPackages { cublasmp = "0.6.0"; cuda = "13.0.2"; cudnn = "9.13.0"; @@ -79,9 +92,8 @@ let nvjpeg2000 = "0.9.0"; nvpl = "25.5"; nvtiff = "0.5.1"; - tensorrt = if cudaPackages_13_0.backendStdenv.hasJetsonCudaCapability then "10.7.0" else "10.9.0"; + tensorrt = if hasJetsonCudaCapability then "10.7.0" else "10.9.0"; }; - }; in { inherit