diff --git a/letta/agent.py b/letta/agent.py index 810f821513..532c4e13a7 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -18,7 +18,7 @@ MESSAGE_SUMMARY_WARNING_FRAC, O1_BASE_TOOLS, REQ_HEARTBEAT_MESSAGE, - STRUCTURED_OUTPUT_MODELS + STRUCTURED_OUTPUT_MODELS, ) from letta.errors import LLMError from letta.helpers import ToolRulesSolver @@ -260,9 +260,6 @@ def __init__( self.user = user - # link tools - self.link_tools(agent_state.tools) - # initialize a tool rules solver if agent_state.tool_rules: # if there are tool rules, print out a warning @@ -385,7 +382,9 @@ def __init__( def check_tool_rules(self): if self.model not in STRUCTURED_OUTPUT_MODELS: if len(self.tool_rules_solver.init_tool_rules) > 1: - raise ValueError("Multiple initial tools are not supported for non-structured models. Please use only one initial tool rule.") + raise ValueError( + "Multiple initial tools are not supported for non-structured models. Please use only one initial tool rule." + ) self.supports_structured_output = False else: self.supports_structured_output = True @@ -424,11 +423,21 @@ def update_memory_if_change(self, new_memory: Memory) -> bool: return True return False - def execute_tool_and_persist_state(self, function_name, function_to_call, function_args): + def execute_tool_and_persist_state(self, function_name: str, function_args: dict, target_letta_tool: Tool): """ Execute tool modifications and persist the state of the agent. Note: only some agent state modifications will be persisted, such as data in the AgentState ORM and block data """ + # TODO: Get rid of this. This whole piece is pretty shady, that we exec the function to just get the type hints for args. + env = {} + env.update(globals()) + exec(target_letta_tool.source_code, env) + callable_func = env[target_letta_tool.json_schema["name"]] + spec = inspect.getfullargspec(callable_func).annotations + for name, arg in function_args.items(): + if isinstance(function_args[name], dict): + function_args[name] = spec[name](**function_args[name]) + # TODO: add agent manager here orig_memory_str = self.agent_state.memory.compile() @@ -441,11 +450,11 @@ def execute_tool_and_persist_state(self, function_name, function_to_call, functi if function_name in BASE_TOOLS or function_name in O1_BASE_TOOLS: # base tools are allowed to access the `Agent` object and run on the database function_args["self"] = self # need to attach self to arg since it's dynamically linked - function_response = function_to_call(**function_args) + function_response = callable_func(**function_args) else: # execute tool in a sandbox # TODO: allow agent_state to specify which sandbox to execute tools in - sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.agent_state.created_by_id).run( + sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.user).run( agent_state=self.agent_state.__deepcopy__() ) function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state @@ -470,27 +479,6 @@ def messages(self) -> List[dict]: def messages(self, value): raise Exception("Modifying message list directly not allowed") - def link_tools(self, tools: List[Tool]): - """Bind a tool object (schema + python function) to the agent object""" - - # Store the functions schemas (this is passed as an argument to ChatCompletion) - self.functions = [] - self.functions_python = {} - env = {} - env.update(globals()) - for tool in tools: - try: - # WARNING: name may not be consistent? - # if tool.module: # execute the whole module - # exec(tool.module, env) - # else: - exec(tool.source_code, env) - self.functions_python[tool.json_schema["name"]] = env[tool.json_schema["name"]] - self.functions.append(tool.json_schema) - except Exception: - warnings.warn(f"WARNING: tool {tool.name} failed to link") - assert all([callable(f) for k, f in self.functions_python.items()]), self.functions_python - def _load_messages_from_recall(self, message_ids: List[str]) -> List[Message]: """Load a list of messages from recall storage""" @@ -599,8 +587,12 @@ def _get_ai_reply( """Get response from LLM API with robust retry mechanism.""" allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names() + agent_state_tool_jsons = [t.json_schema for t in self.agent_state.tools] + allowed_functions = ( - self.functions if not allowed_tool_names else [func for func in self.functions if func["name"] in allowed_tool_names] + agent_state_tool_jsons + if not allowed_tool_names + else [func for func in agent_state_tool_jsons if func["name"] in allowed_tool_names] ) # For the first message, force the initial tool if one is specified @@ -620,7 +612,7 @@ def _get_ai_reply( messages=message_sequence, user_id=self.agent_state.created_by_id, functions=allowed_functions, - functions_python=self.functions_python, + # functions_python=self.functions_python, do we need this? function_call=function_call, first_message=first_message, force_tool_call=force_tool_call, @@ -729,10 +721,13 @@ def _handle_ai_response( function_name = function_call.name printd(f"Request to call function {function_name} with tool_call_id: {tool_call_id}") - # Failure case 1: function name is wrong - try: - function_to_call = self.functions_python[function_name] - except KeyError: + # Failure case 1: function name is wrong (not in agent_state.tools) + target_letta_tool = None + for t in self.agent_state.tools: + if t.name == function_name: + target_letta_tool = t + + if not target_letta_tool: error_msg = f"No function named {function_name}" function_response = package_function_response(False, error_msg) messages.append( @@ -800,14 +795,8 @@ def _handle_ai_response( # this is because the function/tool role message is only created once the function/tool has executed/returned self.interface.function_message(f"Running {function_name}({function_args})", msg_obj=messages[-1]) try: - spec = inspect.getfullargspec(function_to_call).annotations - - for name, arg in function_args.items(): - if isinstance(function_args[name], dict): - function_args[name] = spec[name](**function_args[name]) - # handle tool execution (sandbox) and state updates - function_response = self.execute_tool_and_persist_state(function_name, function_to_call, function_args) + function_response = self.execute_tool_and_persist_state(function_name, function_args, target_letta_tool) # handle trunction if function_name in ["conversation_search", "conversation_search_date", "archival_memory_search"]: @@ -819,8 +808,7 @@ def _handle_ai_response( truncate = True # get the function response limit - tool_obj = [tool for tool in self.agent_state.tools if tool.name == function_name][0] - return_char_limit = tool_obj.return_char_limit + return_char_limit = target_letta_tool.return_char_limit function_response_string = validate_function_response( function_response, return_char_limit=return_char_limit, truncate=truncate ) @@ -1564,9 +1552,10 @@ def get_context_window(self) -> ContextWindowOverview: num_tokens_external_memory_summary = count_tokens(external_memory_summary) # tokens taken up by function definitions - if self.functions: - available_functions_definitions = [ChatCompletionRequestTool(type="function", function=f) for f in self.functions] - num_tokens_available_functions_definitions = num_tokens_from_functions(functions=self.functions, model=self.model) + agent_state_tool_jsons = [t.json_schema for t in self.agent_state.tools] + if agent_state_tool_jsons: + available_functions_definitions = [ChatCompletionRequestTool(type="function", function=f) for f in agent_state_tool_jsons] + num_tokens_available_functions_definitions = num_tokens_from_functions(functions=agent_state_tool_jsons, model=self.model) else: available_functions_definitions = [] num_tokens_available_functions_definitions = 0 diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index e1eb591951..61b89624d4 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -195,7 +195,7 @@ def run_tool_from_source( tool_source_type=request.source_type, tool_args=request.args, tool_name=request.name, - user_id=actor.id, + actor=actor, ) except LettaToolCreateError as e: # HTTP 400 == Bad Request diff --git a/letta/server/server.py b/letta/server/server.py index fa9ca8ed9f..31c87394d7 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -853,10 +853,6 @@ def update_agent( # then (2) setting the attributes ._messages and .state.message_ids letta_agent.set_message_buffer(message_ids=request.message_ids) - # tools - if request.tool_ids: - letta_agent.link_tools(letta_agent.agent_state.tools) - letta_agent.update_state() return agent_state @@ -882,11 +878,6 @@ def add_tool_to_agent( agent_state = self.agent_manager.attach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor) - # TODO: This is very redundant, and should probably be simplified - # Get the agent object (loaded in memory) - letta_agent = self.load_agent(agent_id=agent_id, actor=actor) - letta_agent.link_tools(agent_state.tools) - return agent_state def remove_tool_from_agent( @@ -900,10 +891,6 @@ def remove_tool_from_agent( actor = self.user_manager.get_user_or_default(user_id=user_id) agent_state = self.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor) - # Get the agent object (loaded in memory) - letta_agent = self.load_agent(agent_id=agent_id, actor=actor) - letta_agent.link_tools(agent_state.tools) - return agent_state # convert name->id @@ -1309,9 +1296,7 @@ def get_llm_config_from_handle(self, handle: str, context_window_limit: Optional if context_window_limit: if context_window_limit > llm_config.context_window: - raise ValueError( - f"Context window limit ({context_window_limit}) is greater than maximum of ({llm_config.context_window})" - ) + raise ValueError(f"Context window limit ({context_window_limit}) is greater than maximum of ({llm_config.context_window})") llm_config.context_window = context_window_limit return llm_config @@ -1366,7 +1351,7 @@ def get_agent_context_window( def run_tool_from_source( self, - user_id: str, + actor: User, tool_args: str, tool_source: str, tool_source_type: Optional[str] = None, @@ -1394,7 +1379,7 @@ def run_tool_from_source( # Next, attempt to run the tool with the sandbox try: - sandbox_run_result = ToolExecutionSandbox(tool.name, tool_args_dict, user_id, tool_object=tool).run(agent_state=agent_state) + sandbox_run_result = ToolExecutionSandbox(tool.name, tool_args_dict, actor, tool_object=tool).run(agent_state=agent_state) return FunctionReturn( id="null", function_call_id="null", @@ -1406,9 +1391,7 @@ def run_tool_from_source( ) except Exception as e: - func_return = get_friendly_error_msg( - function_name=tool.name, exception_name=type(e).__name__, exception_message=str(e) - ) + func_return = get_friendly_error_msg(function_name=tool.name, exception_name=type(e).__name__, exception_message=str(e)) return FunctionReturn( id="null", function_call_id="null", diff --git a/letta/services/tool_execution_sandbox.py b/letta/services/tool_execution_sandbox.py index b6004c3c74..fc6e1bdd38 100644 --- a/letta/services/tool_execution_sandbox.py +++ b/letta/services/tool_execution_sandbox.py @@ -16,9 +16,9 @@ from letta.schemas.agent import AgentState from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult, SandboxType from letta.schemas.tool import Tool +from letta.schemas.user import User from letta.services.sandbox_config_manager import SandboxConfigManager from letta.services.tool_manager import ToolManager -from letta.services.user_manager import UserManager from letta.settings import tool_settings from letta.utils import get_friendly_error_msg @@ -38,14 +38,10 @@ class ToolExecutionSandbox: # We make this a long random string to avoid collisions with any variables in the user's code LOCAL_SANDBOX_RESULT_VAR_NAME = "result_ZQqiequkcFwRwwGQMqkt" - def __init__(self, tool_name: str, args: dict, user_id: str, force_recreate=False, tool_object: Optional[Tool] = None): + def __init__(self, tool_name: str, args: dict, user: User, force_recreate=False, tool_object: Optional[Tool] = None): self.tool_name = tool_name self.args = args - - # Get the user - # This user corresponds to the agent_state's user_id field - # agent_state is the state of the agent that invoked this run - self.user = UserManager().get_user_by_id(user_id=user_id) + self.user = user # If a tool object is provided, we use it directly, otherwise pull via name if tool_object is not None: @@ -184,7 +180,9 @@ def run_local_dir_sandbox_venv(self, sbx_config: SandboxConfig, env: Dict[str, s except subprocess.CalledProcessError as e: logger.error(f"Executing tool {self.tool_name} has process error: {e}") func_return = get_friendly_error_msg( - function_name=self.tool_name, exception_name=type(e).__name__, exception_message=str(e), + function_name=self.tool_name, + exception_name=type(e).__name__, + exception_message=str(e), ) return SandboxRunResult( func_return=func_return, @@ -202,9 +200,7 @@ def run_local_dir_sandbox_venv(self, sbx_config: SandboxConfig, env: Dict[str, s logger.error(f"Executing tool {self.tool_name} has an unexpected error: {e}") raise e - def run_local_dir_sandbox_runpy( - self, sbx_config: SandboxConfig, env_vars: Dict[str, str], temp_file_path: str - ) -> SandboxRunResult: + def run_local_dir_sandbox_runpy(self, sbx_config: SandboxConfig, env_vars: Dict[str, str], temp_file_path: str) -> SandboxRunResult: status = "success" agent_state, stderr = None, None @@ -225,9 +221,7 @@ def run_local_dir_sandbox_runpy( func_return, agent_state = self.parse_best_effort(func_result) except Exception as e: - func_return = get_friendly_error_msg( - function_name=self.tool_name, exception_name=type(e).__name__, exception_message=str(e) - ) + func_return = get_friendly_error_msg(function_name=self.tool_name, exception_name=type(e).__name__, exception_message=str(e)) traceback.print_exc(file=sys.stderr) status = "error" @@ -248,7 +242,7 @@ def run_local_dir_sandbox_runpy( def parse_out_function_results_markers(self, text: str): if self.LOCAL_SANDBOX_RESULT_START_MARKER not in text: - return '', text + return "", text marker_len = len(self.LOCAL_SANDBOX_RESULT_START_MARKER) start_index = text.index(self.LOCAL_SANDBOX_RESULT_START_MARKER) + marker_len end_index = text.index(self.LOCAL_SANDBOX_RESULT_END_MARKER) @@ -293,6 +287,7 @@ def run_e2b_sandbox(self, agent_state: AgentState) -> SandboxRunResult: env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100) code = self.generate_execution_script(agent_state=agent_state) execution = sbx.run_code(code, envs=env_vars) + if execution.results: func_return, agent_state = self.parse_best_effort(execution.results[0].text) elif execution.error: @@ -303,7 +298,7 @@ def run_e2b_sandbox(self, agent_state: AgentState) -> SandboxRunResult: execution.logs.stderr.append(execution.error.traceback) else: raise ValueError(f"Tool {self.tool_name} returned execution with None") - + return SandboxRunResult( func_return=func_return, agent_state=agent_state, diff --git a/tests/helpers/endpoints_helper.py b/tests/helpers/endpoints_helper.py index fbe7ffb60b..ddaa1d960d 100644 --- a/tests/helpers/endpoints_helper.py +++ b/tests/helpers/endpoints_helper.py @@ -110,8 +110,7 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str) -> ChatComplet llm_config=agent_state.llm_config, user_id=str(uuid.UUID(int=1)), # dummy user_id messages=agent._messages, - functions=agent.functions, - functions_python=agent.functions_python, + functions=[t.json_schema for t in agent.agent_state.tools], ) # Basic check diff --git a/tests/integration_test_tool_execution_sandbox.py b/tests/integration_test_tool_execution_sandbox.py index 1b29073f43..299e1e967b 100644 --- a/tests/integration_test_tool_execution_sandbox.py +++ b/tests/integration_test_tool_execution_sandbox.py @@ -283,12 +283,12 @@ def test_local_sandbox_default(mock_e2b_api_key_none, add_integers_tool, test_us # Mock and assert correct pathway was invoked with patch.object(ToolExecutionSandbox, "run_local_dir_sandbox") as mock_run_local_dir_sandbox: - sandbox = ToolExecutionSandbox(add_integers_tool.name, args, user_id=test_user.id) + sandbox = ToolExecutionSandbox(add_integers_tool.name, args, user=test_user) sandbox.run() mock_run_local_dir_sandbox.assert_called_once() # Run again to get actual response - sandbox = ToolExecutionSandbox(add_integers_tool.name, args, user_id=test_user.id) + sandbox = ToolExecutionSandbox(add_integers_tool.name, args, user=test_user) result = sandbox.run() assert result.func_return == args["x"] + args["y"] @@ -297,7 +297,7 @@ def test_local_sandbox_default(mock_e2b_api_key_none, add_integers_tool, test_us def test_local_sandbox_stateful_tool(mock_e2b_api_key_none, clear_core_memory_tool, test_user, agent_state): args = {} # Run again to get actual response - sandbox = ToolExecutionSandbox(clear_core_memory_tool.name, args, user_id=test_user.id) + sandbox = ToolExecutionSandbox(clear_core_memory_tool.name, args, user=test_user) result = sandbox.run(agent_state=agent_state) assert result.agent_state.memory.get_block("human").value == "" assert result.agent_state.memory.get_block("persona").value == "" @@ -306,7 +306,7 @@ def test_local_sandbox_stateful_tool(mock_e2b_api_key_none, clear_core_memory_to @pytest.mark.local_sandbox def test_local_sandbox_with_list_rv(mock_e2b_api_key_none, list_tool, test_user): - sandbox = ToolExecutionSandbox(list_tool.name, {}, user_id=test_user.id) + sandbox = ToolExecutionSandbox(list_tool.name, {}, user=test_user) result = sandbox.run() assert len(result.func_return) == 5 @@ -331,7 +331,7 @@ def test_local_sandbox_env(mock_e2b_api_key_none, get_env_tool, test_user): args = {} # Run the custom sandbox - sandbox = ToolExecutionSandbox(get_env_tool.name, args, user_id=test_user.id) + sandbox = ToolExecutionSandbox(get_env_tool.name, args, user=test_user) result = sandbox.run() assert long_random_string in result.func_return @@ -349,7 +349,7 @@ def test_local_sandbox_e2e_composio_star_github(mock_e2b_api_key_none, check_com actor=test_user, ) - result = ToolExecutionSandbox(composio_github_star_tool.name, {"owner": "letta-ai", "repo": "letta"}, user_id=test_user.id).run() + result = ToolExecutionSandbox(composio_github_star_tool.name, {"owner": "letta-ai", "repo": "letta"}, user=test_user).run() assert result.func_return["details"] == "Action executed successfully" @@ -359,7 +359,7 @@ def test_local_sandbox_external_codebase(mock_e2b_api_key_none, custom_test_sand args = {"percentage": 10} # Run again to get actual response - sandbox = ToolExecutionSandbox(external_codebase_tool.name, args, user_id=test_user.id) + sandbox = ToolExecutionSandbox(external_codebase_tool.name, args, user=test_user) result = sandbox.run() # Assert that the function return is correct @@ -371,14 +371,14 @@ def test_local_sandbox_external_codebase(mock_e2b_api_key_none, custom_test_sand def test_local_sandbox_with_venv_and_warnings_does_not_error( mock_e2b_api_key_none, custom_test_sandbox_config, get_warning_tool, test_user ): - sandbox = ToolExecutionSandbox(get_warning_tool.name, {}, user_id=test_user.id) + sandbox = ToolExecutionSandbox(get_warning_tool.name, {}, user=test_user) result = sandbox.run() assert result.func_return == "Hello World" @pytest.mark.e2b_sandbox def test_local_sandbox_with_venv_errors(mock_e2b_api_key_none, custom_test_sandbox_config, always_err_tool, test_user): - sandbox = ToolExecutionSandbox(always_err_tool.name, {}, user_id=test_user.id) + sandbox = ToolExecutionSandbox(always_err_tool.name, {}, user=test_user) # run the sandbox result = sandbox.run() @@ -397,12 +397,12 @@ def test_e2b_sandbox_default(check_e2b_key_is_set, add_integers_tool, test_user) # Mock and assert correct pathway was invoked with patch.object(ToolExecutionSandbox, "run_e2b_sandbox") as mock_run_local_dir_sandbox: - sandbox = ToolExecutionSandbox(add_integers_tool.name, args, user_id=test_user.id) + sandbox = ToolExecutionSandbox(add_integers_tool.name, args, user=test_user) sandbox.run() mock_run_local_dir_sandbox.assert_called_once() # Run again to get actual response - sandbox = ToolExecutionSandbox(add_integers_tool.name, args, user_id=test_user.id) + sandbox = ToolExecutionSandbox(add_integers_tool.name, args, user=test_user) result = sandbox.run() assert int(result.func_return) == args["x"] + args["y"] @@ -420,14 +420,14 @@ def test_e2b_sandbox_pip_installs(check_e2b_key_is_set, cowsay_tool, test_user): SandboxEnvironmentVariableCreate(key=key, value=long_random_string), sandbox_config_id=config.id, actor=test_user ) - sandbox = ToolExecutionSandbox(cowsay_tool.name, {}, user_id=test_user.id) + sandbox = ToolExecutionSandbox(cowsay_tool.name, {}, user=test_user) result = sandbox.run() assert long_random_string in result.stdout[0] @pytest.mark.e2b_sandbox def test_e2b_sandbox_reuses_same_sandbox(check_e2b_key_is_set, list_tool, test_user): - sandbox = ToolExecutionSandbox(list_tool.name, {}, user_id=test_user.id) + sandbox = ToolExecutionSandbox(list_tool.name, {}, user=test_user) # Run the function once result = sandbox.run() @@ -442,7 +442,7 @@ def test_e2b_sandbox_reuses_same_sandbox(check_e2b_key_is_set, list_tool, test_u @pytest.mark.e2b_sandbox def test_e2b_sandbox_stateful_tool(check_e2b_key_is_set, clear_core_memory_tool, test_user, agent_state): - sandbox = ToolExecutionSandbox(clear_core_memory_tool.name, {}, user_id=test_user.id) + sandbox = ToolExecutionSandbox(clear_core_memory_tool.name, {}, user=test_user) # run the sandbox result = sandbox.run(agent_state=agent_state) @@ -458,7 +458,7 @@ def test_e2b_sandbox_inject_env_var_existing_sandbox(check_e2b_key_is_set, get_e config = manager.create_or_update_sandbox_config(config_create, test_user) # Run the custom sandbox once, assert nothing returns because missing env variable - sandbox = ToolExecutionSandbox(get_env_tool.name, {}, user_id=test_user.id, force_recreate=True) + sandbox = ToolExecutionSandbox(get_env_tool.name, {}, user=test_user, force_recreate=True) result = sandbox.run() # response should be None assert result.func_return is None @@ -471,7 +471,7 @@ def test_e2b_sandbox_inject_env_var_existing_sandbox(check_e2b_key_is_set, get_e ) # Assert that the environment variable gets injected correctly, even when the sandbox is NOT refreshed - sandbox = ToolExecutionSandbox(get_env_tool.name, {}, user_id=test_user.id) + sandbox = ToolExecutionSandbox(get_env_tool.name, {}, user=test_user) result = sandbox.run() assert long_random_string in result.func_return @@ -487,7 +487,7 @@ def test_e2b_sandbox_config_change_force_recreates_sandbox(check_e2b_key_is_set, config = manager.create_or_update_sandbox_config(config_create, test_user) # Run the custom sandbox once, assert a failure gets returned because missing environment variable - sandbox = ToolExecutionSandbox(list_tool.name, {}, user_id=test_user.id) + sandbox = ToolExecutionSandbox(list_tool.name, {}, user=test_user) result = sandbox.run() assert len(result.func_return) == 5 old_config_fingerprint = result.sandbox_config_fingerprint @@ -497,7 +497,7 @@ def test_e2b_sandbox_config_change_force_recreates_sandbox(check_e2b_key_is_set, config = manager.update_sandbox_config(config.id, config_update, test_user) # Run again - result = ToolExecutionSandbox(list_tool.name, {}, user_id=test_user.id).run() + result = ToolExecutionSandbox(list_tool.name, {}, user=test_user).run() new_config_fingerprint = result.sandbox_config_fingerprint assert config.fingerprint() == new_config_fingerprint @@ -507,7 +507,7 @@ def test_e2b_sandbox_config_change_force_recreates_sandbox(check_e2b_key_is_set, @pytest.mark.e2b_sandbox def test_e2b_sandbox_with_list_rv(check_e2b_key_is_set, list_tool, test_user): - sandbox = ToolExecutionSandbox(list_tool.name, {}, user_id=test_user.id) + sandbox = ToolExecutionSandbox(list_tool.name, {}, user=test_user) result = sandbox.run() assert len(result.func_return) == 5 @@ -524,7 +524,7 @@ def test_e2b_e2e_composio_star_github(check_e2b_key_is_set, check_composio_key_s actor=test_user, ) - result = ToolExecutionSandbox(composio_github_star_tool.name, {"owner": "letta-ai", "repo": "letta"}, user_id=test_user.id).run() + result = ToolExecutionSandbox(composio_github_star_tool.name, {"owner": "letta-ai", "repo": "letta"}, user=test_user).run() assert result.func_return["details"] == "Action executed successfully" @@ -541,7 +541,7 @@ def test_core_memory_replace_local(self, mock_e2b_api_key_none, core_memory_tool """Test successful replacement of content in core memory - local sandbox.""" new_name = "Charles" args = {"label": "human", "old_content": "Chad", "new_content": new_name} - sandbox = ToolExecutionSandbox(core_memory_tools["core_memory_replace"].name, args, user_id=test_user.id) + sandbox = ToolExecutionSandbox(core_memory_tools["core_memory_replace"].name, args, user=test_user) result = sandbox.run(agent_state=agent_state) assert new_name in result.agent_state.memory.get_block("human").value @@ -552,7 +552,7 @@ def test_core_memory_append_local(self, mock_e2b_api_key_none, core_memory_tools """Test successful appending of content to core memory - local sandbox.""" append_text = "\nLikes coffee" args = {"label": "human", "content": append_text} - sandbox = ToolExecutionSandbox(core_memory_tools["core_memory_append"].name, args, user_id=test_user.id) + sandbox = ToolExecutionSandbox(core_memory_tools["core_memory_append"].name, args, user=test_user) result = sandbox.run(agent_state=agent_state) assert append_text in result.agent_state.memory.get_block("human").value @@ -563,7 +563,7 @@ def test_core_memory_replace_error_local(self, mock_e2b_api_key_none, core_memor """Test error handling when trying to replace non-existent content - local sandbox.""" nonexistent_name = "Alexander Wang" args = {"label": "human", "old_content": nonexistent_name, "new_content": "Charles"} - sandbox = ToolExecutionSandbox(core_memory_tools["core_memory_replace"].name, args, user_id=test_user.id) + sandbox = ToolExecutionSandbox(core_memory_tools["core_memory_replace"].name, args, user=test_user) result = sandbox.run(agent_state=agent_state) assert len(result.stderr) != 0 @@ -575,7 +575,7 @@ def test_core_memory_replace_e2b(self, check_e2b_key_is_set, core_memory_tools, """Test successful replacement of content in core memory - e2b sandbox.""" new_name = "Charles" args = {"label": "human", "old_content": "Chad", "new_content": new_name} - sandbox = ToolExecutionSandbox(core_memory_tools["core_memory_replace"].name, args, user_id=test_user.id) + sandbox = ToolExecutionSandbox(core_memory_tools["core_memory_replace"].name, args, user=test_user) result = sandbox.run(agent_state=agent_state) assert new_name in result.agent_state.memory.get_block("human").value @@ -586,7 +586,7 @@ def test_core_memory_append_e2b(self, check_e2b_key_is_set, core_memory_tools, t """Test successful appending of content to core memory - e2b sandbox.""" append_text = "\nLikes coffee" args = {"label": "human", "content": append_text} - sandbox = ToolExecutionSandbox(core_memory_tools["core_memory_append"].name, args, user_id=test_user.id) + sandbox = ToolExecutionSandbox(core_memory_tools["core_memory_append"].name, args, user=test_user) result = sandbox.run(agent_state=agent_state) assert append_text in result.agent_state.memory.get_block("human").value @@ -597,7 +597,7 @@ def test_core_memory_replace_error_e2b(self, check_e2b_key_is_set, core_memory_t """Test error handling when trying to replace non-existent content - e2b sandbox.""" nonexistent_name = "Alexander Wang" args = {"label": "human", "old_content": nonexistent_name, "new_content": "Charles"} - sandbox = ToolExecutionSandbox(core_memory_tools["core_memory_replace"].name, args, user_id=test_user.id) + sandbox = ToolExecutionSandbox(core_memory_tools["core_memory_replace"].name, args, user=test_user) result = sandbox.run(agent_state=agent_state) assert len(result.stderr) != 0 diff --git a/tests/test_server.py b/tests/test_server.py index 93159aa583..2003c68880 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -288,15 +288,16 @@ def org_id(server): @pytest.fixture(scope="module") -def user_id(server, org_id): - # create user +def user(server, org_id): user = server.user_manager.create_default_user() - print(f"Created user\n{user.id}") + yield user + server.user_manager.delete_user_by_id(user.id) - yield user.id - # cleanup - server.user_manager.delete_user_by_id(user.id) +@pytest.fixture(scope="module") +def user_id(server, user): + # create user + yield user.id @pytest.fixture(scope="module") @@ -789,11 +790,11 @@ def ingest(message: str): ''' -def test_tool_run(server, mock_e2b_api_key_none, user_id, agent_id): +def test_tool_run(server, mock_e2b_api_key_none, user, agent_id): """Test that the server can run tools""" result = server.run_tool_from_source( - user_id=user_id, + actor=user, tool_source=EXAMPLE_TOOL_SOURCE, tool_source_type="python", tool_args=json.dumps({"message": "Hello, world!"}), @@ -806,7 +807,7 @@ def test_tool_run(server, mock_e2b_api_key_none, user_id, agent_id): assert not result.stderr result = server.run_tool_from_source( - user_id=user_id, + actor=user, tool_source=EXAMPLE_TOOL_SOURCE, tool_source_type="python", tool_args=json.dumps({"message": "Well well well"}), @@ -819,7 +820,7 @@ def test_tool_run(server, mock_e2b_api_key_none, user_id, agent_id): assert not result.stderr result = server.run_tool_from_source( - user_id=user_id, + actor=user, tool_source=EXAMPLE_TOOL_SOURCE, tool_source_type="python", tool_args=json.dumps({"bad_arg": "oh no"}), @@ -835,7 +836,7 @@ def test_tool_run(server, mock_e2b_api_key_none, user_id, agent_id): # Test that we can still pull the tool out by default (pulls that last tool in the source) result = server.run_tool_from_source( - user_id=user_id, + actor=user, tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR, tool_source_type="python", tool_args=json.dumps({"message": "Well well well"}), @@ -850,7 +851,7 @@ def test_tool_run(server, mock_e2b_api_key_none, user_id, agent_id): # Test that we can pull the tool out by name result = server.run_tool_from_source( - user_id=user_id, + actor=user, tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR, tool_source_type="python", tool_args=json.dumps({"message": "Well well well"}), @@ -865,7 +866,7 @@ def test_tool_run(server, mock_e2b_api_key_none, user_id, agent_id): # Test that we can pull a different tool out by name result = server.run_tool_from_source( - user_id=user_id, + actor=user, tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR, tool_source_type="python", tool_args=json.dumps({}),