Skip to content

Commit

Permalink
fix: fix the error message for prompt optimizer UI notebook (GoogleCl…
Browse files Browse the repository at this point in the history
…oudPlatform#1448)

Co-authored-by: hootan <hootan@google.com>
  • Loading branch information
nhootan and hootan-na authored Nov 21, 2024
1 parent 17e8cb1 commit 5a13e8a
Showing 1 changed file with 32 additions and 21 deletions.
53 changes: 32 additions & 21 deletions gemini/prompts/prompt_optimizer/vapo_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=too-many-lines

"""Utility functions and classes for the VAPO notebook."""
import csv
import io
Expand All @@ -20,7 +22,7 @@
import re
import string
import subprocess
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Union

from IPython.core.display import DisplayHandle
from IPython.display import HTML, display
Expand Down Expand Up @@ -357,7 +359,7 @@ def __init__(self, params: dict[str, str]) -> None:
self.instruction_df = None
self.demo_df = None

# pylint: disable=too-many-arguments
# pylint: disable=too-many-positional-arguments,too-many-arguments
def update_progress(
self,
progress_bar: widgets.IntProgress | None,
Expand Down Expand Up @@ -448,11 +450,12 @@ def monitor_progress(self, job: aiplatform.CustomJob) -> bool:
for err_file in [
f"{self.output_path}/instruction/error.json",
f"{self.output_path}/demonstration/error.json",
f"{self.output_path}/error.json",
]:
if gfile.exists(err_file):
with gfile.GFile(err_file, "r") as f:
error_json = json.load(f)
errors.append(f"Detailed error: {error_json}")
errors.append(f"Detailed error: {error_json['Error']}")
errors.append(
f"Please feel free to send {err_file} to the VAPO team to help"
" resolving the issue."
Expand All @@ -466,6 +469,7 @@ def monitor_progress(self, job: aiplatform.CustomJob) -> bool:
"Please consider rerunning to make sure the failure is intransient."
)
err = "\n".join(errors)
err = err.replace("\n", "<br>")
self.status_display.update(HTML(f'<span style="color: red;">{err}</span>'))
else:
self.status_display.update(
Expand Down Expand Up @@ -699,9 +703,9 @@ def get_auth_token() -> str:

def init_new_model(
model_name: str,
generation_config: GenerationConfig = None,
safety_settings: List[SafetySetting] = None,
**kwargs,
generation_config: GenerationConfig | None = None,
safety_settings: list[SafetySetting] | None = None,
**kwargs: Any,
) -> GenerativeModel:
"""Initialize a new model with configurable generation and safety settings."""

Expand Down Expand Up @@ -746,10 +750,10 @@ def init_new_model(
async def async_generate(
prompt: str,
model: GenerativeModel,
function_handler: Optional[Dict[str, Callable]] = None,
tools: Optional[Tool] = None,
tool_config: Optional[ToolConfig] = None,
**kwargs,
function_handler: dict[str, Callable] | None = None,
tools: Tool | None = None,
tool_config: ToolConfig | None = None,
**kwargs: Any,
) -> Union[str, None]:
"""Generates a response from the model, optionally handling function calls."""

Expand Down Expand Up @@ -805,20 +809,21 @@ async def async_generate(
) # More robust text extraction
return None

except Exception as e:
except Exception as e: # pylint: disable=broad-except
print(f"Error calling the model: {e}") # Include the actual error message
return "Could not call the model. Please try it again in a few minutes."


# pylint: disable=too-many-positional-arguments,too-many-arguments
def evaluate_task(
df: pd.DataFrame,
prompt_col: str,
reference_col: str,
response_col: str,
experiment_name: str,
eval_metrics: List[str],
eval_metrics: list[str],
eval_sample_n: int,
) -> Dict[str, float]:
) -> dict[str, float]:
"""Evaluate task using Vertex AI Evaluation."""

# Generate a unique id for the experiment run
Expand Down Expand Up @@ -856,7 +861,7 @@ def evaluate_task(


def print_df_rows(
df: pd.DataFrame, columns: Optional[List[str]] = None, n: int = 3
df: pd.DataFrame, columns: list[str] | None = None, n: int = 3
) -> None:
"""Print a subset of rows from a DataFrame."""

Expand All @@ -865,7 +870,10 @@ def print_df_rows(
df = df[columns]

# Style definitions for improved readability
base_style = "font-family: monospace; font-size: 14px; white-space: pre-wrap; width: auto; overflow-x: auto;"
base_style = (
"font-family: monospace; font-size: 14px; white-space: pre-wrap; width:"
" auto; overflow-x: auto;"
)
header_style = base_style + "font-weight: bold;"

# Iterate through the specified number of rows
Expand All @@ -874,7 +882,9 @@ def print_df_rows(
for column in df.columns:
display(
HTML(
f"<span style='{header_style}'>{column.replace('_', ' ').title()}: </span>"
"<span"
f" style='{header_style}'>{column.replace('_', ' ').title()}:"
" </span>"
)
)
display(
Expand All @@ -884,8 +894,8 @@ def print_df_rows(


def plot_eval_metrics(
eval_results: List[tuple[str, Dict[str, float]]],
metrics: Optional[List[str]] = None,
eval_results: list[tuple[str, dict[str, float]]],
metrics: list[str] | None = None,
) -> None:
"""Plot a bar plot for the evaluation results."""

Expand Down Expand Up @@ -926,7 +936,7 @@ def plot_eval_metrics(
fig.show()


def create_target_column(row: Dict[str, Any]) -> str:
def create_target_column(row: dict[str, Any]) -> str:
"""Creates a JSON string representing tool calls from input row."""

tool_calls = (
Expand All @@ -938,12 +948,13 @@ def create_target_column(row: Dict[str, Any]) -> str:
return json.dumps({"content": "", "tool_calls": tool_calls})


def tool_config_to_dict(tool_config: Optional[ToolConfig]) -> Optional[Dict[str, Any]]:
def tool_config_to_dict(tool_config: ToolConfig | None) -> dict[str, Any] | None:
"""Converts a ToolConfig object to a dictionary."""

if tool_config is None:
return None

# pylint: disable=protected-access
config = tool_config._gapic_tool_config.function_calling_config
return {
"function_calling_config": {
Expand All @@ -953,7 +964,7 @@ def tool_config_to_dict(tool_config: Optional[ToolConfig]) -> Optional[Dict[str,
}


def replace_type_key(data: Dict[str, Any]) -> Dict[str, Any]:
def replace_type_key(data: dict[str, Any]) -> dict[str, Any]:
"""Recursively replaces "type_" with "type" in a dictionary or list."""

return {"type" if k == "type_" else k: replace_type_key(v) for k, v in data.items()}
Expand Down

0 comments on commit 5a13e8a

Please sign in to comment.