소스 검색

Add cloth category selection feature to u2net_cloth_seg (#485)

szriru 2 년 전
부모
커밋
f0019d723b
1개의 변경된 파일30개의 추가작업 그리고 14개의 파일을 삭제
  1. 30 14
      rembg/sessions/u2net_cloth_seg.py

+ 30 - 14
rembg/sessions/u2net_cloth_seg.py

@@ -75,20 +75,36 @@ class Unet2ClothSession(BaseSession):
 
         masks = []
 
-        mask1 = mask.copy()
-        mask1.putpalette(palette1)
-        mask1 = mask1.convert("RGB").convert("L")
-        masks.append(mask1)
-
-        mask2 = mask.copy()
-        mask2.putpalette(palette2)
-        mask2 = mask2.convert("RGB").convert("L")
-        masks.append(mask2)
-
-        mask3 = mask.copy()
-        mask3.putpalette(palette3)
-        mask3 = mask3.convert("RGB").convert("L")
-        masks.append(mask3)
+        cloth_category = kwargs.get("cc") or kwargs.get("cloth_category")
+
+        def upper_cloth():
+            mask1 = mask.copy()
+            mask1.putpalette(palette1)
+            mask1 = mask1.convert("RGB").convert("L")
+            masks.append(mask1)
+      
+        def lower_cloth():
+            mask2 = mask.copy()
+            mask2.putpalette(palette2)
+            mask2 = mask2.convert("RGB").convert("L")
+            masks.append(mask2)
+        
+        def full_cloth():
+            mask3 = mask.copy()
+            mask3.putpalette(palette3)
+            mask3 = mask3.convert("RGB").convert("L")
+            masks.append(mask3)
+
+        if cloth_category == "upper":
+            upper_cloth()
+        elif cloth_category == "lower":
+            lower_cloth()
+        elif cloth_category == "full":
+            full_cloth()
+        else:
+            upper_cloth()
+            lower_cloth()
+            full_cloth()
 
         return masks