Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix metadata processing in chat adapter #2040

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 173 additions & 34 deletions dspy/adapters/chat_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
import textwrap
from collections.abc import Mapping
from itertools import chain
from typing import Any, Dict, List, Literal, NamedTuple, Union, get_args, get_origin
from typing import Any, Dict, List, Literal, NamedTuple, Union, Type, get_args, get_origin

import pydantic
from pydantic import TypeAdapter
from pydantic.fields import FieldInfo

from dsp.adapters.base_template import Field
from dspy.adapters.base import Adapter
from dspy.adapters.utils import find_enum_member, format_field_value
from dspy.signatures.field import OutputField
Expand All @@ -29,9 +30,26 @@ class FieldInfoWithName(NamedTuple):
# Built-in field indicating that a chat turn has been completed.
BuiltInCompletedOutputFieldInfo = FieldInfoWithName(name="completed", info=OutputField())

# Constraints that can be applied to numeric fields.
PERMITTED_CONSTRAINTS = {"gt", "lt", "ge", "le", "multiple_of", "allow_inf_nan"}


class ChatAdapter(Adapter):

def format(self, signature: Signature, demos: list[dict[str, Any]], inputs: dict[str, Any]) -> list[dict[str, Any]]:
"""
Creates a formatted list of messages to pass to the LLM as a prompt.

Args:
signature (Signature): The signature of the task.
demos (List[Dict[str, Any]]): A list of dictionaries, each containing a demonstration for how to perform the
task (i.e., mapping from input fields to output fields).
inputs: A dictionary containing the input fields for the task.

Returns:
A list of messages to pass to the LLM as a prompt. Each message is a dictionary with two keys: "role" (i.e.,
whether the message is from the user or the assistant) and "content" (i.e., the message text).
"""
messages: list[dict[str, Any]] = []

# Extract demos where some of the output_fields are not filled in.
Expand All @@ -58,7 +76,7 @@ def format(self, signature: Signature, demos: list[dict[str, Any]], inputs: dict
messages.append(format_turn(signature, inputs, role="user"))
return messages

def parse(self, signature, completion):
def parse(self, signature: Signature, completion: str, _parse_values: bool = True):
sections = [(None, [])]

for line in completion.splitlines():
Expand All @@ -74,10 +92,10 @@ def parse(self, signature, completion):
for k, v in sections:
if (k not in fields) and (k in signature.output_fields):
try:
fields[k] = parse_value(v, signature.output_fields[k].annotation)
fields[k] = parse_value(v, signature.output_fields[k].annotation) if _parse_values else v
except Exception as e:
raise ValueError(
f"Error parsing field {k}: {e}.\n\n\t\tOn attempting to parse the value\n```\n{v}\n```"
f"Error parsing field {k}: {e}.\n\n\t\tOn attempting to parse the value\n\n{v}\n"
)

if fields.keys() != signature.output_fields.keys():
Expand All @@ -86,7 +104,29 @@ def parse(self, signature, completion):
return fields

# TODO(PR): Looks ok?
def format_finetune_data(self, signature, demos, inputs, outputs):
def format_finetune_data(
self,
signature: Signature,
demos: List[Dict[str, Any]],
inputs: Dict[str, Any],
outputs: Dict[str, Any]
) -> Dict[str, List[Dict[str, Any]]]:
"""
Formats the data for fine-tuning an LLM on a task.

Args:
signature (Signature): The signature of the task.
demos (List[Dict[str, Any]]): A list of dictionaries, each containing a demonstration for how to perform the
task (i.e., mapping from input fields to output fields).
inputs: A dictionary containing the input fields for the task.
outputs: A dictionary containing the output fields for the task.

Returns:
A dictionary containing the formatted data for fine-tuning an LLM on the task. The dictionary has a single
key, "messages", which maps to a list of messages to pass to the LLM as a prompt. Each message is a
dictionary with two keys: "role" (i.e., whether the message is from the user or the assistant) and
"content" (i.e., the message text).
"""
# Get system + user messages
messages = self.format(signature, demos, inputs)

Expand All @@ -99,30 +139,54 @@ def format_finetune_data(self, signature, demos, inputs, outputs):
# Wrap the messages in a dictionary with a "messages" key
return dict(messages=messages)

def format_turn(self, signature, values, role, incomplete=False):
def format_turn(
self,
signature: Signature,
values: Dict[str, Any],
role: str,
incomplete: bool = False,
) -> Dict[str, Any]:
"""
Formats a single turn in a chat thread.

Args:
signature (Signature): The signature of the task.
values (Dict[str, Any]): A dictionary mapping field names to corresponding values.
role (str): The role of the message, which can be either "user" or "assistant".
incomplete (bool): If True, indicates that output field values are present in the set of specified values.
If False, indicates that values only contains input field values.

Returns:
A dictionary representing a single turn in a chat thread. The dictionary has two keys: "role" (i.e., whether
the message is from the user or the assistant) and "content" (i.e., the message text).
"""
return format_turn(signature, values, role, incomplete)

def format_fields(self, signature, values, role):
def format_fields(self, signature: Signature, values: Dict[str, Any], role: str) -> str:
fields_with_values = {
FieldInfoWithName(name=field_name, info=field_info): values.get(
field_name, "Not supplied for this particular example."
)
for field_name, field_info in signature.fields.items()
if field_name in values
}

return format_fields(fields_with_values)


def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text=True) -> Union[str, List[dict]]:
def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text: bool = True) -> Union[str, List[dict]]:
"""
Formats the values of the specified fields according to the field's DSPy type (input or output),
annotation (e.g. str, int, etc.), and the type of the value itself. Joins the formatted values
into a single string, which is is a multiline string if there are multiple fields.
Creates a formatted representation of the fields and their values.

Formats the values of the specified fields according to the field's DSPy type (input or output), annotation (e.g. str,
int, etc.), and the type of the value itself. Joins the formatted values into a single string, which is is a multiline
string if there are multiple fields.

Args:
fields_with_values: A dictionary mapping information about a field to its corresponding
value.
fields_with_values (Dict[FieldInforWithName, Any]): A dictionary mapping information about a field to its
corresponding value.
assume_text (bool): If True, assumes that the values are text and formats them as such. If False, formats the
values as a list of dictionaries.

Returns:
The joined formatted values of the fields, represented as a string or a list of dicts
"""
Expand All @@ -143,7 +207,17 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text=
return output


