Skip to content

Commit

Permalink
update colorization
Browse files Browse the repository at this point in the history
  • Loading branch information
柏灌 committed May 19, 2022
1 parent 22d2419 commit 242584d
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
7 changes: 4 additions & 3 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def generate_mask(H, W, img=None):
parser.add_argument('--tile_size', type=int, default=0, help='tile size for SR to avoid OOM')
parser.add_argument('--indir', type=str, default='examples/imgs', help='input folder')
parser.add_argument('--outdir', type=str, default='results/outs-BFR', help='output folder')
parser.add_argument('--ext', type=str, default='.jpg', help='extension of output')
args = parser.parse_args()

#model = {'name':'GPEN-BFR-512', 'size':512, 'channel_multiplier':2, 'narrow':1}
Expand Down Expand Up @@ -115,12 +116,12 @@ def generate_mask(H, W, img=None):
img_out, orig_faces, enhanced_faces = processer.process(img, aligned=args.aligned)

img = cv2.resize(img, img_out.shape[:2][::-1])
cv2.imwrite(os.path.join(args.outdir, '.'.join(filename.split('.')[:-1])+'_COMP.jpg'), np.hstack((img, img_out)))
cv2.imwrite(os.path.join(args.outdir, '.'.join(filename.split('.')[:-1])+'_GPEN.jpg'), img_out)
cv2.imwrite(os.path.join(args.outdir, '.'.join(filename.split('.')[:-1])+f'_COMP{args.ext}'), np.hstack((img, img_out)))
cv2.imwrite(os.path.join(args.outdir, '.'.join(filename.split('.')[:-1])+f'_GPEN{args.ext}'), img_out)

if args.save_face:
for m, (ef, of) in enumerate(zip(enhanced_faces, orig_faces)):
of = cv2.resize(of, ef.shape[:2])
cv2.imwrite(os.path.join(args.outdir, '.'.join(filename.split('.')[:-1])+'_face%02d'%m+'.jpg'), np.hstack((of, ef)))
cv2.imwrite(os.path.join(args.outdir, '.'.join(filename.split('.')[:-1])+'_face%02d'%m+args.ext), np.hstack((of, ef)))

if n%10==0: print(n, filename)
14 changes: 14 additions & 0 deletions face_colorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,31 @@
@paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021)
@author: yangxy (yangtao9009@gmail.com)
'''
import cv2
from face_model.face_gan import FaceGAN

class FaceColorization(object):
def __init__(self, base_dir='./', in_size=1024, out_size=1024, model=None, channel_multiplier=2, narrow=1, key=None, device='cuda'):
self.facegan = FaceGAN(base_dir, in_size, out_size, model, channel_multiplier, narrow, key, device=device)

def post_process(self, gray, out):
out_rs = cv2.resize(out, gray.shape[:2][::-1])
gray_yuv = cv2.cvtColor(gray, cv2.COLOR_BGR2YUV)
out_yuv = cv2.cvtColor(out_rs, cv2.COLOR_BGR2YUV)

out_yuv[:, :, 0] = gray_yuv[:, :, 0]
final = cv2.cvtColor(out_yuv, cv2.COLOR_YUV2BGR)

return final

# make sure the face image is well aligned. Please refer to face_enhancement.py
def process(self, gray, aligned=True):
# colorize the face
out = self.facegan.process(gray)

if gray.shape[:2] != out.shape[:2]:
out = self.post_process(gray, out)

return out, [gray], [out]


0 comments on commit 242584d

Please sign in to comment.