Skip to content

Commit

Permalink
feat: Add Prompt class support for configs and Prompt.generate_conten…
Browse files Browse the repository at this point in the history
…t wrapper

PiperOrigin-RevId: 665012897
  • Loading branch information
matthew29tang authored and copybara-github committed Aug 19, 2024
1 parent fd38b49 commit 7f1e031
Showing 1 changed file with 254 additions and 6 deletions.
260 changes: 254 additions & 6 deletions vertexai/generative_models/_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,34 @@
#

from google.cloud.aiplatform import base
from google.cloud.aiplatform import initializer as aiplatform_initializer
from vertexai.generative_models import (
Content,
Image,
Part,
GenerativeModel,
GenerationConfig,
SafetySetting,
Tool,
ToolConfig,
)
from vertexai.generative_models._generative_models import (
_to_content,
_validate_generate_content_parameters,
_reconcile_model_name,
_get_resource_name_from_model_name,
ContentsType,
GenerationConfigType,
GenerationResponse,
PartsType,
SafetySettingsType,
)

import re
from typing import (
Any,
Dict,
Iterable,
List,
Optional,
Union,
Expand All @@ -55,29 +69,72 @@ class Prompt:
prompt = Prompt(
prompt_data="Hello, {name}! Today is {day}. How are you?",
variables=[{"name": "Alice", "day": "Monday"}]
generation_config=GenerationConfig(
temperature=0.1,
top_p=0.95,
top_k=20,
candidate_count=1,
max_output_tokens=100,
stop_sequences=["\n\n\n"],
),
model_name="gemini-1.0-pro-002",
safety_settings=[SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
method=SafetySetting.HarmBlockMethod.SEVERITY,
)],
system_instruction="Please answer in a short sentence.",
)
# Generate content using the assembled prompt.
model.generate_content(contents=prompt.assemble_contents(**prompt.variables[0]))
prompt.generate_content(
contents=prompt.assemble_contents(**prompt.variables)
)
```
"""

def __init__(
self,
prompt_data: PartsType,
*,
variables: Optional[List[Dict[str, PartsType]]] = None,
generation_config: Optional[GenerationConfig] = None,
model_name: Optional[str] = None,
safety_settings: Optional[SafetySetting] = None,
system_instruction: Optional[PartsType] = None,
tools: Optional[List[Tool]] = None,
tool_config: Optional[ToolConfig] = None,
):
"""Initializes the Prompt with a given prompt, and variables.
Args:
prompt: A PartsType prompt which may be a template with variables or a prompt with no variables.
variables: A list of dictionaries containing the variable names and values.
generation_config: A GenerationConfig object containing parameters for generation.
model_name: Model Garden model resource name.
Alternatively, a tuned model endpoint resource name can be provided.
safety_settings: A SafetySetting object containing safety settings for generation.
system_instruction: A PartsType object representing the system instruction.
tools: A list of Tool objects for function calling.
tool_config: A ToolConfig object for function calling.
"""
self._prompt_data = None
self._variables = None
self._model_name = None
self._generation_config = None
self._safety_settings = None
self._system_instruction = None
self._tools = None
self._tool_config = None

self.prompt_data = prompt_data
self.variables = variables if variables else [{}]
self.model_name = model_name
self.generation_config = generation_config
self.safety_settings = safety_settings
self.system_instruction = system_instruction
self.tools = tools
self.tool_config = tool_config

@property
def prompt_data(self) -> PartsType:
Expand All @@ -87,14 +144,38 @@ def prompt_data(self) -> PartsType:
def variables(self) -> Optional[List[Dict[str, PartsType]]]:
return self._variables

@property
def generation_config(self) -> Optional[GenerationConfig]:
return self._generation_config

@property
def model_name(self) -> Optional[str]:
return self._model_name

@property
def safety_settings(self) -> Optional[List[SafetySetting]]:
return self._safety_settings

@property
def system_instruction(self) -> Optional[PartsType]:
return self._system_instruction

@property
def tools(self) -> Optional[List[Tool]]:
return self._tools

@property
def tool_config(self) -> Optional[ToolConfig]:
return self._tool_config

@prompt_data.setter
def prompt_data(self, prompt_data: PartsType) -> None:
"""Overwrites the existing saved local prompt_data.
Args:
prompt_data: A PartsType prompt.
"""
Prompt._validate_prompt_data(prompt_data)
self._validate_parts_type_data(prompt_data)
self._prompt_data = prompt_data

@variables.setter
Expand All @@ -114,6 +195,98 @@ def variables(self, variables: List[Dict[str, PartsType]]) -> None:
f"Variables must be a list of dictionaries, not {type(variables)}"
)

@model_name.setter
def model_name(self, model_name: Optional[str]) -> None:
"""Overwrites the existing saved local model_name."""
if model_name:
self._model_name = Prompt._format_model_resource_name(model_name)
else:
self._model_name = None

def _format_model_resource_name(model_name: Optional[str]) -> str:
"""Formats the model resource name."""
project = aiplatform_initializer.global_config.project
location = aiplatform_initializer.global_config.location
model_name = _reconcile_model_name(model_name, project, location)

prediction_resource_name = _get_resource_name_from_model_name(
model_name, project, location
)
return prediction_resource_name

def _validate_configs(
self,
generation_config: Optional[GenerationConfig] = None,
safety_settings: Optional[SafetySetting] = None,
system_instruction: Optional[PartsType] = None,
tools: Optional[List[Tool]] = None,
tool_config: Optional[ToolConfig] = None,
):
generation_config = generation_config or self._generation_config
safety_settings = safety_settings or self._safety_settings
tools = tools or self._tools
tool_config = tool_config or self._tool_config
system_instruction = system_instruction or self._system_instruction
return _validate_generate_content_parameters(
contents="test",
generation_config=generation_config,
safety_settings=safety_settings,
system_instruction=system_instruction,
tools=tools,
tool_config=tool_config,
)

@generation_config.setter
def generation_config(self, generation_config: Optional[GenerationConfig]) -> None:
"""Overwrites the existing saved local generation_config.
Args:
generation_config: A GenerationConfig object containing parameters for generation.
"""
self._validate_configs(generation_config=generation_config)
self._generation_config = generation_config

@safety_settings.setter
def safety_settings(self, safety_settings: Optional[SafetySetting]) -> None:
"""Overwrites the existing saved local safety_settings.
Args:
safety_settings: A SafetySetting object containing safety settings for generation.
"""
self._validate_configs(safety_settings=safety_settings)
self._safety_settings = safety_settings

@system_instruction.setter
def system_instruction(self, system_instruction: Optional[PartsType]) -> None:
"""Overwrites the existing saved local system_instruction.
Args:
system_instruction: A PartsType object representing the system instruction.
"""
if system_instruction:
self._validate_parts_type_data(system_instruction)
self._system_instruction = system_instruction

@tools.setter
def tools(self, tools: Optional[List[Tool]]) -> None:
"""Overwrites the existing saved local tools.
Args:
tools: A list of Tool objects for function calling.
"""
self._validate_configs(tools=tools)
self._tools = tools

@tool_config.setter
def tool_config(self, tool_config: Optional[ToolConfig] = None) -> None:
"""Overwrites the existing saved local tool_config.
Args:
tool_config: A ToolConfig object for function calling.
"""
self._validate_configs(tool_config=tool_config)
self._tool_config = tool_config

def _format_variable_value_to_parts(variables_dict: Dict[str, PartsType]) -> None:
"""Formats the variables values to be List[Part].
Expand All @@ -134,7 +307,7 @@ def _format_variable_value_to_parts(variables_dict: Dict[str, PartsType]) -> Non
content = Content._from_gapic(_to_content(value=variables_dict[key]))
variables_dict[key] = content.parts

def _validate_prompt_data(prompt_data: Any) -> None:
def _validate_parts_type_data(self, data: Any) -> None:
"""
Args:
prompt_data: The prompt input to validate
Expand All @@ -143,11 +316,11 @@ def _validate_prompt_data(prompt_data: Any) -> None:
TypeError: If prompt_data is not a PartsType Object.
"""
# Disallow Content as prompt_data.
if isinstance(prompt_data, Content):
if isinstance(data, Content):
raise TypeError("Prompt data must be a PartsType object, not Content")

# Rely on type checks in _to_content.
_to_content(value=prompt_data)
_to_content(value=data)

def assemble_contents(self, **variables_dict: PartsType) -> List[Content]:
"""Returns the prompt data, as a List[Content], assembled with variables if applicable.
Expand Down Expand Up @@ -176,7 +349,7 @@ def assemble_contents(self, **variables_dict: PartsType) -> List[Content]:
Prompt._format_variable_value_to_parts(variables_dict)

# Step 2) Assemble the prompt.
# prompt_data must have been previously validated using _validate_prompt_data.
# prompt_data must have been previously validated using _validate_parts_type_data.
assembled_prompt = []
assembled_variables_cnt = {}
if isinstance(self.prompt_data, list):
Expand Down Expand Up @@ -288,6 +461,81 @@ def _assemble_single_str(

return assembled_data

def generate_content(
self,
contents: ContentsType,
*,
generation_config: Optional[GenerationConfigType] = None,
safety_settings: Optional[SafetySettingsType] = None,
model_name: Optional[str] = None,
tools: Optional[List["Tool"]] = None,
tool_config: Optional["ToolConfig"] = None,
stream: bool = False,
system_instruction: Optional[PartsType] = None,
) -> Union["GenerationResponse", Iterable["GenerationResponse"],]:
"""Generates content using the saved Prompt configs.
Args:
contents: Contents to send to the model.
Supports either a list of Content objects (passing a multi-turn conversation)
or a value that can be converted to a single Content object (passing a single message).
Supports
* str, Image, Part,
* List[Union[str, Image, Part]],
* List[Content]
generation_config: Parameters for the generation.
model_name: Prediction model resource name.
safety_settings: Safety settings as a mapping from HarmCategory to HarmBlockThreshold.
tools: A list of tools (functions) that the model can try calling.
tool_config: Config shared for all tools provided in the request.
stream: Whether to stream the response.
system_instruction: System instruction to pass to the model.
Returns:
A single GenerationResponse object if stream == False
A stream of GenerationResponse objects if stream == True
Usage:
```
prompt = Prompt(
prompt_data="Hello, {name}! Today is {day}. How are you?",
variables={"name": "Alice", "day": "Monday"},
generation_config=GenerationConfig(temperature=0.1,),
system_instruction="Please answer in a short sentence.",
model_name="gemini-1.0-pro-002",
)
prompt.generate_content(
contents=prompt.assemble_contents(**prompt.variables)
)
```
"""

generation_config = generation_config or self.generation_config
safety_settings = safety_settings or self.safety_settings
model_name = model_name or self.model_name
tools = tools or self.tools
tool_config = tool_config or self.tool_config
system_instruction = system_instruction or self.system_instruction

if not model_name:
raise ValueError(
"Model name must be specified to use Prompt.generate_content()"
)
model_name = Prompt._format_model_resource_name(model_name)

model = GenerativeModel(
model_name=model_name, system_instruction=system_instruction
)
return model.generate_content(
contents=contents,
generation_config=generation_config,
safety_settings=safety_settings,
tools=tools,
tool_config=tool_config,
stream=stream,
)

def get_unassembled_prompt_data(self) -> PartsType:
"""Returns the prompt data, without any variables replaced."""
return self.prompt_data
Expand Down

0 comments on commit 7f1e031

Please sign in to comment.