Skip to content

Commit

Permalink
Merge pull request #60 from zeke/update-predictor-to-use-new-cog
Browse files Browse the repository at this point in the history
update Predictor for use with newer Cog versions
  • Loading branch information
YuanxunLu authored May 17, 2022
2 parents 1529d9a + a2f24cf commit 950c24a
Showing 1 changed file with 7 additions and 18 deletions.
25 changes: 7 additions & 18 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import yaml
import tempfile
import argparse
from pathlib import Path
from skimage.io import imread
import numpy as np
import librosa
Expand All @@ -14,7 +13,7 @@
from collections import OrderedDict
import cv2
from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip
import cog
from cog import BasePredictor, Input, Path
import scipy.io as sio
import albumentations as A
from options.test_audio2feature_options import TestOptions as FeatureOptions
Expand All @@ -31,35 +30,25 @@
warnings.filterwarnings("ignore")


class Predictor(cog.Predictor):
class Predictor(BasePredictor):
def setup(self):
self.parser = argparse.ArgumentParser()
self.parser.add_argument('--id', default='May', help="person name, e.g. Obama1, Obama2, May, Nadella, McStay")
self.parser.add_argument('--driving_audio', default='data/Input/00083.wav', help="path to driving audio")
self.parser.add_argument('--save_intermediates', default=0, help="whether to save intermediate results")

@cog.input(
"driving_audio",
type=Path,
help="driving audio, if the file is more than 20 seconds, only the first 20 seconds will be processed for "
"video generation",
)
@cog.input(
"talking_head",
type=str,
options=['May', 'Obama1', 'Obama2', 'Nadella', 'McStay'],
default='May',
help="choose a talking head"
)
def predict(self, driving_audio, talking_head='May'):
def predict(self,
driving_audio: Path = Input(description='driving audio, if the file is more than 20 seconds, only the first 20 seconds will be processed for video generation'),
talking_head: str = Input(description="choose a talking head", choices=['May', 'Obama1', 'Obama2', 'Nadella', 'McStay'], default='May')
) -> Path:

############################### I/O Settings ##############################
# load config files
opt = self.parser.parse_args('')
opt.driving_audio = str(driving_audio)
opt.id = talking_head
with open(join('config', opt.id + '.yaml')) as f:
config = yaml.load(f)
config = yaml.safe_load(f)
data_root = join('data', opt.id)

############################ Hyper Parameters #############################
Expand Down

0 comments on commit 950c24a

Please sign in to comment.