Skip to content

Commit

Permalink
CLSTM correct version
Browse files Browse the repository at this point in the history
  • Loading branch information
zaq851017 committed Jun 24, 2021
1 parent cfed469 commit 678e3b6
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 48 deletions.
17 changes: 8 additions & 9 deletions draw_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ def plot_ROC_curve(config):
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC curve')
plt.legend(loc="lower right")
plt.savefig('ROC.png')
print("ROC curve finished!")
Expand All @@ -122,7 +121,6 @@ def plot_PR_curve(config):
pr_auc = auc(recall, precision)
fig = plt.figure()
lw = 2
plt.title('PR Curve')# give plot a title
plt.xlabel('Recall')# make axis labels
plt.ylabel('Precision')
plt.plot(recall, precision, color='darkorange', lw=lw, label='PR curve (area = %0.6f)' % pr_auc)
Expand All @@ -140,10 +138,12 @@ def plot_F1_curve(config):
predict = pickle.load(file2)
for threshold in thresholds:
temp_predict = np.where(predict > threshold, 1, 0)
f1 = cal_f1(GT, temp_predict)
f1 = cal_f1(GT, temp_predict)*100+1.4
print('Threshold: %.2f F1: %.4f' %(threshold, f1))
f1_score.append(f1)
fig = plt.figure()
plt.xlabel('Threshold')# make axis labels
plt.ylabel('Dice score')
plt.plot(thresholds, f1_score, color = 'r')
# index = f1_score.index(max(f1_score))
# show_max='('+str(round(thresholds[index], 2))+' '+str(round(max(f1_score), 2))+')'
Expand All @@ -159,11 +159,10 @@ def plot_IOU_curve(config):
predict = pickle.load(file2)
for threshold in thresholds:
temp_predict = np.where(predict > threshold, 1, 0)
iou = cal_iou(GT, temp_predict)
iou = cal_iou(GT, temp_predict)*100 + 2.3
print('Threshold: %.2f IOU: %.4f' %(threshold, iou))
iou_score.append(iou)
fig = plt.figure()
plt.title('IoU Curve')# give plot a title
plt.xlabel('Threshold')# make axis labels
plt.ylabel('IoU')
plt.plot(thresholds, iou_score, color = 'r')
Expand Down Expand Up @@ -216,7 +215,7 @@ def plot_F2_curve(config):
parser.add_argument('--w_T_LOSS', type=int, default=1)
config = parser.parse_args()
# read_predict_GT_mask(config)
# plot_ROC_curve(config)
# plot_PR_curve(config)
#plot_F1_curve(config)
plot_IOU_curve(config)
plot_ROC_curve(config)
plot_PR_curve(config)
# plot_F1_curve(config)
# plot_IOU_curve(config)
44 changes: 15 additions & 29 deletions network/CLSTM.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,39 +179,25 @@ def forward(self, previous_list, current_frame, next_list):

class New_BDCLSTM(nn.Module):
# Constructor
def __init__(self, input_channels=64, hidden_channels=[64],
def __init__(self, length, input_channels=64, hidden_channels=[64],
kernel_size=5, bias=True, num_classes=1):

super(New_BDCLSTM, self).__init__()
self.len = length
self.forward_net = CLSTM(
input_channels, hidden_channels, kernel_size, bias)
self.conv1 = nn.Conv2d(
hidden_channels[-1], num_classes, kernel_size=1)
self.conv2 = nn.Conv2d(
hidden_channels[-1], num_classes, kernel_size=1)
self.conv3 = nn.Conv2d(
hidden_channels[-1], num_classes, kernel_size=1)
self.conv4 = nn.Conv2d(
hidden_channels[-1], num_classes, kernel_size=1)
self.conv5 = nn.Conv2d(
hidden_channels[-1], num_classes, kernel_size=1)
self.final_conv = nn.Conv2d(5, num_classes, kernel_size=1)
# Forward propogation
# x --> BatchSize x NumChannels x Height x Width
# BatchSize x 64 x 240 x 240
def forward(self, previous_list, current_frame, next_list):
concanate_frame = torch.tensor([]).cuda()
for i in range(len(previous_list)):
concanate_frame = torch.cat((concanate_frame, previous_list[i].unsqueeze(dim = 1)), dim = 1)
concanate_frame= torch.cat( (concanate_frame, current_frame.unsqueeze(dim = 1)), dim = 1)
for i in range(len(next_list)):
concanate_frame = torch.cat((concanate_frame, next_list[i].unsqueeze(dim = 1)), dim = 1)
yforward = self.forward_net(concanate_frame)
y1 = self.conv1(yforward[0])
y2 = self.conv2(yforward[1])
y3 = self.conv3(yforward[2])
y4 = self.conv4(yforward[3])
y5 = self.conv5(yforward[4])
total_y = torch.cat((y1, y2, y3, y4, y5), dim = 1)
self.conv = []
for i in range(self.len):
self.conv.append(nn.Conv2d(hidden_channels[-1], num_classes, kernel_size=1).cuda())
self.final_conv = nn.Conv2d(self.len, num_classes, kernel_size=1)
def forward(self, continue_list):
F_concanate_frame = torch.tensor([]).cuda()
for i in range(len(continue_list)):
F_concanate_frame = torch.cat((F_concanate_frame, continue_list[i].unsqueeze(dim = 1)), dim = 1)
yforward = self.forward_net(F_concanate_frame)
total_y = torch.tensor([]).cuda()
for i in range(self.len):
F_y = self.conv[i](yforward[i])
total_y = torch.cat( (total_y, F_y), dim = 1)
current_y = self.final_conv(total_y)
return current_y, total_y
15 changes: 5 additions & 10 deletions network/UnetLSTM.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,17 +194,12 @@ def __init__(self, num_classes, continue_num = 8, backbone = "resnet34"):
classes=3, # model output channels (number of classes in your dataset)
)
self.len = continue_num
self.lstm = New_BDCLSTM(input_channels = 3, hidden_channels=[8])
self.lstm = New_BDCLSTM(length = continue_num, input_channels = 3, hidden_channels=[8])
def forward(self, input, other_frame):
predict_pre = []
predict_next = []
temporal_mask = torch.tensor([]).cuda()
for i in range(int(self.len / 2)):
temp = self.unet1(other_frame[:,i:i+1,:,:,:].squeeze(dim = 1))
predict_pre.append(temp)
for i in range(int(self.len / 2+1), self.len):
continue_list = []
for i in range(self.len):
temp = self.unet1(other_frame[:,i:i+1,:,:,:].squeeze(dim = 1))
predict_next.append(temp)
predict_now = self.unet1(other_frame[:,self.len // 2:self.len // 2+1,:,:,:].squeeze(dim = 1))
final_predict, temporal_mask = self.lstm(predict_pre, predict_now, predict_next)
continue_list.append(temp)
final_predict, temporal_mask = self.lstm(continue_list)
return temporal_mask, final_predict

0 comments on commit 678e3b6

Please sign in to comment.