Skip to content

Commit

Permalink
feat: Add progress bar for generating inference.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 663828395
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Aug 16, 2024
1 parent 3974aec commit b78714f
Showing 1 changed file with 22 additions and 15 deletions.
37 changes: 22 additions & 15 deletions vertexai/preview/evaluation/_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,28 +324,34 @@ def _generate_response_from_gemini_model(
constants.Dataset.COMPLETED_PROMPT_COLUMN
in evaluation_run_config.dataset.columns
):
with futures.ThreadPoolExecutor(max_workers=constants.MAX_WORKERS) as executor:
for _, row in df.iterrows():
tasks.append(
executor.submit(
with tqdm(total=len(df)) as pbar:
with futures.ThreadPoolExecutor(
max_workers=constants.MAX_WORKERS
) as executor:
for _, row in df.iterrows():
task = executor.submit(
_generate_response_from_gemini,
prompt=row[constants.Dataset.COMPLETED_PROMPT_COLUMN],
model=model,
)
)
task.add_done_callback(lambda _: pbar.update(1))
tasks.append(task)
else:
content_column_name = evaluation_run_config.column_map[
constants.Dataset.CONTENT_COLUMN
]
with futures.ThreadPoolExecutor(max_workers=constants.MAX_WORKERS) as executor:
for _, row in df.iterrows():
tasks.append(
executor.submit(
with tqdm(total=len(df)) as pbar:
with futures.ThreadPoolExecutor(
max_workers=constants.MAX_WORKERS
) as executor:
for _, row in df.iterrows():
task = executor.submit(
_generate_response_from_gemini,
prompt=row[content_column_name],
model=model,
)
)
task.add_done_callback(lambda _: pbar.update(1))
tasks.append(task)
responses = [future.result() for future in tasks]
if is_baseline_model:
evaluation_run_config.dataset = df.assign(baseline_model_response=responses)
Expand Down Expand Up @@ -384,13 +390,14 @@ def _generate_response_from_custom_model_fn(
constants.Dataset.COMPLETED_PROMPT_COLUMN
in evaluation_run_config.dataset.columns
):
with futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
for _, row in df.iterrows():
tasks.append(
executor.submit(
with tqdm(total=len(df)) as pbar:
with futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
for _, row in df.iterrows():
task = executor.submit(
model_fn, row[constants.Dataset.COMPLETED_PROMPT_COLUMN]
)
)
task.add_done_callback(lambda _: pbar.update(1))
tasks.append(task)
else:
content_column_name = evaluation_run_config.column_map[
constants.Dataset.CONTENT_COLUMN
Expand Down

0 comments on commit b78714f

Please sign in to comment.