Skip to content

Commit

Permalink
Merge releases/2024/5 into master (openvinotoolkit#1159)
Browse files Browse the repository at this point in the history
Co-authored-by: Ilya Lavrenov <ilya.lavrenov@intel.com>
  • Loading branch information
Wovchena and ilya-lavrenov authored Nov 6, 2024
1 parent a99dc93 commit 85ac104
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 17 deletions.
1 change: 0 additions & 1 deletion .github/workflows/causal_lm_cpp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,6 @@ jobs:
<<< $'Describe the images?' | tee py.txt
env:
PYTHONPATH: "./build/"
- run: diff cpp.txt py.txt
- name: Run visual_language_chat C++ sample with 2 prompts - tiny-random-minicpmv-2_6
run: >
source ./ov/setupvars.sh
Expand Down
5 changes: 5 additions & 0 deletions src/cpp/src/logit_processor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,11 @@ class RepetitionPenaltyTransform : public IPenaltyTransformer {
}
for (const auto& input_id_pair : *m_unique_generated_token_ids) {
const auto& input_id = input_id_pair.first;
if (1 == m_unique_prompt_token_ids->count(input_id)) {
// repetition_penalty was already accounted by the for
// loop above.
continue;
}
OPENVINO_ASSERT((input_id >= 0) && (input_id < vocab_size), "input_ids token out of bounds");
if (logits.m_data[input_id] >= 0) {
logits.m_data[input_id] /= m_penalty;
Expand Down
8 changes: 7 additions & 1 deletion src/python/py_llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,13 @@ py::object call_common_generate(
const pyutils::PyBindStreamerVariant& py_streamer,
const py::kwargs& kwargs
) {
auto updated_config = pyutils::update_config_from_kwargs(config, kwargs);
ov::genai::GenerationConfig default_config;
if (config.has_value()) {
default_config = *config;
} else {
default_config = pipe.get_generation_config();
}
auto updated_config = pyutils::update_config_from_kwargs(default_config, kwargs);
py::object results;
EncodedInputs tensor_data;
StreamerVariant streamer = pyutils::pystreamer_to_streamer(py_streamer);
Expand Down
13 changes: 4 additions & 9 deletions tests/python_tests/test_chat_generate_api.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import math
import openvino
import openvino_tokenizers
import openvino_genai as ov_genai
import pytest
from typing import Dict, Tuple
Expand All @@ -19,8 +16,8 @@


configs = [
dict(max_new_tokens=20),
dict(num_beam_groups=3, num_beams=15, num_return_sequences=1, max_new_tokens=10, diversity_penalty=1.0)
dict(do_sample=False, max_new_tokens=20),
dict(do_sample=False, num_beam_groups=3, num_beams=15, num_return_sequences=1, max_new_tokens=10, diversity_penalty=1.0)
]


Expand All @@ -37,7 +34,6 @@
@pytest.mark.precommit
@pytest.mark.nightly
def test_chat_compare_with_HF(model_descr, generation_config: Dict):
device = 'CPU'
chat_history_hf = []
chat_history_ov = []
chat_prompt = ''
Expand All @@ -53,7 +49,7 @@ def test_chat_compare_with_HF(model_descr, generation_config: Dict):
chat_prompt = tokenizer.apply_chat_template(chat_history_hf, tokenize=False, add_generation_prompt=True)
tokenized = tokenizer(chat_prompt, return_tensors='pt', add_special_tokens=False)

answer = model_opt.generate(**tokenized, **generation_config, do_sample=False, repetition_penalty = None)
answer = model_opt.generate(**tokenized, **generation_config)
answer_str = tokenizer.decode(answer[0, tokenized['input_ids'].numel():], skip_special_tokens=True)
chat_history_hf.append({'role': 'assistant', 'content': answer_str})

Expand All @@ -74,7 +70,6 @@ def test_chat_compare_with_HF(model_descr, generation_config: Dict):
@pytest.mark.nightly
def test_chat_compare_text_history_with_HF(model_descr, generation_config: Dict):
# compares with HF when history in ov_genai is save as a text
device = 'CPU'
chat_history_hf = []
chat_history_ov = []
chat_prompt = ''
Expand All @@ -90,7 +85,7 @@ def test_chat_compare_text_history_with_HF(model_descr, generation_config: Dict)
chat_prompt = tokenizer.apply_chat_template(chat_history_hf, tokenize=False, add_generation_prompt=True)
tokenized = tokenizer(chat_prompt, return_tensors='pt', add_special_tokens=False)

answer = model_opt.generate(**tokenized, **generation_config, do_sample=False, repetition_penalty = None)
answer = model_opt.generate(**tokenized, **generation_config)
answer_str = tokenizer.decode(answer[0, tokenized['input_ids'].numel():], skip_special_tokens=True)
chat_history_hf.append({'role': 'assistant', 'content': answer_str})

Expand Down
7 changes: 1 addition & 6 deletions tools/who_what_benchmark/whowhatbench/wwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import logging
import os

import openvino_genai
import pandas as pd
from datasets import load_dataset
from diffusers import DiffusionPipeline
Expand Down Expand Up @@ -385,11 +384,7 @@ def diff_strings(a: str, b: str, *, use_loguru_colors: bool = False) -> str:


def genai_gen_answer(model, tokenizer, question, max_new_tokens, skip_question):
config = openvino_genai.GenerationConfig()
config.max_new_tokens = max_new_tokens
config.do_sample = False
out = model.generate(question, config)
return out
return model.generate(question, do_sample=False, max_new_tokens=max_new_tokens)


def genai_gen_image(model, prompt, num_inference_steps, generator=None):
Expand Down

0 comments on commit 85ac104

Please sign in to comment.