Skip to content

Commit

Permalink
added symmetry check to notebook (sotopia-lab#112)
Browse files Browse the repository at this point in the history
* added symetry check to notebook; added shell script for running experiments

* removed experiment script

* wrapped symmetry checking

---------

Co-authored-by: zqi2cmu <zqi2@andrew.cmu.edu>
  • Loading branch information
Jasonqi146 and zqi2cmu authored Aug 20, 2023
1 parent ff59e3f commit 9b2514f
Showing 1 changed file with 175 additions and 23 deletions.
198 changes: 175 additions & 23 deletions examples/redis_stats.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"import rich\n",
"from sotopia.database.persistent_profile import AgentProfile, EnvironmentProfile, RelationshipProfile\n",
"from sotopia.database.logs import EpisodeLog\n",
"from sotopia.database.env_agent_combo_storage import EnvAgentComboStorage\n",
"from collections import Counter \n",
"from redis_om import Migrator "
]
Expand Down Expand Up @@ -93,6 +94,17 @@
" rich.print(agent)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"agent_pks = AgentProfile.all_pks()\n",
"agent_pks = list(agent_pks)\n",
"print(len(agent_pks))"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -134,11 +146,61 @@
"metadata": {},
"outputs": [],
"source": [
"env_profile_id = \"01H6S9W1BMGR7MFRPH0V55J2TD\"\n",
"# get all environments\n",
"all_envs = list(EnvironmentProfile.all_pks())\n",
"print(len(all_envs))\n",
"print(all_envs[:5])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#get a specific environment profile\n",
"env_profile_id = \"01H7VFHPJKR16MD1KC71V4ZRCF\"\n",
"env = EnvironmentProfile.get(env_profile_id)\n",
"rich.print(env)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## EnvAgentComboStorage"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#all env-agent combos\n",
"all_combos = EnvAgentComboStorage().all_pks()\n",
"all_combos = list(all_combos)\n",
"print(len(all_combos))\n",
"rich.print(EnvAgentComboStorage().get(all_combos[0]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#check for duplicates in EnvAgentComboStorage\n",
"cache = set()\n",
"for combo_pk in all_combos:\n",
" combo = EnvAgentComboStorage.get(combo_pk)\n",
" curr_tuple = (combo.env_id, combo.agent_ids[0], combo.agent_ids[1])\n",
" if curr_tuple in cache:\n",
" print(\"duplicate\")\n",
" else:\n",
" cache.add(curr_tuple)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -152,9 +214,117 @@
"metadata": {},
"outputs": [],
"source": [
"# tag: 6_initial_aug2\n",
"# Migrator().run()\n",
"Episodes = EpisodeLog.find(EpisodeLog.tag == \"6_initial_aug14_full\").all()"
"# find episode log by tag\n",
"Episodes = EpisodeLog.find(EpisodeLog.tag == \"aug18_gpt3.5_llama-2-70b-chat_zqi2\").all()\n",
"len(Episodes)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#get all episode logs' primary keys\n",
"episode_pks = EpisodeLog.all_pks()\n",
"episode_pks = list(episode_pks)\n",
"print(len(episode_pks))\n",
"print(episode_pks[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"EpisodeLog.get(\"01H83KT1MAD1XYT9VXVBMCPYFW\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# some eps have validation error while loading\n",
"# please look at the buggy_eps list\n",
"gpt35_llama2_eps = []\n",
"buggy_eps = []\n",
"for epid in episode_pks:\n",
" try:\n",
" curr_ep = EpisodeLog.get(epid)\n",
" except Exception as e:\n",
" print(e)\n",
" buggy_eps.append(curr_ep)\n",
" continue\n",
" gpt35_llama2_eps.append(curr_ep)\n",
"len(gpt35_llama2_eps)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# get the epilogs that contain the specified models\n",
"model1 = 'gpt-3.5-turbo'\n",
"model2 = 'togethercomputer/llama-2-70b-chat'\n",
"model_comp1 = ['gpt-4', model1, model2]\n",
"model_comp2 = ['gpt-4', model2, model1]\n",
"\n",
"gpt35_llama2_eps = []\n",
"for epid in episode_pks:\n",
" try:\n",
" curr_ep = EpisodeLog.get(epid)\n",
" except:\n",
" continue\n",
" if curr_ep.models == model_comp1 or curr_ep.models == model_comp2:\n",
" gpt35_llama2_eps.append(curr_ep)\n",
"len(gpt35_llama2_eps)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import logging"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# check symetry of epilogs, i.e., if we have an epilog for env1, agent1, agent2, then we should have an epilog for env1, agent2, agent1\n",
"def is_symmetric_epilogs(epilogs):\n",
" asymetric_epilogs = []\n",
" gpt35_llama2_epilogs_dict = {}\n",
" for ep in epilogs:\n",
" gpt35_llama2_epilogs_dict[tuple(ep.models)] = ep\n",
" \n",
" for model_comp in gpt35_llama2_epilogs_dict:\n",
" symetric_model_comp = (model_comp[0], model_comp[2], model_comp[1])\n",
" if symetric_model_comp not in gpt35_llama2_epilogs_dict:\n",
" asymetric_epilogs += [gpt35_llama2_epilogs_dict[model_comp]]\n",
" if not asymetric_epilogs:\n",
" return True\n",
" else:\n",
" logging.warning(f\"Found asymetric epilogs. Primary keys: {[log.pk for log in asymetric_epilogs]}\")\n",
" return False"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"is_symmetric_epilogs(gpt35_llama2_eps)"
]
},
{
Expand All @@ -177,6 +347,7 @@
"metadata": {},
"outputs": [],
"source": [
"#check environments\n",
"len(set([Episode.environment for Episode in Episodes]))"
]
},
Expand Down Expand Up @@ -249,25 +420,6 @@
" \"on_par\": model_on_par_successRate.mean(axis=0).to_dict(),\n",
" })"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"get_avg_reward_for_models(Episodes)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"a = get_avg_successRate_for_models(Episodes)\n",
"a"
]
}
],
"metadata": {
Expand Down

0 comments on commit 9b2514f

Please sign in to comment.