|
@@ -6,7 +6,7 @@ from torchvision import models
|
|
|
|
|
|
class REBNCONV(nn.Module):
|
|
class REBNCONV(nn.Module):
|
|
def __init__(self, in_ch=3, out_ch=3, dirate=1):
|
|
def __init__(self, in_ch=3, out_ch=3, dirate=1):
|
|
- super(REBNCONV, self).__init__()
|
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
|
self.conv_s1 = nn.Conv2d(
|
|
self.conv_s1 = nn.Conv2d(
|
|
in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate
|
|
in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate
|
|
@@ -33,7 +33,7 @@ def _upsample_like(src, tar):
|
|
### RSU-7 ###
|
|
### RSU-7 ###
|
|
class RSU7(nn.Module): # UNet07DRES(nn.Module):
|
|
class RSU7(nn.Module): # UNet07DRES(nn.Module):
|
|
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
|
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
|
- super(RSU7, self).__init__()
|
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
|
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
|
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
|
|
|
|
|
@@ -110,7 +110,7 @@ class RSU7(nn.Module): # UNet07DRES(nn.Module):
|
|
### RSU-6 ###
|
|
### RSU-6 ###
|
|
class RSU6(nn.Module): # UNet06DRES(nn.Module):
|
|
class RSU6(nn.Module): # UNet06DRES(nn.Module):
|
|
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
|
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
|
- super(RSU6, self).__init__()
|
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
|
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
|
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
|
|
|
|
|
@@ -178,7 +178,7 @@ class RSU6(nn.Module): # UNet06DRES(nn.Module):
|
|
### RSU-5 ###
|
|
### RSU-5 ###
|
|
class RSU5(nn.Module): # UNet05DRES(nn.Module):
|
|
class RSU5(nn.Module): # UNet05DRES(nn.Module):
|
|
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
|
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
|
- super(RSU5, self).__init__()
|
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
|
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
|
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
|
|
|
|
|
@@ -236,7 +236,7 @@ class RSU5(nn.Module): # UNet05DRES(nn.Module):
|
|
### RSU-4 ###
|
|
### RSU-4 ###
|
|
class RSU4(nn.Module): # UNet04DRES(nn.Module):
|
|
class RSU4(nn.Module): # UNet04DRES(nn.Module):
|
|
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
|
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
|
- super(RSU4, self).__init__()
|
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
|
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
|
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
|
|
|
|
|
@@ -284,7 +284,7 @@ class RSU4(nn.Module): # UNet04DRES(nn.Module):
|
|
### RSU-4F ###
|
|
### RSU-4F ###
|
|
class RSU4F(nn.Module): # UNet04FRES(nn.Module):
|
|
class RSU4F(nn.Module): # UNet04FRES(nn.Module):
|
|
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
|
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
|
- super(RSU4F, self).__init__()
|
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
|
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
|
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
|
|
|
|
|
@@ -320,7 +320,7 @@ class RSU4F(nn.Module): # UNet04FRES(nn.Module):
|
|
##### U^2-Net ####
|
|
##### U^2-Net ####
|
|
class U2NET(nn.Module):
|
|
class U2NET(nn.Module):
|
|
def __init__(self, in_ch=3, out_ch=1):
|
|
def __init__(self, in_ch=3, out_ch=1):
|
|
- super(U2NET, self).__init__()
|
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
|
self.stage1 = RSU7(in_ch, 32, 64)
|
|
self.stage1 = RSU7(in_ch, 32, 64)
|
|
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
@@ -432,7 +432,7 @@ class U2NET(nn.Module):
|
|
### U^2-Net small ###
|
|
### U^2-Net small ###
|
|
class U2NETP(nn.Module):
|
|
class U2NETP(nn.Module):
|
|
def __init__(self, in_ch=3, out_ch=1):
|
|
def __init__(self, in_ch=3, out_ch=1):
|
|
- super(U2NETP, self).__init__()
|
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
|
self.stage1 = RSU7(in_ch, 16, 64)
|
|
self.stage1 = RSU7(in_ch, 16, 64)
|
|
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|