-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5004ce5
commit 23d1586
Showing
9 changed files
with
2,241 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class AddCoords(nn.Module): | ||
""" | ||
Adding coordinates to the input tensor | ||
""" | ||
def __init__(self, rank, with_r=False): | ||
super(AddCoords, self).__init__() | ||
self.rank = rank | ||
self.with_r = with_r | ||
|
||
def forward(self, input_tensor): | ||
if self.rank == 1: | ||
batch_size, channels, dim_x = input_tensor.shape | ||
xx_range = torch.arange(dim_x, dtype=torch.int32) | ||
xx_channel = xx_range[None, None, :] | ||
|
||
# Normalize and zero center | ||
xx_channel = xx_channel.float() / (dim_x - 1) | ||
xx_channel = xx_channel * 2 - 1 | ||
xx_channel = xx_channel.repeat(batch_size, 1, 1) | ||
|
||
if input_tensor.is_cuda: | ||
xx_channel = xx_channel.cuda() | ||
out = torch.cat([input_tensor, xx_channel]) | ||
|
||
if self.with_r: | ||
rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2)) | ||
|
||
elif self.rank == 2: | ||
batch_size, channels, dim_y, dim_x = input_tensor.shape | ||
|
||
xx_range = torch.arange(dim_y, dtype=torch.int32) | ||
yy_range = torch.arange(dim_x, dtype=torch.int32) | ||
|
||
xx_channel, yy_channel = torch.meshgrid([xx_range, yy_range]) | ||
|
||
# Normalize and zero center | ||
xx_channel = xx_channel.float() / (dim_y - 1) | ||
yy_channel = yy_channel.float() / (dim_x - 1) | ||
xx_channel = xx_channel * 2 - 1 | ||
yy_channel = yy_channel * 2 - 1 | ||
|
||
xx_channel = xx_channel.view(1, 1, xx_channel.shape[0], xx_channel.shape[1]) | ||
yy_channel = yy_channel.view(1, 1, yy_channel.shape[0], yy_channel.shape[1]) | ||
xx_channel = xx_channel.repeat(batch_size, 1, 1, 1) | ||
yy_channel = yy_channel.repeat(batch_size, 1, 1, 1) | ||
if dim_y == 1 and dim_x ==1: | ||
xx_channel = torch.zeros([batch_size, 1,1,1]) | ||
yy_channel = torch.zeros([batch_size, 1,1,1]) | ||
if input_tensor.is_cuda: | ||
xx_channel = xx_channel.cuda() | ||
yy_channel = yy_channel.cuda() | ||
|
||
out = torch.cat([input_tensor, xx_channel, yy_channel], dim=1) | ||
|
||
if self.with_r: | ||
rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2)) | ||
|
||
elif self.rank == 3: | ||
batch_size, channels, dim_z, dim_y, dim_x = input_tensor.shape | ||
|
||
xx_range = torch.arange(dim_x, dtype=torch.int32) | ||
yy_range = torch.arange(dim_y, dtype=torch.int32) | ||
zz_range = torch.arange(dim_x, dtype=torch.int32) | ||
|
||
xx_channel, yy_channel, zz_channel = torch.meshgrid([xx_range, | ||
yy_range, | ||
zz_range]) | ||
|
||
xx_channel = xx_channel.float() / (dim_y - 1) | ||
yy_channel = yy_channel.float() / (dim_z - 1) | ||
zz_channel = yy_channel.float() / (dim_x - 1) | ||
xx_channel = xx_channel * 2 - 1 | ||
yy_channel = yy_channel * 2 - 1 | ||
zz_channel = zz_channel * 2 - 1 | ||
|
||
xx_channel = xx_channel.view(1, 1, xx_channel.shape[0], | ||
xx_channel.shape[1], | ||
xx_channel.shape[2]) | ||
yy_channel = yy_channel.view(1, 1, yy_channel.shape[0], | ||
yy_channel.shape[1], | ||
yy_channel.shape[2]) | ||
zz_channel = zz_channel.view(1, 1, zz_channel.shape[0], | ||
zz_channel.shape[1], | ||
zz_channel.shape[2]) | ||
xx_channel = xx_channel.repeat(batch_size, 1, 1, 1, 1) | ||
yy_channel = yy_channel.repeat(batch_size, 1, 1, 1, 1) | ||
zz_channel = zz_channel.repeat(batch_size, 1, 1, 1, 1) | ||
|
||
if input_tensor.is_cuda: | ||
xx_channel = xx_channel.cuda() | ||
yy_channel = yy_channel.cuda() | ||
zz_channel = zz_channel.cuda() | ||
out = torch.cat([input_tensor, xx_channel, yy_channel, zz_channel], dim=1) | ||
|
||
if self.with_r: | ||
rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + | ||
torch.pow(yy_channel - 0.5, 2) + | ||
torch.pow(zz_channel - 0.5, 2)) | ||
else: | ||
raise NotImplementedError | ||
|
||
if self.with_r: | ||
out = torch.cat([out, rr], dim=1) | ||
|
||
return out | ||
|
||
|
||
class CoordConv1d(nn.Module): | ||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, | ||
padding=0, dilation=1, groups=1, bias=True, with_r=False): | ||
super(CoordConv1d, self).__init__() | ||
self.rank = 1 | ||
self.add_coords = AddCoords(self.rank, with_r) | ||
self.conv = nn.Conv1d(in_channels + self.rank + int(with_r), out_channels, | ||
kernel_size, stride, padding, dilation, groups, bias) | ||
|
||
def forward(self, input_tensor): | ||
out = self.add_coords(input_tensor) | ||
out = self.conv(out) | ||
return out | ||
|
||
|
||
class CoordConvTranspose1d(nn.Module): | ||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, | ||
padding=0, output_padding=0, dilation=1, groups=1, bias=True, with_r=False): | ||
super(CoordConvTranspose1d, self).__init__() | ||
self.rank = 1 | ||
self.add_coords = AddCoords(self.rank, with_r) | ||
self.convT = nn.ConvTranspose1d(in_channels + self.rank + int(with_r), out_channels, | ||
kernel_size, stride, padding, output_padding, groups, | ||
bias, dilation) | ||
|
||
def forward(self, input_tensor): | ||
out = self.add_coords(input_tensor) | ||
out = self.convT(out) | ||
return out | ||
|
||
|
||
class CoordConv2d(nn.Module): | ||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, | ||
padding=0, dilation=1, groups=1, bias=True, with_r=False): | ||
super(CoordConv2d, self).__init__() | ||
self.rank = 2 | ||
self.add_coords = AddCoords(self.rank, with_r) | ||
self.conv = nn.Conv2d(in_channels + self.rank + int(with_r), out_channels, | ||
kernel_size, stride, padding, dilation, groups, bias) | ||
|
||
def forward(self, input_tensor): | ||
out1 = self.add_coords(input_tensor) | ||
out = self.conv(out1) | ||
return out | ||
|
||
|
||
|
||
class CoordConvTranspose2d(nn.Module): | ||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, | ||
padding=0, output_padding=0, groups=1, bias=True, dilation=1, with_r=False): | ||
super(CoordConvTranspose2d, self).__init__() | ||
self.rank = 2 | ||
self.add_coords = AddCoords(self.rank, with_r) | ||
self.convT = nn.ConvTranspose2d(in_channels + self.rank + int(with_r), out_channels, | ||
kernel_size, stride, padding, output_padding, | ||
groups, bias, dilation) | ||
|
||
def forward(self, input_tensor): | ||
out1 = self.add_coords(input_tensor) | ||
out = self.convT(out1) | ||
return out | ||
|
||
class CoordConv3d(nn.Module): | ||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, | ||
padding=0, dilation=1, groups=1, bias=True, with_r=False): | ||
super(CoordConv3d, self).__init__() | ||
self.rank = 3 | ||
self.add_coords = AddCoords(self.rank, with_r) | ||
self.conv = nn.Conv3d(in_channels + self.rank + int(with_r), out_channels, | ||
kernel_size, stride, padding, dilation, groups, bias) | ||
|
||
def forward(self, input_tensor): | ||
out = self.add_coords(input_tensor) | ||
out = self.conv(out) | ||
return out | ||
|
||
|
||
class CoordConvTranspose3d(nn.Module): | ||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, | ||
padding=0, output_padding=0, groups=1, bias=True, dilation=1, with_r=False): | ||
super(CoordConvTranspose3d, self).__init__() | ||
self.rank = 3 | ||
self.add_coords = AddCoords(self.rank, with_r) | ||
self.convT = nn.ConvTranspose3d(in_channels + self.rank + int(with_r), out_channels, | ||
kernel_size, stride, padding, output_padding, groups, | ||
bias, dilation) | ||
|
||
def forward(self, input_tensor): | ||
out = self.add_coords(input_tensor) | ||
out = self.convT(out) | ||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Created on Fri Jun 26 17:12:50 2020 | ||
@author: giles | ||
""" | ||
from __future__ import print_function, division | ||
import os | ||
import pickle | ||
import random | ||
|
||
import torch | ||
import pandas as pd | ||
import torchvision | ||
from PIL import Image | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
from torch.utils.data import Dataset, DataLoader | ||
from torchvision import transforms, utils | ||
from utils import RotationMatrix6D, pose_from_filename | ||
|
||
|
||
class ImagePairDataset(Dataset): | ||
def __init__(self, dir_path): | ||
self.name_pair = self.dataNamePair(dir_path) | ||
self.length = len(self.name_pair) | ||
self.path = dir_path | ||
|
||
def __len__(self): | ||
return self.length | ||
|
||
def __getitem__(self, index): | ||
data = {} | ||
|
||
filename_list = self.name_pair[index] | ||
img1, img2, poseA, poseB = self.imgPreprocess(filename_list) | ||
groundtruth_pose = self.groundTruthTensor(filename_list) | ||
input1 = self.catImgPose(img1, poseA) | ||
|
||
data = {'input1': input1, | ||
'img1': img1, | ||
'img2': img2, | ||
'poseA': poseA, | ||
'poseB': poseB, | ||
'groundtruth_pose': groundtruth_pose | ||
} | ||
|
||
return data | ||
|
||
def dataNamePair(self, datadir): | ||
im_list = os.listdir(datadir) | ||
name_list = [] | ||
index_list = [] | ||
current_list = [] | ||
current = 0 | ||
for i in range(len(im_list)): | ||
name = im_list[i] | ||
split = name.split("_") | ||
if split[0] not in name_list: | ||
index_list.append(current_list) | ||
current_list = [] | ||
current = split[0] | ||
name_list.append(split[0]) | ||
current_list.append(i) | ||
if i == len(im_list) - 1: | ||
index_list.append(current_list) | ||
index_list.pop(0) | ||
|
||
name_pair = [] | ||
for cur_list in index_list: | ||
length = len(cur_list) | ||
for j in range(length - 1): # index1 = cur_list[j] | ||
for k in range(length - 1 - j): # index2 = cur_list[k] | ||
name1 = im_list[cur_list[j]] | ||
name2 = im_list[cur_list[k + j + 1]] | ||
name_pair.append([name1, name2]) | ||
|
||
return name_pair | ||
|
||
def imgPreprocess(self, filename_list): # normalize | ||
|
||
trans = transforms.Compose([ | ||
transforms.Resize((128, 128)), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | ||
]) | ||
tensor = torch.tensor(()) | ||
image1 = tensor.new_zeros((1, 3, 128, 128)) | ||
image2 = tensor.new_zeros((1, 3, 128, 128)) | ||
poseB = tensor.new_zeros((1, 12)) | ||
poseA = tensor.new_zeros((1, 12)) | ||
|
||
image1[0] = trans(Image.open(os.path.join(self.path, filename_list[0][0]))) | ||
image2[0] = trans(Image.open(os.path.join(self.path, filename_list[0][1]))) | ||
|
||
R1 = pose_from_filename(os.path.splitext(image1)[0]) | ||
poseA[0] = torch.from_numpy(np.reshape(R1, (12, 1)))[:, 0] | ||
R2 = pose_from_filename(os.path.splitext(image2)[0]) | ||
poseB[0] = torch.from_numpy(np.reshape(R2, (12, 1)))[:, 0] | ||
|
||
return image1, image2, poseA, poseB | ||
|
||
def groundTruthTensor(self, filename_list): | ||
gt = np.zeros((len(filename_list), 9)) | ||
for i in range(len(filename_list)): | ||
gt[i, :] = RotationMatrix6D(filename_list[i][0], filename_list[i][1]) | ||
|
||
return torch.from_numpy(gt) | ||
|
||
def catImgPose(self, img, pose): | ||
pose = pose[:, :, None, None] | ||
pose = pose.repeat(1, 1, 128, 128) | ||
input1 = torch.cat((img, pose), dim=1) | ||
|
||
return input1 | ||
|
||
|
||
class SRNDataset(Dataset): | ||
|
||
def __init__(self, split, path='./data/SRN/cars_train', pickle_file='./data/cars.pickle', imgsize=128): | ||
self.imgsize = imgsize | ||
self.path = path | ||
super().__init__() | ||
self.pickle_file = pickle.load(open(pickle_file, 'rb')) | ||
|
||
all_the_vid = sorted(list(self.pickle_file.keys())) | ||
|
||
random.seed(0) | ||
random.shuffle(all_the_vid) | ||
self.split = split | ||
if split == 'train': | ||
self.ids = all_the_vid[:int(len(all_the_vid) * 0.9)] | ||
else: | ||
self.ids = all_the_vid[int(len(all_the_vid) * 0.9):] | ||
|
||
def __len__(self): | ||
return len(self.ids) | ||
|
||
def __getitem__(self, idx): | ||
|
||
item = self.ids[idx] | ||
|
||
intrinsics_filename = os.path.join(self.path, item, 'intrinsics', self.pickle_file[item][0][:-4] + ".txt") | ||
|
||
indices = random.sample(self.pickle_file[item], k=2) | ||
|
||
imgs = [] | ||
poses = [] | ||
for i in indices: | ||
img_filename = os.path.join(self.path, item, 'rgb', i) | ||
img = Image.open(img_filename) | ||
if self.imgsize != 128: | ||
img = img.resize((self.imgsize, self.imgsize)) | ||
img = np.array(img) / 255 * 2 - 1 | ||
|
||
img = img.transpose(2, 0, 1)[:3].astype(np.float32) | ||
imgs.append(img) | ||
|
||
pose_filename = os.path.join(self.path, item, 'pose', i[:-4] + ".txt") | ||
with open(pose_filename) as file: | ||
# file =open(pose_filename).read() | ||
pose = np.array(file.read().strip().split()).astype(float).reshape((4, 4)) | ||
# pose_filename.close() | ||
pose = pose[:3, :].reshape((12, )) | ||
poses.append(pose) | ||
|
||
imgs = np.stack(imgs, 0) | ||
poses = np.stack(poses, 0) | ||
|
||
return imgs.astype(np.float32), poses.astype(np.float32) |
Oops, something went wrong.