Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shaokai/sa transfer learning #2389

Merged
merged 27 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
44e4754
Added SA finetune and corresponding project
yeshaokai Aug 9, 2023
3c745ea
removed artifacts
yeshaokai Aug 9, 2023
78b2591
removed more artifacts
yeshaokai Aug 9, 2023
bde5778
Update deeplabcut/pose_estimation_tensorflow/training.py
jeylau Aug 17, 2023
b4b0b59
Remove unused argument
jeylau Aug 17, 2023
8fcde4f
removed example project
yeshaokai Sep 12, 2023
6bc53dc
SA transfer learning
yeshaokai Sep 12, 2023
fd8a926
Update ModelZoo.md
yeshaokai Sep 12, 2023
cf38fd4
Update ModelZoo.md
yeshaokai Sep 12, 2023
5c1ff15
Update ModelZoo.md
yeshaokai Sep 12, 2023
960df24
added transfer learning flag
yeshaokai Sep 12, 2023
df4734c
Merge branch 'shaokai/SA_transfer_learning' of github.com:yeshaokai/D…
yeshaokai Sep 12, 2023
9245284
Update ModelZoo.md
yeshaokai Sep 12, 2023
400bdeb
Update ModelZoo.md
yeshaokai Sep 12, 2023
807aa5e
Fixes
yeshaokai Sep 14, 2023
b880c22
Merge branch 'shaokai/SA_transfer_learning' of github.com:yeshaokai/D…
yeshaokai Sep 14, 2023
688c30b
fixed typo
yeshaokai Sep 14, 2023
7bf0dd8
updates
yeshaokai Sep 18, 2023
c0d06ac
Update ModelZoo.md
Timokleia Sep 18, 2023
5a2b606
Merge pull request #3 from Timokleia/patch-1
yeshaokai Sep 18, 2023
c1201c9
update path for the testscript
yeshaokai Sep 18, 2023
e078e1e
removed debug
yeshaokai Sep 18, 2023
6a56255
Update ModelZoo.md
MMathisLab Sep 18, 2023
7c979e4
Update examples/testscript_superanimal_transfer_learning.py
yeshaokai Sep 21, 2023
0a70992
revert change to config
yeshaokai Sep 21, 2023
5653cec
ran black and removed unused parameters
yeshaokai Oct 31, 2023
5170ed3
Merge branch 'main' into shaokai/SA_transfer_learning
jeylau Nov 6, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion deeplabcut/generate_training_dataset/frame_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,6 @@ def extract_frames(
from deeplabcut.utils import frameselectiontools
from deeplabcut.utils import auxiliaryfunctions


config_file = Path(config).resolve()
cfg = auxiliaryfunctions.read_config(config_file)
print("Config file read successfully.")
Expand Down
35 changes: 31 additions & 4 deletions deeplabcut/generate_training_dataset/trainingsetmanipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
auxfun_multianimal,
)
from deeplabcut.utils.auxfun_videos import VideoReader
from deeplabcut.pose_estimation_tensorflow.config import load_config
from deeplabcut.modelzoo.utils import parse_available_supermodels


