Browse Source

Merge pull request #757 from Yahweasel/main

Daniel Gatis 3 months ago
parent
commit
d6ac01f9f9
3 changed files with 20 additions and 1 deletions
  1. 14 1
      README.md
  2. 5 0
      rembg/sessions/base.py
  3. 1 0
      setup.py

+ 14 - 1
README.md

@@ -109,7 +109,7 @@ pip install rembg[cpu] # for library
 pip install "rembg[cpu,cli]" # for library + cli
 ```
 
-### GPU support:
+### GPU support (NVidia/Cuda):
 
 First of all, you need to check if your system supports the `onnxruntime-gpu`.
 
@@ -128,6 +128,19 @@ pip install "rembg[gpu,cli]" # for library + cli
 
 Nvidia GPU may require onnxruntime-gpu, cuda, and cudnn-devel. [#668](https://github.com/danielgatis/rembg/issues/668#issuecomment-2689830314) . If rembg[gpu] doesn't work and you can't install cuda or cudnn-devel, use rembg[cpu] and onnxruntime instead.
 
+### GPU support (AMD/ROCM):
+
+ROCM support requires the `onnxruntime-rocm` package. Install it following
+[AMD's documentation](https://rocm.docs.amd.com/projects/radeon/en/latest/docs/install/native_linux/install-onnx.html).
+
+If `onnxruntime-rocm` is installed and working, install the `rembg[rocm]`
+version of rembg:
+
+```bash
+pip install "rembg[rocm]" # for library
+pip install "rembg[rocm,cli]" # for library + cli
+```
+
 ## Usage as a cli
 
 After the installation step you can use rembg just typing `rembg` in your terminal window.

+ 5 - 0
rembg/sessions/base.py

@@ -20,6 +20,11 @@ class BaseSession:
             and "CUDAExecutionProvider" in ort.get_available_providers()
         ):
             providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
+        elif (
+            device_type[0:3] == "GPU"
+            and "ROCMExecutionProvider" in ort.get_available_providers()
+        ):
+            providers = ["ROCMExecutionProvider", "CPUExecutionProvider"]
         else:
             providers = ["CPUExecutionProvider"]
 

+ 1 - 0
setup.py

@@ -38,6 +38,7 @@ extras_require = {
     ],
     "cpu": ["onnxruntime"],
     "gpu": ["onnxruntime-gpu"],
+    "rocm": ["onnxruntime-rocm"],
     "cli": [
         "aiohttp",
         "asyncer",