Skip to content

Commit

Permalink
feat(Judge): implementation of judge agent to validate code matches t… (
Browse files Browse the repository at this point in the history
Sinaptik-AI#1238)

* feat(Judge): implementation of judge agent to validate code matches the user query

* fix: ruff errors

* feat(JudgeAgent): make judge agent using memory from chat agent

* chore add datetime in prompt

* add documentation

* docs(judge): update judge documentation
  • Loading branch information
ArslanSaleem authored Jun 18, 2024
1 parent 2c14f15 commit ab0d685
Show file tree
Hide file tree
Showing 22 changed files with 1,029 additions and 14 deletions.
64 changes: 64 additions & 0 deletions docs/judge-agent.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
---
title: "Judge Agent"
description: "Enhance the PandasAI library with the JudgeAgent that evaluates the generated code"
---

## Introduction to the Judge Agent

The `JudgeAgent` extends the capabilities of the PandasAI library by adding an extra judgement in agents pipeline that validates the code generated against the query

> **Note:** Usage of the Judge Agent may be subject to a license. For more details, refer to the [license documentation](https://github.com/Sinaptik-AI/pandas-ai/blob/master/pandasai/ee/LICENSE).
## Instantiating the Judge Agent

JudgeAgent can be used both as a standalone agent and in conjunction with other agents. To use it with other agents, pass JudgeAgent as a parameter to them.

### Using with other agents

```python
import os

import pandas as pd

from pandasai.agent.agent import Agent
from pandasai.ee.agents.judge_agent import JudgeAgent

os.environ["PANDASAI_API_KEY"] = "$2a****************************"

github_stars = pd.read_csv("/Users/arslan/Downloads/stars (2).csv")

judge = JudgeAgent()
agent = Agent([github_stars], judge=judge)

print(agent.chat("return total stars count"))
```

### Using as a standalone

```python
import os

import pandas as pd

from pandasai.ee.agents.judge_agent import JudgeAgent
from pandasai.llm.openai import OpenAI

# can be used with all LLM's
llm = OpenAI("openai_key")
judge_agent = JudgeAgent(config={"llm": llm})
judge_agent.evaluate(
query="return total github star count for year 2023",
code="""sql_query = "SELECT COUNT(`users`.`login`) AS user_count, DATE_FORMAT(`users`.`starredAt`, '%Y-%m') AS starred_at_by_month FROM `users` WHERE `users`.`starredAt` BETWEEN '2023-01-01' AND '2023-12-31' GROUP BY starred_at_by_month ORDER BY starred_at_by_month asc"
data = execute_sql_query(sql_query)
plt.plot(data['starred_at_by_month'], data['user_count'])
plt.xlabel('Month')
plt.ylabel('User Count')
plt.title('GitHub Star Count Per Month - Year 2023')
plt.legend(loc='best')
plt.savefig('/Users/arslan/Documents/SinapTik/pandas-ai/exports/charts/temp_chart.png')
result = {'type': 'plot', 'value': '/Users/arslan/Documents/SinapTik/pandas-ai/exports/charts/temp_chart.png'}
""",
)
```

Judge Agent integration with other agents also gives the flexibility to use different LLM's
2 changes: 1 addition & 1 deletion docs/mint.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
},
{
"group": "Advanced agents",
"pages": ["semantic-agent"]
"pages": ["semantic-agent", "judge-agent"]
},
{
"group": "Advanced usage",
Expand Down
34 changes: 34 additions & 0 deletions examples/judge_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os

import pandas as pd

from pandasai.agent.agent import Agent
from pandasai.ee.agents.judge_agent import JudgeAgent
from pandasai.llm.openai import OpenAI

os.environ["PANDASAI_API_KEY"] = "$2a****************************"

github_stars = pd.read_csv("/Users/arslan/Downloads/stars (2).csv")

judge = JudgeAgent()
agent = Agent([github_stars], judge=judge)

print(agent.chat("return total stars count"))


# Using Judge standalone
llm = OpenAI("openai_key")
judge_agent = JudgeAgent(config={"llm": llm})
judge_agent.evaluate(
query="return total github star count for year 2023",
code="""sql_query = "SELECT COUNT(`users`.`login`) AS user_count, DATE_FORMAT(`users`.`starredAt`, '%Y-%m') AS starred_at_by_month FROM `users` WHERE `users`.`starredAt` BETWEEN '2023-01-01' AND '2023-12-31' GROUP BY starred_at_by_month ORDER BY starred_at_by_month asc"
data = execute_sql_query(sql_query)
plt.plot(data['starred_at_by_month'], data['user_count'])
plt.xlabel('Month')
plt.ylabel('User Count')
plt.title('GitHub Star Count Per Month - Year 2023')
plt.legend(loc='best')
plt.savefig('/Users/arslan/Documents/SinapTik/pandas-ai/exports/charts/temp_chart.png')
result = {'type': 'plot', 'value': '/Users/arslan/Documents/SinapTik/pandas-ai/exports/charts/temp_chart.png'}
""",
)
4 changes: 4 additions & 0 deletions pandasai/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pandas as pd

from pandasai.agent.base import BaseAgent
from pandasai.agent.base_judge import BaseJudge
from pandasai.connectors.base import BaseConnector
from pandasai.pipelines.chat.generate_chat_pipeline import GenerateChatPipeline
from pandasai.schemas.df_config import Config
Expand All @@ -20,6 +21,7 @@ def __init__(
pipeline: Optional[Type[GenerateChatPipeline]] = None,
vectorstore: Optional[VectorStore] = None,
description: str = None,
judge: BaseJudge = None,
):
super().__init__(dfs, config, memory_size, vectorstore, description)

Expand All @@ -31,6 +33,7 @@ def __init__(
on_code_generation=self._callbacks.on_code_generation,
before_code_execution=self._callbacks.before_code_execution,
on_result=self._callbacks.on_result,
judge=judge,
)
if pipeline
else GenerateChatPipeline(
Expand All @@ -40,6 +43,7 @@ def __init__(
on_code_generation=self._callbacks.on_code_generation,
before_code_execution=self._callbacks.before_code_execution,
on_result=self._callbacks.on_result,
judge=judge,
)
)

Expand Down
18 changes: 18 additions & 0 deletions pandasai/agent/base_judge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from pandasai.helpers.logger import Logger
from pandasai.pipelines.pipeline import Pipeline
from pandasai.pipelines.pipeline_context import PipelineContext


class BaseJudge:
context: PipelineContext
pipeline: Pipeline
logger: Logger

def __init__(
self,
pipeline: Pipeline,
) -> None:
self.pipeline = pipeline

def evaluate(self, query: str, code: str) -> bool:
raise NotImplementedError
30 changes: 30 additions & 0 deletions pandasai/ee/agents/judge_agent/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Optional, Union

from pandasai.agent.base_judge import BaseJudge
from pandasai.config import load_config_from_json
from pandasai.ee.agents.judge_agent.pipeline.judge_pipeline import JudgePipeline
from pandasai.pipelines.abstract_pipeline import AbstractPipeline
from pandasai.pipelines.judge.judge_pipeline_input import JudgePipelineInput
from pandasai.pipelines.pipeline_context import PipelineContext
from pandasai.schemas.df_config import Config


class JudgeAgent(BaseJudge):
def __init__(
self,
config: Optional[Union[Config, dict]] = None,
pipeline: AbstractPipeline = None,
) -> None:
context = None
if config:
if isinstance(config, dict):
config = Config(**load_config_from_json(config))

context = PipelineContext(None, config)

pipeline = pipeline or JudgePipeline(context=context)
super().__init__(pipeline)

def evaluate(self, query: str, code: str) -> bool:
input_data = JudgePipelineInput(query, code)
return self.pipeline.run(input_data)
34 changes: 34 additions & 0 deletions pandasai/ee/agents/judge_agent/pipeline/judge_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Optional

from pandasai.ee.agents.judge_agent.pipeline.judge_prompt_generation import (
JudgePromptGeneration,
)
from pandasai.ee.agents.judge_agent.pipeline.llm_call import LLMCall
from pandasai.helpers.logger import Logger
from pandasai.helpers.query_exec_tracker import QueryExecTracker
from pandasai.pipelines.judge.judge_pipeline_input import JudgePipelineInput
from pandasai.pipelines.pipeline import Pipeline
from pandasai.pipelines.pipeline_context import PipelineContext


class JudgePipeline:
def __init__(
self,
context: Optional[PipelineContext] = None,
logger: Optional[Logger] = None,
query_exec_tracker: QueryExecTracker = None,
):
self.query_exec_tracker = query_exec_tracker

self.pipeline = Pipeline(
context=context,
logger=logger,
query_exec_tracker=self.query_exec_tracker,
steps=[
JudgePromptGeneration(),
LLMCall(),
],
)

def run(self, input: JudgePipelineInput):
return self.pipeline.run(input)
50 changes: 50 additions & 0 deletions pandasai/ee/agents/judge_agent/pipeline/judge_prompt_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import datetime
from typing import Any

from pandasai.ee.agents.judge_agent.prompts.judge_agent_prompt import JudgeAgentPrompt
from pandasai.helpers.logger import Logger
from pandasai.pipelines.base_logic_unit import BaseLogicUnit
from pandasai.pipelines.judge.judge_pipeline_input import JudgePipelineInput
from pandasai.pipelines.logic_unit_output import LogicUnitOutput


class JudgePromptGeneration(BaseLogicUnit):
"""
Code Prompt Generation Stage
"""

pass

def execute(self, input_data: JudgePipelineInput, **kwargs) -> Any:
"""
This method will return output according to
Implementation.

:param input: Last logic unit output
:param kwargs: A dictionary of keyword arguments.
- 'logger' (any): The logger for logging.
- 'config' (Config): Global configurations for the test
- 'context' (any): The execution context.

:return: LogicUnitOutput(prompt)
"""
self.context = kwargs.get("context")
self.logger: Logger = kwargs.get("logger")

now = datetime.datetime.now()
human_readable_datetime = now.strftime("%A, %B %d, %Y %I:%M %p")

prompt = JudgeAgentPrompt(
query=input_data.query,
code=input_data.code,
context=self.context,
date=human_readable_datetime,
)
self.logger.log(f"Using prompt: {prompt}")

return LogicUnitOutput(
prompt,
True,
"Prompt Generated Successfully",
{"content_type": "prompt", "value": prompt.to_string()},
)
64 changes: 64 additions & 0 deletions pandasai/ee/agents/judge_agent/pipeline/llm_call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import Any

from pandasai.exceptions import InvalidOutputValueMismatch
from pandasai.helpers.logger import Logger
from pandasai.pipelines.base_logic_unit import BaseLogicUnit
from pandasai.pipelines.logic_unit_output import LogicUnitOutput
from pandasai.pipelines.pipeline_context import PipelineContext


class LLMCall(BaseLogicUnit):
"""
LLM Code Generation Stage
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)

def execute(self, input: Any, **kwargs) -> Any:
"""
This method will return output according to
Implementation.

:param input: Your input data.
:param kwargs: A dictionary of keyword arguments.
- 'logger' (any): The logger for logging.
- 'config' (Config): Global configurations for the test
- 'context' (any): The execution context.

:return: The result of the execution.
"""
pipeline_context: PipelineContext = kwargs.get("context")
logger: Logger = kwargs.get("logger")

retry_count = 0
while retry_count <= pipeline_context.config.max_retries:
response = pipeline_context.config.llm.call(input, pipeline_context)

logger.log(
f"""LLM response:
{response}
"""
)
try:
result = False
if "<Yes>" in response:
result = True
elif "<No>" in response:
result = False
else:
raise InvalidOutputValueMismatch("Invalid response of LLM Call")

pipeline_context.add("llm_call", response)

return LogicUnitOutput(
result,
True,
"Code Generated Successfully",
{"content_type": "string", "value": response},
)
except Exception:
if retry_count == pipeline_context.config.max_retries:
raise

retry_count += 1
39 changes: 39 additions & 0 deletions pandasai/ee/agents/judge_agent/prompts/judge_agent_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from pathlib import Path

from jinja2 import Environment, FileSystemLoader

from pandasai.prompts.base import BasePrompt


class JudgeAgentPrompt(BasePrompt):
"""Prompt to generate Python code from a dataframe."""

template_path = "judge_agent_prompt.tmpl"

def __init__(self, **kwargs):
"""Initialize the prompt."""
self.props = kwargs

if self.template:
env = Environment()
self.prompt = env.from_string(self.template)
elif self.template_path:
# find path to template file
current_dir_path = Path(__file__).parent

path_to_template = current_dir_path / "templates"
env = Environment(loader=FileSystemLoader(path_to_template))
self.prompt = env.get_template(self.template_path)

self._resolved_prompt = None

def to_json(self):
context = self.props["context"]
memory = context.memory
conversations = memory.to_json()
system_prompt = memory.get_system_prompt()
return {
"conversation": conversations,
"system_prompt": system_prompt,
"prompt": self.to_string(),
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Today is {{date}}
### QUERY
{{query}}
### GENERATED CODE
{{code}}

Reason step by step and at the end answer:
1. Explain what the code does
2. Explain what the user query asks for
3. Strictly compare the query with the code that is generated
Always return <Yes> or <No> if exactly meets the requirements
Loading

0 comments on commit ab0d685

Please sign in to comment.