Skip to content

Commit

Permalink
0618 New TCSNet
Browse files Browse the repository at this point in the history
  • Loading branch information
zaq851017 committed Jun 18, 2021
1 parent e4179cb commit e0fd791
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 2 deletions.
4 changes: 4 additions & 0 deletions all_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def WHICH_MODEL(config, frame_continue_num):
net = DeepLabV3Plus_LSTM(1, len(frame_continue_num), config.backbone)
model_name = "DeepLabV3Plus_LSTM"+"_"+config.backbone
print(model_name)
elif config.which_model == -22:
net = New_DeepLabV3Plus_LSTM(1, len(frame_continue_num), config.backbone)
model_name = "New_DeepLabV3Plus_LSTM"+"_"+config.backbone
print(model_name)
elif config.which_model == 23:
net = DeepLabV3_LSTM(1, len(frame_continue_num), config.backbone)
model_name = "DeepLabV3_LSTM"+"_"+config.backbone
Expand Down
43 changes: 41 additions & 2 deletions network/CLSTM.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,46 @@ def forward(self, previous_list, current_frame, next_list):
# xreverse = torch.cat((x3, x2), dim=1)
yforward = self.forward_net(xforward)
yreverse = self.reverse_net(xreverse)
# assumes y is BatchSize x NumClasses x 240 x 240

ycat = torch.cat((yforward[-1], yreverse[-1]), dim=1)
y = self.conv(ycat)
return y
return y

class New_BDCLSTM(nn.Module):
# Constructor
def __init__(self, input_channels=64, hidden_channels=[64],
kernel_size=5, bias=True, num_classes=1):

super(New_BDCLSTM, self).__init__()
self.forward_net = CLSTM(
input_channels, hidden_channels, kernel_size, bias)
self.conv1 = nn.Conv2d(
hidden_channels[-1], num_classes, kernel_size=1)
self.conv2 = nn.Conv2d(
hidden_channels[-1], num_classes, kernel_size=1)
self.conv3 = nn.Conv2d(
hidden_channels[-1], num_classes, kernel_size=1)
self.conv4 = nn.Conv2d(
hidden_channels[-1], num_classes, kernel_size=1)
self.conv5 = nn.Conv2d(
hidden_channels[-1], num_classes, kernel_size=1)
self.final_conv = nn.Conv2d(5, num_classes, kernel_size=1)
# Forward propogation
# x --> BatchSize x NumChannels x Height x Width
# BatchSize x 64 x 240 x 240
def forward(self, previous_list, current_frame, next_list):
concanate_frame = torch.tensor([]).cuda()
for i in range(len(previous_list)):
concanate_frame = torch.cat((concanate_frame, previous_list[i].unsqueeze(dim = 1)), dim = 1)
concanate_frame= torch.cat( (concanate_frame, current_frame.unsqueeze(dim = 1)), dim = 1)
for i in range(len(next_list)):
concanate_frame = torch.cat((concanate_frame, next_list[i].unsqueeze(dim = 1)), dim = 1)
yforward = self.forward_net(concanate_frame)
y1 = self.conv1(yforward[0])
y2 = self.conv2(yforward[1])
y3 = self.conv3(yforward[2])
y4 = self.conv4(yforward[3])
y5 = self.conv5(yforward[4])
total_y = torch.cat((y1, y2, y3, y4, y5), dim = 1)
current_y = self.final_conv(total_y)
return current_y, total_y
26 changes: 26 additions & 0 deletions network/UnetLSTM.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import warnings
import segmentation_models_pytorch as smp
from network.CLSTM import BDCLSTM
from network.CLSTM import New_BDCLSTM
class Unet_LSTM(nn.Module):
def __init__(self, num_classes, continue_num = 8, backbone = "resnet34"):
super().__init__()
Expand Down Expand Up @@ -181,4 +182,29 @@ def forward(self, input, other_frame):
temporal_mask = torch.cat((temporal_mask, predict_now), dim = 1)
for p_n in (predict_next):
temporal_mask = torch.cat((temporal_mask, p_n), dim = 1)
return temporal_mask, final_predict

class New_DeepLabV3Plus_LSTM(nn.Module):
def __init__(self, num_classes, continue_num = 8, backbone = "resnet34"):
super().__init__()
self.unet1 = smp.DeepLabV3Plus(
encoder_name=backbone, # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
encoder_weights="imagenet", # use `imagenet` pre-trained weights for encoder initialization
in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
classes=3, # model output channels (number of classes in your dataset)
)
self.len = continue_num
self.lstm = New_BDCLSTM(input_channels = 3, hidden_channels=[8])
def forward(self, input, other_frame):
predict_pre = []
predict_next = []
temporal_mask = torch.tensor([]).cuda()
for i in range(int(self.len / 2)):
temp = self.unet1(other_frame[:,i:i+1,:,:,:].squeeze(dim = 1))
predict_pre.append(temp)
for i in range(int(self.len / 2+1), self.len):
temp = self.unet1(other_frame[:,i:i+1,:,:,:].squeeze(dim = 1))
predict_next.append(temp)
predict_now = self.unet1(other_frame[:,self.len // 2:self.len // 2+1,:,:,:].squeeze(dim = 1))
final_predict, temporal_mask = self.lstm(predict_pre, predict_now, predict_next)
return temporal_mask, final_predict
4 changes: 4 additions & 0 deletions predict_src/postprocess_src.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,10 @@ def MERGE_VIDEO(config):
os.system("rm -r "+full_path_3)
full_path_3 = os.path.join(full_path, num_files+"/temporal_mask")
os.system("rm -r "+full_path_3)
# full_path_3 = os.path.join(full_path, num_files+"/forfilm")
# os.system("rm -r "+full_path_3)
# full_path_3 = os.path.join(full_path, num_files+"/vol_mask")
# os.system("rm -r "+full_path_3)
def test_wo_postprocess(config, test_loader, net):
if not os.path.isdir(config.output_path):
print("os.makedirs "+ config.output_path)
Expand Down

0 comments on commit e0fd791

Please sign in to comment.