def comparevideolistsanddatafolders(config):
Expand Down Expand Up @@ -482,7 +484,9 @@ def merge_annotateddatasets(cfg, trainingsetfolder_full):
data = pd.read_hdf(file_path)
conversioncode.guarantee_multiindex_rows(data)
if data.columns.levels[0][0] != cfg["scorer"]:
print(f"{file_path} labeled by a different scorer. This data will not be utilized in training dataset creation. If you need to merge datasets across scorers, see https://github.com/DeepLabCut/DeepLabCut/wiki/Using-labeled-data-in-DeepLabCut-that-was-annotated-elsewhere-(or-merge-across-labelers)")
print(
f"{file_path} labeled by a different scorer. This data will not be utilized in training dataset creation. If you need to merge datasets across scorers, see https://github.com/DeepLabCut/DeepLabCut/wiki/Using-labeled-data-in-DeepLabCut-that-was-annotated-elsewhere-(or-merge-across-labelers)"
)
continue
AnnotationData.append(data)
except FileNotFoundError:
Expand Down Expand Up @@ -728,6 +732,7 @@ def create_training_dataset(
net_type=None,
augmenter_type=None,
posecfg_template=None,
superanimal_name="",
):
"""Creates a training dataset.

Expand Down Expand Up @@ -791,6 +796,10 @@ def create_training_dataset(
parameters a previous training iteration. None uses the default
``pose_cfg.yaml``.

superanimal_name: string, optional, default=""
Specify the superanimal name is transfer learning with superanimal is desired. This makes sure the pose config template uses superanimal configs as template


Returns
-------
list(tuple) or None
Expand Down Expand Up @@ -832,23 +841,40 @@ def create_training_dataset(

# Loading metadata from config file:
cfg = auxiliaryfunctions.read_config(config)
dlc_root_path = auxiliaryfunctions.get_deeplabcut_path()

if superanimal_name != "":
supermodels = parse_available_supermodels()
posecfg_template = os.path.join(
dlc_root_path,
"pose_estimation_tensorflow",
"superanimal_configs",
supermodels[superanimal_name],
)

if posecfg_template:
if not posecfg_template.endswith("pose_cfg.yaml"):
if (
not posecfg_template.endswith("pose_cfg.yaml")
and not posecfg_template.endswith("superquadruped.yaml")
and not posecfg_template.endswith("supertopview.yaml")
):
raise ValueError(
"posecfg_template argument must contain path to a pose_cfg.yaml file"
)
else:
print("Reloading pose_cfg parameters from " + posecfg_template + "\n")
from deeplabcut.utils.auxiliaryfunctions import read_plainconfig

prior_cfg = read_plainconfig(posecfg_template)
prior_cfg = read_plainconfig(posecfg_template)
if cfg.get("multianimalproject", False):
from deeplabcut.generate_training_dataset.multiple_individuals_trainingsetmanipulation import (
create_multianimaltraining_dataset,
)

create_multianimaltraining_dataset(
config, num_shuffles, Shuffles,
config,
num_shuffles,
Shuffles,
net_type=net_type,
trainIndices=trainIndices,
testIndices=testIndices,
Expand Down Expand Up @@ -880,6 +906,7 @@ def create_training_dataset(
"resnet" in net_type
or "mobilenet" in net_type
or "efficientnet" in net_type
or "dlcrnet" in net_type
):
pass
else:
Expand Down
6 changes: 5 additions & 1 deletion deeplabcut/gui/tabs/extract_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,11 @@ def _set_page(self):
self._generate_layout_attributes(self.layout_attributes)
self.main_layout.addLayout(self.layout_attributes)

self.main_layout.addWidget(_create_label_widget("Optional: frame extraction from a video subset", "font:bold"))
self.main_layout.addWidget(
_create_label_widget(
"Optional: frame extraction from a video subset", "font:bold"
)
)
self.video_selection_widget = VideoSelectionWidget(self.root, self)
self.main_layout.addWidget(self.video_selection_widget)

Expand Down
4 changes: 2 additions & 2 deletions deeplabcut/gui/tracklet_toolbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def _prepare_canvas(self, manager, fig):

img = self.video.read_frame()
self.im = self.ax1.imshow(img)
self.scat = self.ax1.scatter([], [], s=self.dotsize**2, picker=True)
self.scat = self.ax1.scatter([], [], s=self.dotsize ** 2, picker=True)
self.scat.set_offsets(manager.xy[:, 0])
self.scat.set_color(self.colors)
self.trails = sum(
Expand Down Expand Up @@ -807,7 +807,7 @@ def on_change(self, val):

def update_dotsize(self, val):
self.dotsize = val
self.scat.set_sizes([self.dotsize**2])
self.scat.set_sizes([self.dotsize ** 2])

@staticmethod
def calc_distance(x1, y1, x2, y2):
Expand Down
8 changes: 6 additions & 2 deletions deeplabcut/gui/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,13 +407,17 @@ def _update_project_state(self, config, loaded):
def _ask_for_help(self):
dlg = QMessageBox(self)
dlg.setWindowTitle("Ask for help")
dlg.setText('''Ask our community for help on <a href='https://forum.image.sc/tag/deeplabcut'>the forum</a>!''')
dlg.setText(
"""Ask our community for help on <a href='https://forum.image.sc/tag/deeplabcut'>the forum</a>!"""
)
_ = dlg.exec()

def _learn_dlc(self):
dlg = QMessageBox(self)
dlg.setWindowTitle("Learn DLC")
dlg.setText('''Learn DLC with <a href='https://deeplabcut.github.io/DeepLabCut/docs/UseOverviewGuide.html'>our docs and how-to guides</a>!''')
dlg.setText(
"""Learn DLC with <a href='https://deeplabcut.github.io/DeepLabCut/docs/UseOverviewGuide.html'>our docs and how-to guides</a>!"""
)
_ = dlg.exec()

def _create_project(self):
Expand Down
5 changes: 4 additions & 1 deletion deeplabcut/modelzoo/api/spatiotemporal_adapt.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,10 @@ def before_adapt_inference(self, make_video=False, **kwargs):
customized_test_config=self.customized_pose_config,
)
if kwargs.pop("plot_trajectories", True):
_plot_trajectories(datafiles[0])
if len(datafiles) == 0:
print("No data files found for plotting trajectory")
else:
_plot_trajectories(datafiles[0])

if make_video:
deeplabcut.create_labeled_video(
Expand Down
2 changes: 1 addition & 1 deletion deeplabcut/modelzoo/api/superanimal_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def video_inference(
print("Loading ", video)
vid = VideoWriter(video)
if len(scale_list) == 0:
# spatial pyramid can still be useful for reducing jittering and quantization error
# spatial pyramid can still be useful for reducing jittering and quantization error
scale_list = [vid.height - 50, vid.height, vid.height + 50]
if robust_nframes:
nframes = vid.get_n_frames(robust=True)
Expand Down
12 changes: 11 additions & 1 deletion deeplabcut/pose_estimation_tensorflow/core/train_multianimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def train(
traintime_resize=False,
video_path="",
superanimal=None,
trim_ends = None # trim the both ends of the video for video adaptation
remove_head=False,
):
# in case there was already a graph
tf.compat.v1.reset_default_graph()
Expand Down Expand Up @@ -93,6 +93,7 @@ def train(
cfg["pairwise_predict"] = True

dataset = PoseDatasetFactory.create(cfg)

batch_spec = get_batch_spec(cfg)
batch, enqueue_op, placeholders = setup_preloading(batch_spec)

Expand All @@ -107,6 +108,7 @@ def train(
if init_weights != "":
cfg["init_weights"] = init_weights
cfg["resume_weights_only"] = True
print("replacing default init weights with: ", init_weights)

stem = Path(cfg["init_weights"]).stem
if "snapshot" in stem and keepdeconvweights:
Expand All @@ -117,6 +119,14 @@ def train(
else:
start_iter = int(stem.split("-")[1])

if remove_head:
# removing the decoding layer from the checkpoint
temp = []
for variable in variables_to_restore:
if "pose" not in variable.name:
temp.append(variable)
variables_to_restore = temp

else:
print("Loading ImageNet-pretrained", net_type)
# loading backbone from ResNet, MobileNet etc.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def make_batch(self, data_item, scale, mirror):

def compute_target_part_scoremap(self, joint_id, coords, data_item, size, scale):
dist_thresh = self.cfg["pos_dist_thresh"] * scale
dist_thresh_sq = dist_thresh**2
dist_thresh_sq = dist_thresh ** 2
num_joints = self.cfg["num_joints"]
scmap = np.zeros(np.concatenate([size, np.array([num_joints])]))
locref_size = np.concatenate([size, np.array([num_joints * 2])])
Expand Down Expand Up @@ -260,7 +260,7 @@ def compute_target_part_scoremap(self, joint_id, coords, data_item, size, scale)
pt_x = i * self.stride + self.half_stride
dx = j_x - pt_x
dy = j_y - pt_y
dist = dx**2 + dy**2
dist = dx ** 2 + dy ** 2
# print(la.norm(diff))
if dist <= dist_thresh_sq:
scmap[j, i, j_id] = 1
Expand Down
8 changes: 4 additions & 4 deletions deeplabcut/pose_estimation_tensorflow/datasets/pose_imgaug.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def gaussian_scmap(self, joint_id, coords, data_item, size, scale):
width = size[1]
height = size[0]
dist_thresh = float((width + height) / 6)
dist_thresh_sq = dist_thresh**2
dist_thresh_sq = dist_thresh ** 2

std = dist_thresh / 4
# Grid of coordinates
Expand All @@ -503,7 +503,7 @@ def gaussian_scmap(self, joint_id, coords, data_item, size, scale):
map_j = grid.copy()
# Distance between the joint point and each coordinate
dist = np.linalg.norm(grid - (j_y, j_x), axis=2) ** 2
scmap_j = np.exp(-dist / (2 * (std**2)))
scmap_j = np.exp(-dist / (2 * (std ** 2)))
scmap[..., j_id] = scmap_j
locref_mask[dist <= dist_thresh_sq, j_id * 2 + 0] = 1
locref_mask[dist <= dist_thresh_sq, j_id * 2 + 1] = 1
Expand All @@ -528,7 +528,7 @@ def compute_target_part_scoremap_numpy(
self, joint_id, coords, data_item, size, scale
):
dist_thresh = float(self.cfg["pos_dist_thresh"] * scale)
dist_thresh_sq = dist_thresh**2
dist_thresh_sq = dist_thresh ** 2
num_joints = self.cfg["num_joints"]

scmap = np.zeros(np.concatenate([size, np.array([num_joints])]))
Expand All @@ -555,7 +555,7 @@ def compute_target_part_scoremap_numpy(
y = grid.copy()[:, :, 0]
dx = j_x - x * self.stride - self.half_stride
dy = j_y - y * self.stride - self.half_stride
dist = dx**2 + dy**2
dist = dx ** 2 + dy ** 2
mask1 = dist <= dist_thresh_sq
mask2 = (x >= min_x) & (x <= max_x)
mask3 = (y >= min_y) & (y <= max_y)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,14 @@ def load_dataset(self):
self.has_gt = has_gt
return data

def _load_pseudo_data_from_h5(self, cfg, threshold=0.5, mask_kpts_below_thresh=False):
def _load_pseudo_data_from_h5(
self, cfg, threshold=0.5, mask_kpts_below_thresh=False
):
gt_file = cfg["pseudo_label"]
assert os.path.exists(gt_file)
path_ = Path(gt_file)
print("Using gt file:", path_.name)
num_kpts = len(cfg['all_joints_names'])
num_kpts = len(cfg["all_joints_names"])
df = pd.read_hdf(gt_file)
video_name = path_.name.split("DLC")[0]
video_root = str(path_.parents[0] / video_name)
Expand Down Expand Up @@ -353,11 +355,13 @@ def get_batch_from_video(self):
batch_joints = []
joint_ids = []
data_items = []
trim_ends = self.cfg.get('trim_ends', None)
trim_ends = self.cfg.get("trim_ends", None)
if trim_ends is None:
trim_ends = 0
# because of the existence of threshold, sampling population is adjusted to len(self.data)
img_idx = np.random.choice(len(self.data) - trim_ends *2, size=self.batch_size, replace=True)
img_idx = np.random.choice(
len(self.data) - trim_ends * 2, size=self.batch_size, replace=True
)
for i in range(self.batch_size):
index = img_idx[i]
offset = trim_ends
Expand Down Expand Up @@ -614,7 +618,7 @@ def compute_target_part_scoremap_numpy(
locref_size = *size, num_joints * 2
locref_map = np.zeros(locref_size)
locref_scale = 1.0 / self.cfg["locref_stdev"]
dist_thresh_sq = dist_thresh**2
dist_thresh_sq = dist_thresh ** 2

partaffinityfield_shape = *size, self.cfg["num_limbs"] * 2
partaffinityfield_map = np.zeros(partaffinityfield_shape)
Expand All @@ -640,7 +644,7 @@ def compute_target_part_scoremap_numpy(
dx_ = dx * locref_scale
dy = coords[:, 1] - yy * stride - half_stride
dy_ = dy * locref_scale
dist = dx**2 + dy**2
dist = dx ** 2 + dy ** 2
mask1 = dist <= dist_thresh_sq
mask2 = (xx >= mins[:, 0]) & (xx <= maxs[:, 0])
mask3 = (yy >= mins[:, 1]) & (yy <= maxs[:, 1])
Expand Down Expand Up @@ -747,7 +751,7 @@ def gaussian_scmap(self, joint_id, coords, data_item, size, scale):
locref_map = np.zeros(locref_size)

locref_scale = 1.0 / self.cfg["locref_stdev"]
dist_thresh_sq = dist_thresh**2
dist_thresh_sq = dist_thresh ** 2

partaffinityfield_shape = np.concatenate(
[size, np.array([self.cfg["num_limbs"] * 2])]
Expand Down Expand Up @@ -779,7 +783,7 @@ def gaussian_scmap(self, joint_id, coords, data_item, size, scale):
map_j = grid.copy()
# Distance between the joint point and each coordinate
dist = np.linalg.norm(grid - (j_y, j_x), axis=2) ** 2
scmap_j = np.exp(-dist / (2 * (std**2)))
scmap_j = np.exp(-dist / (2 * (std ** 2)))
scmap[..., j_id] = scmap_j
locref_mask[dist <= dist_thresh_sq, j_id * 2 + 0] = 1
locref_mask[dist <= dist_thresh_sq, j_id * 2 + 1] = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def compute_target_part_scoremap(self, components):
locref_map = np.zeros(locref_size)

locref_scale = 1.0 / self.cfg["locref_stdev"]
dist_thresh_sq = dist_thresh**2
dist_thresh_sq = dist_thresh ** 2

width = size[1]
height = size[0]
Expand All @@ -375,7 +375,7 @@ def compute_target_part_scoremap(self, components):
pt_x = i * stride + half_stride
dx = j_x - pt_x
dy = j_y - pt_y
dist = dx**2 + dy**2
dist = dx ** 2 + dy ** 2
# print(la.norm(diff))
if dist <= dist_thresh_sq:
scmap[j, i, j_id] = 1
Expand Down
6 changes: 3 additions & 3 deletions deeplabcut/pose_estimation_tensorflow/lib/inferenceutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def calc_link_probability(self, link):
ind = _conv_square_to_condensed_indices(i, j, self.n_multibodyparts)
mu = self._kde.mean[ind]
sigma = self._kde.covariance[ind, ind]
z = (link.length**2 - mu) / sigma
z = (link.length ** 2 - mu) / sigma
return 2 * (1 - 0.5 * (1 + erf(abs(z) / sqrt(2))))

@staticmethod
Expand Down Expand Up @@ -809,13 +809,13 @@ def wrapped(i):
if unique is not None:
self.unique[i] = unique
pbar.update()

def from_pickle(self, pickle_path):
with open(pickle_path, "rb") as file:
data = pickle.load(file)
self.unique = data.pop("single", {})
self.assemblies = data

@staticmethod
def parse_metadata(data):
params = dict()
Expand Down
4 changes: 2 additions & 2 deletions deeplabcut/pose_estimation_tensorflow/nnets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def get_batch_spec(cfg):
def make_2d_gaussian_kernel(sigma, size):
sigma = tf.convert_to_tensor(sigma, dtype=tf.float32)
k = tf.range(-size // 2 + 1, size // 2 + 1)
k = tf.cast(k**2, sigma.dtype)
k = tf.nn.softmax(-k / (2 * (sigma**2)))
k = tf.cast(k ** 2, sigma.dtype)
k = tf.nn.softmax(-k / (2 * (sigma ** 2)))
return tf.einsum("i,j->ij", k, k)


Expand Down
Loading
Loading