def parse_value(value, annotation):
def parse_value(value: Any, annotation: Type) -> Any:
"""
Parses a value according to the specified annotation.

Args:
value: The value to parse.
annotation: The type to which the value should be parsed.

Returns:
The parsed value.
"""
if annotation is str:
return str(value)

Expand All @@ -163,22 +237,24 @@ def parse_value(value, annotation):
return TypeAdapter(annotation).validate_python(parsed_value)


def format_turn(signature, values, role, incomplete=False):
def format_turn(signature: Signature, values: Dict[str, Any], role: str, incomplete: bool = False) -> Dict[str, Any]:
"""
Constructs a new message ("turn") to append to a chat thread. The message is carefully formatted
so that it can instruct an LLM to generate responses conforming to the specified DSPy signature.
Constructs a new message ("turn") to append to a chat thread.

The message is carefully formatted so that it can instruct an LLM to generate responses conforming to the specified
DSPy signature.

Args:
signature: The DSPy signature to which future LLM responses should conform.
values: A dictionary mapping field names (from the DSPy signature) to corresponding values
that should be included in the message.
role: The role of the message, which can be either "user" or "assistant".
incomplete: If True, indicates that output field values are present in the set of specified
``values``. If False, indicates that ``values`` only contains input field values.
`values`. If False, indicates that `values` only contains input field values.

Returns:
A chat message that can be appended to a chat thread. The message contains two string fields:
``role`` ("user" or "assistant") and ``content`` (the message text).
`role` ("user" or "assistant") and `content` (the message text).
"""
fields_to_collapse = []
content = []
Expand Down Expand Up @@ -229,8 +305,9 @@ def type_info(v):
{
"type": "text",
"text": "Respond with the corresponding output fields, starting with the field "
+ ", then ".join(f"`[[ ## {f} ## ]]`{type_info(v)}" for f, v in signature.output_fields.items())
+ ", and then ending with the marker for `[[ ## completed ## ]]`.",
+ ", then ".join(
f"[[ ## {f} ## ]]{type_info(v)}" for f, v in signature.output_fields.items())
+ ", and then ending with the marker for [[ ## completed ## ]].",
}
)

Expand Down Expand Up @@ -267,7 +344,8 @@ def type_info(v):
return {"role": role, "content": collapsed_messages}


def get_annotation_name(annotation):
def get_annotation_name(annotation: Type) -> str:
"""Returns the name of the annotation as a string."""
origin = get_origin(annotation)
args = get_args(annotation)
if origin is None:
Expand All @@ -280,38 +358,97 @@ def get_annotation_name(annotation):
return f"{get_annotation_name(origin)}[{args_str}]"


def enumerate_fields(fields: dict) -> str:
def _format_constraint(name: str, value: Union[str, float]) -> str:
"""
Formats a constraint for a numeric field.

Args:
name: The name of the constraint.
value: The value of the constraint.

Returns:
The formatted constraint as a string.
"""
constraints = {
'gt': f"greater than {value}",
'lt': f"less than {value}",
'ge': f"greater than or equal to {value}",
'le': f"less than or equal to {value}",
'multiple_of': f"a multiple of {value}",
'allow_inf_nan': "allows infinite and NaN values" if value else "no infinite or NaN values allowed"
}
return constraints.get(name, f"{name}={value}")


