Skip to content

Commit

Permalink
Supporting different number of landmarks.
Browse files Browse the repository at this point in the history
  • Loading branch information
hhj1897 committed Apr 11, 2022
1 parent 07de821 commit acbb783
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 47 deletions.
6 changes: 5 additions & 1 deletion face_alignment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def main() -> None:
help='Weights to be loaded for face alignment, can be either 2DFAN2, 2DFAN4, ' +
'or 2DFAN2_ALT (default=2DFAN2_ALT)')
parser.add_argument('--alignment-alternative-pth', '-ap', default=None,
help='Alternative pth file to be loaded for face alaignment')
help='Alternative pth file to be loaded for face alignment')
parser.add_argument('--alignment-alternative-landmarks', '-al', default=None,
help='Alternative number of landmarks to detect')
parser.add_argument('--alignment-device', '-ad', default='cuda:0',
help='Device to be used for face alignment (default=cuda:0)')
parser.add_argument('--hide-alignment-results', '-ha', help='Do not visualise face alignment results',
Expand Down Expand Up @@ -89,6 +91,8 @@ def main() -> None:
fa_model = FANPredictor.get_model(args.alignment_weights)
if args.alignment_alternative_pth is not None:
fa_model.weights = args.alignment_alternative_pth
if args.alignment_alternative_landmarks is not None:
fa_model.config.num_landmarks = int(args.alignment_alternative_landmarks)
landmark_detector = FANPredictor(device=args.alignment_device, model=fa_model)
print(f"Landmark detector created using FAN ({fa_model.weights}).")
else:
Expand Down
6 changes: 4 additions & 2 deletions ibug/face_alignment/fan/fan.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,16 @@ def __init__(self, config):
self.add_module('bn_end' + str(hg_module),
nn.InstanceNorm2d(self.config.hg_num_features) if self.config.use_instance_norm
else nn.BatchNorm2d(self.config.hg_num_features))
self.add_module('l' + str(hg_module), nn.Conv2d(self.config.hg_num_features, 68,
self.add_module('l' + str(hg_module), nn.Conv2d(self.config.hg_num_features,
self.config.num_landmarks,
kernel_size=1, stride=1, padding=0))

if hg_module < self.config.num_modules - 1:
self.add_module('bl' + str(hg_module), nn.Conv2d(self.config.hg_num_features,
self.config.hg_num_features,
kernel_size=1, stride=1, padding=0))
self.add_module('al' + str(hg_module), nn.Conv2d(68, self.config.hg_num_features,
self.add_module('al' + str(hg_module), nn.Conv2d(self.config.num_landmarks,
self.config.hg_num_features,
kernel_size=1, stride=1, padding=0))

def forward(self, x):
Expand Down
9 changes: 6 additions & 3 deletions ibug/face_alignment/fan/fan_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,22 @@ def get_model(name: str = '2dfan2') -> SimpleNamespace:
config=SimpleNamespace(crop_ratio=0.55, input_size=256, num_modules=2,
hg_num_features=256, hg_depth=4, use_avg_pool=False,
use_instance_norm=False, stem_conv_kernel_size=7,
stem_conv_stride=2, stem_pool_kernel_size=2))
stem_conv_stride=2, stem_pool_kernel_size=2,
num_landmarks=68))
elif name == '2dfan4':
return SimpleNamespace(weights=os.path.join(os.path.dirname(__file__), 'weights', '2dfan4.pth'),
config=SimpleNamespace(crop_ratio=0.55, input_size=256, num_modules=4,
hg_num_features=256, hg_depth=4, use_avg_pool=True,
use_instance_norm=False, stem_conv_kernel_size=7,
stem_conv_stride=2, stem_pool_kernel_size=2))
stem_conv_stride=2, stem_pool_kernel_size=2,
num_landmarks=68))
elif name == '2dfan2_alt':
return SimpleNamespace(weights=os.path.join(os.path.dirname(__file__), 'weights', '2dfan2_alt.pth'),
config=SimpleNamespace(crop_ratio=0.55, input_size=256, num_modules=2,
hg_num_features=256, hg_depth=4, use_avg_pool=False,
use_instance_norm=False, stem_conv_kernel_size=7,
stem_conv_stride=2, stem_pool_kernel_size=2))
stem_conv_stride=2, stem_pool_kernel_size=2,
num_landmarks=68))
else:
raise ValueError('name must be set to either 2dfan2, 2dfan4, or 2dfan2_alt')

Expand Down
83 changes: 42 additions & 41 deletions ibug/face_alignment/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,48 +3,49 @@
from typing import Optional, Sequence, Tuple


__all__ = ['plot_landmarks']
__all__ = ['get_landmark_connectivity', 'plot_landmarks']


