Skip to content

Commit

Permalink
fix VOC references
Browse files Browse the repository at this point in the history
  • Loading branch information
ellisbrown committed Mar 6, 2018
1 parent 66faf9c commit 49a326f
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
from torch.autograd import Variable
from data import VOCroot, VOC_CLASSES as labelmap
from data import VOC_ROOT, VOC_CLASSES as labelmap
from PIL import Image
from data import AnnotationTransform, VOCDetection, BaseTransform, VOC_CLASSES
from data import VOCAnnotationTransform, VOCDetection, BaseTransform, VOC_CLASSES
import torch.utils.data as data
from ssd import build_ssd

Expand All @@ -20,12 +20,16 @@
help='Dir to save results')
parser.add_argument('--visual_threshold', default=0.6, type=float,
help='Final confidence threshold')
parser.add_argument('--cuda', default=False, type=bool,
parser.add_argument('--cuda', default=True, type=bool,
help='Use cuda to train model')
parser.add_argument('--voc_root', default=VOCroot, help='Location of VOC root directory')

parser.add_argument('--voc_root', default=VOC_ROOT, help='Location of VOC root directory')
args = parser.parse_args()

if args.cuda and torch.cuda.is_available():
torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
torch.set_default_tensor_type('torch.FloatTensor')

if not os.path.exists(args.save_folder):
os.mkdir(args.save_folder)

Expand Down Expand Up @@ -71,19 +75,22 @@ def test_net(save_folder, net, cuda, testset, transform, thresh):
j += 1


if __name__ == '__main__':
def test_voc():
# load net
num_classes = len(VOC_CLASSES) + 1 # +1 background
net = build_ssd('test', 300, num_classes) # initialize SSD
net.load_state_dict(torch.load(args.trained_model))
net.eval()
print('Finished loading model!')
# load data
testset = VOCDetection(args.voc_root, [('2007', 'test')], None, AnnotationTransform())
testset = VOCDetection(args.voc_root, [('2007', 'test')], None, VOCAnnotationTransform())
if args.cuda:
net = net.cuda()
cudnn.benchmark = True
# evaluation
test_net(args.save_folder, net, args.cuda, testset,
BaseTransform(net.size, (104, 117, 123)),
thresh=args.visual_threshold)

if __name__ == '__main__':
test_voc()

0 comments on commit 49a326f

Please sign in to comment.