Skip to content

Commit

Permalink
Sotopia Benchmark CLI API (sotopia-lab#69)
Browse files Browse the repository at this point in the history
* add benchmark social agents

* add benchmark agents

* Add sotopia_benchmark cli api

* fix pre-commit

* add evaluator model argument

* finish benchmarking

* benchmark done

* chore: Fix formatting issue in redis_stats.ipynb and cli.py

* switch back to LLM_Name

* merge main

* add together ai

* fix naming error

* roll back to llama2

* chore: Update langchain-together dependency to version 0.1.2

* use chatopenai for together models

* add logging

* fix pre-commit

* add more logging options

* probably fix the event loop closed error: following NVIDIA/NeMo-Guardrails#336

* modify cli; fix model position bug

* chore: Update benchmark tag to "benchmark_{model}_final"

* Refactor _iterate_all_env_agent_combo_not_in_db function

* chore: Update python version to 3.11.2

* change to dict comparison

* ignore jsonl

* ✨ finish benchmarking script

* chore: Refactor server.py and redis_stats.ipynb

* add type ignore

* push for the eval

* Refactor run_async_benchmark_in_batch function

* Refactor run_async_benchmark_in_batch function

* add doc

* precommit fix

* pre-commit

* refactor

* update w feedback

* pre commit

* chore: Update authors in pyproject.toml and fetch benchmark_agents.json from Hugging Face API

* hotfix

* chore: Remove unnecessary type hint in benchmark/cli.py

---------

Co-authored-by: Hao <prokilchu@gmail.com>
  • Loading branch information
XuhuiZhou and ProKil authored Jun 17, 2024
1 parent f242543 commit 8bff863
Show file tree
Hide file tree
Showing 16 changed files with 866 additions and 476 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ data/*
deprecated/*

*.csv

*.jsonl
#backup
backup/*

Expand Down
11 changes: 11 additions & 0 deletions docs/pages/benchmark.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Benchmark your model as a social agent in Sotopia

```
sotopia_benchmark --model=<your_model_name>
```
or

```
python sotopia/benchmark/cli.py --model=<your_model_name>
```
Currently this script would run over 100 simulations on the Sotopia Hard tasks. And the partner model is fixed to be `meta-llama/Llama-3-70b-chat-hf`
99 changes: 95 additions & 4 deletions notebooks/redis_stats.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@
"metadata": {},
"source": [
"## EnvAgentComboStorage\n",
"\n",
"Identify the combination of environment and agent that is used in the episodes.\n",
"Once we have the combination, we can use it to start the simulation.\n",
"Combo is a combination of Environment and two agents."
]
},
Expand All @@ -204,9 +205,99 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Episode Log\n",
"## EnvironmentList\n",
"Store a list of special environments (e.g., sotopia hard) that can be used to start certain simulations. Agent index is used to identify the special agent in the simulation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sotopia.database.persistent_profile import EnvironmentList\n",
"\n",
"all_list = EnvironmentList.all_pks()\n",
"all_list = list(all_list)\n",
"print(len(all_list))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sotopia.samplers import ConstraintBasedSampler\n",
"from sotopia.messages import AgentAction, Observation\n",
"from sotopia.agents import LLMAgent\n",
"import json\n",
"# In this example, we will demonstrate using the EnvironmentList class to sample a list of EnvAgentComboStorage and serialize it to a json file that can be used for sharing with others for benchmarking purposes.\n",
"\n",
"\n",
"def _sample_env_agent_combo_and_push_to_db(env_id: str) -> list[EnvAgentComboStorage]:\n",
" combo_list = []\n",
" sampler = ConstraintBasedSampler[Observation, AgentAction](env_candidates=[env_id])\n",
" env_agent_combo_list = list(\n",
" sampler.sample(agent_classes=[LLMAgent] * 2, replacement=False, size=10)\n",
" )\n",
" for env, agent in env_agent_combo_list:\n",
" combo = EnvAgentComboStorage(\n",
" env_id=env.profile.pk,\n",
" agent_ids=[agent[0].profile.pk, agent[1].profile.pk],\n",
" )\n",
" combo_list.append(combo)\n",
" return combo_list\n",
"\n",
"\n",
"# First we will extrat the hard environments from the EnvironmentList\n",
"hard_envs = EnvironmentList.get(\"01HAK34YPB1H1RWXQDASDKHSNS\").environments\n",
"print(len(hard_envs))\n",
"hard_envs_set = set(hard_envs)\n",
"\n",
"# Next we will sample 10 EnvAgentComboStorage from each hard environment\n",
"final_list_for_benchmark_agents = []\n",
"for env in hard_envs_set:\n",
" combo_list = EnvAgentComboStorage.find(EnvAgentComboStorage.env_id == env).all()\n",
" print(len(combo_list))\n",
" final_list_for_benchmark_agents.extend(combo_list)\n",
"\n",
"# Finally we will serialize the list to a json file\n",
"with open(\"../data/benchmark_agents.json\", \"w\") as f:\n",
" json.dump(\n",
" [combo.dict() for combo in final_list_for_benchmark_agents],\n",
" f,\n",
" indent=4,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"EnvironmentList.get(\"01HAK34YPB1H1RWXQDASDKHSNS\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Episode Log"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# find episode log by tag\n",
"Episodes = EpisodeLog.find(EpisodeLog.tag == \"aug20_gpt4_llama-2-70b-chat_zqi2\").all()\n",
"len(Episodes) ## Episode Log\n",
"\n",
"Episodelog stores the social conversation between two agents in an environment."
"## Episodelog stores the social conversation between two agents in an environment."
]
},
{
Expand Down Expand Up @@ -289,7 +380,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.11.2"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 8bff863

Please sign in to comment.