xmake.lua 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. package("libtorch")
  2. set_homepage("https://pytorch.org/")
  3. set_description("An open source machine learning framework that accelerates the path from research prototyping to production deployment.")
  4. set_license("BSD-3-Clause")
  5. add_urls("https://github.com/pytorch/pytorch.git")
  6. add_versions("v1.8.0", "37c1f4a7fef115d719104e871d0cf39434aa9d56")
  7. add_versions("v1.8.1", "56b43f4fec1f76953f15a627694d4bba34588969")
  8. add_versions("v1.8.2", "e0495a7aa104471d95dc85a1b8f6473fbcc427a8")
  9. add_versions("v1.9.0", "d69c22dd61a2f006dcfe1e3ea8468a3ecaf931aa")
  10. add_versions("v1.9.1", "dfbd030854359207cb3040b864614affeace11ce")
  11. add_versions("v1.11.0", "bc2c6edaf163b1a1330e37a6e34caf8c553e4755")
  12. add_patches("1.9.x", path.join(os.scriptdir(), "patches", "1.9.0", "gcc11.patch"), "4191bb3296f18f040c230d7c5364fb160871962d6278e4ae0f8bc481f27d8e4b")
  13. add_patches("1.11.0", path.join(os.scriptdir(), "patches", "1.11.0", "gcc11.patch"), "1404b0bc6ce7433ecdc59d3412e3d9ed507bb5fd2cd59134a254d7d4a8d73012")
  14. add_configs("shared", {description = "Build shared library.", default = true, type = "boolean"})
  15. add_configs("python", {description = "Build python interface.", default = false, type = "boolean"})
  16. add_configs("openmp", {description = "Use OpenMP for parallel code.", default = true, type = "boolean"})
  17. add_configs("cuda", {description = "Enable CUDA support.", default = false, type = "boolean"})
  18. add_configs("ninja", {description = "Use ninja as build tool.", default = false, type = "boolean"})
  19. add_configs("blas", {description = "Set BLAS vendor.", default = "openblas", type = "string", values = {"mkl", "openblas", "eigen"}})
  20. if not is_plat("macosx") then
  21. add_configs("distributed", {description = "Enable distributed support.", default = false, type = "boolean"})
  22. end
  23. add_deps("cmake")
  24. add_deps("python 3.x", {kind = "binary"})
  25. add_includedirs("include")
  26. add_includedirs("include/torch/csrc/api/include")
  27. if is_plat("linux") then
  28. add_syslinks("rt")
  29. end
  30. -- enable long paths for git submodule on windows
  31. if is_host("windows") and set_policy then
  32. set_policy("platform.longpaths", true)
  33. end
  34. on_load("windows|x64", "macosx", "linux", function (package)
  35. if package:config("ninja") then
  36. package:add("deps", "ninja")
  37. end
  38. if package:config("openmp") then
  39. package:add("deps", "openmp")
  40. end
  41. if package:config("cuda") then
  42. package:add("deps", "cuda", {configs = {utils = {"nvrtc", "cudnn", "cufft", "curand", "cublas", "cudart_static"}}})
  43. package:add("deps", "nvtx")
  44. end
  45. if package:config("distributed") then
  46. package:add("deps", "libuv")
  47. end
  48. if not package:is_plat("macosx") and package:config("blas") then
  49. package:add("deps", package:config("blas"))
  50. end
  51. end)
  52. on_install("windows|x64", "macosx", "linux", function (package)
  53. import("package.tools.cmake")
  54. if package:is_plat("windows") then
  55. local vs = import("core.tool.toolchain").load("msvc"):config("vs")
  56. if tonumber(vs) < 2019 then
  57. raise("Your compiler is too old to use this library.")
  58. end
  59. end
  60. -- tackle link flags
  61. local libnames = {"torch", "torch_cpu"}
  62. if package:config("cuda") then
  63. table.insert(libnames, "torch_cuda")
  64. end
  65. table.insert(libnames, "c10")
  66. if package:config("cuda") then
  67. table.insert(libnames, "c10_cuda")
  68. end
  69. local suffix = ""
  70. if not package:is_plat("windows") and package:config("shared") then
  71. package:add("ldflags", "-Wl,-rpath," .. package:installdir("lib"))
  72. if package:is_plat("linux") then
  73. suffix = ".so"
  74. elseif package:is_plat("macosx") then
  75. suffix = ".dylib"
  76. end
  77. for _, lib in ipairs(libnames) do
  78. package:add("ldflags", (package:is_plat("linux") and "-Wl,--no-as-needed," or "") .. package:installdir("lib", "lib") .. lib .. suffix)
  79. end
  80. else
  81. for _, lib in ipairs(libnames) do
  82. package:add("links", lib)
  83. end
  84. end
  85. if not package:config("shared") then
  86. for _, lib in ipairs({"nnpack", "pytorch_qnnpack", "qnnpack", "XNNPACK", "caffe2_protos", "protobuf-lite", "protobuf", "protoc", "onnx", "onnx_proto", "foxi_loader", "pthreadpool", "eigen_blas", "fbgemm", "cpuinfo", "clog", "dnnl", "mkldnn", "sleef", "asmjit", "fmt", "kineto"}) do
  87. package:add("links", lib)
  88. end
  89. end
  90. -- some patches to the third-party cmake files
  91. io.replace("third_party/fbgemm/CMakeLists.txt", "PRIVATE FBGEMM_STATIC", "PUBLIC FBGEMM_STATIC", {plain = true})
  92. io.replace("third_party/protobuf/cmake/install.cmake", "install%(DIRECTORY.-%)", "")
  93. if package:is_plat("windows") and package:config("vs_runtime"):startswith("MD") then
  94. io.replace("third_party/fbgemm/CMakeLists.txt", "MT", "MD", {plain = true})
  95. end
  96. -- prepare python
  97. os.vrun("python -m pip install typing_extensions pyyaml")
  98. local configs = {"-DUSE_MPI=OFF",
  99. "-DCMAKE_INSTALL_LIBDIR=lib",
  100. "-DBUILD_TEST=OFF",
  101. "-DATEN_NO_TEST=ON"}
  102. if package:config("python") then
  103. table.insert(configs, "-DBUILD_PYTHON=ON")
  104. os.vrun("python -m pip install numpy")
  105. else
  106. table.insert(configs, "-DBUILD_PYTHON=OFF")
  107. table.insert(configs, "-DUSE_NUMPY=OFF")
  108. end
  109. -- prepare for installation
  110. local envs = cmake.buildenvs(package, {cmake_generator = "Ninja"})
  111. if not package:is_plat("macosx") then
  112. if package:config("blas") == "mkl" then
  113. table.insert(configs, "-DBLAS=MKL")
  114. local mkl = package:dep("mkl"):fetch()
  115. table.insert(configs, "-DINTEL_MKL_DIR=" .. path.directory(mkl.sysincludedirs[1]))
  116. elseif package:config("blas") == "openblas" then
  117. table.insert(configs, "-DBLAS=OpenBLAS")
  118. envs.OpenBLAS_HOME = package:dep("openblas"):installdir()
  119. elseif package:config("blas") == "eigen" then
  120. table.insert(configs, "-DBLAS=Eigen")
  121. end
  122. end
  123. if package:config("distributed") then
  124. envs.libuv_ROOT = package:dep("libuv"):installdir()
  125. end
  126. table.insert(configs, "-DCMAKE_BUILD_TYPE=" .. (package:debug() and "Debug" or "Release"))
  127. table.insert(configs, "-DBUILD_SHARED_LIBS=" .. (package:config("shared") and "ON" or "OFF"))
  128. table.insert(configs, "-DUSE_CUDA=" .. (package:config("cuda") and "ON" or "OFF"))
  129. table.insert(configs, "-DUSE_OPENMP=" .. (package:config("openmp") and "ON" or "OFF"))
  130. table.insert(configs, "-DUSE_DISTRIBUTED=" .. (package:config("distributed") and "ON" or "OFF"))
  131. if package:is_plat("windows") then
  132. table.insert(configs, "-DCAFFE2_USE_MSVC_STATIC_RUNTIME=" .. (package:config("vs_runtime"):startswith("MT") and "ON" or "OFF"))
  133. end
  134. local opt = {envs = envs}
  135. if package:config("ninja") then
  136. opt.cmake_generator = "Ninja"
  137. end
  138. cmake.install(package, configs, opt)
  139. end)
  140. on_test(function (package)
  141. assert(package:check_cxxsnippets({test = [[
  142. void test() {
  143. auto a = torch::ones(3);
  144. auto b = torch::tensor({1, 2, 3});
  145. auto c = torch::dot(a, b);
  146. }
  147. ]]}, {configs = {languages = "c++14"}, includes = "torch/torch.h"}))
  148. end)