Skip to content

Commit

Permalink
fix benchmark non standard model (huggingface#5801)
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten authored Jul 16, 2020
1 parent 8ce610b commit aefc0c0
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/transformers/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _prepare_inference_func(self, model_name: str, batch_size: int, sequence_len
if self.args.torchscript:
config.torchscript = True

has_model_class_in_config = hasattr(config, "architecture") and len(config.architectures) > 1
has_model_class_in_config = hasattr(config, "architectures") and len(config.architectures) > 0
if not self.args.only_pretrain_model and has_model_class_in_config:
try:
model_class = config.architectures[0]
Expand Down Expand Up @@ -138,7 +138,7 @@ def encoder_forward():
def _prepare_train_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
config = self.config_dict[model_name]

has_model_class_in_config = hasattr(config, "architecture") and len(config.architectures) > 1
has_model_class_in_config = hasattr(config, "architectures") and len(config.architectures) > 0
if not self.args.only_pretrain_model and has_model_class_in_config:
try:
model_class = config.architectures[0]
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/benchmark/benchmark_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def _prepare_inference_func(self, model_name: str, batch_size: int, sequence_len
if self.args.fp16:
raise NotImplementedError("Mixed precision is currently not supported.")

has_model_class_in_config = hasattr(config, "architecture") and len(config.architectures) > 1
has_model_class_in_config = hasattr(config, "architectures") and len(config.architectures) > 0
if not self.args.only_pretrain_model and has_model_class_in_config:
try:
model_class = "TF" + config.architectures[0] # prepend 'TF' for tensorflow model
Expand Down Expand Up @@ -172,7 +172,7 @@ def _prepare_train_func(self, model_name: str, batch_size: int, sequence_length:
if self.args.fp16:
raise NotImplementedError("Mixed precision is currently not supported.")

has_model_class_in_config = hasattr(config, "architecture") and len(config.architectures) > 1
has_model_class_in_config = hasattr(config, "architectures") and len(config.architectures) > 0
if not self.args.only_pretrain_model and has_model_class_in_config:
try:
model_class = "TF" + config.architectures[0] # prepend 'TF' for tensorflow model
Expand Down

0 comments on commit aefc0c0

Please sign in to comment.