def format_metadata_summary(field: pydantic.fields.FieldInfo) -> str:
"""
Formats a summary of the metadata for a field."""
if not field.metadata:
return ""
metadata_parts = [str(meta) for meta in field.metadata]
if metadata_parts:
return f" [Metadata: {'; '.join(metadata_parts)}]"
return ""


def format_metadata_constraints(field: Field) -> str:
"""Formats the constraints for a field."""
if not hasattr(field, 'metadata') or not field.metadata:
return ""
formatted_constraints = []
for meta in field.metadata:
constraint_names = [name for name in dir(meta) if not name.startswith('_')]
for name in constraint_names:
if hasattr(meta, name) and name in PERMITTED_CONSTRAINTS:
value = getattr(meta, name)
formatted_constraints.append(_format_constraint(name, value))
if not formatted_constraints:
return ""
elif len(formatted_constraints) == 1:
return f" that is {formatted_constraints[0]}."
else:
*front, last = formatted_constraints
return f" that is {', '.join(front)} and {last}."


def enumerate_fields(fields: Dict[str, FieldInfo]) -> str:
"""Enumerates the fields in a signature."""
parts = []
for idx, (k, v) in enumerate(fields.items()):
parts.append(f"{idx+1}. `{k}`")
parts[-1] += f" ({get_annotation_name(v.annotation)})"
parts[-1] += f": {v.json_schema_extra['desc']}" if v.json_schema_extra["desc"] != f"${{{k}}}" else ""

metadata_info = format_metadata_summary(v)
if metadata_info:
parts[-1] += metadata_info
return "\n".join(parts).strip()


def move_type_to_front(d):
# Move the 'type' key to the front of the dictionary, recursively, for LLM readability/adherence.
def move_type_to_front(d: Union[Dict, List, Any]) -> Union[Dict, List, Any]:
"""Moves the 'type' key to the front of the dictionary, recursively, for LLM readability/adherence."""
if isinstance(d, Mapping):
return {k: move_type_to_front(v) for k, v in sorted(d.items(), key=lambda item: (item[0] != "type", item[0]))}
elif isinstance(d, list):
return [move_type_to_front(item) for item in d]
return d


def prepare_schema(type_):
schema = pydantic.TypeAdapter(type_).json_schema()
def prepare_schema(type_: Type) -> Dict[str, Any]:
"""Prepares a JSON schema for a given type."""
schema: Dict[str, Any] = pydantic.TypeAdapter(type_).json_schema()
schema = move_type_to_front(schema)
return schema


def prepare_instructions(signature: SignatureMeta):
def prepare_instructions(signature: SignatureMeta) -> str:
"""Prepares the instructions for a signature."""
parts = []
parts.append("Your input fields are:\n" + enumerate_fields(signature.input_fields))
parts.append("Your output fields are:\n" + enumerate_fields(signature.output_fields))
parts.append("All interactions will be structured in the following way, with the appropriate values filled in.")

def field_metadata(field_name, field_info):
def field_metadata(field_name: str, field_info: FieldInfo) -> str:
"""Creates a formatted representation of a field's information and metadata."""
field_type = field_info.annotation

if get_dspy_field_type(field_info) == "input" or field_type is str:
Expand All @@ -320,6 +457,8 @@ def field_metadata(field_name, field_info):
desc = "must be True or False"
elif field_type in (int, float):
desc = f"must be a single {field_type.__name__} value"
metadata_info = format_metadata_constraints(field_info)
if metadata_info: desc += metadata_info
elif inspect.isclass(field_type) and issubclass(field_type, enum.Enum):
desc = f"must be one of: {'; '.join(field_type.__members__)}"
elif hasattr(field_type, "__origin__") and field_type.__origin__ is Literal:
Expand All @@ -331,7 +470,8 @@ def field_metadata(field_name, field_info):
desc = (" " * 8) + f"# note: the value you produce {desc}" if desc else ""
return f"{{{field_name}}}{desc}"

def format_signature_fields_for_instructions(fields: Dict[str, FieldInfo]):
def format_signature_fields_for_instructions(fields: Dict[str, FieldInfo]) -> str:
"""Formats the fields from the signature for the instructions."""
return format_fields(
fields_with_values={
FieldInfoWithName(name=field_name, info=field_info): field_metadata(field_name, field_info)
Expand All @@ -346,5 +486,4 @@ def format_signature_fields_for_instructions(fields: Dict[str, FieldInfo]):
instructions = textwrap.dedent(signature.instructions)
objective = ("\n" + " " * 8).join([""] + instructions.splitlines())
parts.append(f"In adhering to this structure, your objective is: {objective}")

return "\n\n".join(parts).strip()
return "\n\n".join(parts).strip()
Loading