Skip to content

Commit

Permalink
Remove all_pks from Sampler.sample (sotopia-lab#160)
Browse files Browse the repository at this point in the history
* Move all_pks to constructors of samplers

* lazy profile pull

* use find().all() in place of all_pks()

* update readme

---------

Co-authored-by: XuhuiZhou <zhouxuhui2018@gmail.com>
  • Loading branch information
ProKil and XuhuiZhou authored Aug 13, 2024
1 parent 0edf4b0 commit 9d50654
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 48 deletions.
22 changes: 10 additions & 12 deletions docs/pages/concepts/database.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
## Basic Concepts of the Database

We use the `JsonModel` class to define the structure of the database. This class is coming from [Redis-OM](https://github.com/redis/redis-om-node), which is a Redis Object Mapper for Node.js. You should learn a lot about how to use the database by reading their documentation.

We do have more customized method `.all()` would return a list of all the corresponding objects in the database. Note that using `.all()` is going to block the database, so you should use it with caution.
You might see codes using `all_pks()` which is a method from the `JsonModel` class. This method returns a list of all the primary keys of the corresponding objects in the database. This method is not blocking the database. However, it might not be as efficient as using `.all()`.

## Adding new characters and environments
You can use the following function with the `**kwargs` being the properties of the `AgentProfile` class. This is the same for the scenarios/environments.
```python
Expand Down Expand Up @@ -98,10 +105,7 @@ It is very easy to serialize any database structures to JSON or CSV.
```python
from sotopia.database import episodes_to_jsonl, EpisodeLog

episodes: list[EpisodeLog] = [
EpisodeLog.get(pk=pk)
for pk in EpisodeLog.all_pks()
]
episodes: list[EpisodeLog] = EpisodeLog.all()

episodes_to_jsonl(episodes, "episodes.jsonl")
```
Expand All @@ -111,15 +115,9 @@ episodes_to_jsonl(episodes, "episodes.jsonl")
```python
from sotopia.database import environmentprofiles_to_jsonl, agentprofiles_to_jsonl

agent_profiles: list[AgentProfile] = [
AgentProfile.get(pk=pk)
for pk in AgentProfile.all_pks()
]
agent_profiles: list[AgentProfile] = AgentProfile.all()

environment_profiles: list[EnvironmentProfile] = [
EnvironmentProfile.get(pk=pk)
for pk in EnvironmentProfile.all_pks()
]
environment_profiles: list[EnvironmentProfile] = EnvironmentProfile.all()

agentprofiles_to_jsonl(agent_profiles, "agent_profiles.jsonl")
environmentprofiles_to_jsonl(environment_profiles, "environment_profiles.jsonl")
Expand Down
4 changes: 2 additions & 2 deletions examples/benchmark_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def get_human_annotations(
target_model_patterns: list[list[str]],
) -> list[AnnotationForEpisode]:
episodes_with_human_annotation: list[AnnotationForEpisode] = []
for pk in AnnotationForEpisode.all_pks():
episode_human = AnnotationForEpisode.get(pk)
human_annotated_episodes = AnnotationForEpisode.all()
for episode_human in human_annotated_episodes:
episode_model = EpisodeLog.get(episode_human.episode)
if episode_model.models in target_model_patterns:
episodes_with_human_annotation.append(episode_human)
Expand Down
5 changes: 1 addition & 4 deletions examples/evaluate_existing_episode.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,7 @@ def run_server(
push_to_db: bool = True,
verbose: bool = False,
) -> None:
annotated_episodes_pks = [
AnnotationForEpisode.get(anno).episode
for anno in AnnotationForEpisode.all_pks()
]
annotated_episodes_pks = [anno.episode for anno in AnnotationForEpisode.all()]
annotated_episodes_pks = list(set(annotated_episodes_pks))
model = typing.cast(LLM_Name, model)
# Call the function with the specified parameters
Expand Down
11 changes: 11 additions & 0 deletions sotopia/database/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import TypeVar
from redis_om import JsonModel
from .annotators import Annotator
from .env_agent_combo_storage import EnvAgentComboStorage
from .logs import AnnotationForEpisode, EpisodeLog
Expand Down Expand Up @@ -60,3 +62,12 @@
"jsonl_to_envagnetcombostorage",
"get_rewards_from_episode",
]

InheritedJsonModel = TypeVar("InheritedJsonModel", bound="JsonModel")


def _json_model_all(cls: type[InheritedJsonModel]) -> list[InheritedJsonModel]:
return cls.find().all() # type: ignore[return-value]


JsonModel.all = classmethod(_json_model_all) # type: ignore[assignment,method-assign]
16 changes: 9 additions & 7 deletions sotopia/samplers/constraint_based_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ def sample(
env_profiles: list[EnvironmentProfile] = []
agents_which_fit_scenario: list[list[str]] = []

if self.env_candidates is None:
self.env_candidates = EnvironmentProfile.all()

if self.agent_candidates is None:
self.agent_candidates = AgentProfile.all()

agent_candidate_ids: set[str] | None = None
if self.agent_candidates:
agent_candidate_ids = set(
Expand Down Expand Up @@ -120,13 +126,9 @@ def sample(
)
else:
for _ in range(size):
if self.env_candidates:
env_profile = random.choice(self.env_candidates)
if isinstance(env_profile, str):
env_profile = EnvironmentProfile.get(env_profile)
else:
env_profile_id = random.choice(list(EnvironmentProfile.all_pks()))
env_profile = EnvironmentProfile.get(env_profile_id)
env_profile = random.choice(self.env_candidates)
if isinstance(env_profile, str):
env_profile = EnvironmentProfile.get(env_profile)
env_profiles.append(env_profile)
env_profile_id = env_profile.pk
assert env_profile_id, "Env candidate must have an id"
Expand Down
33 changes: 10 additions & 23 deletions sotopia/samplers/uniform_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,32 +46,19 @@ def sample(

assert replacement, "Uniform sampling without replacement is not supported yet"

if self.env_candidates is None:
self.env_candidates = EnvironmentProfile.all()

if self.agent_candidates is None:
self.agent_candidates = AgentProfile.all()

for _ in range(size):
if self.env_candidates:
env_profile = random.choice(self.env_candidates)
if isinstance(env_profile, str):
env_profile = EnvironmentProfile.get(env_profile)
else:
env_profile_id = random.choice(list(EnvironmentProfile.all_pks()))
env_profile = EnvironmentProfile.get(env_profile_id)
env_profile = random.choice(self.env_candidates)
if isinstance(env_profile, str):
env_profile = EnvironmentProfile.get(env_profile)
env = ParallelSotopiaEnv(env_profile=env_profile, **env_params)

if self.agent_candidates:
agent_profile_candidates = self.agent_candidates
if len(agent_profile_candidates) < n_agent:
raise ValueError(
f"Number of agent candidates ({len(agent_profile_candidates)}) is less than number of agents ({n_agent})"
)
else:
agent_profile_candidates_keys = list(AgentProfile.all_pks())
if len(agent_profile_candidates_keys) < n_agent:
raise ValueError(
f"Number of agent profile candidates ({len(agent_profile_candidates_keys)}) in database is less than number of agents ({n_agent})"
)
agent_profile_candidates = [
AgentProfile.get(pk=pk) for pk in agent_profile_candidates_keys
]

agent_profile_candidates = self.agent_candidates
if len(agent_profile_candidates) == n_agent:
agent_profiles_maybe_id = agent_profile_candidates
else:
Expand Down
2 changes: 2 additions & 0 deletions stubs/redis_om/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class JsonModel(RedisModel, abc.ABC):
@classmethod
def all_pks(cls) -> Generator[str, None, None]: ...
@classmethod
def all(cls: type[InheritedJsonModel]) -> list[InheritedJsonModel]: ...
@classmethod
def find(cls, *args: Any, **kwargs: Any) -> FindQuery: ...
def save(self) -> None: ...

Expand Down

0 comments on commit 9d50654

Please sign in to comment.