2
0

xmake.lua 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. package("libtorch")
  2. set_homepage("https://pytorch.org/")
  3. set_description("An open source machine learning framework that accelerates the path from research prototyping to production deployment.")
  4. set_license("BSD-3-Clause")
  5. add_urls("https://github.com/pytorch/pytorch.git")
  6. add_versions("v1.8.0", "37c1f4a7fef115d719104e871d0cf39434aa9d56")
  7. add_versions("v1.8.1", "56b43f4fec1f76953f15a627694d4bba34588969")
  8. add_versions("v1.8.2", "e0495a7aa104471d95dc85a1b8f6473fbcc427a8")
  9. add_versions("v1.9.0", "d69c22dd61a2f006dcfe1e3ea8468a3ecaf931aa")
  10. add_versions("v1.9.1", "dfbd030854359207cb3040b864614affeace11ce")
  11. add_versions("v1.11.0", "bc2c6edaf163b1a1330e37a6e34caf8c553e4755")
  12. add_versions("v1.12.1", "664058fa83f1d8eede5d66418abff6e20bd76ca8")
  13. add_versions("v2.1.0", "7bcf7da3a268b435777fe87c7794c382f444e86d")
  14. add_versions("v2.1.2", "a8e7c98cb95ff97bb30a728c6b2a1ce6bff946eb")
  15. add_versions("v2.2.2", "39901f229520a5256505ec24782f716ee7ddc843")
  16. add_versions("v2.3.1", "63d5e9221bedd1546b7d364b5ce4171547db12a9")
  17. add_versions("v2.4.0", "d990dada86a8ad94882b5c23e859b88c0c255bda")
  18. add_versions("v2.5.0", "32f585d9346e316e554c8d9bf7548af9f62141fc")
  19. add_patches("1.9.x", "patches/1.9.0/gcc11.patch", "4191bb3296f18f040c230d7c5364fb160871962d6278e4ae0f8bc481f27d8e4b")
  20. add_patches("1.11.0", "patches/1.11.0/gcc11.patch", "1404b0bc6ce7433ecdc59d3412e3d9ed507bb5fd2cd59134a254d7d4a8d73012")
  21. -- Fix compile on macOS. Refer to https://github.com/pytorch/pytorch/pull/80916
  22. add_patches("1.12.1", "patches/1.12.1/clang.patch", "cdc3e00b2fea847678b1bcc6b25a4dbd924578d8fb25d40543521a09aab2f7d4")
  23. add_patches("1.12.1", "patches/1.12.1/vs2022.patch", "5a31b9772793c943ca752c92d6415293f7b3863813ca8c5eb9d92a6156afd21d")
  24. add_patches("2.2.2", "patches/2.2.2/pocketfft.patch", "8b756d867fb60839dcaeb1ee0bdf4189ee95e7f5c6f3810f8cbc8f6a5fae60e9")
  25. add_configs("shared", {description = "Build shared library.", default = true, type = "boolean"})
  26. add_configs("python", {description = "Build python interface.", default = false, type = "boolean"})
  27. add_configs("openmp", {description = "Use OpenMP for parallel code.", default = true, type = "boolean"})
  28. add_configs("cuda", {description = "Enable CUDA support.", default = false, type = "boolean"})
  29. -- https://github.com/pytorch/pytorch/issues/24186 only ninja is supported on windows
  30. add_configs("ninja", {description = "Use ninja as build tool.", default = is_plat("windows"), type = "boolean"})
  31. add_configs("blas", {description = "Set BLAS vendor.", default = "openblas", type = "string", values = {"mkl", "openblas", "eigen"}})
  32. add_configs("pybind11", {description = "Use pybind11 from xrepo.", default = false, type = "boolean"})
  33. add_configs("protobuf-cpp", {description = "Use protobuf from xrepo.", default = false, type = "boolean"})
  34. if not is_plat("macosx") then
  35. add_configs("distributed", {description = "Enable distributed support.", default = false, type = "boolean"})
  36. end
  37. add_deps("cmake")
  38. add_deps("python 3.x", {kind = "binary"})
  39. add_includedirs("include")
  40. add_includedirs("include/torch/csrc/api/include")
  41. if is_plat("linux") then
  42. add_syslinks("rt")
  43. end
  44. -- enable long paths for git submodule on windows
  45. if is_host("windows") and set_policy then
  46. set_policy("platform.longpaths", true)
  47. end
  48. on_load("windows|x64", "macosx", "linux", function (package)
  49. if package:config("ninja") then
  50. package:add("deps", "ninja")
  51. end
  52. if package:config("openmp") then
  53. package:add("deps", "openmp")
  54. end
  55. if package:config("cuda") then
  56. package:add("deps", "cuda", {configs = {utils = {"nvrtc", "cudnn", "cufft", "curand", "cublas", "cudart_static"}}})
  57. if package:version():lt("2.5.0") then
  58. package:add("deps", "nvtx")
  59. end
  60. end
  61. if package:config("distributed") then
  62. package:add("deps", "libuv")
  63. end
  64. if not package:is_plat("macosx") and package:config("blas") then
  65. package:add("deps", package:config("blas"))
  66. end
  67. if package:config("pybind11") then
  68. package:add("deps", "pybind11")
  69. end
  70. if package:config("protobuf-cpp") then
  71. package:add("deps", "protobuf-cpp")
  72. end
  73. end)
  74. on_install("windows|x64", "macosx", "linux", function (package)
  75. import("package.tools.cmake")
  76. import("core.tool.toolchain")
  77. if package:is_plat("windows") then
  78. local vs = toolchain.load("msvc"):config("vs")
  79. if tonumber(vs) < 2019 then
  80. raise("Your compiler is too old to use this library.")
  81. end
  82. end
  83. -- tackle link flags
  84. local libnames = {"torch", "torch_cpu"}
  85. if package:config("cuda") then
  86. table.insert(libnames, "torch_cuda")
  87. end
  88. table.insert(libnames, "c10")
  89. if package:config("cuda") then
  90. table.insert(libnames, "c10_cuda")
  91. end
  92. local suffix = ""
  93. if not package:is_plat("windows") and package:config("shared") then
  94. package:add("ldflags", "-Wl,-rpath," .. package:installdir("lib"))
  95. if package:is_plat("linux") then
  96. suffix = ".so"
  97. elseif package:is_plat("macosx") then
  98. suffix = ".dylib"
  99. end
  100. for _, lib in ipairs(libnames) do
  101. package:add("ldflags", (package:is_plat("linux") and "-Wl,--no-as-needed," or "") .. package:installdir("lib", "lib") .. lib .. suffix)
  102. end
  103. else
  104. for _, lib in ipairs(libnames) do
  105. package:add("links", lib)
  106. end
  107. end
  108. if not package:config("shared") then
  109. for _, lib in ipairs({"nnpack", "pytorch_qnnpack", "qnnpack", "XNNPACK", "caffe2_protos", "protobuf-lite", "protobuf", "protoc", "onnx", "onnx_proto", "foxi_loader", "pthreadpool", "eigen_blas", "fbgemm", "cpuinfo", "clog", "dnnl_graph", "dnnl", "mkldnn", "sleef", "asmjit", "fmt", "kineto"}) do
  110. package:add("links", lib)
  111. end
  112. end
  113. -- some patches to the third-party cmake files
  114. io.replace("cmake/MiscCheck.cmake", "if(UNIX)", "if(TRUE)", {plain = true})
  115. io.replace("third_party/fbgemm/CMakeLists.txt", "PRIVATE FBGEMM_STATIC", "PUBLIC FBGEMM_STATIC", {plain = true})
  116. io.replace("third_party/fbgemm/CMakeLists.txt", "-Werror", "", {plain = true})
  117. io.replace("third_party/protobuf/cmake/install.cmake", "install%(DIRECTORY.-%)", "")
  118. if package:is_plat("windows") then
  119. if package:config("vs_runtime"):startswith("MD") then
  120. io.replace("third_party/fbgemm/CMakeLists.txt", "MT", "MD", {plain = true})
  121. io.replace("c10/macros/Macros.h", "extern \"C\" {\nC10_IMPORT", "extern \"C\" {\n__declspec(dllimport)", {plain = true})
  122. else
  123. io.replace("CMakeLists.txt", "\"NOT BUILD_SHARED_LIBS\" OFF", "\"NOT BUILD_SHARED_LIBS\" ON", {plain = true})
  124. io.replace("c10/macros/Macros.h", "extern \"C\" {\nC10_IMPORT", "extern \"C\" {", {plain = true})
  125. end
  126. end
  127. -- prepare python
  128. local python_exe = package:is_plat("windows") and "python" or "python3"
  129. os.vrun(python_exe .. " -m pip install typing_extensions pyyaml")
  130. local configs = {"-DUSE_MPI=OFF",
  131. "-DUSE_NUMA=OFF",
  132. "-DUSE_MAGMA=OFF",
  133. "-DBUILD_TEST=OFF",
  134. "-DATEN_NO_TEST=ON"}
  135. if package:config("python") then
  136. table.insert(configs, "-DBUILD_PYTHON=ON")
  137. os.vrun(python_exe .. " -m pip install numpy")
  138. else
  139. table.insert(configs, "-DBUILD_PYTHON=OFF")
  140. table.insert(configs, "-DUSE_NUMPY=OFF")
  141. end
  142. -- prepare for installation
  143. local envs = cmake.buildenvs(package)
  144. if not package:is_plat("macosx") then
  145. if package:config("blas") == "mkl" then
  146. table.insert(configs, "-DBLAS=MKL")
  147. local mkl = package:dep("mkl"):fetch()
  148. table.insert(configs, "-DINTEL_MKL_DIR=" .. path.directory(mkl.sysincludedirs[1]))
  149. elseif package:config("blas") == "openblas" then
  150. table.insert(configs, "-DBLAS=OpenBLAS")
  151. envs.OpenBLAS_HOME = package:dep("openblas"):installdir()
  152. elseif package:config("blas") == "eigen" then
  153. table.insert(configs, "-DBLAS=Eigen")
  154. end
  155. end
  156. if package:config("distributed") then
  157. envs.libuv_ROOT = package:dep("libuv"):installdir()
  158. end
  159. table.insert(configs, "-DCMAKE_BUILD_TYPE=" .. (package:debug() and "Debug" or "Release"))
  160. table.insert(configs, "-DBUILD_SHARED_LIBS=" .. (package:config("shared") and "ON" or "OFF"))
  161. table.insert(configs, "-DUSE_CUDA=" .. (package:config("cuda") and "ON" or "OFF"))
  162. table.insert(configs, "-DUSE_OPENMP=" .. (package:config("openmp") and "ON" or "OFF"))
  163. table.insert(configs, "-DUSE_DISTRIBUTED=" .. (package:config("distributed") and "ON" or "OFF"))
  164. table.insert(configs, "-DUSE_SYSTEM_PYBIND11=" .. (package:config("pybind11") and "ON" or "OFF"))
  165. table.insert(configs, "-DBUILD_CUSTOM_PROTOBUF=" .. (package:config("protobuf-cpp") and "OFF" or "ON"))
  166. local pythonpath, err = os.iorun(python_exe .. " -c \"import sys; print(sys.executable)\"")
  167. table.insert(configs, "-DPYTHON_EXECUTABLE=" .. pythonpath)
  168. if package:is_plat("windows") then
  169. table.insert(configs, "-DCAFFE2_USE_MSVC_STATIC_RUNTIME=" .. (package:config("vs_runtime"):startswith("MT") and "ON" or "OFF"))
  170. table.insert(configs, "-DCPUINFO_RUNTIME_TYPE=" .. (package:config("vs_runtime"):startswith("MT") and "static" or "shared"))
  171. local vs_sdkver = toolchain.load("msvc"):config("vs_sdkver")
  172. if vs_sdkver then
  173. local build_ver = string.match(vs_sdkver, "%d+%.%d+%.(%d+)%.?%d*")
  174. assert(tonumber(build_ver) >= 18362, "libtorch requires Windows SDK to be at least 10.0.18362.0")
  175. table.insert(configs, "-DCMAKE_VS_WINDOWS_TARGET_PLATFORM_VERSION=" .. vs_sdkver)
  176. table.insert(configs, "-DCMAKE_SYSTEM_VERSION=" .. vs_sdkver)
  177. end
  178. end
  179. local opt = {envs = envs}
  180. if package:config("ninja") then
  181. opt.cmake_generator = "Ninja"
  182. end
  183. cmake.install(package, configs, opt)
  184. -- These libs are not installed by cmake but are required for static link.
  185. local cp_libs = {"libonnx", "libonnx_proto"}
  186. if package:version():eq("v1.11.0") then
  187. table.insert(cp_libs, "libbreakpad")
  188. table.insert(cp_libs, "libbreakpad_common")
  189. end
  190. local static_lib_suffix = ".a"
  191. if package:is_plat("windows") then
  192. static_lib_suffix = ".lib"
  193. end
  194. for _, libname in ipairs(cp_libs) do
  195. os.trycp(path.join(package:buildir(), "lib", libname .. static_lib_suffix), package:installdir("lib"))
  196. end
  197. -- Following patches are needed for static link.
  198. io.replace(
  199. path.join(package:installdir("share/cmake/Torch/TorchConfig.cmake")),
  200. "append_torchlib_if_found(dnnl mkldnn)",
  201. "append_torchlib_if_found(dnnl_graph dnnl mkldnn)",
  202. {plain = true}
  203. )
  204. if package:version():eq("v1.11.0") then
  205. io.replace(
  206. path.join(package:installdir("share/cmake/Torch/TorchConfig.cmake")),
  207. "append_torchlib_if_found(sleef asmjit)",
  208. "append_torchlib_if_found(sleef asmjit)\n append_torchlib_if_found(breakpad breakpad_common)",
  209. {plain = true}
  210. )
  211. end
  212. end)
  213. on_test(function (package)
  214. assert(package:check_cxxsnippets({test = [[
  215. void test() {
  216. auto a = torch::ones(3);
  217. auto b = torch::tensor({1, 2, 3});
  218. auto c = torch::dot(a, b);
  219. }
  220. ]]}, {configs = {languages = "c++17"}, includes = "torch/torch.h"}))
  221. end)