Skip to content

Commit

Permalink
Resolve switching chatglm benchmarking class (openvinotoolkit#600)
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova authored Jul 10, 2024
1 parent da00c67 commit f084b61
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions llm_bench/python/utils/ov_model_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,11 @@ def __init__(
**kwargs,
):
super().__init__(model, config, device, dynamic_shapes, ov_config, model_save_dir, **kwargs)
self.key_value_input_names = ['past_key_values']
self.key_value_output_names = [o.any_name for o in self.model.outputs[1:]]
self.is_v1 = False
if not self.stateful and not self.key_value_input_names:
self.is_v1 = True
self.key_value_input_names = ['past_key_values']
self.key_value_output_names = [o.any_name for o in self.model.outputs[1:]]

def prepare_inputs_for_generation(
self,
Expand All @@ -300,6 +303,13 @@ def prepare_inputs_for_generation(
past: Optional[torch.Tensor] = None,
**kwargs,
) -> dict:
if not self.is_v1:
return super().prepare_inputs_for_generation(
input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask,
position_ids=position_ids,
past=past,
**kwargs
)
batch_size, seq_length = input_ids.shape
mask = self.mask_token_id
g_mask = self.gmask_token_id
Expand Down Expand Up @@ -430,6 +440,9 @@ def forward(
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
**kwargs,
) -> CausalLMOutputWithPast:

if not self.is_v1:
return super().forward(input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, **kwargs)
self.compile()

inputs = {}
Expand Down

0 comments on commit f084b61

Please sign in to comment.