Skip to content

Commit

Permalink
Shaokai/sa transfer learning (#2389)
Browse files Browse the repository at this point in the history
* Added SA finetune and corresponding project

---------

Co-authored-by: Jessy Lauer <30733203+jeylau@users.noreply.github.com>
Co-authored-by: Timokleia <86098649+Timokleia@users.noreply.github.com>
Co-authored-by: Mackenzie Mathis <mathis@rowland.harvard.edu>
  • Loading branch information
4 people authored Nov 6, 2023
1 parent 25f8c50 commit 11f12d8
Show file tree
Hide file tree
Showing 30 changed files with 331 additions and 156 deletions.
1 change: 0 additions & 1 deletion deeplabcut/generate_training_dataset/frame_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,6 @@ def extract_frames(
from deeplabcut.utils import frameselectiontools
from deeplabcut.utils import auxiliaryfunctions


config_file = Path(config).resolve()
cfg = auxiliaryfunctions.read_config(config_file)
print("Config file read successfully.")
Expand Down
35 changes: 31 additions & 4 deletions deeplabcut/generate_training_dataset/trainingsetmanipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
auxfun_multianimal,
)
from deeplabcut.utils.auxfun_videos import VideoReader
from deeplabcut.pose_estimation_tensorflow.config import load_config
from deeplabcut.modelzoo.utils import parse_available_supermodels


def comparevideolistsanddatafolders(config):
Expand Down Expand Up @@ -482,7 +484,9 @@ def merge_annotateddatasets(cfg, trainingsetfolder_full):
data = pd.read_hdf(file_path)
conversioncode.guarantee_multiindex_rows(data)
if data.columns.levels[0][0] != cfg["scorer"]:
print(f"{file_path} labeled by a different scorer. This data will not be utilized in training dataset creation. If you need to merge datasets across scorers, see https://github.com/DeepLabCut/DeepLabCut/wiki/Using-labeled-data-in-DeepLabCut-that-was-annotated-elsewhere-(or-merge-across-labelers)")
print(
f"{file_path} labeled by a different scorer. This data will not be utilized in training dataset creation. If you need to merge datasets across scorers, see https://github.com/DeepLabCut/DeepLabCut/wiki/Using-labeled-data-in-DeepLabCut-that-was-annotated-elsewhere-(or-merge-across-labelers)"
)
continue
AnnotationData.append(data)
except FileNotFoundError:
Expand Down Expand Up @@ -728,6 +732,7 @@ def create_training_dataset(
net_type=None,
augmenter_type=None,
posecfg_template=None,
superanimal_name="",
):
"""Creates a training dataset.
Expand Down Expand Up @@ -791,6 +796,10 @@ def create_training_dataset(
parameters a previous training iteration. None uses the default
``pose_cfg.yaml``.
superanimal_name: string, optional, default=""
Specify the superanimal name is transfer learning with superanimal is desired. This makes sure the pose config template uses superanimal configs as template
Returns
-------
list(tuple) or None
Expand Down Expand Up @@ -832,23 +841,40 @@ def create_training_dataset(

# Loading metadata from config file:
cfg = auxiliaryfunctions.read_config(config)
dlc_root_path = auxiliaryfunctions.get_deeplabcut_path()

if superanimal_name != "":
supermodels = parse_available_supermodels()
posecfg_template = os.path.join(
dlc_root_path,
"pose_estimation_tensorflow",
"superanimal_configs",
supermodels[superanimal_name],
)

if posecfg_template:
if not posecfg_template.endswith("pose_cfg.yaml"):
if (
not posecfg_template.endswith("pose_cfg.yaml")
and not posecfg_template.endswith("superquadruped.yaml")
and not posecfg_template.endswith("supertopview.yaml")
):
raise ValueError(
"posecfg_template argument must contain path to a pose_cfg.yaml file"
)
else:
print("Reloading pose_cfg parameters from " + posecfg_template + "\n")
from deeplabcut.utils.auxiliaryfunctions import read_plainconfig

prior_cfg = read_plainconfig(posecfg_template)
prior_cfg = read_plainconfig(posecfg_template)
if cfg.get("multianimalproject", False):
from deeplabcut.generate_training_dataset.multiple_individuals_trainingsetmanipulation import (
create_multianimaltraining_dataset,
)

create_multianimaltraining_dataset(
config, num_shuffles, Shuffles,
config,
num_shuffles,
Shuffles,
net_type=net_type,
trainIndices=trainIndices,
testIndices=testIndices,
Expand Down Expand Up @@ -880,6 +906,7 @@ def create_training_dataset(
"resnet" in net_type
or "mobilenet" in net_type
or "efficientnet" in net_type
or "dlcrnet" in net_type
):
pass
else:
Expand Down
6 changes: 5 additions & 1 deletion deeplabcut/gui/tabs/extract_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,11 @@ def _set_page(self):
self._generate_layout_attributes(self.layout_attributes)
self.main_layout.addLayout(self.layout_attributes)

self.main_layout.addWidget(_create_label_widget("Optional: frame extraction from a video subset", "font:bold"))
self.main_layout.addWidget(
_create_label_widget(
"Optional: frame extraction from a video subset", "font:bold"
)
)
self.video_selection_widget = VideoSelectionWidget(self.root, self)
self.main_layout.addWidget(self.video_selection_widget)

Expand Down
4 changes: 2 additions & 2 deletions deeplabcut/gui/tracklet_toolbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def _prepare_canvas(self, manager, fig):

img = self.video.read_frame()
self.im = self.ax1.imshow(img)
self.scat = self.ax1.scatter([], [], s=self.dotsize**2, picker=True)
self.scat = self.ax1.scatter([], [], s=self.dotsize ** 2, picker=True)
self.scat.set_offsets(manager.xy[:, 0])
self.scat.set_color(self.colors)
self.trails = sum(
Expand Down Expand Up @@ -807,7 +807,7 @@ def on_change(self, val):

def update_dotsize(self, val):
self.dotsize = val
self.scat.set_sizes([self.dotsize**2])
self.scat.set_sizes([self.dotsize ** 2])

@staticmethod
def calc_distance(x1, y1, x2, y2):
Expand Down
8 changes: 6 additions & 2 deletions deeplabcut/gui/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,13 +407,17 @@ def _update_project_state(self, config, loaded):
def _ask_for_help(self):
dlg = QMessageBox(self)
dlg.setWindowTitle("Ask for help")
dlg.setText('''Ask our community for help on <a href='https://forum.image.sc/tag/deeplabcut'>the forum</a>!''')
dlg.setText(
"""Ask our community for help on <a href='https://forum.image.sc/tag/deeplabcut'>the forum</a>!"""
)
_ = dlg.exec()

def _learn_dlc(self):
dlg = QMessageBox(self)
dlg.setWindowTitle("Learn DLC")
dlg.setText('''Learn DLC with <a href='https://deeplabcut.github.io/DeepLabCut/docs/UseOverviewGuide.html'>our docs and how-to guides</a>!''')
dlg.setText(
"""Learn DLC with <a href='https://deeplabcut.github.io/DeepLabCut/docs/UseOverviewGuide.html'>our docs and how-to guides</a>!"""
)
_ = dlg.exec()

def _create_project(self):
Expand Down
5 changes: 4 additions & 1 deletion deeplabcut/modelzoo/api/spatiotemporal_adapt.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,10 @@ def before_adapt_inference(self, make_video=False, **kwargs):
customized_test_config=self.customized_pose_config,
)
if kwargs.pop("plot_trajectories", True):
_plot_trajectories(datafiles[0])
if len(datafiles) == 0:
print("No data files found for plotting trajectory")
else:
_plot_trajectories(datafiles[0])

if make_video:
deeplabcut.create_labeled_video(
Expand Down
2 changes: 1 addition & 1 deletion deeplabcut/modelzoo/api/superanimal_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def video_inference(
print("Loading ", video)
vid = VideoWriter(video)
if len(scale_list) == 0:
# spatial pyramid can still be useful for reducing jittering and quantization error
# spatial pyramid can still be useful for reducing jittering and quantization error
scale_list = [vid.height - 50, vid.height, vid.height + 50]
if robust_nframes:
nframes = vid.get_n_frames(robust=True)
Expand Down
12 changes: 11 additions & 1 deletion deeplabcut/pose_estimation_tensorflow/core/train_multianimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def train(
traintime_resize=False,
video_path="",
superanimal=None,
trim_ends = None # trim the both ends of the video for video adaptation
remove_head=False,
):
# in case there was already a graph
tf.compat.v1.reset_default_graph()
Expand Down Expand Up @@ -93,6 +93,7 @@ def train(
cfg["pairwise_predict"] = True

dataset = PoseDatasetFactory.create(cfg)

batch_spec = get_batch_spec(cfg)
batch, enqueue_op, placeholders = setup_preloading(batch_spec)

Expand All @@ -107,6 +108,7 @@ def train(
if init_weights != "":
cfg["init_weights"] = init_weights
cfg["resume_weights_only"] = True
print("replacing default init weights with: ", init_weights)

stem = Path(cfg["init_weights"]).stem
if "snapshot" in stem and keepdeconvweights:
Expand All @@ -117,6 +119,14 @@ def train(
else:
start_iter = int(stem.split("-")[1])

if remove_head:
# removing the decoding layer from the checkpoint
temp = []
for variable in variables_to_restore:
if "pose" not in variable.name:
temp.append(variable)
variables_to_restore = temp

else:
print("Loading ImageNet-pretrained", net_type)
# loading backbone from ResNet, MobileNet etc.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def make_batch(self, data_item, scale, mirror):

def compute_target_part_scoremap(self, joint_id, coords, data_item, size, scale):
dist_thresh = self.cfg["pos_dist_thresh"] * scale
dist_thresh_sq = dist_thresh**2
dist_thresh_sq = dist_thresh ** 2
num_joints = self.cfg["num_joints"]
scmap = np.zeros(np.concatenate([size, np.array([num_joints])]))
locref_size = np.concatenate([size, np.array([num_joints * 2])])
Expand Down Expand Up @@ -260,7 +260,7 @@ def compute_target_part_scoremap(self, joint_id, coords, data_item, size, scale)
pt_x = i * self.stride + self.half_stride
dx = j_x - pt_x
dy = j_y - pt_y
dist = dx**2 + dy**2
dist = dx ** 2 + dy ** 2
# print(la.norm(diff))
if dist <= dist_thresh_sq:
scmap[j, i, j_id] = 1
Expand Down
8 changes: 4 additions & 4 deletions deeplabcut/pose_estimation_tensorflow/datasets/pose_imgaug.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def gaussian_scmap(self, joint_id, coords, data_item, size, scale):
width = size[1]
height = size[0]
dist_thresh = float((width + height) / 6)
dist_thresh_sq = dist_thresh**2
dist_thresh_sq = dist_thresh ** 2

std = dist_thresh / 4
# Grid of coordinates
Expand All @@ -503,7 +503,7 @@ def gaussian_scmap(self, joint_id, coords, data_item, size, scale):
map_j = grid.copy()
# Distance between the joint point and each coordinate
dist = np.linalg.norm(grid - (j_y, j_x), axis=2) ** 2
scmap_j = np.exp(-dist / (2 * (std**2)))
scmap_j = np.exp(-dist / (2 * (std ** 2)))
scmap[..., j_id] = scmap_j
locref_mask[dist <= dist_thresh_sq, j_id * 2 + 0] = 1
locref_mask[dist <= dist_thresh_sq, j_id * 2 + 1] = 1
Expand All @@ -528,7 +528,7 @@ def compute_target_part_scoremap_numpy(
self, joint_id, coords, data_item, size, scale
):
dist_thresh = float(self.cfg["pos_dist_thresh"] * scale)
dist_thresh_sq = dist_thresh**2
dist_thresh_sq = dist_thresh ** 2
num_joints = self.cfg["num_joints"]

scmap = np.zeros(np.concatenate([size, np.array([num_joints])]))
Expand All @@ -555,7 +555,7 @@ def compute_target_part_scoremap_numpy(
y = grid.copy()[:, :, 0]
dx = j_x - x * self.stride - self.half_stride
dy = j_y - y * self.stride - self.half_stride
dist = dx**2 + dy**2
dist = dx ** 2 + dy ** 2
mask1 = dist <= dist_thresh_sq
mask2 = (x >= min_x) & (x <= max_x)
mask3 = (y >= min_y) & (y <= max_y)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,14 @@ def load_dataset(self):
self.has_gt = has_gt
return data

def _load_pseudo_data_from_h5(self, cfg, threshold=0.5, mask_kpts_below_thresh=False):
def _load_pseudo_data_from_h5(
self, cfg, threshold=0.5, mask_kpts_below_thresh=False
):
gt_file = cfg["pseudo_label"]
assert os.path.exists(gt_file)
path_ = Path(gt_file)
print("Using gt file:", path_.name)
num_kpts = len(cfg['all_joints_names'])
num_kpts = len(cfg["all_joints_names"])
df = pd.read_hdf(gt_file)
video_name = path_.name.split("DLC")[0]
video_root = str(path_.parents[0] / video_name)
Expand Down Expand Up @@ -353,11 +355,13 @@ def get_batch_from_video(self):
batch_joints = []
joint_ids = []
data_items = []
trim_ends = self.cfg.get('trim_ends', None)
trim_ends = self.cfg.get("trim_ends", None)
if trim_ends is None:
trim_ends = 0
# because of the existence of threshold, sampling population is adjusted to len(self.data)
img_idx = np.random.choice(len(self.data) - trim_ends *2, size=self.batch_size, replace=True)
img_idx = np.random.choice(
len(self.data) - trim_ends * 2, size=self.batch_size, replace=True
)
for i in range(self.batch_size):
index = img_idx[i]
offset = trim_ends
Expand Down Expand Up @@ -614,7 +618,7 @@ def compute_target_part_scoremap_numpy(
locref_size = *size, num_joints * 2
locref_map = np.zeros(locref_size)
locref_scale = 1.0 / self.cfg["locref_stdev"]
dist_thresh_sq = dist_thresh**2
dist_thresh_sq = dist_thresh ** 2

partaffinityfield_shape = *size, self.cfg["num_limbs"] * 2
partaffinityfield_map = np.zeros(partaffinityfield_shape)
Expand All @@ -640,7 +644,7 @@ def compute_target_part_scoremap_numpy(
dx_ = dx * locref_scale
dy = coords[:, 1] - yy * stride - half_stride
dy_ = dy * locref_scale
dist = dx**2 + dy**2
dist = dx ** 2 + dy ** 2
mask1 = dist <= dist_thresh_sq
mask2 = (xx >= mins[:, 0]) & (xx <= maxs[:, 0])
mask3 = (yy >= mins[:, 1]) & (yy <= maxs[:, 1])
Expand Down Expand Up @@ -747,7 +751,7 @@ def gaussian_scmap(self, joint_id, coords, data_item, size, scale):
locref_map = np.zeros(locref_size)

locref_scale = 1.0 / self.cfg["locref_stdev"]
dist_thresh_sq = dist_thresh**2
dist_thresh_sq = dist_thresh ** 2

partaffinityfield_shape = np.concatenate(
[size, np.array([self.cfg["num_limbs"] * 2])]
Expand Down Expand Up @@ -779,7 +783,7 @@ def gaussian_scmap(self, joint_id, coords, data_item, size, scale):
map_j = grid.copy()
# Distance between the joint point and each coordinate
dist = np.linalg.norm(grid - (j_y, j_x), axis=2) ** 2
scmap_j = np.exp(-dist / (2 * (std**2)))
scmap_j = np.exp(-dist / (2 * (std ** 2)))
scmap[..., j_id] = scmap_j
locref_mask[dist <= dist_thresh_sq, j_id * 2 + 0] = 1
locref_mask[dist <= dist_thresh_sq, j_id * 2 + 1] = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def compute_target_part_scoremap(self, components):
locref_map = np.zeros(locref_size)

locref_scale = 1.0 / self.cfg["locref_stdev"]
dist_thresh_sq = dist_thresh**2
dist_thresh_sq = dist_thresh ** 2

width = size[1]
height = size[0]
Expand All @@ -375,7 +375,7 @@ def compute_target_part_scoremap(self, components):
pt_x = i * stride + half_stride
dx = j_x - pt_x
dy = j_y - pt_y
dist = dx**2 + dy**2
dist = dx ** 2 + dy ** 2
# print(la.norm(diff))
if dist <= dist_thresh_sq:
scmap[j, i, j_id] = 1
Expand Down
6 changes: 3 additions & 3 deletions deeplabcut/pose_estimation_tensorflow/lib/inferenceutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def calc_link_probability(self, link):
ind = _conv_square_to_condensed_indices(i, j, self.n_multibodyparts)
mu = self._kde.mean[ind]
sigma = self._kde.covariance[ind, ind]
z = (link.length**2 - mu) / sigma
z = (link.length ** 2 - mu) / sigma
return 2 * (1 - 0.5 * (1 + erf(abs(z) / sqrt(2))))

@staticmethod
Expand Down Expand Up @@ -809,13 +809,13 @@ def wrapped(i):
if unique is not None:
self.unique[i] = unique
pbar.update()

def from_pickle(self, pickle_path):
with open(pickle_path, "rb") as file:
data = pickle.load(file)
self.unique = data.pop("single", {})
self.assemblies = data

@staticmethod
def parse_metadata(data):
params = dict()
Expand Down
4 changes: 2 additions & 2 deletions deeplabcut/pose_estimation_tensorflow/nnets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def get_batch_spec(cfg):
def make_2d_gaussian_kernel(sigma, size):
sigma = tf.convert_to_tensor(sigma, dtype=tf.float32)
k = tf.range(-size // 2 + 1, size // 2 + 1)
k = tf.cast(k**2, sigma.dtype)
k = tf.nn.softmax(-k / (2 * (sigma**2)))
k = tf.cast(k ** 2, sigma.dtype)
k = tf.nn.softmax(-k / (2 * (sigma ** 2)))
return tf.einsum("i,j->ij", k, k)


Expand Down
Loading

0 comments on commit 11f12d8

Please sign in to comment.