|
@@ -75,20 +75,36 @@ class Unet2ClothSession(BaseSession):
|
|
|
|
|
|
masks = []
|
|
masks = []
|
|
|
|
|
|
- mask1 = mask.copy()
|
|
|
|
- mask1.putpalette(palette1)
|
|
|
|
- mask1 = mask1.convert("RGB").convert("L")
|
|
|
|
- masks.append(mask1)
|
|
|
|
-
|
|
|
|
- mask2 = mask.copy()
|
|
|
|
- mask2.putpalette(palette2)
|
|
|
|
- mask2 = mask2.convert("RGB").convert("L")
|
|
|
|
- masks.append(mask2)
|
|
|
|
-
|
|
|
|
- mask3 = mask.copy()
|
|
|
|
- mask3.putpalette(palette3)
|
|
|
|
- mask3 = mask3.convert("RGB").convert("L")
|
|
|
|
- masks.append(mask3)
|
|
|
|
|
|
+ cloth_category = kwargs.get("cc") or kwargs.get("cloth_category")
|
|
|
|
+
|
|
|
|
+ def upper_cloth():
|
|
|
|
+ mask1 = mask.copy()
|
|
|
|
+ mask1.putpalette(palette1)
|
|
|
|
+ mask1 = mask1.convert("RGB").convert("L")
|
|
|
|
+ masks.append(mask1)
|
|
|
|
+
|
|
|
|
+ def lower_cloth():
|
|
|
|
+ mask2 = mask.copy()
|
|
|
|
+ mask2.putpalette(palette2)
|
|
|
|
+ mask2 = mask2.convert("RGB").convert("L")
|
|
|
|
+ masks.append(mask2)
|
|
|
|
+
|
|
|
|
+ def full_cloth():
|
|
|
|
+ mask3 = mask.copy()
|
|
|
|
+ mask3.putpalette(palette3)
|
|
|
|
+ mask3 = mask3.convert("RGB").convert("L")
|
|
|
|
+ masks.append(mask3)
|
|
|
|
+
|
|
|
|
+ if cloth_category == "upper":
|
|
|
|
+ upper_cloth()
|
|
|
|
+ elif cloth_category == "lower":
|
|
|
|
+ lower_cloth()
|
|
|
|
+ elif cloth_category == "full":
|
|
|
|
+ full_cloth()
|
|
|
|
+ else:
|
|
|
|
+ upper_cloth()
|
|
|
|
+ lower_cloth()
|
|
|
|
+ full_cloth()
|
|
|
|
|
|
return masks
|
|
return masks
|
|
|
|
|