-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 0cc074f
Showing
78 changed files
with
16,590 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
*~ | ||
*#* | ||
*pyc | ||
#.* | ||
.DS_Store | ||
*.out | ||
*.gif | ||
reference_model/ | ||
*.mp4 | ||
logs_demo/ | ||
__pycache__/ | ||
*.avi | ||
checkpoints/ | ||
logs_train/ | ||
core | ||
debug_temp_output/ | ||
baseline/RAFT/ | ||
*.png | ||
results/ | ||
demos/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Install pytorch 1.13.1 cuda 11.6 | ||
FROM pytorch/pytorch:2.1.0-cuda11.8-cudnn8-runtime | ||
|
||
MAINTAINER Wanwen Chen | ||
|
||
ARG https_proxy | ||
ARG http_proxy | ||
|
||
ENV TZ=America/Los_Angeles \ | ||
DEBIAN_FRONTEND=noninteractive | ||
|
||
RUN apt-get update && \ | ||
apt-get install -y \ | ||
build-essential \ | ||
cmake \ | ||
git \ | ||
wget \ | ||
unzip \ | ||
yasm \ | ||
pkg-config \ | ||
curl \ | ||
vim | ||
|
||
RUN apt-get install -y \ | ||
python3-dev \ | ||
python3-numpy \ | ||
python3-pip | ||
|
||
RUN pip3 install --upgrade pip | ||
|
||
# Install dependency | ||
RUN pip3 --no-cache-dir install \ | ||
numpy \ | ||
SimpleITK \ | ||
scipy \ | ||
pillow | ||
|
||
RUN pip3 install einops==0.6.1 \ | ||
scikit-learn==1.3.0 \ | ||
matplotlib==3.7.3 \ | ||
opencv-python==4.8.0.76 \ | ||
scikit-image==0.21.0 \ | ||
albumentations==1.3.1 \ | ||
tensorboardX==2.6.2.2 \ | ||
fire==0.5.0 \ | ||
moviepy==1.0.3 \ | ||
prettytable==3.9.0 | ||
|
||
RUN pip3 install tensorboard | ||
|
||
# Fix opencv lib error | ||
RUN apt-get update | ||
RUN apt-get install ffmpeg libsm6 libxext6 -y | ||
|
||
# Set the library path to use cuda and cupti | ||
ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH | ||
|
||
|
||
RUN apt-get update |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2022 Adam W. Harley | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Online Point Tracking in Ultrasound | ||
|
||
This is the official code release for PIPsUS: Self-Supervised Dense Point Tracking in Ultrasound | ||
**[[Paper](https://arxiv.org/abs/2403.04969)]** | ||
|
||
The paper is accepted by The 5th International Workshop of Advances in Simplifying Medical UltraSound (ASMUS) - a workshop held in conjunction with MICCAI 2024! | ||
|
||
## Requirements | ||
|
||
The lines below should set up a fresh environment with everything you need: | ||
|
||
``` | ||
conda create -n pips2 python=3.8 | ||
conda activate pips2 | ||
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia | ||
pip install -r requirements.txt | ||
pip install SimpleITK | ||
``` | ||
|
||
|
||
## Citation | ||
|
||
|
||
## Acknowledgement | ||
|
||
We would like to thanks the authors of [PIPs++](https://arxiv.org/abs/2307.15055) and [RAFT](https://arxiv.org/pdf/2003.12039) for open-sourcing their codes and models. This work is built on top of their contribution. |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import numpy as np | ||
|
||
from skimage.feature import match_template | ||
|
||
|
||
|
||
def ncc_matching(image1, image2, kps, patch_size, search_size): | ||
# Fast Normalized Cross-Correlation | ||
kp_num = kps.shape[0] | ||
next_kps = np.zeros((kp_num, 2)) | ||
# zero padding for the image | ||
pad_size = search_size + patch_size | ||
image1 = np.pad(image1, ((pad_size, pad_size), (pad_size, pad_size)), 'constant', constant_values=(0, 0)) | ||
image2 = np.pad(image2, ((pad_size, pad_size), (pad_size, pad_size)), 'constant', constant_values=(0, 0)) | ||
half_patch_size = patch_size // 2 | ||
|
||
for i in range(kp_num): | ||
x = kps[i,0] + pad_size | ||
y = kps[i,1] + pad_size | ||
y = int(y) | ||
x = int(x) | ||
template_patch = image1[y-half_patch_size:y+half_patch_size+1, x-half_patch_size:x+half_patch_size+1] | ||
search_patch = image2[y-search_size-half_patch_size:y+search_size+half_patch_size+1, x-search_size-half_patch_size:x+search_size+half_patch_size+1] | ||
if template_patch.shape[0] == 0 or template_patch.shape[1] == 0 or search_patch.shape[0] == 0 or search_patch.shape[1] == 0: | ||
# print("Error: template patch size is 0") | ||
# print("x: %d, y: %d" %(x, y)) | ||
# print(kps[i]) | ||
# maybe the keypoint is out of the image, just skip | ||
next_kps[i] = kps[i] | ||
continue | ||
signal = match_template(search_patch, template_patch, pad_input=True) | ||
|
||
ij = np.unravel_index(np.argmax(signal), signal.shape) | ||
x_target, y_target = ij[::-1] | ||
center_x = search_size+half_patch_size | ||
center_y = search_size+half_patch_size | ||
next_kps[i,0] = x_target - center_x + x - pad_size | ||
next_kps[i,1] = y_target - center_y + y - pad_size | ||
|
||
return next_kps | ||
|
||
# def ncc(patch1, patch2): | ||
# norm_patch1 = norm_data(patch1.flatten()) | ||
# norm_patch2 = norm_data(patch2.flatten()) | ||
# if norm_patch1.shape[0] != norm_patch2.shape[0]: | ||
# print("Error: patch size not match %d %d" %(norm_patch1.shape[0], norm_patch2.shape[0])) | ||
# return 0 | ||
# result = np.dot(norm_patch1, norm_patch2) / norm_patch1.shape[0] | ||
# return result | ||
|
||
|
||
# def norm_data(data): | ||
# mean_data = np.mean(data) | ||
# std_data = np.std(data) | ||
# data = (data - mean_data) / (std_data + 1e-6) | ||
# # data = data / std_data | ||
# return data | ||
|
||
|
||
|
||
# if __name__ == '__main__': | ||
# data = np.random.rand(100, 100) | ||
# flow = ncc_matching(data, data, np.array([[50,50]]), 8, 8) | ||
# print(flow) | ||
# print(ncc(data, data)) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
import time | ||
import numpy as np | ||
import saverloader | ||
from nets.pips2 import Pips | ||
import utils.improc | ||
from utils.basic import print_, print_stats | ||
import torch | ||
from tensorboardX import SummaryWriter | ||
import torch.nn.functional as F | ||
from fire import Fire | ||
import sys | ||
import cv2 | ||
from pathlib import Path | ||
|
||
from ultrasound.pseudo_label import extract_keypoints, cvt_opencv_kps_to_numpy | ||
|
||
|
||
def read_mp4(fn): | ||
vidcap = cv2.VideoCapture(fn) | ||
frames = [] | ||
while(vidcap.isOpened()): | ||
ret, frame = vidcap.read() | ||
if ret == False: | ||
break | ||
frames.append(frame) | ||
vidcap.release() | ||
return frames | ||
|
||
def run_model(model, rgbs, S_max=128, N=64, iters=16, sw=None, counter=0): | ||
rgbs = rgbs.cuda().float() # B, S, C, H, W | ||
|
||
B, S, C, H, W = rgbs.shape | ||
assert(B==1) | ||
|
||
# # pick N points to track; we'll use a uniform grid | ||
# N_ = np.sqrt(N).round().astype(np.int32) | ||
# grid_y, grid_x = utils.basic.meshgrid2d(B, N_, N_, stack=False, norm=False, device='cuda') | ||
# grid_y = 8 + grid_y.reshape(B, -1)/float(N_-1) * (H-16) | ||
# grid_x = 8 + grid_x.reshape(B, -1)/float(N_-1) * (W-16) | ||
# xy0 = torch.stack([grid_x, grid_y], dim=-1) # B, N_*N_, 2 | ||
|
||
# get init keypoints | ||
kps = extract_keypoints(rgbs[0,0].permute(1,2,0), keypoint_type='sift') | ||
xy0 = cvt_opencv_kps_to_numpy(kps) # N x 2 | ||
xy0 = torch.from_numpy(xy0).float().cuda().unsqueeze(0) # 1 x N x 2 | ||
|
||
# zero-vel init | ||
trajs_e = xy0.unsqueeze(1).repeat(1,S,1,1) | ||
|
||
iter_start_time = time.time() | ||
|
||
preds, preds_anim, _, _ = model(trajs_e, rgbs, iters=iters, feat_init=None, beautify=True) | ||
trajs_e = preds[-1] | ||
|
||
iter_time = time.time()-iter_start_time | ||
print('inference time: %.2f seconds (%.1f fps)' % (iter_time, S/iter_time)) | ||
|
||
if sw is not None and sw.save_this: | ||
rgbs_prep = utils.improc.preprocess_color(rgbs) | ||
sw.summ_traj2ds_on_rgbs('outputs/trajs_on_rgbs', trajs_e[0:1], utils.improc.preprocess_color(rgbs[0:1]), cmap='hot', linewidth=1, show_dots=False) | ||
rgb_save = sw.summ_traj2ds_on_rgbs('outputs/trajs_on_rgbs', trajs_e[0:1], utils.improc.preprocess_color(rgbs[0:1]), cmap='hot', linewidth=1, show_dots=False, only_return=True) | ||
# print('rgb_save', rgb_save.shape) | ||
# print(rgb_save.dtype) | ||
rgb_save = rgb_save[0] | ||
rgb_save = rgb_save.permute(0,2,3,1) | ||
# save the video | ||
out = cv2.VideoWriter('inplane_iter6_sift_1s_vis_' + str(counter).zfill(3) + '.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 4, (W, H)) | ||
for i in range(rgb_save.shape[0]): | ||
out.write(cv2.cvtColor(rgb_save[i].numpy().astype(np.uint8), cv2.COLOR_RGB2BGR)) | ||
out.release() | ||
return trajs_e | ||
|
||
|
||
def main( | ||
filename='./stock_videos/camel.mp4', | ||
S=48, # seqlen | ||
N=1024, # number of points per clip | ||
stride=8, # spatial stride of the model | ||
timestride=1, # temporal stride of the model | ||
iters=16, # inference steps of the model | ||
image_size=(512,896), # input resolution | ||
max_iters=4, # number of clips to run | ||
shuffle=False, # dataset shuffling | ||
log_freq=1, # how often to make image summaries | ||
log_dir='./logs_demo', | ||
init_dir='./reference_model', | ||
device_ids=[0], | ||
): | ||
|
||
# the idea in this file is to run the model on a demo video, | ||
# and return some visualizations | ||
|
||
exp_name = 'de00' # copy from dev repo | ||
|
||
print('filename', filename) | ||
name = Path(filename).stem | ||
print('name', name) | ||
|
||
rgbs = read_mp4(filename) | ||
rgbs = np.stack(rgbs, axis=0) # S,H,W,3 | ||
rgbs = rgbs[:,:,:,::-1].copy() # BGR->RGB | ||
rgbs = rgbs[::timestride] | ||
S_here,H,W,C = rgbs.shape | ||
print('rgbs', rgbs.shape) | ||
|
||
# autogen a name | ||
model_name = "%s_%d_%d_%s" % (name, S, N, exp_name) | ||
import datetime | ||
model_date = datetime.datetime.now().strftime('%H:%M:%S') | ||
model_name = model_name # + '_' + model_date ## this will cause OS error in windows | ||
print('model_name', model_name) | ||
|
||
log_dir = 'logs_demo' | ||
writer_t = SummaryWriter(log_dir + '/' + model_name + '/t', max_queue=10, flush_secs=60) | ||
|
||
global_step = 0 | ||
|
||
model = Pips(stride=8).cuda() | ||
parameters = list(model.parameters()) | ||
if init_dir: | ||
_ = saverloader.load(init_dir, model) | ||
global_step = 0 | ||
model.eval() | ||
|
||
idx = list(range(0, max(S_here-S,1), S)) | ||
if max_iters: | ||
idx = idx[:max_iters] | ||
|
||
for si in idx: | ||
global_step += 1 | ||
|
||
iter_start_time = time.time() | ||
|
||
sw_t = utils.improc.Summ_writer( | ||
writer=writer_t, | ||
global_step=global_step, | ||
log_freq=log_freq, | ||
fps=16, | ||
scalar_freq=int(log_freq/2), | ||
just_gif=True) | ||
|
||
rgb_seq = rgbs[si:si+S] | ||
rgb_seq = torch.from_numpy(rgb_seq).permute(0,3,1,2).to(torch.float32) # S,3,H,W | ||
rgb_seq = F.interpolate(rgb_seq, image_size, mode='bilinear').unsqueeze(0) # 1,S,3,H,W | ||
|
||
with torch.no_grad(): | ||
trajs_e = run_model(model, rgb_seq, S_max=S, N=N, iters=iters, sw=sw_t, counter=si) | ||
|
||
iter_time = time.time()-iter_start_time | ||
|
||
print('%s; step %06d/%d; itime %.2f' % ( | ||
model_name, global_step, max_iters, iter_time)) | ||
|
||
|
||
|
||
writer_t.close() | ||
|
||
if __name__ == '__main__': | ||
# Fire(main) | ||
main(filename='case1_082042_0001.mp4', S=5, image_size=(727, 698), iters=6) | ||
|
||
# inference time: | ||
# ...found checkpoint ./reference_model\model-000200000.pth | ||
# inference time: 8.90 seconds (4.5 fps) | ||
# track_40_1024_de00; step 000001/4; itime 209.44 | ||
|
||
# inplane | ||
# inference time: 8.33 seconds (4.8 fps) | ||
# inplane_124611_40_1024_de00; step 000001/4; itime 199.35 | ||
# inference time: 6.84 seconds (5.9 fps) | ||
# inplane_124611_40_1024_de00; step 000002/4; itime 191.09 | ||
# inference time: 6.84 seconds (5.8 fps) | ||
# inplane_124611_40_1024_de00; step 000003/4; itime 190.66 |
Oops, something went wrong.