forked from NVIDIA/flownet2-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2787b3a
commit 7b7b153
Showing
1 changed file
with
44 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import torch | ||
import numpy as np | ||
import argparse | ||
|
||
from Networks.FlowNet2 import FlowNet2 # the path is depended on where you create this module | ||
from frame_utils import read_gen # the path is depended on where you create this module | ||
|
||
if __name__ == '__main__': | ||
# obtain the necessary args for construct the flownet framework | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--fp16', action='store_true', help='Run model in pseudo-fp16 mode (fp16 storage fp32 math).') | ||
parser.add_argument("--rgb_max", type=float, default=255.) | ||
args = parser.parse_args() | ||
|
||
# initial a Net | ||
net = FlowNet2(args).cuda() | ||
# load the state_dict | ||
dict = torch.load("/home/hjj/PycharmProjects/flownet2_pytorch/FlowNet2_checkpoint.pth.tar") | ||
net.load_state_dict(dict["state_dict"]) | ||
|
||
# load the image pair, you can find this operation in dataset.py | ||
pim1 = read_gen("/home/hjj/flownet2-master/data/FlyingChairs_examples/0000007-img0.ppm") | ||
pim2 = read_gen("/home/hjj/flownet2-master/data/FlyingChairs_examples/0000007-img1.ppm") | ||
images = [pim1, pim2] | ||
images = np.array(images).transpose(3, 0, 1, 2) | ||
im = torch.from_numpy(images.astype(np.float32)).unsqueeze(0).cuda() | ||
|
||
# process the image pair to obtian the flow | ||
result = net(im).squeeze() | ||
|
||
|
||
# save flow, I reference the code in scripts/run-flownet.py in flownet2-caffe project | ||
def writeFlow(name, flow): | ||
f = open(name, 'wb') | ||
f.write('PIEH'.encode('utf-8')) | ||
np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) | ||
flow = flow.astype(np.float32) | ||
flow.tofile(f) | ||
f.flush() | ||
f.close() | ||
|
||
|
||
data = result.data.cpu().numpy().transpose(1, 2, 0) | ||
writeFlow("/home/hjj/flownet2-master/data/FlyingChairs_examples/0000007-img.flo", data) |