xmake.lua 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. package("cuda")
  2. set_kind("toolchain")
  3. set_homepage("https://developer.nvidia.com/cuda-zone/")
  4. set_description("CUDA® is a parallel computing platform and programming model developed by NVIDIA for general computing on graphical processing units (GPUs).")
  5. if is_host("windows") then
  6. add_urls("https://developer.download.nvidia.com/compute/cuda/$(version)_windows.exe", {
  7. version = function (version)
  8. local driver_version_map = {
  9. ["12.8.1"] = "572.61",
  10. ["12.6.3"] = "561.17",
  11. }
  12. return format("%s/local_installers/cuda_%s_%s", version, version, driver_version_map[tostring(version)])
  13. end})
  14. add_versions("12.8.1", "19392bbffd0ad4ee7cb295a181e87f682187f17653679c1c548c263b7e1cd9a6")
  15. add_versions("12.6.3", "d73e937c75aaa8114da3aff4eee96f9cae03d4b9d70a30b962ccf3c9b4d7a8e1")
  16. end
  17. add_configs("utils", {description = "Enabled cuda utilities.", default = {}, type = "table"})
  18. add_configs("debug", {description = "Enable debug symbols.", default = false, type = "boolean", readonly = true})
  19. set_policy("package.precompiled", false)
  20. on_fetch(function (package, opt)
  21. if opt.system then
  22. import("detect.sdks.find_cuda")
  23. import("lib.detect.find_library")
  24. local cuda = find_cuda()
  25. if cuda then
  26. local result = {includedirs = cuda.includedirs, linkdirs = cuda.linkdirs, links = {}}
  27. local utils = package:config("utils")
  28. table.insert(utils, package:config("shared") and "cudart" or "cudart_static")
  29. for _, util in ipairs(utils) do
  30. if not find_library(util, cuda.linkdirs) then
  31. wprint(format("The library %s for %s is not found!", util, package:arch()))
  32. return
  33. end
  34. table.insert(result.links, util)
  35. end
  36. return result
  37. end
  38. end
  39. end)
  40. on_load("windows", function (package)
  41. package:mark_as_pathenv("CUDA_PATH")
  42. package:setenv("CUDA_PATH", ".")
  43. end)
  44. on_install("windows|x64", function(package)
  45. import("lib.detect.find_tool")
  46. import("lib.detect.find_directory")
  47. if package:is_plat("windows") then
  48. local z7 = assert(find_tool("7z"), "7z tool not found!")
  49. os.vrunv(z7.program, {"x", "-y", package:originfile()})
  50. -- reference: https://github.com/ScoopInstaller/Main/blob/master/bucket/cuda.json
  51. local names = {"bin", "extras", "include", "lib", "libnvvp", "nvml", "nvvm", "compute-sanitizer"}
  52. for _, dir in ipairs(os.dirs("*")) do
  53. if dir:startswith("cuda_") or dir:startswith("lib") then
  54. for _, name in ipairs(names) do
  55. local util_dir = find_directory(name, path.join(dir, "*"))
  56. if util_dir then
  57. os.vcp(path.join(util_dir, "*"), package:installdir(name))
  58. end
  59. end
  60. end
  61. end
  62. end
  63. end)
  64. on_test(function (package)
  65. if not package:is_cross() then
  66. os.vrun("nvcc -V")
  67. end
  68. end)