forked from Sinaptik-AI/pandas-ai
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(Judge): implementation of judge agent to validate code matches t… (
Sinaptik-AI#1238) * feat(Judge): implementation of judge agent to validate code matches the user query * fix: ruff errors * feat(JudgeAgent): make judge agent using memory from chat agent * chore add datetime in prompt * add documentation * docs(judge): update judge documentation
- Loading branch information
1 parent
2c14f15
commit ab0d685
Showing
22 changed files
with
1,029 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
--- | ||
title: "Judge Agent" | ||
description: "Enhance the PandasAI library with the JudgeAgent that evaluates the generated code" | ||
--- | ||
|
||
## Introduction to the Judge Agent | ||
|
||
The `JudgeAgent` extends the capabilities of the PandasAI library by adding an extra judgement in agents pipeline that validates the code generated against the query | ||
|
||
> **Note:** Usage of the Judge Agent may be subject to a license. For more details, refer to the [license documentation](https://github.com/Sinaptik-AI/pandas-ai/blob/master/pandasai/ee/LICENSE). | ||
## Instantiating the Judge Agent | ||
|
||
JudgeAgent can be used both as a standalone agent and in conjunction with other agents. To use it with other agents, pass JudgeAgent as a parameter to them. | ||
|
||
### Using with other agents | ||
|
||
```python | ||
import os | ||
|
||
import pandas as pd | ||
|
||
from pandasai.agent.agent import Agent | ||
from pandasai.ee.agents.judge_agent import JudgeAgent | ||
|
||
os.environ["PANDASAI_API_KEY"] = "$2a****************************" | ||
|
||
github_stars = pd.read_csv("/Users/arslan/Downloads/stars (2).csv") | ||
|
||
judge = JudgeAgent() | ||
agent = Agent([github_stars], judge=judge) | ||
|
||
print(agent.chat("return total stars count")) | ||
``` | ||
|
||
### Using as a standalone | ||
|
||
```python | ||
import os | ||
|
||
import pandas as pd | ||
|
||
from pandasai.ee.agents.judge_agent import JudgeAgent | ||
from pandasai.llm.openai import OpenAI | ||
|
||
# can be used with all LLM's | ||
llm = OpenAI("openai_key") | ||
judge_agent = JudgeAgent(config={"llm": llm}) | ||
judge_agent.evaluate( | ||
query="return total github star count for year 2023", | ||
code="""sql_query = "SELECT COUNT(`users`.`login`) AS user_count, DATE_FORMAT(`users`.`starredAt`, '%Y-%m') AS starred_at_by_month FROM `users` WHERE `users`.`starredAt` BETWEEN '2023-01-01' AND '2023-12-31' GROUP BY starred_at_by_month ORDER BY starred_at_by_month asc" | ||
data = execute_sql_query(sql_query) | ||
plt.plot(data['starred_at_by_month'], data['user_count']) | ||
plt.xlabel('Month') | ||
plt.ylabel('User Count') | ||
plt.title('GitHub Star Count Per Month - Year 2023') | ||
plt.legend(loc='best') | ||
plt.savefig('/Users/arslan/Documents/SinapTik/pandas-ai/exports/charts/temp_chart.png') | ||
result = {'type': 'plot', 'value': '/Users/arslan/Documents/SinapTik/pandas-ai/exports/charts/temp_chart.png'} | ||
""", | ||
) | ||
``` | ||
|
||
Judge Agent integration with other agents also gives the flexibility to use different LLM's |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import os | ||
|
||
import pandas as pd | ||
|
||
from pandasai.agent.agent import Agent | ||
from pandasai.ee.agents.judge_agent import JudgeAgent | ||
from pandasai.llm.openai import OpenAI | ||
|
||
os.environ["PANDASAI_API_KEY"] = "$2a****************************" | ||
|
||
github_stars = pd.read_csv("/Users/arslan/Downloads/stars (2).csv") | ||
|
||
judge = JudgeAgent() | ||
agent = Agent([github_stars], judge=judge) | ||
|
||
print(agent.chat("return total stars count")) | ||
|
||
|
||
# Using Judge standalone | ||
llm = OpenAI("openai_key") | ||
judge_agent = JudgeAgent(config={"llm": llm}) | ||
judge_agent.evaluate( | ||
query="return total github star count for year 2023", | ||
code="""sql_query = "SELECT COUNT(`users`.`login`) AS user_count, DATE_FORMAT(`users`.`starredAt`, '%Y-%m') AS starred_at_by_month FROM `users` WHERE `users`.`starredAt` BETWEEN '2023-01-01' AND '2023-12-31' GROUP BY starred_at_by_month ORDER BY starred_at_by_month asc" | ||
data = execute_sql_query(sql_query) | ||
plt.plot(data['starred_at_by_month'], data['user_count']) | ||
plt.xlabel('Month') | ||
plt.ylabel('User Count') | ||
plt.title('GitHub Star Count Per Month - Year 2023') | ||
plt.legend(loc='best') | ||
plt.savefig('/Users/arslan/Documents/SinapTik/pandas-ai/exports/charts/temp_chart.png') | ||
result = {'type': 'plot', 'value': '/Users/arslan/Documents/SinapTik/pandas-ai/exports/charts/temp_chart.png'} | ||
""", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from pandasai.helpers.logger import Logger | ||
from pandasai.pipelines.pipeline import Pipeline | ||
from pandasai.pipelines.pipeline_context import PipelineContext | ||
|
||
|
||
class BaseJudge: | ||
context: PipelineContext | ||
pipeline: Pipeline | ||
logger: Logger | ||
|
||
def __init__( | ||
self, | ||
pipeline: Pipeline, | ||
) -> None: | ||
self.pipeline = pipeline | ||
|
||
def evaluate(self, query: str, code: str) -> bool: | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from typing import Optional, Union | ||
|
||
from pandasai.agent.base_judge import BaseJudge | ||
from pandasai.config import load_config_from_json | ||
from pandasai.ee.agents.judge_agent.pipeline.judge_pipeline import JudgePipeline | ||
from pandasai.pipelines.abstract_pipeline import AbstractPipeline | ||
from pandasai.pipelines.judge.judge_pipeline_input import JudgePipelineInput | ||
from pandasai.pipelines.pipeline_context import PipelineContext | ||
from pandasai.schemas.df_config import Config | ||
|
||
|
||
class JudgeAgent(BaseJudge): | ||
def __init__( | ||
self, | ||
config: Optional[Union[Config, dict]] = None, | ||
pipeline: AbstractPipeline = None, | ||
) -> None: | ||
context = None | ||
if config: | ||
if isinstance(config, dict): | ||
config = Config(**load_config_from_json(config)) | ||
|
||
context = PipelineContext(None, config) | ||
|
||
pipeline = pipeline or JudgePipeline(context=context) | ||
super().__init__(pipeline) | ||
|
||
def evaluate(self, query: str, code: str) -> bool: | ||
input_data = JudgePipelineInput(query, code) | ||
return self.pipeline.run(input_data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from typing import Optional | ||
|
||
from pandasai.ee.agents.judge_agent.pipeline.judge_prompt_generation import ( | ||
JudgePromptGeneration, | ||
) | ||
from pandasai.ee.agents.judge_agent.pipeline.llm_call import LLMCall | ||
from pandasai.helpers.logger import Logger | ||
from pandasai.helpers.query_exec_tracker import QueryExecTracker | ||
from pandasai.pipelines.judge.judge_pipeline_input import JudgePipelineInput | ||
from pandasai.pipelines.pipeline import Pipeline | ||
from pandasai.pipelines.pipeline_context import PipelineContext | ||
|
||
|
||
class JudgePipeline: | ||
def __init__( | ||
self, | ||
context: Optional[PipelineContext] = None, | ||
logger: Optional[Logger] = None, | ||
query_exec_tracker: QueryExecTracker = None, | ||
): | ||
self.query_exec_tracker = query_exec_tracker | ||
|
||
self.pipeline = Pipeline( | ||
context=context, | ||
logger=logger, | ||
query_exec_tracker=self.query_exec_tracker, | ||
steps=[ | ||
JudgePromptGeneration(), | ||
LLMCall(), | ||
], | ||
) | ||
|
||
def run(self, input: JudgePipelineInput): | ||
return self.pipeline.run(input) |
50 changes: 50 additions & 0 deletions
50
pandasai/ee/agents/judge_agent/pipeline/judge_prompt_generation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import datetime | ||
from typing import Any | ||
|
||
from pandasai.ee.agents.judge_agent.prompts.judge_agent_prompt import JudgeAgentPrompt | ||
from pandasai.helpers.logger import Logger | ||
from pandasai.pipelines.base_logic_unit import BaseLogicUnit | ||
from pandasai.pipelines.judge.judge_pipeline_input import JudgePipelineInput | ||
from pandasai.pipelines.logic_unit_output import LogicUnitOutput | ||
|
||
|
||
class JudgePromptGeneration(BaseLogicUnit): | ||
""" | ||
Code Prompt Generation Stage | ||
""" | ||
|
||
pass | ||
|
||
def execute(self, input_data: JudgePipelineInput, **kwargs) -> Any: | ||
""" | ||
This method will return output according to | ||
Implementation. | ||
|
||
:param input: Last logic unit output | ||
:param kwargs: A dictionary of keyword arguments. | ||
- 'logger' (any): The logger for logging. | ||
- 'config' (Config): Global configurations for the test | ||
- 'context' (any): The execution context. | ||
|
||
:return: LogicUnitOutput(prompt) | ||
""" | ||
self.context = kwargs.get("context") | ||
self.logger: Logger = kwargs.get("logger") | ||
|
||
now = datetime.datetime.now() | ||
human_readable_datetime = now.strftime("%A, %B %d, %Y %I:%M %p") | ||
|
||
prompt = JudgeAgentPrompt( | ||
query=input_data.query, | ||
code=input_data.code, | ||
context=self.context, | ||
date=human_readable_datetime, | ||
) | ||
self.logger.log(f"Using prompt: {prompt}") | ||
|
||
return LogicUnitOutput( | ||
prompt, | ||
True, | ||
"Prompt Generated Successfully", | ||
{"content_type": "prompt", "value": prompt.to_string()}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from typing import Any | ||
|
||
from pandasai.exceptions import InvalidOutputValueMismatch | ||
from pandasai.helpers.logger import Logger | ||
from pandasai.pipelines.base_logic_unit import BaseLogicUnit | ||
from pandasai.pipelines.logic_unit_output import LogicUnitOutput | ||
from pandasai.pipelines.pipeline_context import PipelineContext | ||
|
||
|
||
class LLMCall(BaseLogicUnit): | ||
""" | ||
LLM Code Generation Stage | ||
""" | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
def execute(self, input: Any, **kwargs) -> Any: | ||
""" | ||
This method will return output according to | ||
Implementation. | ||
|
||
:param input: Your input data. | ||
:param kwargs: A dictionary of keyword arguments. | ||
- 'logger' (any): The logger for logging. | ||
- 'config' (Config): Global configurations for the test | ||
- 'context' (any): The execution context. | ||
|
||
:return: The result of the execution. | ||
""" | ||
pipeline_context: PipelineContext = kwargs.get("context") | ||
logger: Logger = kwargs.get("logger") | ||
|
||
retry_count = 0 | ||
while retry_count <= pipeline_context.config.max_retries: | ||
response = pipeline_context.config.llm.call(input, pipeline_context) | ||
|
||
logger.log( | ||
f"""LLM response: | ||
{response} | ||
""" | ||
) | ||
try: | ||
result = False | ||
if "<Yes>" in response: | ||
result = True | ||
elif "<No>" in response: | ||
result = False | ||
else: | ||
raise InvalidOutputValueMismatch("Invalid response of LLM Call") | ||
|
||
pipeline_context.add("llm_call", response) | ||
|
||
return LogicUnitOutput( | ||
result, | ||
True, | ||
"Code Generated Successfully", | ||
{"content_type": "string", "value": response}, | ||
) | ||
except Exception: | ||
if retry_count == pipeline_context.config.max_retries: | ||
raise | ||
|
||
retry_count += 1 |
39 changes: 39 additions & 0 deletions
39
pandasai/ee/agents/judge_agent/prompts/judge_agent_prompt.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from pathlib import Path | ||
|
||
from jinja2 import Environment, FileSystemLoader | ||
|
||
from pandasai.prompts.base import BasePrompt | ||
|
||
|
||
class JudgeAgentPrompt(BasePrompt): | ||
"""Prompt to generate Python code from a dataframe.""" | ||
|
||
template_path = "judge_agent_prompt.tmpl" | ||
|
||
def __init__(self, **kwargs): | ||
"""Initialize the prompt.""" | ||
self.props = kwargs | ||
|
||
if self.template: | ||
env = Environment() | ||
self.prompt = env.from_string(self.template) | ||
elif self.template_path: | ||
# find path to template file | ||
current_dir_path = Path(__file__).parent | ||
|
||
path_to_template = current_dir_path / "templates" | ||
env = Environment(loader=FileSystemLoader(path_to_template)) | ||
self.prompt = env.get_template(self.template_path) | ||
|
||
self._resolved_prompt = None | ||
|
||
def to_json(self): | ||
context = self.props["context"] | ||
memory = context.memory | ||
conversations = memory.to_json() | ||
system_prompt = memory.get_system_prompt() | ||
return { | ||
"conversation": conversations, | ||
"system_prompt": system_prompt, | ||
"prompt": self.to_string(), | ||
} |
11 changes: 11 additions & 0 deletions
11
pandasai/ee/agents/judge_agent/prompts/templates/judge_agent_prompt.tmpl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
Today is {{date}} | ||
### QUERY | ||
{{query}} | ||
### GENERATED CODE | ||
{{code}} | ||
|
||
Reason step by step and at the end answer: | ||
1. Explain what the code does | ||
2. Explain what the user query asks for | ||
3. Strictly compare the query with the code that is generated | ||
Always return <Yes> or <No> if exactly meets the requirements |
Oops, something went wrong.