def get_landmark_connectivity(num_landmarks):
if num_landmarks == 68:
return ((0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12),
(12, 13), (13, 14), (14, 15), (15, 16), (17, 18), (18, 19), (19, 20), (20, 21), (22, 23), (23, 24),
(24, 25), (25, 26), (27, 28), (28, 29), (29, 30), (31, 32), (32, 33), (33, 34), (34, 35), (36, 37),
(37, 38), (38, 39), (40, 41), (41, 36), (42, 43), (43, 44), (44, 45), (45, 46), (46, 47), (47, 42),
(48, 49), (49, 50), (50, 51), (51, 52), (52, 53), (53, 54), (54, 55), (55, 56), (56, 57), (57, 58),
(58, 59), (59, 48), (60, 61), (61, 62), (62, 63), (63, 64), (64, 65), (65, 66), (66, 67), (67, 60),
(39, 40))
elif num_landmarks == 100:
return ((0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12),
(12, 13), (13, 14), (14, 15), (15, 16), (17, 18), (18, 19), (19, 20), (20, 21), (22, 23), (23, 24),
(24, 25), (25, 26), (68, 69), (69, 70), (70, 71), (72, 73), (73, 74), (74, 75), (36, 76), (76, 37),
(37, 77), (77, 38), (38, 78), (78, 39), (39, 40), (40, 79), (79, 41), (41, 36), (42, 80), (80, 43),
(43, 81), (81, 44), (44, 82), (82, 45), (45, 46), (46, 83), (83, 47), (47, 42), (27, 28), (28, 29),
(29, 30), (30, 33), (31, 32), (32, 33), (33, 34), (34, 35), (84, 85), (86, 87), (48, 49), (49, 88),
(88, 50), (50, 51), (51, 52), (52, 89), (89, 53), (53, 54), (54, 55), (55, 90), (90, 56), (56, 57),
(57, 58), (58, 91), (91, 59), (59, 48), (60, 92), (92, 93), (93, 61), (61, 62), (62, 63), (63, 94),
(94, 95), (95, 64), (64, 96), (96, 97), (97, 65), (65, 66), (66, 67), (67, 98), (98, 99), (99, 60),
(17, 68), (21, 71), (22, 72), (26, 75))
else:
return None


def plot_landmarks(image: np.ndarray, landmarks: np.ndarray, landmark_scores: Optional[Sequence[float]] = None,
threshold: float = 0.2, line_colour: Tuple[int, int, int] = (0, 255, 0),
pts_colour: Tuple[int, int, int] = (0, 0, 255),
line_thickness: int = 1, pts_radius: int = 1) -> None:
if landmarks.shape[0] > 0:
if landmark_scores is None:
landmark_scores = np.full(shape=(landmarks.shape[0],), fill_value=threshold + 1)
if landmarks.shape[0] == 68:
for idx in range(len(landmarks) - 1):
if idx not in [16, 21, 26, 30, 35, 41, 47, 59]:
if landmark_scores[idx] >= threshold and landmark_scores[idx + 1] >= threshold:
cv2.line(image, tuple(landmarks[idx].astype(int).tolist()),
tuple(landmarks[idx + 1].astype(int).tolist()),
color=line_colour, thickness=line_thickness, lineType=cv2.LINE_AA)
if idx == 30:
if landmark_scores[30] >= threshold and landmark_scores[33] >= threshold:
cv2.line(image, tuple(landmarks[30].astype(int).tolist()),
tuple(landmarks[33].astype(int).tolist()),
color=line_colour, thickness=line_thickness, lineType=cv2.LINE_AA)
elif idx == 36:
if landmark_scores[36] >= threshold and landmark_scores[41] >= threshold:
cv2.line(image, tuple(landmarks[36].astype(int).tolist()),
tuple(landmarks[41].astype(int).tolist()),
color=line_colour, thickness=line_thickness, lineType=cv2.LINE_AA)
elif idx == 42:
if landmark_scores[42] >= threshold and landmark_scores[47] >= threshold:
cv2.line(image, tuple(landmarks[42].astype(int).tolist()),
tuple(landmarks[47].astype(int).tolist()),
color=line_colour, thickness=line_thickness, lineType=cv2.LINE_AA)
elif idx == 48:
if landmark_scores[48] >= threshold and landmark_scores[59] >= threshold:
cv2.line(image, tuple(landmarks[48].astype(int).tolist()),
tuple(landmarks[59].astype(int).tolist()),
color=line_colour, thickness=line_thickness, lineType=cv2.LINE_AA)
elif idx == 60:
if landmark_scores[60] >= threshold and landmark_scores[67] >= threshold:
cv2.line(image, tuple(landmarks[60].astype(int).tolist()),
tuple(landmarks[67].astype(int).tolist()),
color=line_colour, thickness=line_thickness, lineType=cv2.LINE_AA)
for landmark, score in zip(landmarks, landmark_scores):
if score >= threshold:
cv2.circle(image, tuple(landmark.astype(int).tolist()), pts_radius, pts_colour, -1)
pts_colour: Tuple[int, int, int] = (0, 0, 255), line_thickness: int = 1, pts_radius: int = 1,
landmark_connectivity: Optional[Sequence[Sequence[int]]] = None) -> None:
num_landmarks = len(landmarks)
if landmark_scores is None:
landmark_scores = np.full((num_landmarks,), threshold + 1.0, dtype=float)
if landmark_connectivity is None:
landmark_connectivity = get_landmark_connectivity(len(landmarks))
if landmark_connectivity is not None:
for (idx1, idx2) in landmark_connectivity:
if (idx1 < num_landmarks and idx2 < num_landmarks and
landmark_scores[idx1] >= threshold and landmark_scores[idx2] >= threshold):
cv2.line(image, tuple(landmarks[idx1].astype(int).tolist()),
tuple(landmarks[idx2].astype(int).tolist()),
color=line_colour, thickness=line_thickness, lineType=cv2.LINE_AA)
for landmark, score in zip(landmarks, landmark_scores):
if score >= threshold:
cv2.circle(image, tuple(landmark.astype(int).tolist()), pts_radius, pts_colour, -1)

0 comments on commit acbb783

Please sign in to comment.