diff --git a/annotation/agreement.py b/annotation/agreement.py index 5f342cb84..ba5fe47c1 100644 --- a/annotation/agreement.py +++ b/annotation/agreement.py @@ -172,7 +172,7 @@ def computeAlpha( # rater2 agrees with rater1 most of the time rater2 = np.random.uniform(size=rater1.shape) - rater2 = pd.Series((rater2 > 0.1).astype(int) * rater1) + rater2 = pd.Series((rater2 > 0.1).astype(int) * rater1) # type: ignore[assignment] # rater3 should be random rater3 = pd.Series(np.random.randint(0, 5, 100)) diff --git a/examples/profiles.py b/examples/profiles.py index 76f5333f2..80b772bec 100644 --- a/examples/profiles.py +++ b/examples/profiles.py @@ -3,7 +3,7 @@ from typing import Any, cast import pandas as pd -from redis_om import Migrator # type: ignore +from redis_om import Migrator from sotopia.database.persistent_profile import ( AgentProfile, diff --git a/sotopia-chat/chat_server.py b/sotopia-chat/chat_server.py new file mode 100644 index 000000000..1c4b36f51 --- /dev/null +++ b/sotopia-chat/chat_server.py @@ -0,0 +1,272 @@ +import asyncio +import logging +import os +import random +import subprocess +from asyncio import gather +from asyncio import run as aiorun +from datetime import datetime +from logging import FileHandler +from typing import Literal, cast + +import redis.asyncio as redis +import typer +from rich.logging import RichHandler + +from sotopia.agents import redis_agent +from sotopia.agents.llm_agent import LLMAgent +from sotopia.database import EnvAgentComboStorage +from sotopia.database.persistent_profile import ( + AgentProfile, + EnvironmentList, + EnvironmentProfile, +) +from sotopia.envs.evaluators import ( + ReachGoalLLMEvaluator, + RuleBasedTerminatedEvaluator, +) +from sotopia.envs.parallel import ParallelSotopiaEnv +from sotopia.server import arun_one_episode + +process = subprocess.Popen( + ["git", "rev-parse", "HEAD"], shell=False, stdout=subprocess.PIPE +) +git_head_hash = process.communicate()[0].strip() + +FORMAT = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" +logging.basicConfig( + level=15, + format=FORMAT, + datefmt="[%X]", + handlers=[ + RichHandler(), + FileHandler( + datetime.now().strftime( + f"./logs/%H_%M_%d_%m_%Y_{str(git_head_hash.decode('utf-8'))}.log" + ) + ), + ], +) + +app = typer.Typer() + + +async def _start_server_with_two_session_ids_and_agent_env_combo( + session_ids: list[str], agent_env_combo_pk: str +) -> None: + env_agent_combo_storage = EnvAgentComboStorage.get(agent_env_combo_pk) + env = ParallelSotopiaEnv( + env_profile=EnvironmentProfile.get(env_agent_combo_storage.env_id), + model_name="gpt-4", + action_order="round-robin", + evaluators=[ + RuleBasedTerminatedEvaluator(max_turn_number=20, max_stale_turn=2), + ], + terminal_evaluators=[ + ReachGoalLLMEvaluator("gpt-4"), + ], + ) + random.shuffle(session_ids) + agents = [ + redis_agent.RedisAgent( + agent_profile=AgentProfile.get( + env_agent_combo_storage.agent_ids[idx] + ), + session_id=session_id, + ) + for idx, session_id in enumerate(session_ids) + ] + await arun_one_episode( + env, + agents, + {"env": "gpt-4", "agent1": "redis", "agent2": "redis"}, + tag="human_human_v0.0.3_dryrun", + push_to_db=True, + ) + + +async def _start_server_with_one_session_id_and_agent_env_combo( + session_id: str, + agent_env_combo_pk: str, + left_or_right: Literal["left", "right"], +) -> None: + env_agent_combo_storage = EnvAgentComboStorage.get(agent_env_combo_pk) + env = ParallelSotopiaEnv( + env_profile=EnvironmentProfile.get(env_agent_combo_storage.env_id), + model_name="gpt-4", + action_order="round-robin", + evaluators=[ + RuleBasedTerminatedEvaluator(max_turn_number=20, max_stale_turn=2), + ], + terminal_evaluators=[ + ReachGoalLLMEvaluator("gpt-4"), + ], + ) + + agents = ( + [ + redis_agent.RedisAgent( + agent_profile=AgentProfile.get( + env_agent_combo_storage.agent_ids[0] + ), + session_id=session_id, + ), + LLMAgent( + model_name="gpt-4", + agent_profile=AgentProfile.get( + env_agent_combo_storage.agent_ids[1] + ), + ), + ] + if left_or_right == "left" + else [ + LLMAgent( + model_name="gpt-4", + agent_profile=AgentProfile.get( + env_agent_combo_storage.agent_ids[0] + ), + ), + redis_agent.RedisAgent( + agent_profile=AgentProfile.get( + env_agent_combo_storage.agent_ids[1] + ), + session_id=session_id, + ), + ] + ) + await arun_one_episode( + env, + agents, + { + "env": "gpt-4", + "agent1": "redis" if left_or_right == "left" else "gpt-4", + "agent2": "redis" if left_or_right == "right" else "gpt-4", + }, + tag="human_human_v0.0.3_dryrun", + push_to_db=True, + ) + + +async def async_add_env_agent_combo_to_redis_queue( + use_hard_env_set: bool = False, +) -> None: + r = redis.Redis.from_url(os.environ["REDIS_OM_URL"]) + if use_hard_env_set: + env_list = cast( + list[EnvironmentList], + EnvironmentList.find(EnvironmentList.name == "hard_env_set").all(), + )[0] + envs = env_list.environments + agent_indices = env_list.agent_index + env_agent_combo_storage_pks: list[str] = [] + for env in envs: + env_agent_combo_storage = list( + EnvAgentComboStorage.find( + EnvAgentComboStorage.env_id == env + ).all() + )[0] + assert env_agent_combo_storage.pk + env_agent_combo_storage_pks.append(env_agent_combo_storage.pk) + assert agent_indices + await r.rpush( + "chat_server_combos_double", + *tuple(set(env_agent_combo_storage_pks)), + ) + for agent_index, env_agent_combo_storage_pk in zip( + agent_indices, env_agent_combo_storage_pks + ): + if agent_index == "0": + await r.rpush( + "chat_server_combos_single_left", + env_agent_combo_storage_pk, + ) + else: + await r.rpush( + "chat_server_combos_single_right", + env_agent_combo_storage_pk, + ) + + else: + envs = list(EnvironmentProfile.all_pks()) + random.shuffle(envs) + for env in envs: + env_agent_combo_storage = list( + EnvAgentComboStorage.find( + EnvAgentComboStorage.env_id == env + ).all() + )[0] + assert env_agent_combo_storage.pk + await r.rpush( + "chat_server_combos_double", env_agent_combo_storage.pk + ) + await r.rpush( + "chat_server_combos_single_left", env_agent_combo_storage.pk + ) + await r.rpush( + "chat_server_combos_single_right", env_agent_combo_storage.pk + ) + await r.close() + + +@app.command() +def add_env_agent_combo_to_redis_queue(use_hard_env_set: bool = False) -> None: + aiorun(async_add_env_agent_combo_to_redis_queue(use_hard_env_set)) + + +async def async_start_server_with_session_ids(session_ids: list[str]) -> None: + typer.echo(f"Starting server with session ids: {session_ids}") + + r = redis.Redis.from_url(os.environ["REDIS_OM_URL"]) + + async def _assign_left_or_right_and_run(session_id: str) -> None: + assert ( + await r.llen("chat_server_combos_single_left") + + await r.llen("chat_server_combos_single_right") + > 0 + ), "No agent-env combos available" + if await r.llen("chat_server_combos_single_left") >= await r.llen( + "chat_server_combos_single_right" + ): + agent_env_combo_pk = ( + await r.rpop("chat_server_combos_single_left") + ).decode("utf-8") + return await _start_server_with_one_session_id_and_agent_env_combo( + session_id, agent_env_combo_pk, "left" + ) + else: + agent_env_combo_pk = ( + await r.rpop("chat_server_combos_single_right") + ).decode("utf-8") + return await _start_server_with_one_session_id_and_agent_env_combo( + session_id, agent_env_combo_pk, "right" + ) + + match (len(session_ids)): + case 1: + await _assign_left_or_right_and_run(session_ids[0]) + case 2: + if await r.llen("chat_server_combos_double") == 0: + await gather( + _assign_left_or_right_and_run(session_id) + for session_id in session_ids + ) + else: + agent_env_combo_pk: str = ( + await r.rpop("chat_server_combos_double") + ).decode("utf-8") + await _start_server_with_two_session_ids_and_agent_env_combo( + session_ids, agent_env_combo_pk + ) + case _: + raise ValueError( + f"Only 1 or 2 session ids are supported, but got {len(session_ids)}" + ) + + +@app.command() +def start_server_with_session_ids(session_ids: list[str]) -> None: + aiorun(async_start_server_with_session_ids(session_ids)) + + +if __name__ == "__main__": + app() diff --git a/sotopia-chat/fastapi_server.py b/sotopia-chat/fastapi_server.py new file mode 100644 index 000000000..27673b180 --- /dev/null +++ b/sotopia-chat/fastapi_server.py @@ -0,0 +1,379 @@ +import asyncio +import json +import os +import random +import subprocess +import typing +import uuid +from datetime import datetime +from typing import Generator, Literal, cast + +import pydantic +import pytest +from fastapi import Body, Request +from fastapi.applications import FastAPI +from fastapi.exceptions import HTTPException +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import RedirectResponse +from fastapi.testclient import TestClient +from redis import Redis +from redis.lock import Lock +from redis_om import JsonModel, Migrator +from redis_om.model.model import Field + +from sotopia.database import ( + AgentProfile, + EpisodeLog, + MatchingInWaitingRoom, + MessageTransaction, + SessionTransaction, +) +from sotopia.messages import AgentAction, Observation, SimpleMessage + +Migrator().run() + +app = FastAPI() + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +conn = Redis.from_url(os.environ["REDIS_OM_URL"]) + +WAITING_ROOM_TIMEOUT = 60 # 60 secs + + +@app.post("/connect/{session_id}/{role}/{id}") +async def connect( + session_id: str, role: Literal["server", "client"], id: str +) -> list[MessageTransaction]: + session_transactions = cast( + list[SessionTransaction], + SessionTransaction.find( + SessionTransaction.session_id == session_id + ).all(), + ) + if not session_transactions: + if role == "client": + raise HTTPException(status_code=404, detail="Session not found") + else: + session_transaction = SessionTransaction( + session_id=session_id, + server_id=id, + client_id="", + message_list=[], + ) + session_transaction.save() + return [] + else: + if role == "client": + if len(session_transactions) > 1: + raise HTTPException( + status_code=500, + detail="Multiple session transactions found", + ) + session_transaction = session_transactions[0] + session_transaction.client_id = id + session_transaction.save() + return session_transaction.message_list + else: + raise HTTPException(status_code=500, detail="Session exists") + + +async def _get_single_exist_session(session_id: str) -> SessionTransaction: + session_transactions = cast( + list[SessionTransaction], + SessionTransaction.find( + SessionTransaction.session_id == session_id + ).all(), + ) + if not session_transactions: + raise HTTPException(status_code=404, detail="Session not found") + elif len(session_transactions) > 1: + raise HTTPException( + status_code=500, detail="Multiple session transactions found" + ) + else: + return session_transactions[0] + + +@app.post("/send/{session_id}/{sender_id}") +async def send( + session_id: str, + sender_id: str, + message: str = Body(...), +) -> list[MessageTransaction]: + session_transaction = await _get_single_exist_session(session_id) + sender: str = "" + if sender_id == session_transaction.server_id: + # Sender is server + sender = "server" + elif sender_id == session_transaction.client_id: + # Sender is client + if session_transaction.client_action_lock == "no action": + raise HTTPException( + status_code=412, detail="Client cannot take action now." + ) + sender = "client" + else: + raise HTTPException(status_code=401, detail="Unauthorized sender") + + session_transaction.message_list.append( + MessageTransaction( + timestamp_str=str(datetime.now().timestamp()), + sender=sender, + message=message, + ) + ) + try: + session_transaction.save() + except pydantic.error_wrappers.ValidationError: + raise HTTPException(status_code=500, detail="timestamp error") + return session_transaction.message_list + + +@app.put("/lock/{session_id}/{server_id}/{lock}") +async def lock( + session_id: str, server_id: str, lock: Literal["no action", "action"] +) -> str: + session_transaction = await _get_single_exist_session(session_id) + if server_id != session_transaction.server_id: + raise HTTPException(status_code=401, detail="Unauthorized sender") + session_transaction.client_action_lock = lock + session_transaction.save() + return "success" + + +@app.get("/get/{session_id}") +async def get(session_id: str) -> list[MessageTransaction]: + session_transaction = await _get_single_exist_session(session_id) + return session_transaction.message_list + + +@app.delete("/delete/{session_id}/{server_id}") +async def delete(session_id: str, server_id: str) -> str: + session_transaction = await _get_single_exist_session(session_id) + if server_id != session_transaction.server_id: + raise HTTPException(status_code=401, detail="Unauthorized sender") + session_transaction.delete(session_transaction.pk) + return "success" + + +@app.get("/get_lock/{session_id}") +async def get_lock(session_id: str) -> str: + session_transaction = await _get_single_exist_session(session_id) + return session_transaction.client_action_lock + + +def _start_server(session_ids: list[str]) -> None: + print("start server", session_ids) + subprocess.Popen( + [ + "python", + "chat_server.py", + "start-server-with-session-ids", + *session_ids, + ] + ) + + +@app.get("/enter_waiting_room/{sender_id}") +async def enter_waiting_room(sender_id: str) -> str: + matchings_in_waiting_room = cast( + list[MatchingInWaitingRoom], + MatchingInWaitingRoom.find().all(), + ) + for matching_in_waiting_room in matchings_in_waiting_room: + if sender_id in matching_in_waiting_room.client_ids: + index = matching_in_waiting_room.client_ids.index(sender_id) + match (index): + case 0: + if len(matching_in_waiting_room.client_ids) > 1: + _start_server(matching_in_waiting_room.session_ids) + matching_in_waiting_room.session_id_retrieved[ + 0 + ] = "true" + return matching_in_waiting_room.session_ids[0] + else: + if ( + datetime.now().timestamp() + - matching_in_waiting_room.timestamp + > WAITING_ROOM_TIMEOUT + ): + MatchingInWaitingRoom.delete( + matching_in_waiting_room.pk + ) + _start_server(matching_in_waiting_room.session_ids) + return matching_in_waiting_room.session_ids[0] + else: + return "" + case 1: + if matching_in_waiting_room.session_id_retrieved[0]: + if ( + datetime.now().timestamp() + - matching_in_waiting_room.timestamp + > WAITING_ROOM_TIMEOUT + ): + MatchingInWaitingRoom.delete( + matching_in_waiting_room.pk + ) + _start_server( + matching_in_waiting_room.session_ids[1:] + ) + return matching_in_waiting_room.session_ids[1] + else: + return "" + else: + matching_in_waiting_room.session_id_retrieved[ + 1 + ] = "true" + MatchingInWaitingRoom.delete( + matching_in_waiting_room.pk + ) + return matching_in_waiting_room.session_ids[1] + case _: + assert ( + False + ), f"{matching_in_waiting_room} has more than 2 clients, not expected" + else: + lock = Lock(conn, f"lock:check_available_spots") + with lock: + matchings_in_waiting_room = cast( + list[MatchingInWaitingRoom], + MatchingInWaitingRoom.find().all(), + ) + for matching_in_waiting_room in matchings_in_waiting_room: + if len(matching_in_waiting_room.client_ids) == 1: + matching_in_waiting_room.timestamp = ( + datetime.now().timestamp() + ) + matching_in_waiting_room.client_ids.append(sender_id) + matching_in_waiting_room.session_ids.append( + str(uuid.uuid4()) + ) + matching_in_waiting_room.session_id_retrieved.append("") + matching_in_waiting_room.save() + return "" + + matching_in_waiting_room = MatchingInWaitingRoom( + timestamp=datetime.now().timestamp(), + client_ids=[sender_id], + session_ids=[str(uuid.uuid4())], + session_id_retrieved=[""], + ) + matching_in_waiting_room.save() + return "" + + +from starlette.responses import Response + + +class PrettyJSONResponse(Response): + media_type = "application/json" + + def render(self, content: typing.Any) -> bytes: + return json.dumps( + content, + ensure_ascii=False, + allow_nan=False, + indent=4, + separators=(", ", ": "), + ).encode("utf-8") + + +@app.get("/get_episode/{episode_id}", response_class=PrettyJSONResponse) +async def get_episode(episode_id: str) -> EpisodeLog: + try: + episode_log = EpisodeLog.get(pk=episode_id) + except Exception as e: + raise HTTPException(status_code=404, detail=f"Episode not found: {e}") + return episode_log + + +@app.get("/get_agent/{agent_id}", response_class=PrettyJSONResponse) +async def get_agent(agent_id: str) -> AgentProfile: + try: + agent_profile = AgentProfile.get(pk=agent_id) + except Exception as e: + raise HTTPException(status_code=404, detail=f"Agent not found: {e}") + return agent_profile + + +client = TestClient(app) + + +def test_connect() -> None: + session_id = str(uuid.uuid4()) + server_id = str(uuid.uuid4()) + response = client.post(f"/connect/{session_id}/server/{server_id}") + assert response.status_code == 200 + assert response.json() == [] + + sessions = cast( + list[SessionTransaction], + SessionTransaction.find( + SessionTransaction.session_id == session_id + ).all(), + ) + assert len(sessions) == 1 + assert sessions[0].server_id == server_id + assert sessions[0].client_id == "" + assert sessions[0].message_list == [] + SessionTransaction.delete(sessions[0].pk) + + +def test_send_message() -> None: + session_id = str(uuid.uuid4()) + server_id = str(uuid.uuid4()) + response = client.post(f"/connect/{session_id}/server/{server_id}") + assert response.status_code == 200 + assert response.json() == [] + + response = client.post( + f"/send/{session_id}/{server_id}", + json="hello", + ) + assert response.status_code == 200 + + sessions = cast( + list[SessionTransaction], + SessionTransaction.find( + SessionTransaction.session_id == session_id + ).all(), + ) + assert len(sessions) == 1 + assert sessions[0].server_id == server_id + assert sessions[0].client_id == "" + assert len(sessions[0].message_list) == 1 + + message = sessions[0].message_list[0] + assert message.sender == "server" + assert message.message == "hello" + + +@pytest.mark.asyncio +async def test_waiting_room() -> None: + async def _join_after_seconds( + seconds: float, + ) -> str: + sender_id = str(uuid.uuid4()) + await asyncio.sleep(seconds) + while True: + response = client.get(f"/enter_waiting_room/{sender_id}") + if response.text: + break + await asyncio.sleep(0.1) + return str(response.text) + + async with asyncio.timeout(200): + _ = await asyncio.gather( + _join_after_seconds(random.random() * 199), + _join_after_seconds(random.random() * 199), + _join_after_seconds(random.random() * 199), + _join_after_seconds(random.random() * 199), + _join_after_seconds(random.random() * 199), + ) diff --git a/sotopia/agents/__init__.py b/sotopia/agents/__init__.py index 8788a210d..e56b1c4f2 100644 --- a/sotopia/agents/__init__.py +++ b/sotopia/agents/__init__.py @@ -4,6 +4,7 @@ generate_background_conversation, ) from .llm_agent import Agents, HumanAgent, LLMAgent, SpeakAgent +from .redis_agent import RedisAgent __all__ = [ "BaseAgent", @@ -13,4 +14,5 @@ "SpeakAgent", "generate_background", "generate_background_conversation", + "RedisAgent", ] diff --git a/sotopia/agents/redis_agent.py b/sotopia/agents/redis_agent.py new file mode 100644 index 000000000..c5956d985 --- /dev/null +++ b/sotopia/agents/redis_agent.py @@ -0,0 +1,173 @@ +import asyncio +import logging +import os +import time +from datetime import datetime +from uuid import uuid4 + +import aiohttp +import pydantic +import redis +import redis.asyncio as aredis +import requests + +from sotopia.agents import BaseAgent +from sotopia.database import AgentProfile, MessageTransaction +from sotopia.messages import AgentAction, Observation + +_URL = "http://tiger.lti.cs.cmu.edu:8000" + + +class RedisAgent(BaseAgent[Observation, AgentAction]): + """An agent use redis as a message broker.""" + + def __init__( + self, + agent_name: str | None = None, + uuid_str: str | None = None, + session_id: str | None = None, + agent_profile: AgentProfile | None = None, + ) -> None: + super().__init__( + agent_name=agent_name, + uuid_str=uuid_str, + agent_profile=agent_profile, + ) + # super().__init__(agent_name=agent_name, uuid_str=uuid_str) + self.session_id = session_id or str(uuid4()) + self.sender_id = str(uuid4()) + response = requests.request( + "POST", f"{_URL}/connect/{self.session_id}/server/{self.sender_id}" + ) + assert ( + response.status_code == 200 and response.text == "[]" + ), "Failed to connect to the server" + logging.info(f"Session ID: {self.session_id}") + # logging.info(f"Sender ID: {self.sender_id}") + + def act( + self, + obs: Observation, + ) -> AgentAction: + raise NotImplementedError + + async def aact( + self, + obs: Observation, + ) -> AgentAction: + self.recv_message("Environment", obs) + + if len(obs.available_actions) == 1 and "none" in obs.available_actions: + if obs.turn_number == 0: + async with aiohttp.ClientSession() as session: + response = await session.request( + "POST", + f"{_URL}/send/{self.session_id}/{self.sender_id}", + data=obs.to_natural_language(), + ) + assert response.status == 200, response + sorted_message_list: list[tuple[float, str, str]] = list( + map( + lambda x: MessageTransaction.parse_obj( + x + ).to_tuple(), + await response.json(), + ) + ) + last_timestamp = sorted_message_list[-1][0] + return AgentAction(action_type="none", argument="") + else: + async with aiohttp.ClientSession() as session: + # 1. post observation to the message list + response = await session.request( + "POST", + f"{_URL}/send/{self.session_id}/{self.sender_id}", + data=obs.to_natural_language(), + ) + assert response.status == 200, response + sorted_message_list = list( + map( + lambda x: MessageTransaction.parse_obj(x).to_tuple(), + await response.json(), + ) + ) + last_timestamp = sorted_message_list[-1][0] + + print("step 2: unlock the server for the client") + # 2. unlock the server for the client + response = await session.request( + "PUT", + f"{_URL}/lock/{self.session_id}/{self.sender_id}/action", + ) + assert response.status == 200, response + + print("step 3: wait for the client to post their message") + # 3. wait for the client to post their message + for _ in range(300): + response = await session.request( + "GET", + f"{_URL}/get/{self.session_id}", + ) + # print(f"get response: {response}") + assert response.status == 200, response + sorted_message_list = list( + map( + lambda x: MessageTransaction.parse_obj( + x + ).to_tuple(), + await response.json(), + ) + ) + if ( + sorted_message_list[-1][0] > last_timestamp + and sorted_message_list[-1][1] == "client" + ): + # 3.a if the client has posted their message, lock the server for the client + response = await session.request( + "PUT", + f"{_URL}/lock/{self.session_id}/{self.sender_id}/no%20action", + ) + assert response.status == 200, response + break + else: + # 3.b if the client has not posted their message, wait for 0.1 second and retry + await asyncio.sleep(1) + else: + response = await session.request( + "PUT", + f"{_URL}/lock/{self.session_id}/{self.sender_id}/no%20action", + ) + self.reset( + "Someone has left or the conversation is too long." + ) + return AgentAction(action_type="leave", argument="") + action_string = sorted_message_list[-1][2] + try: + action = AgentAction.parse_raw(action_string) + return action + except pydantic.error_wrappers.ValidationError: + logging.warn( + "Failed to parse action string {}. Fall back to speak".format( + action_string + ) + ) + return AgentAction( + action_type="speak", argument=sorted_message_list[-1][2] + ) + + def reset( + self, + reset_reason: str = "", + ) -> None: + super().reset() + try: + if reset_reason != "": + response = requests.request( + "POST", + f"{_URL}/send/{self.session_id}/{self.sender_id}", + json=reset_reason, + ) + assert response.status_code == 200 + + except Exception as e: + logging.error(f"Failed to reset RedisAgent {self.sender_id}: {e}") diff --git a/sotopia/database/__init__.py b/sotopia/database/__init__.py index 338288d31..32c7037cc 100644 --- a/sotopia/database/__init__.py +++ b/sotopia/database/__init__.py @@ -7,6 +7,8 @@ RelationshipProfile, RelationshipType, ) +from .session_transaction import MessageTransaction, SessionTransaction +from .waiting_room import MatchingInWaitingRoom __all__ = [ "AgentProfile", @@ -17,4 +19,8 @@ "Annotator", "RelationshipProfile", "RelationshipType", + "RedisCommunicationMixin", + "SessionTransaction", + "MessageTransaction", + "MatchingInWaitingRoom", ] diff --git a/sotopia/database/auto_expires_mixin.py b/sotopia/database/auto_expires_mixin.py new file mode 100644 index 000000000..90cd4875e --- /dev/null +++ b/sotopia/database/auto_expires_mixin.py @@ -0,0 +1,12 @@ +from redis_om import JsonModel +from redis_om.model.model import Field + +DEFAULT_EXPIRE_TIME = 60 * 60 * 24 * 7 # 7 day + + +class AutoExpireMixin(JsonModel): + expire_time: int = Field(default=DEFAULT_EXPIRE_TIME, index=True) + + def save(self) -> None: + super().save() + self.expire(self.expire_time) diff --git a/sotopia/database/handshake.py b/sotopia/database/handshake.py new file mode 100644 index 000000000..c0f6f2877 --- /dev/null +++ b/sotopia/database/handshake.py @@ -0,0 +1,2 @@ +""" +""" diff --git a/sotopia/database/persistent_profile.py b/sotopia/database/persistent_profile.py index d6a9748a7..2ea0a9e8b 100644 --- a/sotopia/database/persistent_profile.py +++ b/sotopia/database/persistent_profile.py @@ -2,7 +2,7 @@ from enum import IntEnum from typing import Any, cast -from pydantic import validator +from pydantic import root_validator, validator from redis_om import JsonModel from redis_om.model.model import Field @@ -81,3 +81,23 @@ class RelationshipProfile(JsonModel): description="0 means stranger, 1 means know_by_name, 2 means acquaintance, 3 means friend, 4 means romantic_relationship, 5 means family_member", ) # this could be improved by limiting str to a relationship Enum background_story: str | None = Field(default_factory=lambda: None) + + +class EnvironmentList(JsonModel): + name: str = Field(index=True) + environments: list[str] = Field(default_factory=lambda: []) + agent_index: list[str] | None = Field(default_factory=lambda: None) + + # validate the length of agent_index should be same as environments + @root_validator + def the_length_agent_index_matches_environments(cls, values: Any) -> Any: + environments, agent_index = ( + values.get("environments"), + values.get("agent_index"), + ) + if agent_index is None: + return values + assert len(environments) == len( + agent_index + ), f"Number of environments {len(environments)} and agent_index {len(agent_index)} do not match" + return values diff --git a/sotopia/database/session_transaction.py b/sotopia/database/session_transaction.py new file mode 100644 index 000000000..46255c3fe --- /dev/null +++ b/sotopia/database/session_transaction.py @@ -0,0 +1,47 @@ +from pydantic import Field as PydanticField +from pydantic import validator +from redis_om import EmbeddedJsonModel, JsonModel +from redis_om.model.model import Field + +from sotopia.messages import AgentAction, Observation, SimpleMessage + +from .auto_expires_mixin import AutoExpireMixin + + +class MessageTransaction(EmbeddedJsonModel): + timestamp_str: str = Field(index=True) + sender: str = Field(index=True) + message: str + + def to_tuple(self) -> tuple[float, str, str]: + return ( + float(self.timestamp_str), + self.sender, + self.message, + ) + + +class SessionTransaction(AutoExpireMixin, JsonModel): + session_id: str = Field(index=True) + client_id: str = Field(index=True) + server_id: str = Field(index=True) + client_action_lock: str = Field(default="no action") + message_list: list[MessageTransaction] = Field( + description="""List of messages in this session. + Each message is a tuple of (timestamp, sender_id, message) + The message list should be sorted by timestamp. + """ + ) + + @validator("message_list") + def validate_message_list( + cls, v: list[MessageTransaction] + ) -> list[MessageTransaction]: + def _is_sorted(x: list[MessageTransaction]) -> bool: + return all( + float(x[i].timestamp_str) <= float(x[i + 1].timestamp_str) + for i in range(len(x) - 1) + ) + + assert _is_sorted(v), "Message list should be sorted by timestamp" + return v diff --git a/sotopia/database/waiting_room.py b/sotopia/database/waiting_room.py new file mode 100644 index 000000000..bde8c465e --- /dev/null +++ b/sotopia/database/waiting_room.py @@ -0,0 +1,11 @@ +from redis_om import JsonModel +from redis_om.model.model import Field + +from .auto_expires_mixin import AutoExpireMixin + + +class MatchingInWaitingRoom(AutoExpireMixin, JsonModel): + timestamp: float = Field() + client_ids: list[str] = Field(default_factory=lambda: []) + session_ids: list[str] = Field(default_factory=lambda: []) + session_id_retrieved: list[str] = Field(default_factory=lambda: []) diff --git a/sotopia/generation_utils/generate.py b/sotopia/generation_utils/generate.py index 05ef0a471..93211affb 100644 --- a/sotopia/generation_utils/generate.py +++ b/sotopia/generation_utils/generate.py @@ -48,6 +48,7 @@ "text-davinci-003", "gpt-4", "human", + "redis", ] OutputType = TypeVar("OutputType", bound=object) diff --git a/stubs/redis_om/__init__.pyi b/stubs/redis_om/__init__.pyi index a9362a0ea..6a7f03467 100644 --- a/stubs/redis_om/__init__.pyi +++ b/stubs/redis_om/__init__.pyi @@ -14,6 +14,9 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): @classmethod def delete(cls, pk: Any) -> None: ... + def expire( + self, num_seconds: int + ) -> None: ... # pipeline arg can be added here class HashModel(RedisModel, abc.ABC): @classmethod @@ -28,3 +31,8 @@ class JsonModel(RedisModel, abc.ABC): @classmethod def find(cls, *args: Any, **kwargs: Any) -> FindQuery: ... def save(self) -> None: ... + +class EmbeddedJsonModel(JsonModel): ... + +class Migrator: + def run(self) -> None: ... diff --git a/tests/generation_utils/test_generation.py b/tests/generation_utils/test_generation.py index a9c33b0e3..b5827c0d8 100644 --- a/tests/generation_utils/test_generation.py +++ b/tests/generation_utils/test_generation.py @@ -43,6 +43,7 @@ async def test_agenerate_list_integer() -> None: assert all(lower <= i <= upper for i in l) +@pytest.mark.skip(reason="togethercompute out of credit") @pytest.mark.asyncio async def test_agenerate_list_integer_together() -> None: """