From 5ad340d2c72abb43dd3b28b42ef58ad0c7f0fafd Mon Sep 17 00:00:00 2001 From: Josh XT <102809327+Josh-XT@users.noreply.github.com> Date: Fri, 24 May 2024 11:44:44 -0400 Subject: [PATCH 0001/1256] Fix chain responses (#1192) --- agixt/Chains.py | 6 ++++-- agixt/db/Chain.py | 16 +++++++--------- agixt/fb/Chain.py | 13 ++++++++++++- 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/agixt/Chains.py b/agixt/Chains.py index 3dd73cde5f0a..cb708de11f2d 100644 --- a/agixt/Chains.py +++ b/agixt/Chains.py @@ -155,8 +155,10 @@ async def run_chain( responses[step_data["step"]] = step # Store the response. logging.info(f"Step {step_data['step']} response: {step_response}") # Write the response to the chain responses file. - await self.chain.update_chain_responses( - chain_name=chain_name, responses=responses + await self.chain.update_step_response( + chain_name=chain_name, + step_number=step_data["step"], + response=step_response, ) if all_responses: return responses diff --git a/agixt/db/Chain.py b/agixt/db/Chain.py index d10c95d525cd..b94d6a1aef43 100644 --- a/agixt/db/Chain.py +++ b/agixt/db/Chain.py @@ -630,14 +630,12 @@ def get_step_content(self, chain_name, prompt_content, user_input, agent_name): else: return prompt_content - async def update_chain_responses(self, chain_name, responses): - for response in responses: - step_data = responses[response] - chain_step = self.get_step(chain_name, step_data["step"]) - response_content = { - "chain_step_id": chain_step.id, - "content": step_data["response"], - } - chain_step_response = ChainStepResponse(**response_content) + async def update_step_response(self, chain_name, step_number, response): + chain = self.get_chain(chain_name=chain_name) + chain_step = self.get_step(chain_name=chain_name, step_number=step_number) + if chain_step: + chain_step_response = ChainStepResponse( + chain_step_id=chain_step.id, content=response + ) self.session.add(chain_step_response) self.session.commit() diff --git a/agixt/fb/Chain.py b/agixt/fb/Chain.py index 2b6f3d158d93..ddebf8c27431 100644 --- a/agixt/fb/Chain.py +++ b/agixt/fb/Chain.py @@ -257,7 +257,18 @@ def get_step_content(self, chain_name, prompt_content, user_input, agent_name): else: return prompt_content - async def update_chain_responses(self, chain_name, responses): + async def update_step_response(self, chain_name, step_number, response): file_path = get_chain_responses_file_path(chain_name=chain_name) + try: + with open(file_path, "r") as f: + responses = json.load(f) + except: + responses = {} + + if str(step_number) not in responses: + responses[str(step_number)] = [] + + responses[str(step_number)].append(response) + with open(file_path, "w") as f: json.dump(responses, f) From 9857659fae7a0f2b6e190be42ba632d2ef892560 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 24 May 2024 14:18:09 -0400 Subject: [PATCH 0002/1256] fix chain err --- agixt/db/Chain.py | 32 ++++++++++++++++++++++++++++---- agixt/fb/Chain.py | 11 ++++++++--- 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/agixt/db/Chain.py b/agixt/db/Chain.py index b94d6a1aef43..7bd699c90fb9 100644 --- a/agixt/db/Chain.py +++ b/agixt/db/Chain.py @@ -634,8 +634,32 @@ async def update_step_response(self, chain_name, step_number, response): chain = self.get_chain(chain_name=chain_name) chain_step = self.get_step(chain_name=chain_name, step_number=step_number) if chain_step: - chain_step_response = ChainStepResponse( - chain_step_id=chain_step.id, content=response + existing_response = ( + self.session.query(ChainStepResponse) + .filter(ChainStepResponse.chain_step_id == chain_step.id) + .order_by(ChainStepResponse.timestamp.desc()) + .first() ) - self.session.add(chain_step_response) - self.session.commit() + if existing_response: + if isinstance(existing_response.content, dict) and isinstance( + response, dict + ): + existing_response.content.update(response) + self.session.commit() + elif isinstance(existing_response.content, list) and isinstance( + response, list + ): + existing_response.content.extend(response) + self.session.commit() + else: + chain_step_response = ChainStepResponse( + chain_step_id=chain_step.id, content=response + ) + self.session.add(chain_step_response) + self.session.commit() + else: + chain_step_response = ChainStepResponse( + chain_step_id=chain_step.id, content=response + ) + self.session.add(chain_step_response) + self.session.commit() diff --git a/agixt/fb/Chain.py b/agixt/fb/Chain.py index ddebf8c27431..5658d77bf5ff 100644 --- a/agixt/fb/Chain.py +++ b/agixt/fb/Chain.py @@ -266,9 +266,14 @@ async def update_step_response(self, chain_name, step_number, response): responses = {} if str(step_number) not in responses: - responses[str(step_number)] = [] - - responses[str(step_number)].append(response) + responses[str(step_number)] = response + else: + if isinstance(responses[str(step_number)], dict): + responses[str(step_number)].update(response) + elif isinstance(responses[str(step_number)], list): + responses[str(step_number)].append(response) + else: + responses[str(step_number)] = response with open(file_path, "w") as f: json.dump(responses, f) From f358be78caebaa1ad47c6c102452e1124f5384d1 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 24 May 2024 14:56:40 -0400 Subject: [PATCH 0003/1256] fix error --- agixt/fb/Chain.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/agixt/fb/Chain.py b/agixt/fb/Chain.py index 5658d77bf5ff..b465bf6614da 100644 --- a/agixt/fb/Chain.py +++ b/agixt/fb/Chain.py @@ -268,10 +268,15 @@ async def update_step_response(self, chain_name, step_number, response): if str(step_number) not in responses: responses[str(step_number)] = response else: - if isinstance(responses[str(step_number)], dict): + if isinstance(responses[str(step_number)], dict) and isinstance( + response, dict + ): responses[str(step_number)].update(response) elif isinstance(responses[str(step_number)], list): - responses[str(step_number)].append(response) + if isinstance(response, list): + responses[str(step_number)].extend(response) + else: + responses[str(step_number)].append(response) else: responses[str(step_number)] = response From ab8cca6612cc250f14e451d11a67e2b08524218a Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 24 May 2024 15:13:39 -0400 Subject: [PATCH 0004/1256] default to prompt if not defined --- agixt/db/Chain.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/agixt/db/Chain.py b/agixt/db/Chain.py index 7bd699c90fb9..b0ac07e07564 100644 --- a/agixt/db/Chain.py +++ b/agixt/db/Chain.py @@ -150,6 +150,7 @@ def add_chain_step(self, chain_name, step_number, agent_name, prompt_type, promp prompt_category = prompt["prompt_category"] else: prompt_category = "Default" + argument_key = None if "prompt_name" in prompt: argument_key = "prompt_name" target_id = ( @@ -183,7 +184,20 @@ def add_chain_step(self, chain_name, step_number, agent_name, prompt_type, promp .id ) target_type = "command" - + else: + prompt["prompt_name"] = "User Input" + argument_key = "prompt_name" + target_id = ( + self.session.query(Prompt) + .filter( + Prompt.name == prompt["prompt_name"], + Prompt.user_id == self.user_id, + Prompt.prompt_category.has(name=prompt_category), + ) + .first() + .id + ) + target_type = "prompt" argument_value = prompt[argument_key] prompt_arguments = prompt.copy() del prompt_arguments[argument_key] From 660c486f7e06f339b0f4643fc1e9939399709a1d Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 24 May 2024 19:55:22 -0400 Subject: [PATCH 0005/1256] support magical auth key type --- agixt/ApiClient.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/agixt/ApiClient.py b/agixt/ApiClient.py index da83bd68b562..68d8fc8b43b9 100644 --- a/agixt/ApiClient.py +++ b/agixt/ApiClient.py @@ -3,6 +3,7 @@ from agixtsdk import AGiXTSDK from fastapi import Header, HTTPException from Defaults import getenv +from datetime import datetime logging.basicConfig( level=getenv("LOG_LEVEL"), @@ -28,6 +29,8 @@ def verify_api_key(authorization: str = Header(None)): USING_JWT = True if getenv("USING_JWT").lower() == "true" else False AGIXT_API_KEY = getenv("AGIXT_API_KEY") + if getenv("AUTH_PROVIDER") == "magicalauth": + AGIXT_API_KEY = f"{AGIXT_API_KEY}{datetime.now().strftime("%Y%m%d")}" DEFAULT_USER = getenv("DEFAULT_USER") if DEFAULT_USER == "" or DEFAULT_USER is None or DEFAULT_USER == "None": DEFAULT_USER = "USER" @@ -65,6 +68,8 @@ def get_api_client(authorization: str = Header(None)): def is_admin(email: str = "USER", api_key: str = None): AGIXT_API_KEY = getenv("AGIXT_API_KEY") + if getenv("AUTH_PROVIDER") == "magicalauth": + AGIXT_API_KEY = f"{AGIXT_API_KEY}{datetime.now().strftime("%Y%m%d")}" DB_CONNECTED = True if getenv("DB_CONNECTED").lower() == "true" else False if DB_CONNECTED != True: return True From 0b18100728ac4a9a8068fab77049ccd17c03c8d6 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 24 May 2024 20:00:24 -0400 Subject: [PATCH 0006/1256] move api key check --- agixt/ApiClient.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/agixt/ApiClient.py b/agixt/ApiClient.py index 68d8fc8b43b9..1987774068cc 100644 --- a/agixt/ApiClient.py +++ b/agixt/ApiClient.py @@ -41,6 +41,8 @@ def verify_api_key(authorization: str = Header(None)): status_code=401, detail="Authorization header is missing" ) authorization = str(authorization).replace("Bearer ", "").replace("bearer ", "") + if AGIXT_API_KEY == authorization: + return DEFAULT_USER if USING_JWT: try: token = jwt.decode( From 125c23107361639c4f2ac11f21585bf5b364a65e Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 24 May 2024 20:01:54 -0400 Subject: [PATCH 0007/1256] remove check from admin --- agixt/ApiClient.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/agixt/ApiClient.py b/agixt/ApiClient.py index 1987774068cc..bfeeac10a155 100644 --- a/agixt/ApiClient.py +++ b/agixt/ApiClient.py @@ -70,8 +70,6 @@ def get_api_client(authorization: str = Header(None)): def is_admin(email: str = "USER", api_key: str = None): AGIXT_API_KEY = getenv("AGIXT_API_KEY") - if getenv("AUTH_PROVIDER") == "magicalauth": - AGIXT_API_KEY = f"{AGIXT_API_KEY}{datetime.now().strftime("%Y%m%d")}" DB_CONNECTED = True if getenv("DB_CONNECTED").lower() == "true" else False if DB_CONNECTED != True: return True From 398758bd6b153fc50a210dd987b1a48d9d5f28a4 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 24 May 2024 20:03:49 -0400 Subject: [PATCH 0008/1256] set default to empty string --- agixt/Defaults.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/agixt/Defaults.py b/agixt/Defaults.py index 5b4d487ddf27..d3c97841f962 100644 --- a/agixt/Defaults.py +++ b/agixt/Defaults.py @@ -55,8 +55,9 @@ def getenv(var_name: str): "CHROMA_SSL": "false", "DISABLED_EXTENSIONS": "", "DISABLED_PROVIDERS": "", + "AUTH_PROVIDER": "", } - default_value = default_values[var_name] if var_name in default_values else None + default_value = default_values[var_name] if var_name in default_values else "" return os.getenv(var_name, default_value) From 6a48a9f51417d6a5075cd3024fc4c47b1dac47c4 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 24 May 2024 20:08:06 -0400 Subject: [PATCH 0009/1256] fix key --- agixt/ApiClient.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agixt/ApiClient.py b/agixt/ApiClient.py index bfeeac10a155..ea2807427245 100644 --- a/agixt/ApiClient.py +++ b/agixt/ApiClient.py @@ -30,7 +30,7 @@ def verify_api_key(authorization: str = Header(None)): USING_JWT = True if getenv("USING_JWT").lower() == "true" else False AGIXT_API_KEY = getenv("AGIXT_API_KEY") if getenv("AUTH_PROVIDER") == "magicalauth": - AGIXT_API_KEY = f"{AGIXT_API_KEY}{datetime.now().strftime("%Y%m%d")}" + AGIXT_API_KEY = AGIXT_API_KEY + str(datetime.now().strftime("%Y%m%d")) DEFAULT_USER = getenv("DEFAULT_USER") if DEFAULT_USER == "" or DEFAULT_USER is None or DEFAULT_USER == "None": DEFAULT_USER = "USER" From 85a2126102fdbbae60b83a30fb9043410cb100d9 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 24 May 2024 20:15:35 -0400 Subject: [PATCH 0010/1256] improve order of ops --- agixt/ApiClient.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/agixt/ApiClient.py b/agixt/ApiClient.py index ea2807427245..e4c64d45c6b3 100644 --- a/agixt/ApiClient.py +++ b/agixt/ApiClient.py @@ -29,11 +29,25 @@ def verify_api_key(authorization: str = Header(None)): USING_JWT = True if getenv("USING_JWT").lower() == "true" else False AGIXT_API_KEY = getenv("AGIXT_API_KEY") - if getenv("AUTH_PROVIDER") == "magicalauth": - AGIXT_API_KEY = AGIXT_API_KEY + str(datetime.now().strftime("%Y%m%d")) DEFAULT_USER = getenv("DEFAULT_USER") + authorization = str(authorization).replace("Bearer ", "").replace("bearer ", "") if DEFAULT_USER == "" or DEFAULT_USER is None or DEFAULT_USER == "None": DEFAULT_USER = "USER" + if getenv("AUTH_PROVIDER") == "magicalauth": + auth_key = AGIXT_API_KEY + str(datetime.now().strftime("%Y%m%d")) + try: + token = jwt.decode( + jwt=authorization, + key=auth_key, + algorithms=["HS256"], + ) + return token["email"] + except Exception as e: + if authorization == auth_key: + return DEFAULT_USER + if authorization != AGIXT_API_KEY: + logging.info(f"Invalid API Key: {authorization}") + raise HTTPException(status_code=401, detail="Invalid API Key") if AGIXT_API_KEY: if authorization is None: logging.info("Authorization header is missing") From 6e5d6cd3dbc9781ead34a51d8d2e30a7e1ed0e97 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 24 May 2024 20:57:07 -0400 Subject: [PATCH 0011/1256] add timezone --- .github/workflows/operation-test-with-jupyter.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/operation-test-with-jupyter.yml b/.github/workflows/operation-test-with-jupyter.yml index 69647d399c35..5762cc9e3499 100644 --- a/.github/workflows/operation-test-with-jupyter.yml +++ b/.github/workflows/operation-test-with-jupyter.yml @@ -101,6 +101,7 @@ jobs: SCHEMA: ${{ inputs.auth-schema }} LOG_LEVEL: DEBUG MFA_VERIFY: authenticator + TZ: America/New_York service-under-test: image: ${{ inputs.image }} ports: @@ -121,6 +122,7 @@ jobs: MFA_VERIFY: authenticator MODE: development STRIPE_API_KEY: ${{ inputs.stripe-api-key }} + TZ: America/New_York steps: - uses: actions/setup-python@v5 with: From f96f70291ae707cd11e2d683eb53ee7a81610323 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 24 May 2024 21:21:09 -0400 Subject: [PATCH 0012/1256] handle no user --- agixt/db/User.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/agixt/db/User.py b/agixt/db/User.py index ea9b435fd074..6688020f5c91 100644 --- a/agixt/db/User.py +++ b/agixt/db/User.py @@ -9,6 +9,8 @@ def is_agixt_admin(email: str = "", api_key: str = ""): return True session = get_session() user = session.query(User).filter_by(email=email).first() + if not user: + return False if user.role == "admin": return True return False From 6397cd8a9bd4b3fa44e3bfb314c2da8e20850791 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 24 May 2024 22:08:20 -0400 Subject: [PATCH 0013/1256] Return true admin --- agixt/ApiClient.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/agixt/ApiClient.py b/agixt/ApiClient.py index e4c64d45c6b3..b95d7649b05d 100644 --- a/agixt/ApiClient.py +++ b/agixt/ApiClient.py @@ -83,6 +83,8 @@ def get_api_client(authorization: str = Header(None)): def is_admin(email: str = "USER", api_key: str = None): + return True + # Commenting out functionality until testing is complete. AGIXT_API_KEY = getenv("AGIXT_API_KEY") DB_CONNECTED = True if getenv("DB_CONNECTED").lower() == "true" else False if DB_CONNECTED != True: From 837b042a3d3c675fdc3da08e6617c7e1c52f302a Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sat, 25 May 2024 00:11:55 -0400 Subject: [PATCH 0014/1256] fix chain name injection --- agixt/Chains.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/agixt/Chains.py b/agixt/Chains.py index cb708de11f2d..cf2919ff07ce 100644 --- a/agixt/Chains.py +++ b/agixt/Chains.py @@ -45,7 +45,10 @@ async def run_chain_step( if chain_args != {}: for arg, value in chain_args.items(): args[arg] = value - + if "chain_name" in args: + args["chain"] = args["chain_name"] + if "chain" not in args: + args["chain"] = chain_name if "conversation_name" not in args: args["conversation_name"] = f"Chain Execution History: {chain_name}" if "conversation" in args: From 1891cdeb4245d6d05fd605436054ba6f9186df69 Mon Sep 17 00:00:00 2001 From: Ikko Eltociear Ashimine Date: Wed, 29 May 2024 01:24:24 +0900 Subject: [PATCH 0015/1256] docs: update tests.ipynb (#1193) minor fix Signed-off-by: Ikko Eltociear Ashimine Co-authored-by: Josh XT <102809327+Josh-XT@users.noreply.github.com> --- tests/tests.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests.ipynb b/tests/tests.ipynb index 2ea4fc9c6284..44a6057cb6ff 100644 --- a/tests/tests.ipynb +++ b/tests/tests.ipynb @@ -1038,7 +1038,7 @@ "agent_name = \"new_agent\"\n", "instruct_resp = ApiClient.instruct(\n", " agent_name=agent_name,\n", - " user_input=\"Save a file with the the capital of France in it called 'france.txt'.\",\n", + " user_input=\"Save a file with the capital of France in it called 'france.txt'.\",\n", " conversation=\"Talk for Tests\",\n", ")\n", "print(\"Instruct response:\", instruct_resp)" From 0ea48f50e8fd178963116cdcffcc550d1a1da556 Mon Sep 17 00:00:00 2001 From: Josh XT <102809327+Josh-XT@users.noreply.github.com> Date: Tue, 28 May 2024 16:08:31 -0400 Subject: [PATCH 0016/1256] Improvements to webseach to memory and activity logging (#1194) * Add website summarization * Recursive summarization * add tts endpoint with url output * fix ref * always inject from collection 1 * add more to agixt.py * add slash * add easy access functions * fix url ref * handle when to pull github repos better * add default user object, improve logic * set url as a string * improve memories function * improve documentation and variable name * get all if no user input * fix ref * replace arxiv with ar5iv * use vision during websearches * fix bug * new agent setting for use_visual_browsing * false by default on tts * add activity logging in conversations * remove support for non autonomous execution * fix lint * test AGiXT.py * fix user email ref * fix user email ref * Only add if a conversation is defined * update sdk and g4f * Improve activities and websearch logging * fix lint * Add conversation and tts disable to website summaries --- agixt/AGiXT.py | 714 ++++++++++++++++++++++ agixt/ApiClient.py | 2 + agixt/Defaults.py | 11 +- agixt/Interactions.py | 115 ++-- agixt/Memories.py | 12 +- agixt/Models.py | 9 + agixt/Websearch.py | 238 +++++++- agixt/db/Agent.py | 10 - agixt/endpoints/Agent.py | 31 + agixt/endpoints/Completions.py | 386 +----------- agixt/endpoints/Memory.py | 10 +- agixt/fb/Agent.py | 13 +- agixt/prompts/Default/Website Summary.txt | 4 + agixt/readers/youtube.py | 10 +- docs/2-Concepts/03-Agents.md | 3 +- requirements.txt | 4 +- start.py | 3 +- static-requirements.txt | 6 +- 18 files changed, 1085 insertions(+), 496 deletions(-) create mode 100644 agixt/AGiXT.py create mode 100644 agixt/prompts/Default/Website Summary.txt diff --git a/agixt/AGiXT.py b/agixt/AGiXT.py new file mode 100644 index 000000000000..c918b77a489a --- /dev/null +++ b/agixt/AGiXT.py @@ -0,0 +1,714 @@ +from Interactions import Interactions +from ApiClient import get_api_client, Conversations, Prompts, Chain +from readers.file import FileReader +from Extensions import Extensions +from Chains import Chains +from pydub import AudioSegment +from Defaults import getenv, get_tokens, DEFAULT_SETTINGS +from Models import ChatCompletions +import os +import base64 +import uuid +import requests +import json +import time + + +class AGiXT: + def __init__(self, user: str, agent_name: str, api_key: str): + self.user_email = user.lower() + self.api_key = api_key + self.agent_name = agent_name + self.uri = getenv("AGIXT_URI") + self.outputs = f"{self.uri}/outputs/" + self.ApiClient = get_api_client(api_key) + self.agent_interactions = Interactions( + agent_name=self.agent_name, user=self.user_email, ApiClient=self.ApiClient + ) + self.agent = self.agent_interactions.agent + self.agent_settings = ( + self.agent.AGENT_CONFIG["settings"] + if "settings" in self.agent.AGENT_CONFIG + else DEFAULT_SETTINGS + ) + + async def prompts(self, prompt_category: str = "Default"): + """ + Get a list of available prompts + + Args: + prompt_category (str): Category of the prompt + + Returns: + list: List of available prompts + """ + return Prompts(user=self.user_email).get_prompts( + prompt_category=prompt_category + ) + + async def chains(self): + """ + Get a list of available chains + + Returns: + list: List of available chains + """ + return Chain(user=self.user_email).get_chains() + + async def settings(self): + """ + Get the agent settings + + Returns: + dict: Agent settings + """ + return self.agent_settings + + async def commands(self): + """ + Get a list of available commands + + Returns: + list: List of available commands + """ + return self.agent.available_commands() + + async def browsed_links(self): + """ + Get a list of browsed links + + Returns: + list: List of browsed links + """ + return self.agent.get_browsed_links() + + async def memories( + self, + user_input: str = "", + limit_per_collection: int = 5, + minimum_relevance_score: float = 0.3, + additional_collection_number: int = 0, + ): + """ + Get a list of memories + + Args: + user_input (str): User input to the agent + limit_per_collection (int): Number of memories to return per collection + minimum_relevance_score (float): Minimum relevance score for memories + additional_collection_number (int): Additional collection number to pull memories from. Collections 0-5 are injected automatically. + + Returns: + str: Agents relevant memories from the user input from collections 0-5 and the additional collection number if provided + """ + formatted_prompt, prompt, tokens = await self.agent_interactions.format_prompt( + user_input=user_input if user_input else "*", + top_results=limit_per_collection, + min_relevance_score=minimum_relevance_score, + inject_memories_from_collection_number=int(additional_collection_number), + ) + return formatted_prompt + + async def inference( + self, + user_input: str, + prompt_category: str = "Default", + prompt_name: str = "Custom Input", + conversation_name: str = "", + images: list = [], + injected_memories: int = 5, + shots: int = 1, + browse_links: bool = False, + voice_response: bool = False, + log_user_input: bool = True, + **kwargs, + ): + """ + Run inference on the AGiXT agent + + Args: + user_input (str): User input to the agent + prompt_category (str): Category of the prompt + prompt_name (str): Name of the prompt to use + injected_memories (int): Number of memories to inject into the conversation + conversation_name (str): Name of the conversation + browse_links (bool): Whether to browse links in the response + images (list): List of image file paths + shots (int): Number of responses to generate + **kwargs: Additional keyword arguments + + Returns: + str: Response from the agent + """ + return await self.agent_interactions.run( + user_input=user_input, + prompt_category=prompt_category, + prompt_name=prompt_name, + context_results=injected_memories, + shots=shots, + conversation_name=conversation_name, + browse_links=browse_links, + images=images, + tts=voice_response, + log_user_input=log_user_input, + **kwargs, + ) + + async def generate_image(self, prompt: str, conversation_name: str = ""): + """ + Generate an image from a prompt + + Args: + prompt (str): Prompt for the image generation + + Returns: + str: URL of the generated image + """ + if conversation_name != "" and conversation_name != None: + c = Conversations( + conversation_name="Image Generation", user=self.user_email + ) + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY_START] Generating image... [ACTIVITY_END]", + ) + return await self.agent.generate_image(prompt=prompt) + + async def text_to_speech(self, text: str, conversation_name: str = ""): + """ + Generate Text to Speech audio from text + + Args: + text (str): Text to convert to speech + + Returns: + str: URL of the generated audio + """ + if conversation_name != "" and conversation_name != None: + c = Conversations(conversation_name="Text to Speech", user=self.user_email) + c.log_interaction( + role="USER", + message=f"[ACTIVITY_START] Generating audio from text: {text} [ACTIVITY_END]", + ) + tts_url = await self.agent.text_to_speech(text=text.text) + if not str(tts_url).startswith("http"): + file_type = "wav" + file_name = f"{uuid.uuid4().hex}.{file_type}" + audio_path = f"./WORKSPACE/{file_name}" + audio_data = base64.b64decode(tts_url) + with open(audio_path, "wb") as f: + f.write(audio_data) + tts_url = f"{self.outputs}/{file_name}" + return tts_url + + async def audio_to_text(self, audio_path: str, conversation_name: str = ""): + """ + Audio to Text transcription + + Args: + audio_path (str): Path to the audio file + + Returns + str: Transcription of the audio + """ + response = await self.agent.transcribe_audio(audio_path=audio_path) + if conversation_name != "" and conversation_name != None: + c = Conversations( + conversation_name="Audio Transcription", user=self.user_email + ) + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY_START] Transcribed audio to text: {response} [ACTIVITY_END]", + ) + + async def translate_audio(self, audio_path: str, conversation_name: str = ""): + """ + Translate an audio file + + Args: + audio_path (str): Path to the audio file + + Returns + str: Translation of the audio + """ + response = await self.agent.translate_audio(audio_path=audio_path) + if conversation_name != "" and conversation_name != None: + c = Conversations( + conversation_name="Audio Translation", user=self.user_email + ) + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY_START] Translated audio: {response} [ACTIVITY_END]", + ) + return response + + async def execute_command( + self, + command_name: str, + command_args: dict, + conversation_name: str = "", + voice_response: bool = False, + ): + """ + Execute a command with arguments + + Args: + command_name (str): Name of the command to execute + command_args (dict): Arguments for the command + conversation_name (str): Name of the conversation + voice_response (bool): Whether to generate a voice response + + Returns: + str: Response from the command + """ + if conversation_name != "" and conversation_name != None: + c = Conversations(conversation_name=conversation_name, user=self.user_email) + c.log_interaction( + role=self.agent, + message=f"[ACTIVITY_START] Execute command: {command_name} with args: {command_args} [ACTIVITY_END]", + ) + response = await Extensions( + agent_name=self.agent_name, + agent_config=self.agent.AGENT_CONFIG, + conversation_name=conversation_name, + ApiClient=self.ApiClient, + api_key=self.api_key, + user=self.user_email, + ).execute_command( + command_name=command_name, + command_args=command_args, + ) + if "tts_provider" in self.agent_settings and voice_response: + if ( + self.agent_settings["tts_provider"] != "None" + and self.agent_settings["tts_provider"] != "" + and self.agent_settings["tts_provider"] != None + ): + tts_response = await self.text_to_speech(text=response) + response = f"{response}\n\n{tts_response}" + if conversation_name != "" and conversation_name != None: + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY_START] {response} [ACTIVITY_END]", + ) + return response + + async def execute_chain( + self, + chain_name: str, + user_input: str, + chain_args: dict = {}, + use_current_agent: bool = True, + conversation_name: str = "", + voice_response: bool = False, + ): + """ + Execute a chain with arguments + + Args: + chain_name (str): Name of the chain to execute + user_input (str): Message to add to conversation log pre-execution + chain_args (dict): Arguments for the chain + use_current_agent (bool): Whether to use the current agent + conversation_name (str): Name of the conversation + voice_response (bool): Whether to generate a voice response + + Returns: + str: Response from the chain + """ + c = Conversations(conversation_name=conversation_name, user=self.user_email) + c.log_interaction(role="USER", message=user_input) + response = await Chains( + user=self.user_email, ApiClient=self.ApiClient + ).run_chain( + chain_name=chain_name, + user_input=user_input, + agent_override=self.agent_name if use_current_agent else None, + all_responses=False, + chain_args=chain_args, + from_step=1, + ) + if "tts_provider" in self.agent_settings and voice_response: + if ( + self.agent_settings["tts_provider"] != "None" + and self.agent_settings["tts_provider"] != "" + and self.agent_settings["tts_provider"] != None + ): + tts_response = await self.text_to_speech(text=response) + response = f'{response}\n\n' + c.log_interaction(role=self.agent_name, message=response) + return response + + async def learn_from_websites( + self, + urls: list = [], + scrape_depth: int = 3, + summarize_content: bool = True, + conversation_name: str = "", + ): + """ + Scrape a website and summarize the content + + Args: + urls (list): List of URLs to scrape + scrape_depth (int): Depth to scrape each URL + summarize_content (bool): Whether to summarize the content + conversation_name (str): Name of the conversation + + Returns: + str: Agent response with a list of scraped links + """ + if isinstance(urls, str): + user_input = f"Learn from the information from this website:\n {urls} " + else: + url_str = {"\n".join(urls)} + user_input = f"Learn from the information from these websites:\n {url_str} " + c = Conversations(conversation_name=conversation_name, user=self.user_email) + if conversation_name != "" and conversation_name != None: + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY_START] Browsing the web... [ACTIVITY_END]", + ) + response = await self.agent_interactions.websearch.scrape_website( + user_input=user_input, + search_depth=scrape_depth, + summarize_content=summarize_content, + conversation_name=conversation_name, + ) + if conversation_name != "" and conversation_name != None: + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY_START] {response} [ACTIVITY_END]", + ) + return "I have read the information from the websites into my memory." + + async def learn_from_file( + self, + file_path: str, + collection_number: int = 1, + conversation_name: str = "", + ): + """ + Learn from a file + + Args: + file_path (str): Path to the file + collection_number (int): Collection number to store the file + conversation_name (str): Name of the conversation + + Returns: + str: Response from the agent + """ + + file_name = os.path.basename(file_path) + file_reader = FileReader( + agent_name=self.agent_name, + agent_config=self.agent.AGENT_CONFIG, + collection_number=collection_number, + ApiClient=self.ApiClient, + user=self.user_email, + ) + res = await file_reader.write_file_to_memory(file_path=file_path) + if res == True: + response = f"I have read the entire content of the file called {file_name} into my memory." + else: + response = f"I was unable to read the file called {file_name}." + if conversation_name != "" and conversation_name != None: + c = Conversations(conversation_name=conversation_name, user=self.user_email) + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY_START] {response} [ACTIVITY_END]", + ) + return response + + async def chat_completions(self, prompt: ChatCompletions): + """ + Generate an OpenAI style chat completion response with a ChatCompletion prompt + + Args: + prompt (ChatCompletions): Chat completions prompt + + Returns: + dict: Chat completion response + """ + conversation_name = prompt.user + images = [] + new_prompt = "" + browse_links = True + tts = False + urls = [] + if "mode" in self.agent_settings: + mode = self.agent_settings["mode"] + else: + mode = "prompt" + if "prompt_name" in self.agent_settings: + prompt_name = self.agent_settings["prompt_name"] + else: + prompt_name = "Chat" + if "prompt_category" in self.agent_settings: + prompt_category = self.agent_settings["prompt_category"] + else: + prompt_category = "Default" + prompt_args = {} + if "prompt_args" in self.agent_settings: + prompt_args = ( + json.loads(self.agent_settings["prompt_args"]) + if isinstance(self.agent_settings["prompt_args"], str) + else self.agent_settings["prompt_args"] + ) + if "context_results" in self.agent_settings: + context_results = int(self.agent_settings["context_results"]) + else: + context_results = 5 + if "injected_memories" in self.agent_settings: + context_results = int(self.agent_settings["injected_memories"]) + if "command_name" in self.agent_settings: + command_name = self.agent_settings["command_name"] + else: + command_name = "" + if "command_args" in self.agent_settings: + command_args = ( + json.loads(self.agent_settings["command_args"]) + if isinstance(self.agent_settings["command_args"], str) + else self.agent_settings["command_args"] + ) + else: + command_args = {} + if "command_variable" in self.agent_settings: + command_variable = self.agent_settings["command_variable"] + else: + command_variable = "text" + if "chain_name" in self.agent_settings: + chain_name = self.agent_settings["chain_name"] + else: + chain_name = "" + if "chain_args" in self.agent_settings: + chain_args = ( + json.loads(self.agent_settings["chain_args"]) + if isinstance(self.agent_settings["chain_args"], str) + else self.agent_settings["chain_args"] + ) + else: + chain_args = {} + if "tts_provider" in self.agent_settings: + tts_provider = str(self.agent_settings["tts_provider"]).lower() + if tts_provider != "none" and tts_provider != "": + tts = True + for message in prompt.messages: + if "mode" in message: + if message["mode"] in ["prompt", "command", "chain"]: + mode = message["mode"] + if "injected_memories" in message: + context_results = int(message["injected_memories"]) + if "prompt_category" in message: + prompt_category = message["prompt_category"] + if "prompt_name" in message: + prompt_name = message["prompt_name"] + if "prompt_args" in message: + prompt_args = ( + json.loads(message["prompt_args"]) + if isinstance(message["prompt_args"], str) + else message["prompt_args"] + ) + if "command_name" in message: + command_name = message["command_name"] + if "command_args" in message: + command_args = ( + json.loads(message["command_args"]) + if isinstance(message["command_args"], str) + else message["command_args"] + ) + if "command_variable" in message: + command_variable = message["command_variable"] + if "chain_name" in message: + chain_name = message["chain_name"] + if "chain_args" in message: + chain_args = ( + json.loads(message["chain_args"]) + if isinstance(message["chain_args"], str) + else message["chain_args"] + ) + if "browse_links" in message: + browse_links = str(message["browse_links"]).lower() == "true" + if "tts" in message: + tts = str(message["tts"]).lower() == "true" + if "content" not in message: + continue + if isinstance(message["content"], str): + role = message["role"] if "role" in message else "User" + if role.lower() == "system": + if "/" in message["content"]: + new_prompt += f"{message['content']}\n\n" + if role.lower() == "user": + new_prompt += f"{message['content']}\n\n" + if isinstance(message["content"], list): + for msg in message["content"]: + if "text" in msg: + role = message["role"] if "role" in message else "User" + if role.lower() == "user": + new_prompt += f"{msg['text']}\n\n" + if "image_url" in msg: + url = str( + msg["image_url"]["url"] + if "url" in msg["image_url"] + else msg["image_url"] + ) + image_path = f"./WORKSPACE/{uuid.uuid4().hex}.jpg" + if url.startswith("http"): + image = requests.get(url).content + else: + file_type = url.split(",")[0].split("/")[1].split(";")[0] + if file_type == "jpeg": + file_type = "jpg" + file_name = f"{uuid.uuid4().hex}.{file_type}" + image_path = f"./WORKSPACE/{file_name}" + image = base64.b64decode(url.split(",")[1]) + with open(image_path, "wb") as f: + f.write(image) + images.append(image_path) + if "audio_url" in msg: + audio_url = str( + msg["audio_url"]["url"] + if "url" in msg["audio_url"] + else msg["audio_url"] + ) + # If it is not a url, we need to find the file type and convert with pydub + if not audio_url.startswith("http"): + file_type = ( + audio_url.split(",")[0].split("/")[1].split(";")[0] + ) + audio_data = base64.b64decode(audio_url.split(",")[1]) + audio_path = f"./WORKSPACE/{uuid.uuid4().hex}.{file_type}" + with open(audio_path, "wb") as f: + f.write(audio_data) + audio_url = audio_path + else: + # Download the audio file from the url, get the file type and convert to wav + audio_type = audio_url.split(".")[-1] + audio_url = f"./WORKSPACE/{uuid.uuid4().hex}.{audio_type}" + audio_data = requests.get(audio_url).content + with open(audio_url, "wb") as f: + f.write(audio_data) + wav_file = f"./WORKSPACE/{uuid.uuid4().hex}.wav" + AudioSegment.from_file(audio_url).set_frame_rate(16000).export( + wav_file, format="wav" + ) + transcribed_audio = await self.audio_to_text( + audio_path=wav_file, + conversation_name=conversation_name, + ) + new_prompt += transcribed_audio + if "video_url" in msg: + video_url = str( + msg["video_url"]["url"] + if "url" in msg["video_url"] + else msg["video_url"] + ) + if video_url.startswith("http"): + urls.append(video_url) + if ( + "file_url" in msg + or "application_url" in msg + or "text_url" in msg + or "url" in msg + ): + file_url = str( + msg["file_url"]["url"] + if "url" in msg["file_url"] + else msg["file_url"] + ) + if file_url.startswith("http"): + urls.append(file_url) + else: + file_type = ( + file_url.split(",")[0].split("/")[1].split(";")[0] + ) + file_data = base64.b64decode(file_url.split(",")[1]) + file_path = f"./WORKSPACE/{uuid.uuid4().hex}.{file_type}" + with open(file_path, "wb") as f: + f.write(file_data) + file_url = f"{self.outputs}/{os.path.basename(file_path)}" + urls.append(file_url) + # Add user input to conversation + c = Conversations(conversation_name=conversation_name, user=self.user_email) + c.log_interaction(role="USER", message=new_prompt) + await self.learn_from_websites( + urls=urls, + scrape_depth=3, + summarize_content=True, + conversation_name=conversation_name, + ) + if mode == "command" and command_name and command_variable: + try: + command_args = ( + json.loads(self.agent_settings["command_args"]) + if isinstance(self.agent_settings["command_args"], str) + else self.agent_settings["command_args"] + ) + except Exception as e: + command_args = {} + command_args[self.agent_settings["command_variable"]] = new_prompt + response = await self.execute_command( + command_name=self.agent_settings["command_name"], + command_args=command_args, + conversation_name=conversation_name, + voice_response=tts, + ) + elif mode == "chain" and chain_name: + chain_name = self.agent_settings["chain_name"] + try: + chain_args = ( + json.loads(self.agent_settings["chain_args"]) + if isinstance(self.agent_settings["chain_args"], str) + else self.agent_settings["chain_args"] + ) + except Exception as e: + chain_args = {} + response = await self.execute_chain( + chain_name=chain_name, + user_input=new_prompt, + chain_args=chain_args, + use_current_agent=True, + conversation_name=conversation_name, + voice_response=tts, + ) + elif mode == "prompt": + response = await self.inference( + user_input=new_prompt, + prompt_name=prompt_name, + prompt_category=prompt_category, + conversation_name=conversation_name, + injected_memories=context_results, + shots=prompt.n, + browse_links=browse_links, + voice_response=tts, + images=images, + log_user_input=False, + **prompt_args, + ) + prompt_tokens = get_tokens(new_prompt) + completion_tokens = get_tokens(response) + total_tokens = int(prompt_tokens) + int(completion_tokens) + res_model = { + "id": conversation_name, + "object": "chat.completion", + "created": int(time.time()), + "model": self.agent_name, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": str(response), + }, + "finish_reason": "stop", + "logprobs": None, + } + ], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + }, + } + return res_model diff --git a/agixt/ApiClient.py b/agixt/ApiClient.py index b95d7649b05d..351c62745c9e 100644 --- a/agixt/ApiClient.py +++ b/agixt/ApiClient.py @@ -19,11 +19,13 @@ from db.Chain import Chain from db.Prompts import Prompts from db.Conversations import Conversations + from db.User import User else: from fb.Agent import Agent, add_agent, delete_agent, rename_agent, get_agents from fb.Chain import Chain from fb.Prompts import Prompts from fb.Conversations import Conversations + from Models import User_fb as User def verify_api_key(authorization: str = Header(None)): diff --git a/agixt/Defaults.py b/agixt/Defaults.py index d3c97841f962..f09bdbef6fe5 100644 --- a/agixt/Defaults.py +++ b/agixt/Defaults.py @@ -1,4 +1,5 @@ import os +import tiktoken from dotenv import load_dotenv load_dotenv() @@ -27,8 +28,7 @@ "WAIT_AFTER_FAILURE": 3, "WORKING_DIRECTORY": "./WORKSPACE", "WORKING_DIRECTORY_RESTRICTED": True, - "AUTONOMOUS_EXECUTION": True, - "PERSONA": "", + "persona": "", } @@ -36,6 +36,7 @@ def getenv(var_name: str): default_values = { "AGIXT_URI": "http://localhost:7437", "AGIXT_API_KEY": None, + "ALLOWED_DOMAINS": "*", "ALLOWLIST": "*", "WORKSPACE": os.path.join(os.getcwd(), "WORKSPACE"), "APP_NAME": "AGiXT", @@ -61,4 +62,10 @@ def getenv(var_name: str): return os.getenv(var_name, default_value) +def get_tokens(text: str) -> int: + encoding = tiktoken.get_encoding("cl100k_base") + num_tokens = len(encoding.encode(text)) + return num_tokens + + DEFAULT_USER = getenv("DEFAULT_USER") diff --git a/agixt/Interactions.py b/agixt/Interactions.py index 92d27296cac5..cb88295b4153 100644 --- a/agixt/Interactions.py +++ b/agixt/Interactions.py @@ -4,11 +4,11 @@ import json import time import logging -import tiktoken import base64 import uuid from datetime import datetime from readers.file import FileReader +from readers.github import GithubReader from Websearch import Websearch from Extensions import Extensions from ApiClient import ( @@ -18,7 +18,7 @@ Conversations, AGIXT_URI, ) -from Defaults import getenv, DEFAULT_USER +from Defaults import getenv, DEFAULT_USER, get_tokens logging.basicConfig( level=getenv("LOG_LEVEL"), @@ -26,12 +26,6 @@ ) -def get_tokens(text: str) -> int: - encoding = tiktoken.get_encoding("cl100k_base") - num_tokens = len(encoding.encode(text)) - return num_tokens - - class Interactions: def __init__( self, @@ -66,6 +60,13 @@ def __init__( ApiClient=self.ApiClient, user=self.user, ) + self.github_memories = GithubReader( + agent_name=self.agent_name, + agent_config=self.agent.AGENT_CONFIG, + collection_number=7, + user=self.user, + ApiClient=self.ApiClient, + ) else: self.agent_name = "" self.agent = None @@ -102,7 +103,6 @@ async def format_prompt( step_number=0, conversation_name="", vision_response: str = "", - websearch: bool = False, **kwargs, ): if "user_input" in kwargs and user_input == "": @@ -137,6 +137,16 @@ async def format_prompt( limit=top_results, min_relevance_score=min_relevance_score, ) + context += await self.websearch.agent_memory.get_memories( + user_input=user_input, + limit=top_results, + min_relevance_score=min_relevance_score, + ) + context += await self.github_memories.get_memories( + user_input=user_input, + limit=top_results, + min_relevance_score=min_relevance_score, + ) positive_feedback = await self.positive_feedback_memories.get_memories( user_input=user_input, limit=3, @@ -157,14 +167,8 @@ async def format_prompt( if negative_feedback: joined_feedback = "\n".join(negative_feedback) context.append(f"Negative Feedback:\n{joined_feedback}\n") - if websearch: - context += await self.websearch.agent_memory.get_memories( - user_input=user_input, - limit=top_results, - min_relevance_score=min_relevance_score, - ) if "inject_memories_from_collection_number" in kwargs: - if int(kwargs["inject_memories_from_collection_number"]) > 3: + if int(kwargs["inject_memories_from_collection_number"]) > 5: context += await FileReader( agent_name=self.agent_name, agent_config=self.agent.AGENT_CONFIG, @@ -389,6 +393,7 @@ async def run( browse_links: bool = False, persist_context_in_history: bool = False, images: list = [], + log_user_input: bool = True, **kwargs, ): global AGIXT_URI @@ -435,8 +440,7 @@ async def run( if "conversation_name" in kwargs: conversation_name = kwargs["conversation_name"] if conversation_name == "": - clean_datetime = re.sub(r"[^a-zA-Z0-9]", "", str(datetime.now())) - conversation_name = f"{clean_datetime} Conversation" + conversation_name = datetime.now().strftime("%Y-%m-%d") if "WEBSEARCH_TIMEOUT" in kwargs: try: websearch_timeout = int(kwargs["WEBSEARCH_TIMEOUT"]) @@ -445,7 +449,7 @@ async def run( else: websearch_timeout = 0 if browse_links != False: - await self.websearch.browse_links_in_input( + await self.websearch.scrape_website( user_input=user_input, search_depth=websearch_depth ) if websearch: @@ -508,10 +512,11 @@ async def run( else formatted_prompt ) c = Conversations(conversation_name=conversation_name, user=self.user) - c.log_interaction( - role="USER", - message=log_message, - ) + if log_user_input: + c.log_interaction( + role="USER", + message=log_message, + ) try: self.response = await self.agent.inference( prompt=formatted_prompt, tokens=tokens @@ -545,6 +550,7 @@ async def run( return await self.run( prompt_name=prompt, prompt_category=prompt_category, + log_user_input=log_user_input, **prompt_args, ) # Handle commands if the prompt contains the {COMMANDS} placeholder @@ -565,10 +571,10 @@ async def run( self.response = re.sub( r"!\[.*?\]\(.*?\)", "", self.response, flags=re.DOTALL ) - tts = True + tts = False if "tts" in kwargs: tts = str(kwargs["tts"]).lower() == "true" - if "tts_provider" in agent_settings and tts: + if "tts_provider" in agent_settings and tts == True: if ( agent_settings["tts_provider"] != "None" and agent_settings["tts_provider"] != "" @@ -651,6 +657,7 @@ async def run( agent_name=self.agent_name, prompt_name=prompt, prompt_category=prompt_category, + log_user_interaction=False, **prompt_args, ) time.sleep(1) @@ -685,6 +692,7 @@ def create_command_suggestion_chain(self, agent_name, command_name, command_args return f"**The command has been added to a chain called '{agent_name} Command Suggestions' for you to review and execute manually.**" async def execution_agent(self, conversation_name): + c = Conversations(conversation_name=conversation_name, user=self.user) command_list = [ available_command["friendly_name"] for available_command in self.agent.available_commands @@ -734,49 +742,34 @@ async def execution_agent(self, conversation_name): }, ) else: - # Check if the command is a valid command in the self.agent.available_commands list try: - if ( - str(self.agent.AUTONOMOUS_EXECUTION).lower() - == "true" - ): - ext = Extensions( - agent_name=self.agent_name, - agent_config=self.agent.AGENT_CONFIG, - conversation_name=conversation_name, - ApiClient=self.ApiClient, - user=self.user, - ) - command_output = await ext.execute_command( - command_name=command_name, - command_args=command_args, - ) - formatted_output = f"```\n{command_output}\n```" - message = f"**Executed Command:** `{command_name}` with the following parameters:\n```json\n{json.dumps(command_args, indent=4)}\n```\n\n**Command Output:**\n{formatted_output}" - Conversations( - conversation_name=f"{self.agent_name} Command Execution Log", - user=self.user, - ).log_interaction( - role=self.agent_name, - message=message, - ) - else: - command_output = ( - self.create_command_suggestion_chain( - agent_name=self.agent_name, - command_name=command_name, - command_args=command_args, - ) - ) - # TODO: Ask the user if they want to execute the suggested chain of commands. - command_output = f"{command_output}\n\n**Would you like to execute the command `{command_name}` with the following parameters?**\n```json\n{json.dumps(command_args, indent=4)}\n```" - # Ask the AI to make the command output more readable and relevant to the conversation and respond with that. + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY_START] Executing command `{command_name}` with args `{command_args}`. [ACTIVITY_END]", + ) + ext = Extensions( + agent_name=self.agent_name, + agent_config=self.agent.AGENT_CONFIG, + conversation_name=conversation_name, + ApiClient=self.ApiClient, + user=self.user, + ) + command_output = await ext.execute_command( + command_name=command_name, + command_args=command_args, + ) + formatted_output = f"```\n{command_output}\n```" + command_output = f"**Executed Command:** `{command_name}` with the following parameters:\n```json\n{json.dumps(command_args, indent=4)}\n```\n\n**Command Output:**\n{formatted_output}" except Exception as e: logging.error( f"Error: {self.agent_name} failed to execute command `{command_name}`. {e}" ) command_output = f"**Failed to execute command `{command_name}` with args `{command_args}`. Please try again.**" if command_output: + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY_START] {command_output} [ACTIVITY_END]", + ) reformatted_response = reformatted_response.replace( f"#execute({command_name}, {command_args})", ( diff --git a/agixt/Memories.py b/agixt/Memories.py index 8b609b59612a..1707c7597f36 100644 --- a/agixt/Memories.py +++ b/agixt/Memories.py @@ -176,7 +176,7 @@ def __init__( if "settings" in self.agent_config else {"embeddings_provider": "default"} ) - self.chroma_client = get_chroma_client() + self.chroma_client = None self.ApiClient = ApiClient self.embedding_provider = Providers( name=( @@ -195,6 +195,8 @@ def __init__( self.summarize_content = summarize_content async def wipe_memory(self): + if self.chroma_client == None: + self.chroma_client = get_chroma_client() try: self.chroma_client.delete_collection(name=self.collection_name) return True @@ -256,6 +258,8 @@ async def import_collections_from_json(self, json_data: List[dict]): # get collections that start with the collection name async def get_collections(self): + if self.chroma_client == None: + self.chroma_client = get_chroma_client() collections = self.chroma_client.list_collections() if int(self.collection_number) > 0: collection_name = snake(self.agent_name) @@ -269,6 +273,8 @@ async def get_collections(self): ] async def get_collection(self): + if self.chroma_client == None: + self.chroma_client = get_chroma_client() try: return self.chroma_client.get_collection( name=self.collection_name, embedding_function=self.embedder @@ -426,6 +432,8 @@ async def get_memories( return response def delete_memories_from_external_source(self, external_source: str): + if self.chroma_client == None: + self.chroma_client = get_chroma_client() collection = self.chroma_client.get_collection(name=self.collection_name) if collection: results = collection.query( @@ -439,6 +447,8 @@ def delete_memories_from_external_source(self, external_source: str): return False def get_external_data_sources(self): + if self.chroma_client == None: + self.chroma_client = get_chroma_client() collection = self.chroma_client.get_collection(name=self.collection_name) if collection: results = collection.query( diff --git a/agixt/Models.py b/agixt/Models.py index 532c1bb9dc9a..26a5e7197ea0 100644 --- a/agixt/Models.py +++ b/agixt/Models.py @@ -1,5 +1,6 @@ from pydantic import BaseModel from typing import Optional, Dict, List, Any, Union +from Defaults import DEFAULT_USER class AgentName(BaseModel): @@ -221,6 +222,10 @@ class ConversationHistoryModel(BaseModel): conversation_content: List[dict] = [] +class TTSInput(BaseModel): + text: str + + class ConversationHistoryMessageModel(BaseModel): agent_name: str conversation_name: str @@ -268,3 +273,7 @@ class User(BaseModel): commands: Optional[Dict[str, Any]] = {} training_urls: Optional[List[str]] = [] github_repos: Optional[List[str]] = [] + + +class User_fb(BaseModel): + email: str = DEFAULT_USER diff --git a/agixt/Websearch.py b/agixt/Websearch.py index bf4a26f93b43..c03db7f75a6e 100644 --- a/agixt/Websearch.py +++ b/agixt/Websearch.py @@ -1,4 +1,5 @@ import re +import os import json import random import requests @@ -9,9 +10,10 @@ from playwright.async_api import async_playwright from bs4 import BeautifulSoup from typing import List -from ApiClient import Agent -from Defaults import getenv +from ApiClient import Agent, Conversations +from Defaults import getenv, get_tokens from readers.youtube import YoutubeReader +from readers.github import GithubReader logging.basicConfig( level=getenv("LOG_LEVEL"), @@ -30,6 +32,7 @@ def __init__( ): self.ApiClient = ApiClient self.agent = agent + self.user = user self.agent_name = self.agent.agent_name self.agent_config = self.agent.AGENT_CONFIG self.agent_settings = self.agent_config["settings"] @@ -69,13 +72,142 @@ def verify_link(self, link: str = "") -> bool: return True return False - async def get_web_content(self, url): - if str(url).startswith("https://www.youtube.com/watch?v="): - video_id = url.split("watch?v=")[1] - await self.agent_memory.write_youtube_captions_to_memory(video_id=video_id) + async def summarize_web_content(self, url, content): + max_tokens = ( + int(self.agent_settings["MAX_TOKENS"]) + if "MAX_TOKENS" in self.agent_settings + else 8000 + ) + # max_tokens is max input tokens for the model + max_tokens = int(max_tokens) - 1000 + if max_tokens < 0: + max_tokens = 5000 + if max_tokens > 8000: + # The reason for this is that most models max output tokens is 4096 + # It is unlikely to reduce the content by more than half. + # We don't want to hit the max tokens limit and risk losing content. + max_tokens = 8000 + if get_tokens(text=content) < int(max_tokens): + return self.ApiClient.prompt_agent( + agent_name=self.agent_name, + prompt_name="Website Summary", + prompt_args={ + "user_input": content, + "url": url, + "browse_links": False, + "disable_memory": True, + "conversation_name": "AGiXT Terminal", + "tts": "false", + }, + ) + chunks = await self.agent_memory.chunk_content( + text=content, chunk_size=int(max_tokens) + ) + new_content = [] + for chunk in chunks: + new_content.append( + self.ApiClient.prompt_agent( + agent_name=self.agent_name, + prompt_name="Website Summary", + prompt_args={ + "user_input": chunk, + "url": url, + "browse_links": False, + "disable_memory": True, + "conversation_name": "AGiXT Terminal", + "tts": "false", + }, + ) + ) + new_content = "\n".join(new_content) + if get_tokens(text=new_content) > int(max_tokens): + # If the content is still too long, we will just send it to be chunked into memory. + return new_content + else: + # If the content isn't too long, we will ask AI to resummarize the combined chunks. + return await self.summarize_web_content(url=url, content=new_content) + + async def get_web_content(self, url: str, summarize_content=False): + if url.startswith("https://arxiv.org/") or url.startswith( + "https://www.arxiv.org/" + ): + url = url.replace("arxiv.org", "ar5iv.org") + if ( + url.startswith("https://www.youtube.com/watch?v=") + or url.startswith("https://youtube.com/watch?v=") + or url.startswith("https://youtu.be/") + ): + video_id = ( + url.split("watch?v=")[1] + if "watch?v=" in url + else url.split("youtu.be/")[1] + ) + if "&" in video_id: + video_id = video_id.split("&")[0] + content = await self.agent_memory.get_transcription(video_id=video_id) self.browsed_links.append(url) self.agent.add_browsed_link(url=url) - return None, None + if summarize_content: + content = await self.summarize_web_content(url=url, content=content) + await self.agent_memory.write_text_to_memory( + user_input=url, + text=f"Content from YouTube video: {url}\n\n{content}", + external_source=url, + ) + return content, None + if url.startswith("https://github.com/"): + do_not_pull_repo = [ + "/pull/", + "/issues", + "/discussions", + "/actions/", + "/projects", + "/security", + "/releases", + "/commits", + "/branches", + "/tags", + "/stargazers", + "/watchers", + "/network", + "/settings", + "/compare", + "/archive", + ] + if any(x in url for x in do_not_pull_repo): + res = False + else: + if "/tree/" in url: + branch = url.split("/tree/")[1].split("/")[0] + else: + branch = "main" + res = await GithubReader( + agent_name=self.agent_name, + agent_config=self.agent.AGENT_CONFIG, + collection_number=7, + user=self.user, + ApiClient=self.ApiClient, + ).write_github_repository_to_memory( + github_repo=url, + github_user=( + self.agent_settings["GITHUB_USER"] + if "GITHUB_USER" in self.agent_settings + else None + ), + github_token=( + self.agent_settings["GITHUB_TOKEN"] + if "GITHUB_TOKEN" in self.agent_settings + else None + ), + github_branch=branch, + ) + if res: + self.browsed_links.append(url) + self.agent.add_browsed_link(url=url) + return ( + f"Content from GitHub repository at {url} has been added to memory.", + None, + ) try: async with async_playwright() as p: browser = await p.chromium.launch() @@ -96,13 +228,44 @@ async def get_web_content(self, url): title = title.replace(" ", "") href = await page.evaluate("(link) => link.href", link) link_list.append((title, href)) + vision_response = "" + if "vision_provider" in self.agent.AGENT_CONFIG["settings"]: + vision_provider = str( + self.agent.AGENT_CONFIG["settings"]["vision_provider"] + ).lower() + if "use_visual_browsing" in self.agent.AGENT_CONFIG["settings"]: + use_visual_browsing = str( + self.agent.AGENT_CONFIG["settings"]["use_visual_browsing"] + ).lower() + if use_visual_browsing != "true": + vision_provider = "none" + else: + vision_provider = "none" + if vision_provider != "none" and vision_provider != "": + try: + random_screenshot_name = str(random.randint(100000, 999999)) + screenshot_path = f"WORKSPACE/{random_screenshot_name}.png" + await page.screenshot(path=screenshot_path) + vision_response = self.agent.inference( + prompt=f"Provide a detailed visual description of the screenshotted website in the image. The website in the screenshot is from {url}.", + images=[screenshot_path], + ) + os.remove(screenshot_path) + except: + vision_response = "" await browser.close() soup = BeautifulSoup(content, "html.parser") text_content = soup.get_text() text_content = " ".join(text_content.split()) + if vision_response != "": + text_content = f"{text_content}\n\nVisual description from viewing {url}:\n{vision_response}" + if summarize_content: + text_content = await self.summarize_web_content( + url=url, content=text_content + ) await self.agent_memory.write_text_to_memory( user_input=url, - text=f"From website: {url}\n\nContent:\n{text_content}", + text=f"Content from website: {url}\n\n{text_content}", external_source=url, ) self.browsed_links.append(url) @@ -264,23 +427,65 @@ async def search(self, query: str) -> List[str]: self.searx_instance_url = "" return await self.search(query=query) - async def browse_links_in_input(self, user_input: str = "", search_depth: int = 0): + async def scrape_website( + self, + user_input: str = "", + search_depth: int = 0, + summarize_content: bool = False, + conversation_name: str = "", + ): + # user_input = "I am browsing {url} and collecting data from it to learn more." + c = Conversations(conversation_name=conversation_name, user=self.user) links = re.findall(r"(?Phttps?://[^\s]+)", user_input) + scraped_links = [] if links is not None and len(links) > 0: for link in links: if self.verify_link(link=link): - text_content, link_list = await self.get_web_content(url=link) - if int(search_depth) > 0: + if conversation_name != "" and conversation_name is not None: + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY_START] Browsing {link}... [ACTIVITY_END]", + ) + text_content, link_list = await self.get_web_content( + url=link, summarize_content=summarize_content + ) + scraped_links.append(link) + if ( + int(search_depth) > 0 + and "youtube.com/" not in link + and "youtu.be/" not in link + ): if link_list is not None and len(link_list) > 0: i = 0 for sublink in link_list: if self.verify_link(link=sublink[1]): if i <= search_depth: + if ( + conversation_name != "" + and conversation_name is not None + ): + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY_START] Browsing {sublink[1]}... [ACTIVITY_END]", + ) ( text_content, link_list, - ) = await self.get_web_content(url=sublink[1]) + ) = await self.get_web_content( + url=sublink[1], + summarize_content=summarize_content, + ) i = i + 1 + scraped_links.append(sublink[1]) + str_links = "\n".join(scraped_links) + message = f"I have read all of the content from the following links into my memory:\n{str_links}" + if conversation_name: + c = Conversations(conversation_name=conversation_name, user=self.user) + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY_START] {message} [ACTIVITY_END]", + ) + return message async def websearch_agent( self, @@ -288,9 +493,7 @@ async def websearch_agent( websearch_depth: int = 0, websearch_timeout: int = 0, ): - await self.browse_links_in_input( - user_input=user_input, search_depth=websearch_depth - ) + await self.scrape_website(user_input=user_input, search_depth=websearch_depth) try: websearch_depth = int(websearch_depth) except: @@ -312,7 +515,10 @@ async def websearch_agent( if len(search_string) > 0: links = [] logging.info(f"Searching for: {search_string}") - if self.searx_instance_url != "": + if ( + self.searx_instance_url != "" + and self.searx_instance_url is not None + ): links = await self.search(query=search_string) else: links = await self.ddg_search(query=search_string) diff --git a/agixt/db/Agent.py b/agixt/db/Agent.py index 46bbef7d0a95..4047e6fef022 100644 --- a/agixt/db/Agent.py +++ b/agixt/db/Agent.py @@ -236,15 +236,6 @@ def __init__(self, agent_name=None, user=DEFAULT_USER, ApiClient=None): name="default", ApiClient=ApiClient, **self.PROVIDER_SETTINGS ).embedder ) - if "AUTONOMOUS_EXECUTION" in self.PROVIDER_SETTINGS: - self.AUTONOMOUS_EXECUTION = self.PROVIDER_SETTINGS["AUTONOMOUS_EXECUTION"] - if isinstance(self.AUTONOMOUS_EXECUTION, str): - self.AUTONOMOUS_EXECUTION = self.AUTONOMOUS_EXECUTION.lower() - self.AUTONOMOUS_EXECUTION = ( - False if self.AUTONOMOUS_EXECUTION.lower() == "false" else True - ) - else: - self.AUTONOMOUS_EXECUTION = True if hasattr(self.EMBEDDINGS_PROVIDER, "chunk_size"): self.chunk_size = self.EMBEDDINGS_PROVIDER.chunk_size else: @@ -261,7 +252,6 @@ def load_config_keys(self): "AI_MODEL", "AI_TEMPERATURE", "MAX_TOKENS", - "AUTONOMOUS_EXECUTION", "embedder", ] for key in config_keys: diff --git a/agixt/endpoints/Agent.py b/agixt/endpoints/Agent.py index 881cca979cf0..ef7949c7c127 100644 --- a/agixt/endpoints/Agent.py +++ b/agixt/endpoints/Agent.py @@ -2,6 +2,7 @@ from fastapi import APIRouter, HTTPException, Depends, Header from Interactions import Interactions from Websearch import Websearch +from Defaults import getenv from ApiClient import ( Agent, add_agent, @@ -21,7 +22,10 @@ AgentConfig, ResponseMessage, UrlInput, + TTSInput, ) +import base64 +import uuid app = APIRouter() @@ -190,6 +194,7 @@ async def prompt_agent( agent = Interactions(agent_name=agent_name, user=user, ApiClient=ApiClient) response = await agent.run( prompt=agent_prompt.prompt_name, + log_user_input=True, **agent_prompt.prompt_args, ) return {"response": str(response)} @@ -292,3 +297,29 @@ async def delete_browsed_link( websearch.agent_memory.delete_memories_from_external_source(url=url.url) agent.delete_browsed_link(url=url.url) return {"message": "Browsed links deleted."} + + +@app.post( + "/api/agent/{agent_name}/text_to_speech", + tags=["Agent"], + dependencies=[Depends(verify_api_key)], +) +async def text_to_speech( + agent_name: str, + text: TTSInput, + user=Depends(verify_api_key), + authorization: str = Header(None), +): + ApiClient = get_api_client(authorization=authorization) + agent = Agent(agent_name=agent_name, user=user, ApiClient=ApiClient) + AGIXT_URI = getenv("AGIXT_URI") + tts_response = await agent.text_to_speech(text=text.text) + if not str(tts_response).startswith("http"): + file_type = "wav" + file_name = f"{uuid.uuid4().hex}.{file_type}" + audio_path = f"./WORKSPACE/{file_name}" + audio_data = base64.b64decode(tts_response) + with open(audio_path, "wb") as f: + f.write(audio_data) + tts_response = f"{AGIXT_URI}/outputs/{file_name}" + return {"url": tts_response} diff --git a/agixt/endpoints/Completions.py b/agixt/endpoints/Completions.py index 1829ad7127a4..c6717e3d9b82 100644 --- a/agixt/endpoints/Completions.py +++ b/agixt/endpoints/Completions.py @@ -1,17 +1,9 @@ import time import base64 import uuid -import json -import requests from fastapi import APIRouter, Depends, Header -from Interactions import Interactions, get_tokens -from ApiClient import Agent, Conversations, verify_api_key, get_api_client, AGIXT_URI -from Extensions import Extensions -from Chains import Chains -from Websearch import Websearch -from readers.file import FileReader -from readers.youtube import YoutubeReader -from readers.github import GithubReader +from Defaults import get_tokens +from ApiClient import Agent, verify_api_key, get_api_client from providers.default import DefaultProvider from fastapi import UploadFile, File, Form from typing import Optional, List @@ -21,7 +13,7 @@ TextToSpeech, ImageCreation, ) -from pydub import AudioSegment +from AGiXT import AGiXT app = APIRouter() @@ -40,376 +32,8 @@ async def chat_completion( ): # prompt.model is the agent name # prompt.user is the conversation name - global AGIXT_URI - ApiClient = get_api_client(authorization=authorization) - agent_name = prompt.model - conversation_name = prompt.user - c = Conversations(conversation_name=conversation_name, user=user) - agent = Interactions(agent_name=agent_name, user=user, ApiClient=ApiClient) - agent_config = agent.agent.AGENT_CONFIG - agent_settings = agent_config["settings"] if "settings" in agent_config else {} - images = [] - new_prompt = "" - browse_links = True - if "mode" in agent_settings: - mode = agent_settings["mode"] - else: - mode = "prompt" - if "prompt_name" in agent_settings: - prompt_name = agent_settings["prompt_name"] - else: - prompt_name = "Chat" - if "prompt_category" in agent_settings: - prompt_category = agent_settings["prompt_category"] - else: - prompt_category = "Default" - prompt_args = {} - if "prompt_args" in agent_settings: - prompt_args = ( - json.loads(agent_settings["prompt_args"]) - if isinstance(agent_settings["prompt_args"], str) - else agent_settings["prompt_args"] - ) - if "command_name" in agent_settings: - command_name = agent_settings["command_name"] - else: - command_name = "" - if "command_args" in agent_settings: - command_args = ( - json.loads(agent_settings["command_args"]) - if isinstance(agent_settings["command_args"], str) - else agent_settings["command_args"] - ) - else: - command_args = {} - if "command_variable" in agent_settings: - command_variable = agent_settings["command_variable"] - else: - command_variable = "text" - if "chain_name" in agent_settings: - chain_name = agent_settings["chain_name"] - else: - chain_name = "" - if "chain_args" in agent_settings: - chain_args = ( - json.loads(agent_settings["chain_args"]) - if isinstance(agent_settings["chain_args"], str) - else agent_settings["chain_args"] - ) - else: - chain_args = {} - for message in prompt.messages: - if "mode" in message: - if message["mode"] in ["prompt", "command", "chain"]: - mode = message["mode"] - if "context_results" in message: - context_results = int(message["context_results"]) - else: - context_results = 5 - if "prompt_category" in message: - prompt_category = message["prompt_category"] - if "prompt_name" in message: - prompt_name = message["prompt_name"] - if "prompt_args" in message: - prompt_args = ( - json.loads(message["prompt_args"]) - if isinstance(message["prompt_args"], str) - else message["prompt_args"] - ) - if "command_name" in message: - command_name = message["command_name"] - if "command_args" in message: - command_args = ( - json.loads(message["command_args"]) - if isinstance(message["command_args"], str) - else message["command_args"] - ) - if "command_variable" in message: - command_variable = message["command_variable"] - if "chain_name" in message: - chain_name = message["chain_name"] - if "chain_args" in message: - chain_args = ( - json.loads(message["chain_args"]) - if isinstance(message["chain_args"], str) - else message["chain_args"] - ) - if "browse_links" in message: - browse_links = str(message["browse_links"]).lower() == "true" - tts = True - if "tts" in message: - tts = str(message["tts"]).lower() == "true" - if "content" not in message: - continue - if isinstance(message["content"], str): - role = message["role"] if "role" in message else "User" - if role.lower() == "system": - if "/" in message["content"]: - new_prompt += f"{message['content']}\n\n" - if role.lower() == "user": - new_prompt += f"{message['content']}\n\n" - if isinstance(message["content"], list): - for msg in message["content"]: - if "text" in msg: - role = message["role"] if "role" in message else "User" - if role.lower() == "user": - new_prompt += f"{msg['text']}\n\n" - if "image_url" in msg: - url = ( - msg["image_url"]["url"] - if "url" in msg["image_url"] - else msg["image_url"] - ) - image_path = f"./WORKSPACE/{uuid.uuid4().hex}.jpg" - if url.startswith("http"): - image = requests.get(url).content - else: - file_type = url.split(",")[0].split("/")[1].split(";")[0] - if file_type == "jpeg": - file_type = "jpg" - file_name = f"{uuid.uuid4().hex}.{file_type}" - image_path = f"./WORKSPACE/{file_name}" - image = base64.b64decode(url.split(",")[1]) - with open(image_path, "wb") as f: - f.write(image) - images.append(image_path) - if "audio_url" in msg: - audio_url = ( - msg["audio_url"]["url"] - if "url" in msg["audio_url"] - else msg["audio_url"] - ) - # If it is not a url, we need to find the file type and convert with pydub - if not audio_url.startswith("http"): - file_type = audio_url.split(",")[0].split("/")[1].split(";")[0] - audio_data = base64.b64decode(audio_url.split(",")[1]) - audio_path = f"./WORKSPACE/{uuid.uuid4().hex}.{file_type}" - with open(audio_path, "wb") as f: - f.write(audio_data) - audio_url = audio_path - else: - # Download the audio file from the url, get the file type and convert to wav - audio_type = audio_url.split(".")[-1] - audio_url = f"./WORKSPACE/{uuid.uuid4().hex}.{audio_type}" - audio_data = requests.get(audio_url).content - with open(audio_url, "wb") as f: - f.write(audio_data) - wav_file = f"./WORKSPACE/{uuid.uuid4().hex}.wav" - AudioSegment.from_file(audio_url).set_frame_rate(16000).export( - wav_file, format="wav" - ) - transcribed_audio = await agent.agent.transcribe_audio( - audio_path=wav_file - ) - new_prompt += transcribed_audio - if "video_url" in msg: - video_url = str( - msg["video_url"]["url"] - if "url" in msg["video_url"] - else msg["video_url"] - ) - if "collection_number" in msg: - collection_number = int(msg["collection_number"]) - else: - collection_number = 0 - if video_url.startswith("https://www.youtube.com/watch?v="): - youtube_reader = YoutubeReader( - agent_name=agent_name, - agent_config=agent_config, - collection_number=collection_number, - ApiClient=ApiClient, - user=user, - ) - await youtube_reader.write_youtube_captions_to_memory(video_url) - if ( - "file_url" in msg - or "application_url" in msg - or "text_url" in msg - or "url" in msg - ): - file_url = str( - msg["file_url"]["url"] - if "url" in msg["file_url"] - else msg["file_url"] - ) - if "collection_number" in message or "collection_number" in msg: - collection_number = int( - message["collection_number"] - if "collection_number" in message - else msg["collection_number"] - ) - else: - collection_number = 0 - if file_url.startswith("http"): - if file_url.startswith("https://www.youtube.com/watch?v="): - youtube_reader = YoutubeReader( - agent_name=agent_name, - agent_config=agent_config, - collection_number=collection_number, - ApiClient=ApiClient, - user=user, - ) - await youtube_reader.write_youtube_captions_to_memory( - file_url - ) - elif file_url.startswith("https://github.com"): - github_reader = GithubReader( - agent_name=agent_name, - agent_config=agent_config, - collection_number=collection_number, - ApiClient=ApiClient, - user=user, - ) - await github_reader.write_github_repository_to_memory( - github_repo=file_url, - github_user=( - agent_settings["GITHUB_USER"] - if "GITHUB_USER" in agent_settings - else None - ), - github_token=( - agent_settings["GITHUB_TOKEN"] - if "GITHUB_TOKEN" in agent_settings - else None - ), - github_branch=( - "main" - if "branch" not in message - else message["branch"] - ), - ) - else: - website_reader = Websearch( - collection_number=collection_number, - agent=agent.agent, - ApiClient=ApiClient, - user=user, - ) - await website_reader.get_web_content(url=file_url) - else: - file_type = file_url.split(",")[0].split("/")[1].split(";")[0] - file_data = base64.b64decode(file_url.split(",")[1]) - file_path = f"./WORKSPACE/{uuid.uuid4().hex}.{file_type}" - with open(file_path, "wb") as f: - f.write(file_data) - file_reader = FileReader( - agent_name=agent_name, - agent_config=agent_config, - collection_number=collection_number, - ApiClient=ApiClient, - user=user, - ) - await file_reader.write_file_to_memory(file_path) - if mode == "command" and command_name and command_variable: - command_args = ( - json.loads(agent_settings["command_args"]) - if isinstance(agent_settings["command_args"], str) - else agent_settings["command_args"] - ) - command_args[agent_settings["command_variable"]] = new_prompt - c.log_interaction(role="USER", message=new_prompt) - response = await Extensions( - agent_name=agent_name, - agent_config=agent_config, - conversation_name=conversation_name, - ApiClient=ApiClient, - api_key=authorization, - user=user, - ).execute_command( - command_name=agent_settings["command_name"], - command_args=command_args, - ) - if "tts_provider" in agent_settings and tts: - if ( - agent_settings["tts_provider"] != "None" - and agent_settings["tts_provider"] != "" - and agent_settings["tts_provider"] != None - ): - tts_response = await agent.agent.text_to_speech(text=response) - # If tts_response is a not a url starting with http, it is a base64 encoded audio file - if not str(tts_response).startswith("http"): - file_type = "wav" - file_name = f"{uuid.uuid4().hex}.{file_type}" - audio_path = f"./WORKSPACE/{file_name}" - audio_data = base64.b64decode(tts_response) - with open(audio_path, "wb") as f: - f.write(audio_data) - tts_response = f'' - response = f"{response}\n\n{tts_response}" - c.log_interaction(role=agent_name, message=response) - elif mode == "chain" and chain_name: - chain_name = agent_settings["chain_name"] - chain_args = ( - json.loads(agent_settings["chain_args"]) - if isinstance(agent_settings["chain_args"], str) - else agent_settings["chain_args"] - ) - c.log_interaction(role="USER", message=new_prompt) - response = await Chains(user=user, ApiClient=ApiClient).run_chain( - chain_name=chain_name, - user_input=new_prompt, - agent_override=agent_name, - all_responses=False, - chain_args=chain_args, - from_step=1, - ) - if "tts_provider" in agent_settings: - if ( - agent_settings["tts_provider"] != "None" - and agent_settings["tts_provider"] != "" - and agent_settings["tts_provider"] != None - ): - tts_response = await agent.agent.text_to_speech(text=response) - # If tts_response is a not a url starting with http, it is a base64 encoded audio file - if not str(tts_response).startswith("http"): - file_type = "wav" - file_name = f"{uuid.uuid4().hex}.{file_type}" - audio_path = f"./WORKSPACE/{file_name}" - audio_data = base64.b64decode(tts_response) - with open(audio_path, "wb") as f: - f.write(audio_data) - - tts_response = f'' - response = f"{response}\n\n{tts_response}" - c.log_interaction(role=agent_name, message=response) - elif mode == "prompt": - response = await agent.run( - user_input=new_prompt, - context_results=context_results, - shots=prompt.n, - conversation_name=conversation_name, - browse_links=browse_links, - images=images, - prompt_name=prompt_name, - prompt_category=prompt_category, - **prompt_args, - ) - prompt_tokens = get_tokens(new_prompt) - completion_tokens = get_tokens(response) - total_tokens = int(prompt_tokens) + int(completion_tokens) - res_model = { - "id": conversation_name, - "object": "chat.completion", - "created": int(time.time()), - "model": agent_name, - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": str(response), - }, - "finish_reason": "stop", - "logprobs": None, - } - ], - "usage": { - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": total_tokens, - }, - } - return res_model + agixt = AGiXT(user=user, agent_name=prompt.model, api_key=authorization) + return await agixt.chat_completions(prompt=prompt) # Embedding endpoint diff --git a/agixt/endpoints/Memory.py b/agixt/endpoints/Memory.py index f75d71ecfd55..8a487243e527 100644 --- a/agixt/endpoints/Memory.py +++ b/agixt/endpoints/Memory.py @@ -191,13 +191,17 @@ async def learn_url( ) -> ResponseMessage: ApiClient = get_api_client(authorization=authorization) agent = Agent(agent_name=agent_name, user=user, ApiClient=ApiClient) - await Websearch( + url.url = url.url.replace(" ", "%20") + response = await Websearch( collection_number=url.collection_number, agent=agent, user=user, ApiClient=ApiClient, - ).get_web_content(url=url.url) - return ResponseMessage(message="Agent learned the content from the url.") + ).scrape_website( + user_input=f"I am browsing {url.url} and collecting data from it to learn more.", + search_depth=3, + ) + return ResponseMessage(message=response) @app.post( diff --git a/agixt/fb/Agent.py b/agixt/fb/Agent.py index 6a19862ee95b..97865409b2fd 100644 --- a/agixt/fb/Agent.py +++ b/agixt/fb/Agent.py @@ -120,9 +120,7 @@ def __init__(self, agent_name=None, user="USER", ApiClient=None): self.PROVIDER = Providers( name=self.AI_PROVIDER, ApiClient=ApiClient, **self.PROVIDER_SETTINGS ) - self._load_agent_config_keys( - ["AI_MODEL", "AI_TEMPERATURE", "MAX_TOKENS", "AUTONOMOUS_EXECUTION"] - ) + self._load_agent_config_keys(["AI_MODEL", "AI_TEMPERATURE", "MAX_TOKENS"]) tts_provider = ( self.AGENT_CONFIG["settings"]["tts_provider"] if "tts_provider" in self.AGENT_CONFIG["settings"] @@ -188,15 +186,6 @@ def __init__(self, agent_name=None, user="USER", ApiClient=None): self.MAX_TOKENS = self.PROVIDER_SETTINGS["MAX_TOKENS"] else: self.MAX_TOKENS = 4000 - if "AUTONOMOUS_EXECUTION" in self.PROVIDER_SETTINGS: - self.AUTONOMOUS_EXECUTION = self.PROVIDER_SETTINGS["AUTONOMOUS_EXECUTION"] - if isinstance(self.AUTONOMOUS_EXECUTION, str): - self.AUTONOMOUS_EXECUTION = self.AUTONOMOUS_EXECUTION.lower() - self.AUTONOMOUS_EXECUTION = ( - False if self.AUTONOMOUS_EXECUTION.lower() == "false" else True - ) - else: - self.AUTONOMOUS_EXECUTION = True self.commands = self.load_commands() self.available_commands = Extensions( agent_name=self.agent_name, diff --git a/agixt/prompts/Default/Website Summary.txt b/agixt/prompts/Default/Website Summary.txt new file mode 100644 index 000000000000..2fc73a0602c9 --- /dev/null +++ b/agixt/prompts/Default/Website Summary.txt @@ -0,0 +1,4 @@ +Content of {url} to summarize for the user: +{user_input} + +**Task: Summarize the content in as little text as possible without losing any details, it is important to retain details. If something in the content does not belong, such as a third party ad, do not include it in the summary. Do not summarize anything inside of code blocks, return fully populated code blocks if they exist. Do not mention the URL of the content in the summary.** diff --git a/agixt/readers/youtube.py b/agixt/readers/youtube.py index fc547a852272..93c8a8f07e02 100644 --- a/agixt/readers/youtube.py +++ b/agixt/readers/youtube.py @@ -20,7 +20,7 @@ def __init__( user=user, ) - async def write_youtube_captions_to_memory(self, video_id: str = None): + async def get_transcription(self, video_id: str = None): if "?v=" in video_id: video_id = video_id.split("?v=")[1] srt = YouTubeTranscriptApi.get_transcript(video_id, languages=["en"]) @@ -28,8 +28,14 @@ async def write_youtube_captions_to_memory(self, video_id: str = None): for line in srt: if line["text"] != "[Music]": content += line["text"].replace("[Music]", "") + " " + return content + + async def write_youtube_captions_to_memory(self, video_id: str = None): + content = await self.get_transcription(video_id=video_id) if content != "": - stored_content = f"From YouTube video: {video_id}\nContent: {content}" + stored_content = ( + f"Content from video at youtube.com/watch?v={video_id}:\n{content}" + ) await self.write_text_to_memory( user_input=video_id, text=stored_content, diff --git a/docs/2-Concepts/03-Agents.md b/docs/2-Concepts/03-Agents.md index 23c9fb41382c..d24eb6fcb596 100644 --- a/docs/2-Concepts/03-Agents.md +++ b/docs/2-Concepts/03-Agents.md @@ -1,7 +1,9 @@ # Agents + Agents are a combintation of a single model and a single directive. An agent can be given a task to pursue. In the course of pursuing this task, the agent may request the execution of commands through the AGiXT server. If this occurs, the result of that command will be passed back into the agent and execution will continue until the agent is satisfied that its goal is complete. ## Agent Settings + Agent Settings allow users to manage and configure their agents. This includes adding new agents, updating existing agents, and deleting agents as needed. Users can customize the provider and embedder used by the agent to generate responses. Additionally, users have the option to set custom settings and enable or disable specific agent commands, giving them fine control over the agent's behavior and capabilities. If the agent settings are not specified, the agent will use the default settings. The default settings are as follows: @@ -21,4 +23,3 @@ If the agent settings are not specified, the agent will use the default settings | `stream` | `False` | Whether or not to stream the response from the LLM provider. | | `WORKING_DIRECTORY` | `./WORKSPACE` | The working directory to use for the agent. | | `WORKING_DIRECTORY_RESTRICTED` | `True` | Whether or not to restrict the working directory to the agent's working directory. | -| `AUTONOMOUS_EXECUTION` | `True` | Whether or not to allow the agent to execute commands autonomously. Enabled by default. | diff --git a/requirements.txt b/requirements.txt index d8505e55de44..26ab36e2fa9f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -agixtsdk==0.0.37 +agixtsdk==0.0.45 safeexecute==0.0.9 google-generativeai==0.4.1 discord==2.3.2 @@ -16,4 +16,4 @@ google-auth==2.29.0 google-api-python-client==2.125.0 python-multipart==0.0.9 nest_asyncio -g4f==0.3.1.8 \ No newline at end of file +g4f==0.3.1.9 \ No newline at end of file diff --git a/start.py b/start.py index c55d61819e2b..0c954f57896c 100644 --- a/start.py +++ b/start.py @@ -176,8 +176,7 @@ def start_ezlocalai(): "WAIT_AFTER_FAILURE": 3, "WORKING_DIRECTORY": "./WORKSPACE", "WORKING_DIRECTORY_RESTRICTED": True, - "AUTONOMOUS_EXECUTION": True, - "PERSONA": "", + "persona": "", }, } os.makedirs("agixt/agents/AGiXT", exist_ok=True) diff --git a/static-requirements.txt b/static-requirements.txt index c9674331e693..8e1fcf42322b 100644 --- a/static-requirements.txt +++ b/static-requirements.txt @@ -1,10 +1,10 @@ chromadb==0.4.24 beautifulsoup4==4.12.3 -docker==7.0.0 +docker==6.1.3 docx2txt==0.8 GitPython==3.1.42 pdfplumber==0.11.0 -playwright==1.43.0 +playwright==1.44.0 pandas==2.1.4 PyYAML==6.0.1 requests==2.32.0 @@ -15,7 +15,7 @@ pillow==10.3.0 SQLAlchemy==2.0.29 psycopg2-binary==2.9.9 gTTS==2.5.1 -tiktoken==0.6.0 +tiktoken==0.7.0 PyJWT==2.8.0 websocket-client==1.7.0 lxml==5.1.1 From 676dbfc006ace2de0ad37ba6c17824373d08f11a Mon Sep 17 00:00:00 2001 From: Josh XT Date: Tue, 28 May 2024 16:34:11 -0400 Subject: [PATCH 0017/1256] initialize client in init --- agixt/Memories.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/agixt/Memories.py b/agixt/Memories.py index 1707c7597f36..8b609b59612a 100644 --- a/agixt/Memories.py +++ b/agixt/Memories.py @@ -176,7 +176,7 @@ def __init__( if "settings" in self.agent_config else {"embeddings_provider": "default"} ) - self.chroma_client = None + self.chroma_client = get_chroma_client() self.ApiClient = ApiClient self.embedding_provider = Providers( name=( @@ -195,8 +195,6 @@ def __init__( self.summarize_content = summarize_content async def wipe_memory(self): - if self.chroma_client == None: - self.chroma_client = get_chroma_client() try: self.chroma_client.delete_collection(name=self.collection_name) return True @@ -258,8 +256,6 @@ async def import_collections_from_json(self, json_data: List[dict]): # get collections that start with the collection name async def get_collections(self): - if self.chroma_client == None: - self.chroma_client = get_chroma_client() collections = self.chroma_client.list_collections() if int(self.collection_number) > 0: collection_name = snake(self.agent_name) @@ -273,8 +269,6 @@ async def get_collections(self): ] async def get_collection(self): - if self.chroma_client == None: - self.chroma_client = get_chroma_client() try: return self.chroma_client.get_collection( name=self.collection_name, embedding_function=self.embedder @@ -432,8 +426,6 @@ async def get_memories( return response def delete_memories_from_external_source(self, external_source: str): - if self.chroma_client == None: - self.chroma_client = get_chroma_client() collection = self.chroma_client.get_collection(name=self.collection_name) if collection: results = collection.query( @@ -447,8 +439,6 @@ def delete_memories_from_external_source(self, external_source: str): return False def get_external_data_sources(self): - if self.chroma_client == None: - self.chroma_client = get_chroma_client() collection = self.chroma_client.get_collection(name=self.collection_name) if collection: results = collection.query( From ed134f85101f5ae75c277d8b6b302e3c26a3698d Mon Sep 17 00:00:00 2001 From: Josh XT Date: Tue, 28 May 2024 17:02:40 -0400 Subject: [PATCH 0018/1256] try again if failing --- agixt/Memories.py | 33 ++++++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/agixt/Memories.py b/agixt/Memories.py index 8b609b59612a..4e26371a704b 100644 --- a/agixt/Memories.py +++ b/agixt/Memories.py @@ -3,6 +3,7 @@ import asyncio import sys import json +import time import spacy import chromadb from chromadb.config import Settings @@ -193,6 +194,7 @@ def __init__( ) self.embedder = self.embedding_provider.embedder self.summarize_content = summarize_content + self.failures = 0 async def wipe_memory(self): try: @@ -329,11 +331,32 @@ async def write_text_to_memory( (chunk + datetime.now().isoformat()).encode() ).hexdigest(), } - collection.add( - ids=metadata["id"], - metadatas=metadata, - documents=chunk, - ) + try: + collection.add( + ids=metadata["id"], + metadatas=metadata, + documents=chunk, + ) + except: + logging.warning(f"Error writing to memory: {chunk}") + # Try again 5 times before giving up + self.failures += 1 + for i in range(5): + try: + time.sleep(0.1) + collection.add( + ids=metadata["id"], + metadatas=metadata, + documents=chunk, + ) + self.failures = 0 + break + except: + self.failures += 1 + if self.failures > 5: + break + continue + return True async def get_memories_data( self, From 37ab9be922c0af08459b2537aeda5696caa3a05b Mon Sep 17 00:00:00 2001 From: Josh XT Date: Tue, 28 May 2024 17:05:55 -0400 Subject: [PATCH 0019/1256] handle new conversation better --- agixt/db/Conversations.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/agixt/db/Conversations.py b/agixt/db/Conversations.py index 1952b074f24e..18aaf2cd7943 100644 --- a/agixt/db/Conversations.py +++ b/agixt/db/Conversations.py @@ -149,6 +149,7 @@ def log_interaction(self, role, message): if not conversation: conversation = self.new_conversation() + session.close() session = get_session() timestamp = datetime.now().strftime("%B %d, %Y %I:%M %p") try: @@ -159,8 +160,8 @@ def log_interaction(self, role, message): conversation_id=conversation.id, ) except Exception as e: - logging.info(f"Error logging interaction: {e}") conversation = self.new_conversation() + session.close() session = get_session() new_message = Message( role=role, From e68b681eb34cc8600a4bcbc5de29f2a0e9d88719 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Tue, 28 May 2024 17:29:26 -0400 Subject: [PATCH 0020/1256] add get or create flag --- agixt/Memories.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/agixt/Memories.py b/agixt/Memories.py index 4e26371a704b..3752e5781aaa 100644 --- a/agixt/Memories.py +++ b/agixt/Memories.py @@ -276,10 +276,15 @@ async def get_collection(self): name=self.collection_name, embedding_function=self.embedder ) except: - self.chroma_client.create_collection( - name=self.collection_name, - embedding_function=self.embedder, - ) + try: + return self.chroma_client.create_collection( + name=self.collection_name, + embedding_function=self.embedder, + get_or_create=True, + ) + except: + # Collection already exists + pass return self.chroma_client.get_collection( name=self.collection_name, embedding_function=self.embedder ) From c19438fe7ae464b993cc8c16e2c0286282f7dba9 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Tue, 28 May 2024 17:41:29 -0400 Subject: [PATCH 0021/1256] scrape_websites vs scrape_website --- agixt/AGiXT.py | 2 +- agixt/Interactions.py | 7 +++++-- agixt/Websearch.py | 6 ++++-- agixt/endpoints/Memory.py | 2 +- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/agixt/AGiXT.py b/agixt/AGiXT.py index c918b77a489a..dd2cb8b65b67 100644 --- a/agixt/AGiXT.py +++ b/agixt/AGiXT.py @@ -369,7 +369,7 @@ async def learn_from_websites( role=self.agent_name, message=f"[ACTIVITY_START] Browsing the web... [ACTIVITY_END]", ) - response = await self.agent_interactions.websearch.scrape_website( + response = await self.agent_interactions.websearch.scrape_websites( user_input=user_input, search_depth=scrape_depth, summarize_content=summarize_content, diff --git a/agixt/Interactions.py b/agixt/Interactions.py index cb88295b4153..1746a495c6c3 100644 --- a/agixt/Interactions.py +++ b/agixt/Interactions.py @@ -449,8 +449,11 @@ async def run( else: websearch_timeout = 0 if browse_links != False: - await self.websearch.scrape_website( - user_input=user_input, search_depth=websearch_depth + await self.websearch.scrape_websites( + user_input=user_input, + search_depth=websearch_depth, + summarize_content=True, + conversation_name=conversation_name, ) if websearch: if user_input == "": diff --git a/agixt/Websearch.py b/agixt/Websearch.py index c03db7f75a6e..f5b074b7cb1b 100644 --- a/agixt/Websearch.py +++ b/agixt/Websearch.py @@ -427,7 +427,7 @@ async def search(self, query: str) -> List[str]: self.searx_instance_url = "" return await self.search(query=query) - async def scrape_website( + async def scrape_websites( self, user_input: str = "", search_depth: int = 0, @@ -437,6 +437,8 @@ async def scrape_website( # user_input = "I am browsing {url} and collecting data from it to learn more." c = Conversations(conversation_name=conversation_name, user=self.user) links = re.findall(r"(?Phttps?://[^\s]+)", user_input) + if len(links) < 1: + return "" scraped_links = [] if links is not None and len(links) > 0: for link in links: @@ -493,7 +495,7 @@ async def websearch_agent( websearch_depth: int = 0, websearch_timeout: int = 0, ): - await self.scrape_website(user_input=user_input, search_depth=websearch_depth) + await self.scrape_websites(user_input=user_input, search_depth=websearch_depth) try: websearch_depth = int(websearch_depth) except: diff --git a/agixt/endpoints/Memory.py b/agixt/endpoints/Memory.py index 8a487243e527..1c821e9dc281 100644 --- a/agixt/endpoints/Memory.py +++ b/agixt/endpoints/Memory.py @@ -197,7 +197,7 @@ async def learn_url( agent=agent, user=user, ApiClient=ApiClient, - ).scrape_website( + ).scrape_websites( user_input=f"I am browsing {url.url} and collecting data from it to learn more.", search_depth=3, ) From e28e5488ef8a28a427129713a6762dc8a40a6491 Mon Sep 17 00:00:00 2001 From: Josh XT <102809327+Josh-XT@users.noreply.github.com> Date: Wed, 29 May 2024 13:13:14 -0400 Subject: [PATCH 0022/1256] v1.6.0 (#1195) * v1.6.0 * add db to compose files * Allow sqlite * fix user creation * add default user * remove unused functions * create default user if it doesn't exist * add user id * add logging for test * add wait for server in test * fix searxng search default * lower username * clean up * change func name to webhook_create_user --- .github/workflows/publish-docker-dev.yml | 13 - .github/workflows/publish-docker.yml | 13 - agixt/AGiXT.py | 2 +- agixt/{db => }/Agent.py | 15 +- agixt/ApiClient.py | 34 +- agixt/{db => }/Chain.py | 4 +- agixt/Chains.py | 2 +- agixt/{db => }/Conversations.py | 4 +- agixt/DB.py | 481 +++++++++++++++++++++++ agixt/DBConnection.py | 287 -------------- agixt/Extensions.py | 2 +- agixt/{Defaults.py => Globals.py} | 5 +- agixt/Interactions.py | 6 +- agixt/MagicalAuth.py | 425 ++++++++++++++++++++ agixt/Memories.py | 6 +- agixt/Models.py | 29 +- agixt/{db => }/Prompts.py | 4 +- agixt/Providers.py | 2 +- agixt/{db/imports.py => SeedImports.py} | 10 +- agixt/Tunnel.py | 2 +- agixt/Websearch.py | 7 +- agixt/app.py | 2 +- agixt/db/User.py | 55 --- agixt/endpoints/Agent.py | 2 +- agixt/endpoints/Auth.py | 132 +++++++ agixt/endpoints/Completions.py | 2 +- agixt/endpoints/Provider.py | 45 +-- agixt/extensions/agixt_actions.py | 2 +- agixt/fb/Agent.py | 389 ------------------ agixt/fb/Chain.py | 284 ------------- agixt/fb/Conversations.py | 103 ----- agixt/fb/Prompts.py | 112 ------ agixt/launch-backend.sh | 9 +- agixt/providers/ezlocalai.py | 2 +- agixt/providers/openai.py | 2 +- agixt/version | 2 +- docker-compose-dev.yml | 3 +- docker-compose.yml | 22 +- tests/completions-tests.ipynb | 2 + 39 files changed, 1143 insertions(+), 1380 deletions(-) rename agixt/{db => }/Agent.py (97%) rename agixt/{db => }/Chain.py (99%) rename agixt/{db => }/Conversations.py (99%) create mode 100644 agixt/DB.py delete mode 100644 agixt/DBConnection.py rename agixt/{Defaults.py => Globals.py} (94%) create mode 100644 agixt/MagicalAuth.py rename agixt/{db => }/Prompts.py (98%) rename agixt/{db/imports.py => SeedImports.py} (98%) delete mode 100644 agixt/db/User.py create mode 100644 agixt/endpoints/Auth.py delete mode 100644 agixt/fb/Agent.py delete mode 100644 agixt/fb/Chain.py delete mode 100644 agixt/fb/Conversations.py delete mode 100644 agixt/fb/Prompts.py diff --git a/.github/workflows/publish-docker-dev.yml b/.github/workflows/publish-docker-dev.yml index 3cdaa72548a8..2a755385f806 100644 --- a/.github/workflows/publish-docker-dev.yml +++ b/.github/workflows/publish-docker-dev.yml @@ -51,28 +51,15 @@ jobs: notebook: tests/tests.ipynb image: ghcr.io/${{ needs.build-agixt.outputs.github_user }}/${{ needs.build-agixt.outputs.repo_name }}:${{ github.sha }} port: "7437" - db-connected: false report-name: "agixt-tests" additional-python-dependencies: agixtsdk needs: build-agixt - test-agixt-db: - uses: josh-xt/AGiXT/.github/workflows/operation-test-with-jupyter.yml@main - with: - notebook: tests/tests.ipynb - image: ghcr.io/${{ needs.build-agixt.outputs.github_user }}/${{ needs.build-agixt.outputs.repo_name }}:${{ github.sha }} - port: "7437" - db-connected: true - database-type: "postgresql" - report-name: "agixt-db-tests" - additional-python-dependencies: agixtsdk - needs: build-agixt test-completions: uses: josh-xt/AGiXT/.github/workflows/operation-test-with-jupyter.yml@main with: notebook: tests/completions-tests.ipynb image: ghcr.io/${{ needs.build-agixt.outputs.github_user }}/${{ needs.build-agixt.outputs.repo_name }}:${{ github.sha }} port: "7437" - db-connected: false report-name: "completions-tests" additional-python-dependencies: openai requests python-dotenv needs: build-agixt \ No newline at end of file diff --git a/.github/workflows/publish-docker.yml b/.github/workflows/publish-docker.yml index db0799feeada..39e1d65f5467 100644 --- a/.github/workflows/publish-docker.yml +++ b/.github/workflows/publish-docker.yml @@ -31,28 +31,15 @@ jobs: notebook: tests/tests.ipynb image: ${{ needs.build-agixt.outputs.primary-image }} port: "7437" - db-connected: false report-name: "agixt-tests" additional-python-dependencies: agixtsdk needs: build-agixt - test-agixt-db: - uses: josh-xt/AGiXT/.github/workflows/operation-test-with-jupyter.yml@main - with: - notebook: tests/tests.ipynb - image: ${{ needs.build-agixt.outputs.primary-image }} - port: "7437" - db-connected: true - database-type: "postgresql" - report-name: "agixt-db-tests" - additional-python-dependencies: agixtsdk - needs: build-agixt test-completions: uses: josh-xt/AGiXT/.github/workflows/operation-test-with-jupyter.yml@main with: notebook: tests/completions-tests.ipynb image: ${{ needs.build-agixt.outputs.primary-image }} port: "7437" - db-connected: false report-name: "completions-tests" additional-python-dependencies: openai requests python-dotenv needs: build-agixt \ No newline at end of file diff --git a/agixt/AGiXT.py b/agixt/AGiXT.py index dd2cb8b65b67..494ee8b0a126 100644 --- a/agixt/AGiXT.py +++ b/agixt/AGiXT.py @@ -4,7 +4,7 @@ from Extensions import Extensions from Chains import Chains from pydub import AudioSegment -from Defaults import getenv, get_tokens, DEFAULT_SETTINGS +from Globals import getenv, get_tokens, DEFAULT_SETTINGS from Models import ChatCompletions import os import base64 diff --git a/agixt/db/Agent.py b/agixt/Agent.py similarity index 97% rename from agixt/db/Agent.py rename to agixt/Agent.py index 4047e6fef022..919011b059e4 100644 --- a/agixt/db/Agent.py +++ b/agixt/Agent.py @@ -1,4 +1,4 @@ -from DBConnection import ( +from DB import ( Agent as AgentModel, AgentSetting as AgentSettingModel, AgentBrowsedLink, @@ -15,7 +15,7 @@ ) from Providers import Providers from Extensions import Extensions -from Defaults import getenv, DEFAULT_SETTINGS, DEFAULT_USER +from Globals import getenv, DEFAULT_SETTINGS, DEFAULT_USER from datetime import datetime, timezone, timedelta import logging import json @@ -169,9 +169,14 @@ class Agent: def __init__(self, agent_name=None, user=DEFAULT_USER, ApiClient=None): self.agent_name = agent_name if agent_name is not None else "AGiXT" self.session = get_session() - self.user = user - user_data = self.session.query(User).filter(User.email == self.user).first() - self.user_id = user_data.id + user = user if user is not None else DEFAULT_USER + self.user = user.lower() + try: + user_data = self.session.query(User).filter(User.email == self.user).first() + self.user_id = user_data.id + except Exception as e: + logging.error(f"User {self.user} not found.") + raise self.AGENT_CONFIG = self.get_agent_config() self.load_config_keys() if "settings" not in self.AGENT_CONFIG: diff --git a/agixt/ApiClient.py b/agixt/ApiClient.py index 351c62745c9e..82fb43d6c0a2 100644 --- a/agixt/ApiClient.py +++ b/agixt/ApiClient.py @@ -2,30 +2,21 @@ import jwt from agixtsdk import AGiXTSDK from fastapi import Header, HTTPException -from Defaults import getenv +from Globals import getenv from datetime import datetime logging.basicConfig( level=getenv("LOG_LEVEL"), format=getenv("LOG_FORMAT"), ) -DB_CONNECTED = True if getenv("DB_CONNECTED").lower() == "true" else False WORKERS = int(getenv("UVICORN_WORKERS")) AGIXT_URI = getenv("AGIXT_URI") # Defining these here to be referenced externally. -if DB_CONNECTED: - from db.Agent import Agent, add_agent, delete_agent, rename_agent, get_agents - from db.Chain import Chain - from db.Prompts import Prompts - from db.Conversations import Conversations - from db.User import User -else: - from fb.Agent import Agent, add_agent, delete_agent, rename_agent, get_agents - from fb.Chain import Chain - from fb.Prompts import Prompts - from fb.Conversations import Conversations - from Models import User_fb as User +from Agent import Agent, add_agent, delete_agent, rename_agent, get_agents +from Chain import Chain +from Prompts import Prompts +from Conversations import Conversations def verify_api_key(authorization: str = Header(None)): @@ -34,7 +25,7 @@ def verify_api_key(authorization: str = Header(None)): DEFAULT_USER = getenv("DEFAULT_USER") authorization = str(authorization).replace("Bearer ", "").replace("bearer ", "") if DEFAULT_USER == "" or DEFAULT_USER is None or DEFAULT_USER == "None": - DEFAULT_USER = "USER" + DEFAULT_USER = "user" if getenv("AUTH_PROVIDER") == "magicalauth": auth_key = AGIXT_API_KEY + str(datetime.now().strftime("%Y%m%d")) try: @@ -88,20 +79,15 @@ def is_admin(email: str = "USER", api_key: str = None): return True # Commenting out functionality until testing is complete. AGIXT_API_KEY = getenv("AGIXT_API_KEY") - DB_CONNECTED = True if getenv("DB_CONNECTED").lower() == "true" else False - if DB_CONNECTED != True: - return True if api_key is None: api_key = "" api_key = api_key.replace("Bearer ", "").replace("bearer ", "") if AGIXT_API_KEY == api_key: return True - if DB_CONNECTED == True: - from db.User import is_agixt_admin + if email == "" or email is None or email == "None": + email = getenv("DEFAULT_USER") if email == "" or email is None or email == "None": - email = getenv("DEFAULT_USER") - if email == "" or email is None or email == "None": - email = "USER" - return is_agixt_admin(email=email, api_key=api_key) + email = "USER" + return is_agixt_admin(email=email, api_key=api_key) return False diff --git a/agixt/db/Chain.py b/agixt/Chain.py similarity index 99% rename from agixt/db/Chain.py rename to agixt/Chain.py index b0ac07e07564..054a402b3b95 100644 --- a/agixt/db/Chain.py +++ b/agixt/Chain.py @@ -1,4 +1,4 @@ -from DBConnection import ( +from DB import ( get_session, Chain as ChainDB, ChainStep, @@ -10,7 +10,7 @@ Command, User, ) -from Defaults import getenv, DEFAULT_USER +from Globals import getenv, DEFAULT_USER import logging logging.basicConfig( diff --git a/agixt/Chains.py b/agixt/Chains.py index cf2919ff07ce..6951a5055c31 100644 --- a/agixt/Chains.py +++ b/agixt/Chains.py @@ -1,5 +1,5 @@ import logging -from Defaults import getenv +from Globals import getenv from ApiClient import Chain, Prompts, Conversations from Extensions import Extensions diff --git a/agixt/db/Conversations.py b/agixt/Conversations.py similarity index 99% rename from agixt/db/Conversations.py rename to agixt/Conversations.py index 18aaf2cd7943..420e546ecdb1 100644 --- a/agixt/db/Conversations.py +++ b/agixt/Conversations.py @@ -1,12 +1,12 @@ from datetime import datetime import logging -from DBConnection import ( +from DB import ( Conversation, Message, User, get_session, ) -from Defaults import getenv, DEFAULT_USER +from Globals import getenv, DEFAULT_USER logging.basicConfig( level=getenv("LOG_LEVEL"), diff --git a/agixt/DB.py b/agixt/DB.py new file mode 100644 index 000000000000..6d78c2879516 --- /dev/null +++ b/agixt/DB.py @@ -0,0 +1,481 @@ +import uuid +import time +import logging +from sqlalchemy import ( + create_engine, + Column, + Text, + String, + Integer, + ForeignKey, + DateTime, + Boolean, + func, +) +from sqlalchemy.orm import sessionmaker, relationship, declarative_base +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.sql import text +from Globals import getenv + +logging.basicConfig( + level=getenv("LOG_LEVEL"), + format=getenv("LOG_FORMAT"), +) +DEFAULT_USER = getenv("DEFAULT_USER") +try: + DATABASE_TYPE = getenv("DATABASE_TYPE") + DATABASE_NAME = getenv("DATABASE_NAME") + if DATABASE_TYPE != "sqlite": + DATABASE_USER = getenv("DATABASE_USER") + DATABASE_PASSWORD = getenv("DATABASE_PASSWORD") + DATABASE_HOST = getenv("DATABASE_HOST") + DATABASE_PORT = getenv("DATABASE_PORT") + LOGIN_URI = f"{DATABASE_USER}:{DATABASE_PASSWORD}@{DATABASE_HOST}:{DATABASE_PORT}/{DATABASE_NAME}" + DATABASE_URI = f"postgresql://{LOGIN_URI}" + else: + DATABASE_URI = f"sqlite:///{DATABASE_NAME}.db" + engine = create_engine(DATABASE_URI, pool_size=40, max_overflow=-1) + connection = engine.connect() + Base = declarative_base() +except Exception as e: + logging.error(f"Error connecting to database: {e}") + Base = None + engine = None + + +def get_session(): + Session = sessionmaker(bind=engine, autoflush=False) + session = Session() + return session + + +class User(Base): + __tablename__ = "user" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + email = Column(String, unique=True) + first_name = Column(String, default="", nullable=True) + last_name = Column(String, default="", nullable=True) + admin = Column(Boolean, default=False, nullable=False) + mfa_token = Column(String, default="", nullable=True) + created_at = Column(DateTime, server_default=func.now()) + updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now()) + is_active = Column(Boolean, default=True) + + +class FailedLogins(Base): + __tablename__ = "failed_logins" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + user_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("user.id"), + ) + user = relationship("User") + ip_address = Column(String, default="", nullable=True) + created_at = Column(DateTime, server_default=func.now()) + + +class Provider(Base): + __tablename__ = "provider" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + name = Column(Text, nullable=False) + provider_settings = relationship("ProviderSetting", backref="provider") + + +class ProviderSetting(Base): + __tablename__ = "provider_setting" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + provider_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("provider.id"), + nullable=False, + ) + name = Column(Text, nullable=False) + value = Column(Text, nullable=True) + + +class AgentProviderSetting(Base): + __tablename__ = "agent_provider_setting" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + provider_setting_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("provider_setting.id"), + nullable=False, + ) + agent_provider_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("agent_provider.id"), + nullable=False, + ) + value = Column(Text, nullable=True) + + +class AgentProvider(Base): + __tablename__ = "agent_provider" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + provider_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("provider.id"), + nullable=False, + ) + agent_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("agent.id"), + nullable=False, + ) + settings = relationship("AgentProviderSetting", backref="agent_provider") + + +class AgentBrowsedLink(Base): + __tablename__ = "agent_browsed_link" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + agent_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("agent.id"), + nullable=False, + ) + link = Column(Text, nullable=False) + timestamp = Column(DateTime, server_default=text("now()")) + + +class Agent(Base): + __tablename__ = "agent" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + name = Column(Text, nullable=False) + provider_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("provider.id"), + nullable=True, + default=None, + ) + user_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("user.id"), + nullable=True, + ) + settings = relationship("AgentSetting", backref="agent") # One-to-many relationship + browsed_links = relationship("AgentBrowsedLink", backref="agent") + user = relationship("User", backref="agent") + + +class Command(Base): + __tablename__ = "command" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + name = Column(Text, nullable=False) + extension_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("extension.id"), + ) + extension = relationship("Extension", backref="commands") + + +class AgentCommand(Base): + __tablename__ = "agent_command" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + command_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("command.id"), + nullable=False, + ) + agent_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("agent.id"), + nullable=False, + ) + state = Column(Boolean, nullable=False) + command = relationship("Command") # Add this line to define the relationship + + +class Conversation(Base): + __tablename__ = "conversation" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + name = Column(Text, nullable=False) + user_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("user.id"), + nullable=True, + ) + user = relationship("User", backref="conversation") + + +class Message(Base): + __tablename__ = "message" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + role = Column(Text, nullable=False) + content = Column(Text, nullable=False) + timestamp = Column(DateTime, server_default=text("now()")) + conversation_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("conversation.id"), + nullable=False, + ) + + +class Setting(Base): + __tablename__ = "setting" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + name = Column(Text, nullable=False) + extension_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("extension.id"), + ) + value = Column(Text) + + +class AgentSetting(Base): + __tablename__ = "agent_setting" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + agent_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("agent.id"), + nullable=False, + ) + name = Column(String) + value = Column(String) + + +class Chain(Base): + __tablename__ = "chain" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + name = Column(Text, nullable=False) + description = Column(Text, nullable=True) + user_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("user.id"), + nullable=True, + ) + steps = relationship( + "ChainStep", + backref="chain", + cascade="all, delete", # Add the cascade option for deleting steps + passive_deletes=True, + foreign_keys="ChainStep.chain_id", + ) + target_steps = relationship( + "ChainStep", backref="target_chain", foreign_keys="ChainStep.target_chain_id" + ) + user = relationship("User", backref="chain") + + +class ChainStep(Base): + __tablename__ = "chain_step" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + chain_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("chain.id", ondelete="CASCADE"), + nullable=False, + ) + agent_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("agent.id"), + nullable=False, + ) + prompt_type = Column(Text) # Add the prompt_type field + prompt = Column(Text) # Add the prompt field + target_chain_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("chain.id", ondelete="SET NULL"), + ) + target_command_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("command.id", ondelete="SET NULL"), + ) + target_prompt_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("prompt.id", ondelete="SET NULL"), + ) + step_number = Column(Integer, nullable=False) + responses = relationship( + "ChainStepResponse", backref="chain_step", cascade="all, delete" + ) + + def add_response(self, content): + session = get_session() + response = ChainStepResponse(content=content, chain_step=self) + session.add(response) + session.commit() + + +class ChainStepArgument(Base): + __tablename__ = "chain_step_argument" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + argument_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("argument.id"), + nullable=False, + ) + chain_step_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("chain_step.id", ondelete="CASCADE"), + nullable=False, # Add the ondelete option + ) + value = Column(Text, nullable=True) + + +class ChainStepResponse(Base): + __tablename__ = "chain_step_response" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + chain_step_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("chain_step.id", ondelete="CASCADE"), + nullable=False, # Add the ondelete option + ) + timestamp = Column(DateTime, server_default=text("now()")) + content = Column(Text, nullable=False) + + +class Extension(Base): + __tablename__ = "extension" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + name = Column(Text, nullable=False) + description = Column(Text, nullable=True, default="") + + +class Argument(Base): + __tablename__ = "argument" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + prompt_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("prompt.id"), + ) + command_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("command.id"), + ) + chain_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("chain.id"), + ) + name = Column(Text, nullable=False) + + +class PromptCategory(Base): + __tablename__ = "prompt_category" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + name = Column(Text, nullable=False) + description = Column(Text, nullable=False) + user_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("user.id"), + nullable=True, + ) + user = relationship("User", backref="prompt_category") + + +class Prompt(Base): + __tablename__ = "prompt" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + prompt_category_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("prompt_category.id"), + nullable=False, + ) + name = Column(Text, nullable=False) + description = Column(Text, nullable=False) + content = Column(Text, nullable=False) + user_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("user.id"), + nullable=True, + ) + prompt_category = relationship("PromptCategory", backref="prompts") + user = relationship("User", backref="prompt") + arguments = relationship("Argument", backref="prompt", cascade="all, delete-orphan") + + +if __name__ == "__main__": + logging.info("Connecting to database...") + time.sleep(10) + Base.metadata.create_all(engine) + logging.info("Connected to database.") + # Check if the user table is empty + from SeedImports import import_all_data + + import_all_data() diff --git a/agixt/DBConnection.py b/agixt/DBConnection.py deleted file mode 100644 index 4f372dc8f049..000000000000 --- a/agixt/DBConnection.py +++ /dev/null @@ -1,287 +0,0 @@ -import uuid -import time -import logging -from sqlalchemy import ( - create_engine, - Column, - Text, - String, - Integer, - ForeignKey, - DateTime, - Boolean, -) -from sqlalchemy.orm import sessionmaker, relationship, declarative_base -from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.sql import text -from Defaults import getenv - -logging.basicConfig( - level=getenv("LOG_LEVEL"), - format=getenv("LOG_FORMAT"), -) -DB_CONNECTED = True if getenv("DB_CONNECTED").lower() == "true" else False -DEFAULT_USER = getenv("DEFAULT_USER") -if DB_CONNECTED: - DATABASE_USER = getenv("DATABASE_USER") - DATABASE_PASSWORD = getenv("DATABASE_PASSWORD") - DATABASE_HOST = getenv("DATABASE_HOST") - DATABASE_PORT = getenv("DATABASE_PORT") - DATABASE_NAME = getenv("DATABASE_NAME") - LOGIN_URI = f"{DATABASE_USER}:{DATABASE_PASSWORD}@{DATABASE_HOST}:{DATABASE_PORT}/{DATABASE_NAME}" - DATABASE_URL = f"postgresql://{LOGIN_URI}" - try: - engine = create_engine(DATABASE_URL, pool_size=40, max_overflow=-1) - except Exception as e: - logging.error(f"Error connecting to database: {e}") - connection = engine.connect() - Base = declarative_base() -else: - Base = None - engine = None - - -def get_session(): - Session = sessionmaker(bind=engine, autoflush=False) - session = Session() - return session - - -class User(Base): - __tablename__ = "user" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - email = Column(String, default=DEFAULT_USER, unique=True) - role = Column(String, default="user") - - -class Provider(Base): - __tablename__ = "provider" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - name = Column(Text, nullable=False) - provider_settings = relationship("ProviderSetting", backref="provider") - - -class ProviderSetting(Base): - __tablename__ = "provider_setting" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - provider_id = Column(UUID(as_uuid=True), ForeignKey("provider.id"), nullable=False) - name = Column(Text, nullable=False) - value = Column(Text, nullable=True) - - -class AgentProviderSetting(Base): - __tablename__ = "agent_provider_setting" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - provider_setting_id = Column( - UUID(as_uuid=True), ForeignKey("provider_setting.id"), nullable=False - ) - agent_provider_id = Column( - UUID(as_uuid=True), ForeignKey("agent_provider.id"), nullable=False - ) - value = Column(Text, nullable=True) - - -class AgentProvider(Base): - __tablename__ = "agent_provider" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - provider_id = Column(UUID(as_uuid=True), ForeignKey("provider.id"), nullable=False) - agent_id = Column(UUID(as_uuid=True), ForeignKey("agent.id"), nullable=False) - settings = relationship("AgentProviderSetting", backref="agent_provider") - - -class AgentBrowsedLink(Base): - __tablename__ = "agent_browsed_link" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - agent_id = Column(UUID(as_uuid=True), ForeignKey("agent.id"), nullable=False) - link = Column(Text, nullable=False) - timestamp = Column(DateTime, server_default=text("now()")) - - -class Agent(Base): - __tablename__ = "agent" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - name = Column(Text, nullable=False) - provider_id = Column( - UUID(as_uuid=True), ForeignKey("provider.id"), nullable=True, default=None - ) - user_id = Column(UUID(as_uuid=True), ForeignKey("user.id"), nullable=True) - settings = relationship("AgentSetting", backref="agent") # One-to-many relationship - browsed_links = relationship("AgentBrowsedLink", backref="agent") - user = relationship("User", backref="agent") - - -class Command(Base): - __tablename__ = "command" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - name = Column(Text, nullable=False) - extension_id = Column(UUID(as_uuid=True), ForeignKey("extension.id")) - extension = relationship("Extension", backref="commands") - - -class AgentCommand(Base): - __tablename__ = "agent_command" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - command_id = Column(UUID(as_uuid=True), ForeignKey("command.id"), nullable=False) - agent_id = Column(UUID(as_uuid=True), ForeignKey("agent.id"), nullable=False) - state = Column(Boolean, nullable=False) - command = relationship("Command") # Add this line to define the relationship - - -class Conversation(Base): - __tablename__ = "conversation" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - name = Column(Text, nullable=False) - user_id = Column(UUID(as_uuid=True), ForeignKey("user.id"), nullable=True) - user = relationship("User", backref="conversation") - - -class Message(Base): - __tablename__ = "message" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - role = Column(Text, nullable=False) - content = Column(Text, nullable=False) - timestamp = Column(DateTime, server_default=text("now()")) - conversation_id = Column( - UUID(as_uuid=True), ForeignKey("conversation.id"), nullable=False - ) - - -class Setting(Base): - __tablename__ = "setting" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - name = Column(Text, nullable=False) - extension_id = Column(UUID(as_uuid=True), ForeignKey("extension.id")) - value = Column(Text) - - -class AgentSetting(Base): - __tablename__ = "agent_setting" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - agent_id = Column(UUID(as_uuid=True), ForeignKey("agent.id"), nullable=False) - name = Column(String) - value = Column(String) - - -class Chain(Base): - __tablename__ = "chain" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - name = Column(Text, nullable=False) - description = Column(Text, nullable=True) - user_id = Column(UUID(as_uuid=True), ForeignKey("user.id"), nullable=True) - steps = relationship( - "ChainStep", - backref="chain", - cascade="all, delete", # Add the cascade option for deleting steps - passive_deletes=True, - foreign_keys="ChainStep.chain_id", - ) - target_steps = relationship( - "ChainStep", backref="target_chain", foreign_keys="ChainStep.target_chain_id" - ) - user = relationship("User", backref="chain") - - -class ChainStep(Base): - __tablename__ = "chain_step" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - chain_id = Column( - UUID(as_uuid=True), ForeignKey("chain.id", ondelete="CASCADE"), nullable=False - ) - agent_id = Column(UUID(as_uuid=True), ForeignKey("agent.id"), nullable=False) - prompt_type = Column(Text) # Add the prompt_type field - prompt = Column(Text) # Add the prompt field - target_chain_id = Column( - UUID(as_uuid=True), ForeignKey("chain.id", ondelete="SET NULL") - ) - target_command_id = Column( - UUID(as_uuid=True), ForeignKey("command.id", ondelete="SET NULL") - ) - target_prompt_id = Column( - UUID(as_uuid=True), ForeignKey("prompt.id", ondelete="SET NULL") - ) - step_number = Column(Integer, nullable=False) - responses = relationship( - "ChainStepResponse", backref="chain_step", cascade="all, delete" - ) - - def add_response(self, content): - session = get_session() - response = ChainStepResponse(content=content, chain_step=self) - session.add(response) - session.commit() - - -class ChainStepArgument(Base): - __tablename__ = "chain_step_argument" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - argument_id = Column(UUID(as_uuid=True), ForeignKey("argument.id"), nullable=False) - chain_step_id = Column( - UUID(as_uuid=True), - ForeignKey("chain_step.id", ondelete="CASCADE"), - nullable=False, # Add the ondelete option - ) - value = Column(Text, nullable=True) - - -class ChainStepResponse(Base): - __tablename__ = "chain_step_response" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - chain_step_id = Column( - UUID(as_uuid=True), - ForeignKey("chain_step.id", ondelete="CASCADE"), - nullable=False, # Add the ondelete option - ) - timestamp = Column(DateTime, server_default=text("now()")) - content = Column(Text, nullable=False) - - -class Extension(Base): - __tablename__ = "extension" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - name = Column(Text, nullable=False) - description = Column(Text, nullable=True, default="") - - -class Argument(Base): - __tablename__ = "argument" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - prompt_id = Column(UUID(as_uuid=True), ForeignKey("prompt.id")) - command_id = Column(UUID(as_uuid=True), ForeignKey("command.id")) - chain_id = Column(UUID(as_uuid=True), ForeignKey("chain.id")) - name = Column(Text, nullable=False) - - -class PromptCategory(Base): - __tablename__ = "prompt_category" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - name = Column(Text, nullable=False) - description = Column(Text, nullable=False) - user_id = Column(UUID(as_uuid=True), ForeignKey("user.id"), nullable=True) - user = relationship("User", backref="prompt_category") - - -class Prompt(Base): - __tablename__ = "prompt" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - prompt_category_id = Column( - UUID(as_uuid=True), ForeignKey("prompt_category.id"), nullable=False - ) - name = Column(Text, nullable=False) - description = Column(Text, nullable=False) - content = Column(Text, nullable=False) - user_id = Column(UUID(as_uuid=True), ForeignKey("user.id"), nullable=True) - prompt_category = relationship("PromptCategory", backref="prompts") - user = relationship("User", backref="prompt") - arguments = relationship("Argument", backref="prompt", cascade="all, delete-orphan") - - -if __name__ == "__main__": - if DB_CONNECTED: - logging.info("Connecting to database...") - time.sleep(10) - Base.metadata.create_all(engine) - logging.info("Connected to database.") - # Check if the user table is empty - from db.imports import import_all_data - - import_all_data() diff --git a/agixt/Extensions.py b/agixt/Extensions.py index dcc7a3c842c4..9c8bbe02bcb5 100644 --- a/agixt/Extensions.py +++ b/agixt/Extensions.py @@ -4,7 +4,7 @@ from inspect import signature, Parameter import logging import inspect -from Defaults import getenv, DEFAULT_USER +from Globals import getenv, DEFAULT_USER logging.basicConfig( level=getenv("LOG_LEVEL"), diff --git a/agixt/Defaults.py b/agixt/Globals.py similarity index 94% rename from agixt/Defaults.py rename to agixt/Globals.py index f09bdbef6fe5..28efef1582bf 100644 --- a/agixt/Defaults.py +++ b/agixt/Globals.py @@ -44,13 +44,12 @@ def getenv(var_name: str): "LOG_LEVEL": "INFO", "LOG_FORMAT": "%(asctime)s | %(levelname)s | %(message)s", "UVICORN_WORKERS": 10, - "DB_CONNECTED": "false", "DATABASE_NAME": "postgres", "DATABASE_USER": "postgres", "DATABASE_PASSWORD": "postgres", "DATABASE_HOST": "localhost", "DATABASE_PORT": "5432", - "DEFAULT_USER": "USER", + "DEFAULT_USER": "user", "USING_JWT": "false", "CHROMA_PORT": "8000", "CHROMA_SSL": "false", @@ -68,4 +67,4 @@ def get_tokens(text: str) -> int: return num_tokens -DEFAULT_USER = getenv("DEFAULT_USER") +DEFAULT_USER = str(getenv("DEFAULT_USER")).lower() diff --git a/agixt/Interactions.py b/agixt/Interactions.py index 1746a495c6c3..b58eef729934 100644 --- a/agixt/Interactions.py +++ b/agixt/Interactions.py @@ -18,7 +18,7 @@ Conversations, AGIXT_URI, ) -from Defaults import getenv, DEFAULT_USER, get_tokens +from Globals import getenv, DEFAULT_USER, get_tokens logging.basicConfig( level=getenv("LOG_LEVEL"), @@ -470,8 +470,8 @@ async def run( websearch_depth=websearch_depth, websearch_timeout=websearch_timeout, ) - except: - logging.warning("Failed to websearch.") + except Exception as e: + logging.warning("Failed to websearch. Error: {e}") vision_response = "" if "vision_provider" in self.agent.AGENT_CONFIG["settings"]: vision_provider = self.agent.AGENT_CONFIG["settings"]["vision_provider"] diff --git a/agixt/MagicalAuth.py b/agixt/MagicalAuth.py new file mode 100644 index 000000000000..471580bbf525 --- /dev/null +++ b/agixt/MagicalAuth.py @@ -0,0 +1,425 @@ +from DB import User, FailedLogins, get_session +from Models import UserInfo, Register, Login +from fastapi import Header, HTTPException +from Globals import getenv +from datetime import datetime, timedelta +from Agent import add_agent +from agixtsdk import AGiXTSDK +from fastapi import HTTPException +from sendgrid import SendGridAPIClient +from sendgrid.helpers.mail import ( + Attachment, + FileContent, + FileName, + FileType, + Disposition, + Mail, +) +import pyotp +import requests +import logging +import jwt + + +logging.basicConfig( + level=getenv("LOG_LEVEL"), + format=getenv("LOG_FORMAT"), +) +""" +Required environment variables: + +- SENDGRID_API_KEY: SendGrid API key +- SENDGRID_FROM_EMAIL: Default email address to send emails from +- ENCRYPTION_SECRET: Encryption key to encrypt and decrypt data +- MAGIC_LINK_URL: URL to send in the email for the user to click on +- REGISTRATION_WEBHOOK: URL to send a POST request to when a user registers +""" + + +def is_agixt_admin(email: str = "", api_key: str = ""): + if api_key == getenv("AGIXT_API_KEY"): + return True + session = get_session() + user = session.query(User).filter_by(email=email).first() + if not user: + return False + if user.admin is True: + return True + return False + + +def webhook_create_user( + api_key: str, + email: str, + role: str = "user", + agent_name: str = "", + settings: dict = {}, + commands: dict = {}, + training_urls: list = [], + github_repos: list = [], + ApiClient: AGiXTSDK = AGiXTSDK(), +): + if not is_agixt_admin(email=email, api_key=api_key): + return {"error": "Access Denied"}, 403 + session = get_session() + email = email.lower() + user_exists = session.query(User).filter_by(email=email).first() + if user_exists: + session.close() + return {"error": "User already exists"}, 400 + admin = True if role.lower() == "admin" else False + user = User( + email=email, + admin=admin, + first_name="", + last_name="", + ) + session.add(user) + session.commit() + session.close() + if agent_name != "" and agent_name is not None: + add_agent( + agent_name=agent_name, + provider_settings=settings, + commands=commands, + user=email, + ) + if training_urls != []: + for url in training_urls: + ApiClient.learn_url(agent_name=agent_name, url=url) + if github_repos != []: + for repo in github_repos: + ApiClient.learn_github_repo(agent_name=agent_name, github_repo=repo) + return {"status": "Success"}, 200 + + +def verify_api_key(authorization: str = Header(None)): + ENCRYPTION_SECRET = getenv("ENCRYPTION_SECRET") + if getenv("AUTH_PROVIDER") == "magicalauth": + ENCRYPTION_SECRET = f'{ENCRYPTION_SECRET}{datetime.now().strftime("%Y%m%d")}' + authorization = str(authorization).replace("Bearer ", "").replace("bearer ", "") + if ENCRYPTION_SECRET: + if authorization is None: + raise HTTPException( + status_code=401, detail="Authorization header is missing" + ) + if authorization == ENCRYPTION_SECRET: + return "ADMIN" + try: + if authorization == ENCRYPTION_SECRET: + return "ADMIN" + token = jwt.decode( + jwt=authorization, + key=ENCRYPTION_SECRET, + algorithms=["HS256"], + ) + db = get_session() + user = db.query(User).filter(User.id == token["sub"]).first() + db.close() + return user + except Exception as e: + raise HTTPException(status_code=401, detail="Invalid API Key") + else: + return authorization + + +def send_email( + email: str, + subject: str, + body: str, + attachment_content=None, + attachment_file_type=None, + attachment_file_name=None, +): + message = Mail( + from_email=getenv("SENDGRID_FROM_EMAIL"), + to_emails=email, + subject=subject, + html_content=body, + ) + if ( + attachment_content != None + and attachment_file_type != None + and attachment_file_name != None + ): + attachment = Attachment( + FileContent(attachment_content), + FileName(attachment_file_name), + FileType(attachment_file_type), + Disposition("attachment"), + ) + message.attachment = attachment + + try: + response = SendGridAPIClient(getenv("SENDGRID_API_KEY")).send(message) + except Exception as e: + print(e) + raise HTTPException(status_code=400, detail="Email could not be sent.") + if response.status_code != 202: + raise HTTPException(status_code=400, detail="Email could not be sent.") + return None + + +class MagicalAuth: + def __init__(self, token: str = None): + encryption_key = getenv("ENCRYPTION_SECRET") + self.link = getenv("MAGIC_LINK_URL") + self.encryption_key = f'{encryption_key}{datetime.now().strftime("%Y%m%d")}' + self.token = ( + str(token) + .replace("%2B", "+") + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%20", " ") + .replace("%3A", ":") + .replace("%3F", "?") + .replace("%26", "&") + .replace("%23", "#") + .replace("%3B", ";") + .replace("%40", "@") + .replace("%21", "!") + .replace("%24", "$") + .replace("%27", "'") + .replace("%28", "(") + .replace("%29", ")") + .replace("%2A", "*") + .replace("%2C", ",") + .replace("%3B", ";") + .replace("%5B", "[") + .replace("%5D", "]") + .replace("%7B", "{") + .replace("%7D", "}") + .replace("%7C", "|") + .replace("%5C", "\\") + .replace("%5E", "^") + .replace("%60", "`") + .replace("%7E", "~") + .replace("Bearer ", "") + .replace("bearer ", "") + if token + else None + ) + try: + # Decode jwt + decoded = jwt.decode( + jwt=token, key=self.encryption_key, algorithms=["HS256"] + ) + self.email = decoded["email"] + self.token = token + except: + self.email = None + self.token = None + + def user_exists(self, email: str = None): + self.email = email.lower() + session = get_session() + user = session.query(User).filter(User.email == self.email).first() + session.close() + if not user: + raise HTTPException(status_code=404, detail="User not found") + return True + + def add_failed_login(self, ip_address): + session = get_session() + user = session.query(User).filter(User.email == self.email).first() + if user is not None: + failed_login = FailedLogins(user_id=user.id, ip_address=ip_address) + session.add(failed_login) + session.commit() + session.close() + + def count_failed_logins(self): + session = get_session() + user = session.query(User).filter(User.email == self.email).first() + if user is None: + session.close() + return 0 + failed_logins = ( + session.query(FailedLogins) + .filter(FailedLogins.user_id == user.id) + .filter(FailedLogins.created_at >= datetime.now() - timedelta(hours=24)) + .count() + ) + session.close() + return failed_logins + + def send_magic_link(self, ip_address, login: Login, referrer=None): + self.email = login.email.lower() + session = get_session() + user = session.query(User).filter(User.email == self.email).first() + session.close() + if user is None: + raise HTTPException(status_code=404, detail="User not found") + if not pyotp.TOTP(user.mfa_token).verify(login.token): + self.add_failed_login(ip_address=ip_address) + raise HTTPException( + status_code=401, detail="Invalid MFA token. Please try again." + ) + self.token = jwt.encode( + { + "sub": str(user.id), + "email": self.email, + "admin": user.admin, + "exp": datetime.utcnow() + timedelta(hours=24), + }, + self.encryption_key, + algorithm="HS256", + ) + token = ( + self.token.replace("+", "%2B") + .replace("/", "%2F") + .replace("=", "%3D") + .replace(" ", "%20") + .replace(":", "%3A") + .replace("?", "%3F") + .replace("&", "%26") + .replace("#", "%23") + .replace(";", "%3B") + .replace("@", "%40") + .replace("!", "%21") + .replace("$", "%24") + .replace("'", "%27") + .replace("(", "%28") + .replace(")", "%29") + .replace("*", "%2A") + .replace(",", "%2C") + .replace(";", "%3B") + .replace("[", "%5B") + .replace("]", "%5D") + .replace("{", "%7B") + .replace("}", "%7D") + .replace("|", "%7C") + .replace("\\", "%5C") + .replace("^", "%5E") + .replace("`", "%60") + .replace("~", "%7E") + ) + if referrer is not None: + self.link = referrer + magic_link = f"{self.link}?token={token}" + if ( + getenv("SENDGRID_API_KEY") != "" + and str(getenv("SENDGRID_API_KEY")).lower() != "none" + and getenv("SENDGRID_FROM_EMAIL") != "" + and str(getenv("SENDGRID_FROM_EMAIL")).lower() != "none" + ): + send_email( + email=self.email, + subject="Magic Link", + body=f"Click here to log in", + ) + else: + return magic_link + # Upon clicking the link, the front end will call the login method and save the email and encrypted_id in the session + return f"A login link has been sent to {self.email}, please check your email and click the link to log in. The link will expire in 24 hours." + + def login(self, ip_address): + """ " + Login method to verify the token and return the user object + + :param ip_address: IP address of the user + :return: User object + """ + session = get_session() + failures = self.count_failed_logins() + if failures >= 50: + raise HTTPException( + status_code=429, + detail="Too many failed login attempts today. Please try again tomorrow.", + ) + try: + user_info = jwt.decode( + jwt=self.token, key=self.encryption_key, algorithms=["HS256"] + ) + except: + self.add_failed_login(ip_address=ip_address) + raise HTTPException( + status_code=401, + detail="Invalid login token. Please log out and try again.", + ) + user_id = user_info["sub"] + user = session.query(User).filter(User.id == user_id).first() + session.close() + if user is None: + raise HTTPException(status_code=404, detail="User not found") + if str(user.id) == str(user_id): + return user + self.add_failed_login(ip_address=ip_address) + raise HTTPException( + status_code=401, + detail="Invalid login token. Please log out and try again.", + ) + + def register( + self, + new_user: Register, + ): + new_user.email = new_user.email.lower() + self.email = new_user.email + allowed_domains = getenv("ALLOWED_DOMAINS") + if allowed_domains is None or allowed_domains == "": + allowed_domains = "*" + if allowed_domains != "*": + if "," in allowed_domains: + allowed_domains = allowed_domains.split(",") + else: + allowed_domains = [allowed_domains] + domain = self.email.split("@")[1] + if domain not in allowed_domains: + raise HTTPException( + status_code=403, + detail="Registration is not allowed for this domain.", + ) + session = get_session() + user = session.query(User).filter(User.email == self.email).first() + if user is not None: + session.close() + raise HTTPException( + status_code=409, detail="User already exists with this email." + ) + mfa_token = pyotp.random_base32() + user = User( + mfa_token=mfa_token, + **new_user.model_dump(), + ) + session.add(user) + session.commit() + session.close() + # Send registration webhook out to third party application such as AGiXT to create a user there. + registration_webhook = getenv("REGISTRATION_WEBHOOK") + if registration_webhook: + try: + requests.post( + registration_webhook, + json={"email": self.email}, + headers={"Authorization": getenv("ENCRYPTION_SECRET")}, + ) + except Exception as e: + pass + # Return mfa_token for QR code generation + return mfa_token + + def update_user(self, **kwargs): + user = verify_api_key(self.token) + if user is None: + raise HTTPException(status_code=404, detail="User not found") + session = get_session() + user = session.query(User).filter(User.id == user.id).first() + allowed_keys = list(UserInfo.__annotations__.keys()) + for key, value in kwargs.items(): + if key in allowed_keys: + setattr(user, key, value) + session.commit() + session.close() + return "User updated successfully" + + def delete_user(self): + user = verify_api_key(self.token) + if user is None: + raise HTTPException(status_code=404, detail="User not found") + session = get_session() + user = session.query(User).filter(User.id == user.id).first() + user.is_active = False + session.commit() + session.close() + return "User deleted successfully" diff --git a/agixt/Memories.py b/agixt/Memories.py index 3752e5781aaa..6c2040415a97 100644 --- a/agixt/Memories.py +++ b/agixt/Memories.py @@ -14,7 +14,7 @@ from datetime import datetime from collections import Counter from typing import List -from Defaults import getenv, DEFAULT_USER +from Globals import getenv, DEFAULT_USER logging.basicConfig( level=getenv("LOG_LEVEL"), @@ -154,9 +154,9 @@ def __init__( global DEFAULT_USER self.agent_name = agent_name if not DEFAULT_USER: - DEFAULT_USER = "USER" + DEFAULT_USER = "user" if not user: - user = "USER" + user = "user" if user != DEFAULT_USER: self.collection_name = f"{snake(user)}_{snake(agent_name)}" else: diff --git a/agixt/Models.py b/agixt/Models.py index 26a5e7197ea0..2fb75888dfa3 100644 --- a/agixt/Models.py +++ b/agixt/Models.py @@ -1,6 +1,6 @@ from pydantic import BaseModel from typing import Optional, Dict, List, Any, Union -from Defaults import DEFAULT_USER +from Globals import DEFAULT_USER class AgentName(BaseModel): @@ -266,7 +266,7 @@ class CommandExecution(BaseModel): conversation_name: str = "AGiXT Terminal Command Execution" -class User(BaseModel): +class WebhookUser(BaseModel): email: str agent_name: Optional[str] = "" settings: Optional[Dict[str, Any]] = {} @@ -275,5 +275,26 @@ class User(BaseModel): github_repos: Optional[List[str]] = [] -class User_fb(BaseModel): - email: str = DEFAULT_USER +# Auth user models +class Login(BaseModel): + email: str + token: str + + +class Register(BaseModel): + email: str + first_name: str + last_name: str + company_name: str + job_title: str + + +class UserInfo(BaseModel): + first_name: str + last_name: str + company_name: str + job_title: str + + +class Detail(BaseModel): + detail: str diff --git a/agixt/db/Prompts.py b/agixt/Prompts.py similarity index 98% rename from agixt/db/Prompts.py rename to agixt/Prompts.py index ee8ac50467d7..544e3fa422c5 100644 --- a/agixt/db/Prompts.py +++ b/agixt/Prompts.py @@ -1,5 +1,5 @@ -from DBConnection import Prompt, PromptCategory, Argument, User, get_session -from Defaults import DEFAULT_USER +from DB import Prompt, PromptCategory, Argument, User, get_session +from Globals import DEFAULT_USER class Prompts: diff --git a/agixt/Providers.py b/agixt/Providers.py index da151f492010..9c9234294d14 100644 --- a/agixt/Providers.py +++ b/agixt/Providers.py @@ -5,7 +5,7 @@ import os import inspect import logging -from Defaults import getenv +from Globals import getenv logging.basicConfig( level=getenv("LOG_LEVEL"), diff --git a/agixt/db/imports.py b/agixt/SeedImports.py similarity index 98% rename from agixt/db/imports.py rename to agixt/SeedImports.py index 59cf487594f9..ff854e0a4430 100644 --- a/agixt/db/imports.py +++ b/agixt/SeedImports.py @@ -2,7 +2,7 @@ import json import yaml import logging -from DBConnection import ( +from DB import ( get_session, Provider, ProviderSetting, @@ -17,8 +17,8 @@ User, ) from Providers import get_providers, get_provider_options -from db.Agent import add_agent -from Defaults import getenv, DEFAULT_USER +from Agent import add_agent +from Globals import getenv, DEFAULT_USER logging.basicConfig( level=getenv("LOG_LEVEL"), @@ -186,7 +186,7 @@ def import_chains(user=DEFAULT_USER): if not chain_files: logging.info(f"No JSON files found in chains directory.") return - from db.Chain import Chain + from Chain import Chain chain_importer = Chain(user=user) for file in chain_files: @@ -406,7 +406,7 @@ def import_all_data(): if user_count == 0: # Create the default user logging.info("Creating default admin user...") - user = User(email=DEFAULT_USER, role="admin") + user = User(email=DEFAULT_USER, admin=True) session.add(user) session.commit() logging.info("Default user created.") diff --git a/agixt/Tunnel.py b/agixt/Tunnel.py index 9f9ac779c2d4..2d51484a38ea 100644 --- a/agixt/Tunnel.py +++ b/agixt/Tunnel.py @@ -1,5 +1,5 @@ import logging -from Defaults import getenv +from Globals import getenv logging.basicConfig( level=getenv("LOG_LEVEL"), diff --git a/agixt/Websearch.py b/agixt/Websearch.py index f5b074b7cb1b..74d7b4cace47 100644 --- a/agixt/Websearch.py +++ b/agixt/Websearch.py @@ -11,7 +11,7 @@ from bs4 import BeautifulSoup from typing import List from ApiClient import Agent, Conversations -from Defaults import getenv, get_tokens +from Globals import getenv, get_tokens from readers.youtube import YoutubeReader from readers.github import GithubReader @@ -517,10 +517,7 @@ async def websearch_agent( if len(search_string) > 0: links = [] logging.info(f"Searching for: {search_string}") - if ( - self.searx_instance_url != "" - and self.searx_instance_url is not None - ): + if self.searx_instance_url != "": links = await self.search(query=search_string) else: links = await self.ddg_search(query=search_string) diff --git a/agixt/app.py b/agixt/app.py index 0647fbdb1df2..e1644c2d6e7f 100644 --- a/agixt/app.py +++ b/agixt/app.py @@ -12,7 +12,7 @@ from endpoints.Memory import app as memory_endpoints from endpoints.Prompt import app as prompt_endpoints from endpoints.Provider import app as provider_endpoints -from Defaults import getenv +from Globals import getenv os.environ["TOKENIZERS_PARALLELISM"] = "false" diff --git a/agixt/db/User.py b/agixt/db/User.py deleted file mode 100644 index 6688020f5c91..000000000000 --- a/agixt/db/User.py +++ /dev/null @@ -1,55 +0,0 @@ -from DBConnection import User, get_session -from db.Agent import add_agent -from Defaults import getenv -from agixtsdk import AGiXTSDK - - -def is_agixt_admin(email: str = "", api_key: str = ""): - if api_key == getenv("AGIXT_API_KEY"): - return True - session = get_session() - user = session.query(User).filter_by(email=email).first() - if not user: - return False - if user.role == "admin": - return True - return False - - -def create_user( - api_key: str, - email: str, - role: str = "user", - agent_name: str = "", - settings: dict = {}, - commands: dict = {}, - training_urls: list = [], - github_repos: list = [], - ApiClient: AGiXTSDK = AGiXTSDK(), -): - if not is_agixt_admin(email=email, api_key=api_key): - return {"error": "Access Denied"}, 403 - session = get_session() - email = email.lower() - user_exists = session.query(User).filter_by(email=email).first() - if user_exists: - session.close() - return {"error": "User already exists"}, 400 - user = User(email=email, role=role.lower()) - session.add(user) - session.commit() - session.close() - if agent_name != "" and agent_name is not None: - add_agent( - agent_name=agent_name, - provider_settings=settings, - commands=commands, - user=email, - ) - if training_urls != []: - for url in training_urls: - ApiClient.learn_url(agent_name=agent_name, url=url) - if github_repos != []: - for repo in github_repos: - ApiClient.learn_github_repo(agent_name=agent_name, github_repo=repo) - return {"status": "Success"}, 200 diff --git a/agixt/endpoints/Agent.py b/agixt/endpoints/Agent.py index ef7949c7c127..c63133a03772 100644 --- a/agixt/endpoints/Agent.py +++ b/agixt/endpoints/Agent.py @@ -2,7 +2,7 @@ from fastapi import APIRouter, HTTPException, Depends, Header from Interactions import Interactions from Websearch import Websearch -from Defaults import getenv +from Globals import getenv from ApiClient import ( Agent, add_agent, diff --git a/agixt/endpoints/Auth.py b/agixt/endpoints/Auth.py new file mode 100644 index 000000000000..44e5108fa340 --- /dev/null +++ b/agixt/endpoints/Auth.py @@ -0,0 +1,132 @@ +from fastapi import APIRouter, Request, Header, Depends, HTTPException +from Models import Detail, Login, UserInfo, Register +from MagicalAuth import MagicalAuth, verify_api_key, webhook_create_user +from ApiClient import get_api_client, is_admin +from Models import WebhookUser +from Globals import getenv +import pyotp + +app = APIRouter() + + +@app.post("/v1/user") +def register(register: Register): + mfa_token = MagicalAuth().register(new_user=register) + totp = pyotp.TOTP(mfa_token) + otp_uri = totp.provisioning_uri(name=register.email, issuer_name=getenv("APP_NAME")) + return {"otp_uri": otp_uri} + + +@app.get("/v1/user/exists", response_model=bool, summary="Check if user exists") +def get_user(email: str) -> bool: + try: + return MagicalAuth().user_exists(email=email) + except: + return False + + +@app.get( + "/v1/user", + dependencies=[Depends(verify_api_key)], + summary="Get user details", +) +def log_in( + request: Request, + authorization: str = Header(None), +): + user_data = MagicalAuth(token=authorization).login(ip_address=request.client.host) + return { + "email": user_data.email, + "first_name": user_data.first_name, + "last_name": user_data.last_name, + } + + +@app.post( + "/v1/login", + response_model=Detail, + summary="Login with email and OTP token", +) +async def send_magic_link(request: Request, login: Login): + auth = MagicalAuth() + data = await request.json() + referrer = None + if "referrer" in data: + referrer = data["referrer"] + magic_link = auth.send_magic_link( + ip_address=request.client.host, login=login, referrer=referrer + ) + return Detail(detail=magic_link) + + +@app.put( + "/v1/user", + dependencies=[Depends(verify_api_key)], + response_model=Detail, + summary="Update user details", +) +def update_user(update: UserInfo, request: Request, authorization: str = Header(None)): + user = MagicalAuth(token=authorization).update_user( + ip_address=request.client.host, **update.model_dump() + ) + return Detail(detail=user) + + +# Delete user +@app.delete( + "/v1/user", + dependencies=[Depends(verify_api_key)], + response_model=Detail, + summary="Delete user", +) +def delete_user( + user=Depends(verify_api_key), + authorization: str = Header(None), +): + MagicalAuth(token=authorization).delete_user() + return Detail(detail="User deleted successfully.") + + +# Webhook user creations from other applications +@app.post("/api/user", tags=["User"]) +async def createuser( + account: WebhookUser, + authorization: str = Header(None), + user=Depends(verify_api_key), +): + if is_admin(email=user, api_key=authorization) != True: + raise HTTPException(status_code=403, detail="Access Denied") + ApiClient = get_api_client(authorization=authorization) + return webhook_create_user( + api_key=authorization, + email=account.email, + role="user", + agent_name=account.agent_name, + settings=account.settings, + commands=account.commands, + training_urls=account.training_urls, + github_repos=account.github_repos, + ApiClient=ApiClient, + ) + + +@app.post("/api/admin", tags=["User"]) +async def createadmin( + account: WebhookUser, + authorization: str = Header(None), + user=Depends(verify_api_key), +): + if is_admin(email=user, api_key=authorization) != True: + raise HTTPException(status_code=403, detail="Access Denied") + ApiClient = get_api_client(authorization=authorization) + return webhook_create_user( + api_key=authorization, + email=account.email, + role="admin", + agent_name=account.agent_name, + settings=account.settings, + commands=account.commands, + training_urls=account.training_urls, + github_repos=account.github_repos, + ApiClient=ApiClient, + ) diff --git a/agixt/endpoints/Completions.py b/agixt/endpoints/Completions.py index c6717e3d9b82..1362f282dd78 100644 --- a/agixt/endpoints/Completions.py +++ b/agixt/endpoints/Completions.py @@ -2,7 +2,7 @@ import base64 import uuid from fastapi import APIRouter, Depends, Header -from Defaults import get_tokens +from Globals import get_tokens from ApiClient import Agent, verify_api_key, get_api_client from providers.default import DefaultProvider from fastapi import UploadFile, File, Form diff --git a/agixt/endpoints/Provider.py b/agixt/endpoints/Provider.py index 3f0c458e11b2..256459f7deff 100644 --- a/agixt/endpoints/Provider.py +++ b/agixt/endpoints/Provider.py @@ -6,7 +6,7 @@ get_providers_with_settings, get_providers_by_service, ) -from ApiClient import verify_api_key, DB_CONNECTED, get_api_client, is_admin +from ApiClient import verify_api_key, get_api_client, is_admin from typing import Any app = APIRouter() @@ -67,46 +67,3 @@ async def get_embed_providers(user=Depends(verify_api_key)): ) async def get_embedder_info(user=Depends(verify_api_key)) -> Dict[str, Any]: return {"embedders": get_providers_by_service(service="embeddings")} - - -if DB_CONNECTED: - from db.User import create_user - from Models import User - - @app.post("/api/user", tags=["User"]) - async def createuser( - account: User, authorization: str = Header(None), user=Depends(verify_api_key) - ): - if is_admin(email=user, api_key=authorization) != True: - raise HTTPException(status_code=403, detail="Access Denied") - ApiClient = get_api_client(authorization=authorization) - return create_user( - api_key=authorization, - email=account.email, - role="user", - agent_name=account.agent_name, - settings=account.settings, - commands=account.commands, - training_urls=account.training_urls, - github_repos=account.github_repos, - ApiClient=ApiClient, - ) - - @app.post("/api/admin", tags=["User"]) - async def createadmin( - account: User, authorization: str = Header(None), user=Depends(verify_api_key) - ): - if is_admin(email=user, api_key=authorization) != True: - raise HTTPException(status_code=403, detail="Access Denied") - ApiClient = get_api_client(authorization=authorization) - return create_user( - api_key=authorization, - email=account.email, - role="admin", - agent_name=account.agent_name, - settings=account.settings, - commands=account.commands, - training_urls=account.training_urls, - github_repos=account.github_repos, - ApiClient=ApiClient, - ) diff --git a/agixt/extensions/agixt_actions.py b/agixt/extensions/agixt_actions.py index 82455c7250ed..8bd4e7b26d6b 100644 --- a/agixt/extensions/agixt_actions.py +++ b/agixt/extensions/agixt_actions.py @@ -163,7 +163,7 @@ def __init__(self, **kwargs): "Strip CSV Data from Code Block": self.get_csv_from_response, "Convert a string to a Pydantic model": self.convert_string_to_pydantic_model, } - user = kwargs["user"] if "user" in kwargs else "USER" + user = kwargs["user"] if "user" in kwargs else "user" for chain in Chain(user=user).get_chains(): self.commands[chain] = self.run_chain self.command_name = ( diff --git a/agixt/fb/Agent.py b/agixt/fb/Agent.py deleted file mode 100644 index 97865409b2fd..000000000000 --- a/agixt/fb/Agent.py +++ /dev/null @@ -1,389 +0,0 @@ -import os -import json -import glob -import shutil -import importlib -import numpy as np -from inspect import signature, Parameter -from Providers import Providers -from Extensions import Extensions -from Defaults import DEFAULT_SETTINGS -from datetime import datetime, timezone, timedelta - - -def get_agent_file_paths(agent_name, user="USER"): - base_path = os.path.join(os.getcwd(), "agents") - folder_path = os.path.normpath(os.path.join(base_path, agent_name)) - config_path = os.path.normpath(os.path.join(folder_path, "config.json")) - if not config_path.startswith(base_path) or not folder_path.startswith(base_path): - raise ValueError("Invalid path, agent name must not contain slashes.") - if not os.path.exists(folder_path): - os.mkdir(folder_path) - return config_path, folder_path - - -def add_agent(agent_name, provider_settings=None, commands={}, user="USER"): - if not agent_name: - return "Agent name cannot be empty." - provider_settings = ( - DEFAULT_SETTINGS - if not provider_settings or provider_settings == {} - else provider_settings - ) - config_path, folder_path = get_agent_file_paths(agent_name=agent_name) - if provider_settings is None or provider_settings == "" or provider_settings == {}: - provider_settings = DEFAULT_SETTINGS - settings = json.dumps( - { - "commands": commands, - "settings": provider_settings, - } - ) - # Write the settings to the agent config file - with open(config_path, "w") as f: - f.write(settings) - return {"message": f"Agent {agent_name} created."} - - -def delete_agent(agent_name, user="USER"): - config_path, folder_path = get_agent_file_paths(agent_name=agent_name) - try: - if os.path.exists(folder_path): - shutil.rmtree(folder_path) - return {"message": f"Agent {agent_name} deleted."}, 200 - except: - return {"message": f"Agent {agent_name} could not be deleted."}, 400 - - -def rename_agent(agent_name, new_name, user="USER"): - config_path, folder_path = get_agent_file_paths(agent_name=agent_name) - base_path = os.path.join(os.getcwd(), "agents") - new_agent_folder = os.path.normpath(os.path.join(base_path, new_name)) - if not new_agent_folder.startswith(base_path): - raise ValueError("Invalid path, agent name must not contain slashes.") - - if os.path.exists(folder_path): - # Check if the new name is already taken - if os.path.exists(new_agent_folder): - # Add a number to the end of the new name - i = 1 - while os.path.exists(new_agent_folder): - i += 1 - new_name = f"{new_name}_{i}" - new_agent_folder = os.path.normpath(os.path.join(base_path, new_name)) - if not new_agent_folder.startswith(base_path): - raise ValueError("Invalid path, agent name must not contain slashes.") - os.rename(folder_path, new_agent_folder) - return {"message": f"Agent {agent_name} renamed to {new_name}."}, 200 - - -def get_agents(user="USER"): - agents_dir = "agents" - if not os.path.exists(agents_dir): - os.makedirs(agents_dir) - agents = [ - dir_name - for dir_name in os.listdir(agents_dir) - if os.path.isdir(os.path.join(agents_dir, dir_name)) - ] - output = [] - if agents: - for agent in agents: - agent_config = Agent(agent_name=agent, user=user).get_agent_config() - if "settings" not in agent_config: - agent_config["settings"] = {} - if "training" in agent_config["settings"]: - if str(agent_config["settings"]["training"]).lower() == "true": - output.append({"name": agent, "status": True}) - else: - output.append({"name": agent, "status": False}) - else: - output.append({"name": agent, "status": False}) - return output - - -class Agent: - def __init__(self, agent_name=None, user="USER", ApiClient=None): - self.USER = user - self.agent_name = agent_name if agent_name is not None else "AGiXT" - self.config_path, self.folder_path = get_agent_file_paths( - agent_name=self.agent_name - ) - self.AGENT_CONFIG = self.get_agent_config() - if "settings" not in self.AGENT_CONFIG: - self.AGENT_CONFIG["settings"] = {} - self.PROVIDER_SETTINGS = self.AGENT_CONFIG["settings"] - for setting in DEFAULT_SETTINGS: - if setting not in self.PROVIDER_SETTINGS: - self.PROVIDER_SETTINGS[setting] = DEFAULT_SETTINGS[setting] - self.AI_PROVIDER = self.PROVIDER_SETTINGS["provider"] - self.PROVIDER = Providers( - name=self.AI_PROVIDER, ApiClient=ApiClient, **self.PROVIDER_SETTINGS - ) - self._load_agent_config_keys(["AI_MODEL", "AI_TEMPERATURE", "MAX_TOKENS"]) - tts_provider = ( - self.AGENT_CONFIG["settings"]["tts_provider"] - if "tts_provider" in self.AGENT_CONFIG["settings"] - else "None" - ) - if tts_provider != "None" and tts_provider != None and tts_provider != "": - self.TTS_PROVIDER = Providers( - name=tts_provider, ApiClient=ApiClient, **self.PROVIDER_SETTINGS - ) - else: - self.TTS_PROVIDER = None - transcription_provider = ( - self.AGENT_CONFIG["settings"]["transcription_provider"] - if "transcription_provider" in self.AGENT_CONFIG["settings"] - else "default" - ) - self.TRANSCRIPTION_PROVIDER = Providers( - name=transcription_provider, ApiClient=ApiClient, **self.PROVIDER_SETTINGS - ) - translation_provider = ( - self.AGENT_CONFIG["settings"]["translation_provider"] - if "translation_provider" in self.AGENT_CONFIG["settings"] - else "default" - ) - self.TRANSLATION_PROVIDER = Providers( - name=translation_provider, ApiClient=ApiClient, **self.PROVIDER_SETTINGS - ) - image_provider = ( - self.AGENT_CONFIG["settings"]["image_provider"] - if "image_provider" in self.AGENT_CONFIG["settings"] - else "default" - ) - self.IMAGE_PROVIDER = Providers( - name=image_provider, ApiClient=ApiClient, **self.PROVIDER_SETTINGS - ) - embeddings_provider = ( - self.AGENT_CONFIG["settings"]["embeddings_provider"] - if "embeddings_provider" in self.AGENT_CONFIG["settings"] - else "default" - ) - self.EMBEDDINGS_PROVIDER = Providers( - name=embeddings_provider, ApiClient=ApiClient, **self.PROVIDER_SETTINGS - ) - if hasattr(self.EMBEDDINGS_PROVIDER, "chunk_size"): - self.chunk_size = self.EMBEDDINGS_PROVIDER.chunk_size - else: - self.chunk_size = 256 - self.embedder = self.EMBEDDINGS_PROVIDER.embedder - if "AI_MODEL" in self.PROVIDER_SETTINGS: - self.AI_MODEL = self.PROVIDER_SETTINGS["AI_MODEL"] - if self.AI_MODEL == "": - self.AI_MODEL = "default" - else: - self.AI_MODEL = "openassistant" - if "embedder" in self.PROVIDER_SETTINGS: - self.EMBEDDER = self.PROVIDER_SETTINGS["embedder"] - else: - if self.AI_PROVIDER == "openai": - self.EMBEDDER = "openai" - else: - self.EMBEDDER = "default" - if "MAX_TOKENS" in self.PROVIDER_SETTINGS: - self.MAX_TOKENS = self.PROVIDER_SETTINGS["MAX_TOKENS"] - else: - self.MAX_TOKENS = 4000 - self.commands = self.load_commands() - self.available_commands = Extensions( - agent_name=self.agent_name, - agent_config=self.AGENT_CONFIG, - ApiClient=ApiClient, - user=user, - ).get_available_commands() - self.clean_agent_config_commands() - - async def inference(self, prompt: str, tokens: int = 0, images: list = []): - if not prompt: - return "" - answer = await self.PROVIDER.inference( - prompt=prompt, tokens=tokens, images=images - ) - return answer.replace("\_", "_") - - def embeddings(self, input) -> np.ndarray: - return self.embedder(input=input) - - async def transcribe_audio(self, audio_path: str): - return await self.TRANSCRIPTION_PROVIDER.transcribe_audio(audio_path=audio_path) - - async def translate_audio(self, audio_path: str): - return await self.TRANSLATION_PROVIDER.translate_audio(audio_path=audio_path) - - async def generate_image(self, prompt: str): - return await self.IMAGE_PROVIDER.generate_image(prompt=prompt) - - async def text_to_speech(self, text: str): - if self.TTS_PROVIDER is not None: - return await self.TTS_PROVIDER.text_to_speech(text=text) - - def _load_agent_config_keys(self, keys): - for key in keys: - if key in self.AGENT_CONFIG: - setattr(self, key, self.AGENT_CONFIG[key]) - - def clean_agent_config_commands(self): - for command in self.commands: - friendly_name = command[0] - if friendly_name not in self.AGENT_CONFIG["commands"]: - self.AGENT_CONFIG["commands"][friendly_name] = False - for command in list(self.AGENT_CONFIG["commands"]): - if command not in [cmd[0] for cmd in self.commands]: - del self.AGENT_CONFIG["commands"][command] - with open(self.config_path, "w") as f: - json.dump(self.AGENT_CONFIG, f) - - def get_commands_string(self): - if len(self.available_commands) == 0: - return "" - working_dir = ( - self.AGENT_CONFIG["WORKING_DIRECTORY"] - if "WORKING_DIRECTORY" in self.AGENT_CONFIG - else os.path.join(os.getcwd(), "WORKSPACE") - ) - verbose_commands = f"### Available Commands\n**The assistant has commands available to use if they would be useful to provide a better user experience.**\nIf a file needs saved, the assistant's working directory is {working_dir}, use that as the file path.\n\n" - verbose_commands += "**See command execution examples of commands that the assistant has access to below:**\n" - for command in self.available_commands: - command_args = json.dumps(command["args"]) - command_args = command_args.replace( - '""', - '"The assistant will fill in the value based on relevance to the conversation."', - ) - verbose_commands += ( - f"\n- #execute('{command['friendly_name']}', {command_args})" - ) - verbose_commands += "\n\n**To execute an available command, the assistant can reference the examples and the command execution response will be replaced with the commands output for the user in the assistants response. The assistant can execute a command anywhere in the response and the commands will be executed in the order they are used.**\n**THE ASSISTANT CANNOT EXECUTE A COMMAND THAT IS NOT ON THE LIST OF EXAMPLES!**\n\n" - return verbose_commands - - def get_provider(self): - config_file = self.get_agent_config() - if "provider" in config_file: - return config_file["provider"] - else: - return "openai" - - def get_command_params(self, func): - params = {} - sig = signature(func) - for name, param in sig.parameters.items(): - if param.default == Parameter.empty: - params[name] = None - else: - params[name] = param.default - return params - - def load_commands(self): - commands = [] - command_files = glob.glob("extensions/*.py") - for command_file in command_files: - module_name = os.path.splitext(os.path.basename(command_file))[0] - module = importlib.import_module(f"extensions.{module_name}") - command_class = getattr(module, module_name.lower())() - if hasattr(command_class, "commands"): - for command_name, command_function in command_class.commands.items(): - params = self.get_command_params(command_function) - commands.append((command_name, command_function.__name__, params)) - return commands - - def get_agent_config(self): - while True: - if os.path.exists(self.config_path): - try: - with open(self.config_path, "r") as f: - file_content = f.read().strip() - if file_content: - return json.loads(file_content) - except: - None - add_agent(agent_name=self.agent_name) - return self.get_agent_config() - - def update_agent_config(self, new_config, config_key): - if os.path.exists(self.config_path): - with open(self.config_path, "r") as f: - current_config = json.load(f) - - # Ensure the config_key is present in the current configuration - if config_key not in current_config: - current_config[config_key] = {} - - # Update the specified key with new_config while preserving other keys and values - for key, value in new_config.items(): - current_config[config_key][key] = value - - # Save the updated configuration back to the file - with open(self.config_path, "w") as f: - json.dump(current_config, f) - return f"Agent {self.agent_name} configuration updated." - else: - return f"Agent {self.agent_name} configuration not found." - - def get_browsed_links(self): - """ - Get the list of URLs that have been browsed by the agent. - - Returns: - list: The list of URLs that have been browsed by the agent. - """ - # They will be stored in the agent's config file as: - # "browsed_links": [{"url": "https://example.com", "timestamp": "2021-01-01T00:00:00Z"}] - return self.AGENT_CONFIG.get("browsed_links", []) - - def browsed_recently(self, url) -> bool: - """ - Check if the given URL has been browsed by the agent within the last 24 hours. - - Args: - url (str): The URL to check. - - Returns: - bool: True if the URL has been browsed within the last 24 hours, False otherwise. - """ - browsed_links = self.get_browsed_links() - if not browsed_links: - return False - for link in browsed_links: - if link["url"] == url: - if link["timestamp"] >= datetime.now(timezone.utc) - timedelta(days=1): - return True - return False - - def add_browsed_link(self, url): - """ - Add a URL to the list of browsed links for the agent. - - Args: - url (str): The URL to add. - - Returns: - str: The response message. - """ - browsed_links = self.get_browsed_links() - # check if the URL has already been browsed - if self.browsed_recently(url): - return "URL has already been browsed recently." - browsed_links.append( - {"url": url, "timestamp": datetime.now(timezone.utc).isoformat()} - ) - self.update_agent_config(browsed_links, "browsed_links") - return "URL added to browsed links." - - def delete_browsed_link(self, url): - """ - Delete a URL from the list of browsed links for the agent. - - Args: - url (str): The URL to delete. - - Returns: - str: The response message. - """ - browsed_links = self.get_browsed_links() - for link in browsed_links: - if link["url"] == url: - browsed_links.remove(link) - self.update_agent_config(browsed_links, "browsed_links") - return "URL deleted from browsed links." - return "URL not found in browsed links." diff --git a/agixt/fb/Chain.py b/agixt/fb/Chain.py deleted file mode 100644 index b465bf6614da..000000000000 --- a/agixt/fb/Chain.py +++ /dev/null @@ -1,284 +0,0 @@ -import os -import json -import logging -from Defaults import getenv - -logging.basicConfig( - level=getenv("LOG_LEVEL"), - format=getenv("LOG_FORMAT"), -) - - -def create_command_suggestion_chain( - agent_name, command_name, command_args, user="USER" -): - chain = Chain() - chains = chain.get_chains() - chain_name = f"{agent_name} Command Suggestions" - if chain_name in chains: - step = int(chain.get_chain(chain_name=chain_name)["steps"][-1]["step"]) + 1 - else: - chain.add_chain(chain_name=chain_name) - step = 1 - chain.add_chain_step( - chain_name=chain_name, - agent_name=agent_name, - step_number=step, - prompt_type="Command", - prompt={ - "command_name": command_name, - **command_args, - }, - ) - return f"The command has been added to a chain called '{agent_name} Command Suggestions' for you to review and execute manually." - - -def get_chain_file_path(chain_name, user="USER"): - base_path = os.path.join(os.getcwd(), "chains") - folder_path = os.path.normpath(os.path.join(base_path, chain_name)) - file_path = os.path.normpath(os.path.join(base_path, f"{chain_name}.json")) - if not file_path.startswith(base_path) or not folder_path.startswith(base_path): - raise ValueError("Invalid path, chain name must not contain slashes.") - if not os.path.exists(folder_path): - os.mkdir(folder_path) - return file_path - - -def get_chain_responses_file_path(chain_name, user="USER"): - base_path = os.path.join(os.getcwd(), "chains") - file_path = os.path.normpath(os.path.join(base_path, chain_name, "responses.json")) - if not file_path.startswith(base_path): - raise ValueError("Invalid path, chain name must not contain slashes.") - return file_path - - -class Chain: - def __init__(self, user="USER"): - self.user = user - - def import_chain(self, chain_name: str, steps: dict): - file_path = get_chain_file_path(chain_name=chain_name) - steps = steps["steps"] if "steps" in steps else steps - with open(file_path, "w") as f: - json.dump({"chain_name": chain_name, "steps": steps}, f) - return f"Chain '{chain_name}' imported." - - def get_chain(self, chain_name): - try: - file_path = get_chain_file_path(chain_name=chain_name) - with open(file_path, "r") as f: - chain_data = json.load(f) - return chain_data - except: - return {} - - def get_chains(self): - chains = [ - f.replace(".json", "") for f in os.listdir("chains") if f.endswith(".json") - ] - return chains - - def add_chain(self, chain_name): - file_path = get_chain_file_path(chain_name=chain_name) - chain_data = {"chain_name": chain_name, "steps": []} - with open(file_path, "w") as f: - json.dump(chain_data, f) - - def rename_chain(self, chain_name, new_name): - file_path = get_chain_file_path(chain_name=chain_name) - new_file_path = get_chain_file_path(chain_name=new_name) - os.rename( - os.path.join(file_path), - os.path.join(new_file_path), - ) - chain_data = self.get_chain(chain_name=new_name) - chain_data["chain_name"] = new_name - with open(new_file_path, "w") as f: - json.dump(chain_data, f) - - def add_chain_step(self, chain_name, step_number, agent_name, prompt_type, prompt): - file_path = get_chain_file_path(chain_name=chain_name) - chain_data = self.get_chain(chain_name=chain_name) - chain_data["steps"].append( - { - "step": step_number, - "agent_name": agent_name, - "prompt_type": prompt_type, - "prompt": prompt, - } - ) - with open(file_path, "w") as f: - json.dump(chain_data, f) - - def update_step(self, chain_name, step_number, agent_name, prompt_type, prompt): - file_path = get_chain_file_path(chain_name=chain_name) - chain_data = self.get_chain(chain_name=chain_name) - for step in chain_data["steps"]: - if step["step"] == step_number: - step["agent_name"] = agent_name - step["prompt_type"] = prompt_type - step["prompt"] = prompt - break - with open(file_path, "w") as f: - json.dump(chain_data, f) - - def delete_step(self, chain_name, step_number): - file_path = get_chain_file_path(chain_name=chain_name) - chain_data = self.get_chain(chain_name=chain_name) - chain_data["steps"] = [ - step for step in chain_data["steps"] if step["step"] != step_number - ] - with open(file_path, "w") as f: - json.dump(chain_data, f) - - def delete_chain(self, chain_name): - file_path = get_chain_file_path(chain_name=chain_name) - os.remove(file_path) - - def get_step(self, chain_name, step_number): - chain_data = self.get_chain(chain_name=chain_name) - for step in chain_data["steps"]: - if step["step"] == step_number: - return step - return None - - def get_steps(self, chain_name): - chain_data = self.get_chain(chain_name=chain_name) - return chain_data["steps"] - - def move_step(self, chain_name, current_step_number, new_step_number): - file_path = get_chain_file_path(chain_name=chain_name) - chain_data = self.get_chain(chain_name=chain_name) - if not 1 <= new_step_number <= len( - chain_data["steps"] - ) or current_step_number not in [step["step"] for step in chain_data["steps"]]: - logging.info(f"Error: Invalid step numbers.") - return - moved_step = None - for step in chain_data["steps"]: - if step["step"] == current_step_number: - moved_step = step - chain_data["steps"].remove(step) - break - for step in chain_data["steps"]: - if new_step_number < current_step_number: - if new_step_number <= step["step"] < current_step_number: - step["step"] += 1 - else: - if current_step_number < step["step"] <= new_step_number: - step["step"] -= 1 - moved_step["step"] = new_step_number - chain_data["steps"].append(moved_step) - chain_data["steps"] = sorted(chain_data["steps"], key=lambda x: x["step"]) - with open(file_path, "w") as f: - json.dump(chain_data, f) - - def get_step_response(self, chain_name, step_number="all"): - file_path = get_chain_responses_file_path(chain_name=chain_name) - try: - with open(file_path, "r") as f: - responses = json.load(f) - if step_number == "all": - return responses - else: - data = responses.get(str(step_number)) - if isinstance(data, dict) and "response" in data: - data = data["response"] - logging.info(f"Step {step_number} response: {data}") - return data - except: - return "" - - def get_chain_responses(self, chain_name): - file_path = get_chain_responses_file_path(chain_name=chain_name) - try: - with open(file_path, "r") as f: - responses = json.load(f) - return responses - except: - return {} - - def get_step_content(self, chain_name, prompt_content, user_input, agent_name): - if isinstance(prompt_content, dict): - new_prompt_content = {} - for arg, value in prompt_content.items(): - if isinstance(value, str): - if "{user_input}" in value: - value = value.replace("{user_input}", user_input) - if "{agent_name}" in value: - value = value.replace("{agent_name}", agent_name) - if "{STEP" in value: - step_count = value.count("{STEP") - for i in range(step_count): - new_step_number = int(value.split("{STEP")[1].split("}")[0]) - step_response = self.get_step_response( - chain_name=chain_name, step_number=new_step_number - ) - if step_response: - resp = ( - step_response[0] - if isinstance(step_response, list) - else step_response - ) - value = value.replace( - f"{{STEP{new_step_number}}}", f"{resp}" - ) - new_prompt_content[arg] = value - return new_prompt_content - elif isinstance(prompt_content, str): - new_prompt_content = prompt_content - if "{user_input}" in prompt_content: - new_prompt_content = new_prompt_content.replace( - "{user_input}", user_input - ) - if "{agent_name}" in new_prompt_content: - new_prompt_content = new_prompt_content.replace( - "{agent_name}", agent_name - ) - if "{STEP" in prompt_content: - step_count = prompt_content.count("{STEP") - for i in range(step_count): - new_step_number = int( - prompt_content.split("{STEP")[1].split("}")[0] - ) - step_response = self.get_step_response( - chain_name=chain_name, step_number=new_step_number - ) - if step_response: - resp = ( - step_response[0] - if isinstance(step_response, list) - else step_response - ) - new_prompt_content = new_prompt_content.replace( - f"{{STEP{new_step_number}}}", f"{resp}" - ) - return new_prompt_content - else: - return prompt_content - - async def update_step_response(self, chain_name, step_number, response): - file_path = get_chain_responses_file_path(chain_name=chain_name) - try: - with open(file_path, "r") as f: - responses = json.load(f) - except: - responses = {} - - if str(step_number) not in responses: - responses[str(step_number)] = response - else: - if isinstance(responses[str(step_number)], dict) and isinstance( - response, dict - ): - responses[str(step_number)].update(response) - elif isinstance(responses[str(step_number)], list): - if isinstance(response, list): - responses[str(step_number)].extend(response) - else: - responses[str(step_number)].append(response) - else: - responses[str(step_number)] = response - - with open(file_path, "w") as f: - json.dump(responses, f) diff --git a/agixt/fb/Conversations.py b/agixt/fb/Conversations.py deleted file mode 100644 index c129143c1184..000000000000 --- a/agixt/fb/Conversations.py +++ /dev/null @@ -1,103 +0,0 @@ -from datetime import datetime -import yaml -import os -import logging -from Defaults import getenv, DEFAULT_USER - -logging.basicConfig( - level=getenv("LOG_LEVEL"), - format=getenv("LOG_FORMAT"), -) - - -class Conversations: - def __init__(self, conversation_name=None, user=DEFAULT_USER): - self.conversation_name = conversation_name - self.user = user - - def export_conversation(self): - if not self.conversation_name: - self.conversation_name = f"{str(datetime.now())} Conversation" - history_file = os.path.join("conversations", f"{self.conversation_name}.yaml") - if os.path.exists(history_file): - with open(history_file, "r") as file: - history = yaml.safe_load(file) - return history - return {"interactions": []} - - def get_conversation(self, limit=100, page=1): - history = {"interactions": []} - try: - history_file = os.path.join( - "conversations", f"{self.conversation_name}.yaml" - ) - if os.path.exists(history_file): - with open(history_file, "r") as file: - history = yaml.safe_load(file) - except: - history = self.new_conversation() - return history - - def get_conversations(self): - conversation_dir = os.path.join("conversations") - if os.path.exists(conversation_dir): - conversations = os.listdir(conversation_dir) - return [conversation.split(".")[0] for conversation in conversations] - return [] - - def new_conversation(self, conversation_content=[]): - history = {"interactions": conversation_content} - history_file = os.path.join("conversations", f"{self.conversation_name}.yaml") - os.makedirs(os.path.dirname(history_file), exist_ok=True) - with open(history_file, "w") as file: - yaml.safe_dump(history, file) - return history - - def log_interaction(self, role: str, message: str): - history = self.get_conversation() - history_file = os.path.join("conversations", f"{self.conversation_name}.yaml") - if not os.path.exists(history_file): - os.makedirs(os.path.dirname(history_file), exist_ok=True) - if not history: - history = {"interactions": []} - if "interactions" not in history: - history["interactions"] = [] - history["interactions"].append( - { - "role": role, - "message": message, - "timestamp": datetime.now().strftime("%B %d, %Y %I:%M %p"), - } - ) - with open(history_file, "w") as file: - yaml.safe_dump(history, file) - if role.lower() == "user": - logging.info(f"{self.user}: {message}") - else: - logging.info(f"{role}: {message}") - - def delete_conversation(self): - history_file = os.path.join("conversations", f"{self.conversation_name}.yaml") - if os.path.exists(history_file): - os.remove(history_file) - - def delete_message(self, message): - history = self.get_conversation() - history["interactions"] = [ - interaction - for interaction in history["interactions"] - if interaction["message"] != message - ] - history_file = os.path.join("conversations", f"{self.conversation_name}.yaml") - with open(history_file, "w") as file: - yaml.safe_dump(history, file) - - def update_message(self, message, new_message): - history = self.get_conversation() - for interaction in history["interactions"]: - if interaction["message"] == message: - interaction["message"] = new_message - break - history_file = os.path.join("conversations", f"{self.conversation_name}.yaml") - with open(history_file, "w") as file: - yaml.safe_dump(history, file) diff --git a/agixt/fb/Prompts.py b/agixt/fb/Prompts.py deleted file mode 100644 index 1b845bfd37cd..000000000000 --- a/agixt/fb/Prompts.py +++ /dev/null @@ -1,112 +0,0 @@ -import os - - -def get_prompt_file_path(prompt_name, prompt_category="Default", user="USER"): - base_path = os.path.join(os.getcwd(), "prompts") - base_model_path = os.path.normpath( - os.path.join(os.getcwd(), "prompts", prompt_category) - ) - model_prompt_file = os.path.normpath( - os.path.join(base_model_path, f"{prompt_name}.txt") - ) - default_prompt_file = os.path.normpath( - os.path.join(base_path, "Default", f"{prompt_name}.txt") - ) - if ( - not base_model_path.startswith(base_path) - or not model_prompt_file.startswith(base_model_path) - or not default_prompt_file.startswith(base_path) - ): - raise ValueError( - "Invalid file path. Prompt name cannot contain '/', '\\' or '..' in" - ) - if not os.path.exists(base_path): - os.mkdir(base_path) - if not os.path.exists(base_model_path): - os.mkdir(base_model_path) - prompt_file = ( - model_prompt_file if os.path.isfile(model_prompt_file) else default_prompt_file - ) - return prompt_file - - -class Prompts: - def __init__(self, user="USER"): - self.user = user - - def add_prompt(self, prompt_name, prompt, prompt_category="Default"): - # if prompts folder does not exist, create it - file_path = get_prompt_file_path( - prompt_name=prompt_name, prompt_category=prompt_category - ) - # if prompt file does not exist, create it - if not os.path.exists(file_path): - with open(file_path, "w") as f: - f.write(prompt) - - def get_prompt(self, prompt_name, prompt_category="Default"): - prompt_file = get_prompt_file_path( - prompt_name=prompt_name, prompt_category=prompt_category - ) - with open(prompt_file, "r") as f: - prompt = f.read() - return prompt - - def get_prompts(self, prompt_category="Default"): - # Get all files in prompts folder that end in .txt and replace .txt with empty string - prompts = [] - # For each folder in prompts folder, get all files that end in .txt and replace .txt with empty string - base_path = os.path.join("prompts", prompt_category) - base_path = os.path.join(os.getcwd(), "prompts") - base_model_path = os.path.normpath( - os.path.join(os.getcwd(), "prompts", prompt_category) - ) - if not base_model_path.startswith(base_path) or not base_model_path.startswith( - base_model_path - ): - raise ValueError( - "Invalid file path. Prompt name cannot contain '/', '\\' or '..' in" - ) - if not os.path.exists(base_model_path): - os.mkdir(base_model_path) - for file in os.listdir(base_model_path): - if file.endswith(".txt"): - prompts.append(file.replace(".txt", "")) - return prompts - - def get_prompt_categories(self): - prompt_categories = [] - for folder in os.listdir("prompts"): - if os.path.isdir(os.path.join("prompts", folder)): - prompt_categories.append(folder) - return prompt_categories - - def get_prompt_args(self, prompt_text): - # Find anything in the file between { and } and add them to a list to return - prompt_vars = [] - for word in prompt_text.split(): - if word.startswith("{") and word.endswith("}"): - prompt_vars.append(word[1:-1]) - return prompt_vars - - def delete_prompt(self, prompt_name, prompt_category="Default"): - prompt_file = get_prompt_file_path( - prompt_name=prompt_name, prompt_category=prompt_category - ) - os.remove(prompt_file) - - def update_prompt(self, prompt_name, prompt, prompt_category="Default"): - prompt_file = get_prompt_file_path( - prompt_name=prompt_name, prompt_category=prompt_category - ) - with open(prompt_file, "w") as f: - f.write(prompt) - - def rename_prompt(self, prompt_name, new_prompt_name, prompt_category="Default"): - prompt_file = get_prompt_file_path( - prompt_name=prompt_name, prompt_category=prompt_category - ) - new_prompt_file = get_prompt_file_path( - prompt_name=new_prompt_name, prompt_category=prompt_category - ) - os.rename(prompt_file, new_prompt_file) diff --git a/agixt/launch-backend.sh b/agixt/launch-backend.sh index 7ce76dc82d81..a415e693886e 100755 --- a/agixt/launch-backend.sh +++ b/agixt/launch-backend.sh @@ -1,11 +1,8 @@ #!/bin/sh echo "Starting AGiXT..." -if [ "$DB_CONNECTED" = "true" ]; then - sleep 15 - echo "Connecting to DB..." - python3 DBConnection.py - sleep 5 -fi +sleep 15 +python3 DB.py +sleep 5 if [ -n "$NGROK_TOKEN" ]; then echo "Starting ngrok..." python3 Tunnel.py diff --git a/agixt/providers/ezlocalai.py b/agixt/providers/ezlocalai.py index 3325bccac970..11290b0cf887 100644 --- a/agixt/providers/ezlocalai.py +++ b/agixt/providers/ezlocalai.py @@ -3,7 +3,7 @@ import re import numpy as np import requests -from Defaults import getenv +from Globals import getenv import uuid from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction diff --git a/agixt/providers/openai.py b/agixt/providers/openai.py index 14fba0223e00..6a095d5b5b0d 100644 --- a/agixt/providers/openai.py +++ b/agixt/providers/openai.py @@ -3,7 +3,7 @@ import random import requests import uuid -from Defaults import getenv +from Globals import getenv import numpy as np from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction diff --git a/agixt/version b/agixt/version index d9f31c4efcf1..05f629f1b7f2 100644 --- a/agixt/version +++ b/agixt/version @@ -1 +1 @@ -v1.5.18 \ No newline at end of file +v1.6.0 \ No newline at end of file diff --git a/docker-compose-dev.yml b/docker-compose-dev.yml index 8273bd26524f..5736bc2d89b1 100644 --- a/docker-compose-dev.yml +++ b/docker-compose-dev.yml @@ -1,4 +1,3 @@ -version: "3.7" services: db: image: postgres @@ -14,7 +13,6 @@ services: image: joshxt/agixt:main init: true environment: - - DB_CONNECTED=${DB_CONNECTED:-false} - DATABASE_HOST=${DATABASE_HOST:-db} - DATABASE_USER=${DATABASE_USER:-postgres} - DATABASE_PASSWORD=${DATABASE_PASSWORD:-postgres} @@ -29,6 +27,7 @@ services: - WORKING_DIRECTORY=${WORKING_DIRECTORY:-/agixt/WORKSPACE} - TOKENIZERS_PARALLELISM=False - LOG_LEVEL=${LOG_LEVEL:-INFO} + - AUTH_PROVIDER=${AUTH_PROVIDER:-none} - TZ=${TZ-America/New_York} ports: - "7437:7437" diff --git a/docker-compose.yml b/docker-compose.yml index 84b05ae02daf..d03b6fbddfb8 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,15 +1,33 @@ -version: "3.7" services: + db: + image: postgres + ports: + - 5432:5432 + environment: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: ${DATABASE_PASSWORD:-postgres} + POSTGRES_DB: postgres + volumes: + - ./data:/var/lib/postgresql/data agixt: image: joshxt/agixt:latest init: true environment: + - DATABASE_HOST=${DATABASE_HOST:-db} + - DATABASE_USER=${DATABASE_USER:-postgres} + - DATABASE_PASSWORD=${DATABASE_PASSWORD:-postgres} + - DATABASE_NAME=${DATABASE_NAME:-postgres} + - DATABASE_PORT=${DATABASE_PORT:-5432} - UVICORN_WORKERS=${UVICORN_WORKERS:-10} + - USING_JWT=${USING_JWT:-false} - AGIXT_API_KEY=${AGIXT_API_KEY} - AGIXT_URI=${AGIXT_URI-http://agixt:7437} + - DISABLED_EXTENSIONS=${DISABLED_EXTENSIONS:-} + - DISABLED_PROVIDERS=${DISABLED_PROVIDERS:-} - WORKING_DIRECTORY=${WORKING_DIRECTORY:-/agixt/WORKSPACE} - TOKENIZERS_PARALLELISM=False - - LOG_LEVEL=${LOG_LEVEL:-ERROR} + - LOG_LEVEL=${LOG_LEVEL:-INFO} + - AUTH_PROVIDER=${AUTH_PROVIDER:-none} - TZ=${TZ-America/New_York} ports: - "7437:7437" diff --git a/tests/completions-tests.ipynb b/tests/completions-tests.ipynb index 74f6a7949380..543369816572 100644 --- a/tests/completions-tests.ipynb +++ b/tests/completions-tests.ipynb @@ -32,7 +32,9 @@ "from dotenv import load_dotenv\n", "\n", "load_dotenv()\n", + "import time\n", "\n", + "time.sleep(180) # wait for the AGiXT server to start\n", "# Set your system message, max tokens, temperature, and top p here, or use the defaults.\n", "AGENT_NAME = \"gpt4free\"\n", "AGIXT_SERVER = \"http://localhost:7437\"\n", From 2ccb4f93a4a2bc038c1a04a764aa0aa22f709942 Mon Sep 17 00:00:00 2001 From: Josh XT <102809327+Josh-XT@users.noreply.github.com> Date: Wed, 29 May 2024 13:14:13 -0400 Subject: [PATCH 0023/1256] Delete sdextensions directory Signed-off-by: Josh XT <102809327+Josh-XT@users.noreply.github.com> --- sdextensions/Put your stable diffusion extensions here.txt | 1 - 1 file changed, 1 deletion(-) delete mode 100644 sdextensions/Put your stable diffusion extensions here.txt diff --git a/sdextensions/Put your stable diffusion extensions here.txt b/sdextensions/Put your stable diffusion extensions here.txt deleted file mode 100644 index 13be6f3f09b4..000000000000 --- a/sdextensions/Put your stable diffusion extensions here.txt +++ /dev/null @@ -1 +0,0 @@ -Put your stable diffusion extensions in this folder. From 3b801ec63c8d78626110f6fa4be5c68b30e68a59 Mon Sep 17 00:00:00 2001 From: Josh XT <102809327+Josh-XT@users.noreply.github.com> Date: Wed, 29 May 2024 13:14:26 -0400 Subject: [PATCH 0024/1256] Delete localizations directory Signed-off-by: Josh XT <102809327+Josh-XT@users.noreply.github.com> --- localizations/Put your stable diffusion localizations here.txt | 1 - 1 file changed, 1 deletion(-) delete mode 100644 localizations/Put your stable diffusion localizations here.txt diff --git a/localizations/Put your stable diffusion localizations here.txt b/localizations/Put your stable diffusion localizations here.txt deleted file mode 100644 index aa2e872075ea..000000000000 --- a/localizations/Put your stable diffusion localizations here.txt +++ /dev/null @@ -1 +0,0 @@ -Put your stable diffusion localizations in this folder. From dba02412c8f087cbdc91ed572ae69513e8f22598 Mon Sep 17 00:00:00 2001 From: Josh XT <102809327+Josh-XT@users.noreply.github.com> Date: Wed, 29 May 2024 13:18:24 -0400 Subject: [PATCH 0025/1256] Delete agixt/data directory Signed-off-by: Josh XT <102809327+Josh-XT@users.noreply.github.com> --- agixt/data/placeholder.txt | 1 - 1 file changed, 1 deletion(-) delete mode 100644 agixt/data/placeholder.txt diff --git a/agixt/data/placeholder.txt b/agixt/data/placeholder.txt deleted file mode 100644 index dcef3683c7c0..000000000000 --- a/agixt/data/placeholder.txt +++ /dev/null @@ -1 +0,0 @@ -Persisted databases can go inside of this folder. From d936c823f7fb849ec86110db2ccf288503a29e38 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Wed, 29 May 2024 14:59:23 -0400 Subject: [PATCH 0026/1256] add auth endpoints --- agixt/app.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/agixt/app.py b/agixt/app.py index e1644c2d6e7f..3f08446fc9b1 100644 --- a/agixt/app.py +++ b/agixt/app.py @@ -12,6 +12,7 @@ from endpoints.Memory import app as memory_endpoints from endpoints.Prompt import app as prompt_endpoints from endpoints.Provider import app as provider_endpoints +from endpoints.Auth import app as auth_endpoints from Globals import getenv os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -49,6 +50,7 @@ app.include_router(memory_endpoints) app.include_router(prompt_endpoints) app.include_router(provider_endpoints) +app.include_router(auth_endpoints) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7437) From dc1a28d979374543498fcf5a52254e022dab5665 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Wed, 29 May 2024 15:06:08 -0400 Subject: [PATCH 0027/1256] add pyotp --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 26ab36e2fa9f..f65835c88eb4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,4 +16,5 @@ google-auth==2.29.0 google-api-python-client==2.125.0 python-multipart==0.0.9 nest_asyncio -g4f==0.3.1.9 \ No newline at end of file +g4f==0.3.1.9 +pyotp \ No newline at end of file From 1271f9c60f4a98fa16468bddec68a32e1098dc91 Mon Sep 17 00:00:00 2001 From: Josh XT <102809327+Josh-XT@users.noreply.github.com> Date: Wed, 29 May 2024 21:27:02 -0400 Subject: [PATCH 0028/1256] Chain improvements for tracking step responses (#1196) * Chain improvements for tracking step responses * fix requirement on user * fix import issue, move run chain functions * fix user ref * fix user_email ref * add user input to inference call * fix role * fix endpoint * remove old endpoint * rerun tests * cascade on delete * cascade when deleting chain --- agixt/Chain.py | 117 +++++++++++++++--- agixt/Chains.py | 212 -------------------------------- agixt/DB.py | 29 +++++ agixt/Interactions.py | 40 ------- agixt/Models.py | 5 +- agixt/{AGiXT.py => XT.py} | 213 +++++++++++++++++++++++++++------ agixt/endpoints/Chain.py | 40 +++---- agixt/endpoints/Completions.py | 2 +- tests/tests.ipynb | 34 +----- 9 files changed, 329 insertions(+), 363 deletions(-) delete mode 100644 agixt/Chains.py rename agixt/{AGiXT.py => XT.py} (77%) diff --git a/agixt/Chain.py b/agixt/Chain.py index 054a402b3b95..a327cf3596ee 100644 --- a/agixt/Chain.py +++ b/agixt/Chain.py @@ -3,6 +3,7 @@ Chain as ChainDB, ChainStep, ChainStepResponse, + ChainRun, Agent, Argument, ChainStepArgument, @@ -11,6 +12,9 @@ User, ) from Globals import getenv, DEFAULT_USER +from Prompts import Prompts +from Conversations import Conversations +from Extensions import Extensions import logging logging.basicConfig( @@ -20,9 +24,10 @@ class Chain: - def __init__(self, user=DEFAULT_USER): + def __init__(self, user=DEFAULT_USER, ApiClient=None): self.session = get_session() self.user = user + self.ApiClient = ApiClient try: user_data = self.session.query(User).filter(User.email == self.user).first() self.user_id = user_data.id @@ -430,12 +435,14 @@ def move_step(self, chain_name, current_step_number, new_step_number): ) self.session.commit() - def get_step_response(self, chain_name, step_number="all"): - chain = self.get_chain(chain_name=chain_name) + def get_step_response(self, chain_name, chain_run_id=None, step_number="all"): + if chain_run_id is None: + chain_run_id = self.get_last_chain_run_id(chain_name=chain_name) + chain_data = self.get_chain(chain_name=chain_name) if step_number == "all": chain_steps = ( self.session.query(ChainStep) - .filter(ChainStep.chain_id == chain["id"]) + .filter(ChainStep.chain_id == chain_data["id"]) .order_by(ChainStep.step_number) .all() ) @@ -444,7 +451,10 @@ def get_step_response(self, chain_name, step_number="all"): for step in chain_steps: chain_step_responses = ( self.session.query(ChainStepResponse) - .filter(ChainStepResponse.chain_step_id == step.id) + .filter( + ChainStepResponse.chain_step_id == step.id, + ChainStepResponse.chain_run_id == chain_run_id, + ) .order_by(ChainStepResponse.timestamp) .all() ) @@ -456,7 +466,7 @@ def get_step_response(self, chain_name, step_number="all"): chain_step = ( self.session.query(ChainStep) .filter( - ChainStep.chain_id == chain["id"], + ChainStep.chain_id == chain_data["id"], ChainStep.step_number == step_number, ) .first() @@ -465,7 +475,10 @@ def get_step_response(self, chain_name, step_number="all"): if chain_step: chain_step_responses = ( self.session.query(ChainStepResponse) - .filter(ChainStepResponse.chain_step_id == chain_step.id) + .filter( + ChainStepResponse.chain_step_id == chain_step.id, + ChainStepResponse.chain_run_id == chain_run_id, + ) .order_by(ChainStepResponse.timestamp) .all() ) @@ -585,7 +598,9 @@ def import_chain(self, chain_name: str, steps: dict): return f"Imported chain: {chain_name}" - def get_step_content(self, chain_name, prompt_content, user_input, agent_name): + def get_step_content( + self, chain_run_id, chain_name, prompt_content, user_input, agent_name + ): if isinstance(prompt_content, dict): new_prompt_content = {} for arg, value in prompt_content.items(): @@ -599,7 +614,9 @@ def get_step_content(self, chain_name, prompt_content, user_input, agent_name): for i in range(step_count): new_step_number = int(value.split("{STEP")[1].split("}")[0]) step_response = self.get_step_response( - chain_name=chain_name, step_number=new_step_number + chain_run_id=chain_run_id, + chain_name=chain_name, + step_number=new_step_number, ) if step_response: resp = ( @@ -629,7 +646,9 @@ def get_step_content(self, chain_name, prompt_content, user_input, agent_name): prompt_content.split("{STEP")[1].split("}")[0] ) step_response = self.get_step_response( - chain_name=chain_name, step_number=new_step_number + chain_run_id=chain_run_id, + chain_name=chain_name, + step_number=new_step_number, ) if step_response: resp = ( @@ -644,13 +663,17 @@ def get_step_content(self, chain_name, prompt_content, user_input, agent_name): else: return prompt_content - async def update_step_response(self, chain_name, step_number, response): - chain = self.get_chain(chain_name=chain_name) + async def update_step_response( + self, chain_run_id, chain_name, step_number, response + ): chain_step = self.get_step(chain_name=chain_name, step_number=step_number) if chain_step: existing_response = ( self.session.query(ChainStepResponse) - .filter(ChainStepResponse.chain_step_id == chain_step.id) + .filter( + ChainStepResponse.chain_step_id == chain_step.id, + ChainStepResponse.chain_run_id == chain_run_id, + ) .order_by(ChainStepResponse.timestamp.desc()) .first() ) @@ -667,13 +690,77 @@ async def update_step_response(self, chain_name, step_number, response): self.session.commit() else: chain_step_response = ChainStepResponse( - chain_step_id=chain_step.id, content=response + chain_step_id=chain_step.id, + chain_run_id=chain_run_id, + content=response, ) self.session.add(chain_step_response) self.session.commit() else: chain_step_response = ChainStepResponse( - chain_step_id=chain_step.id, content=response + chain_step_id=chain_step.id, + chain_run_id=chain_run_id, + content=response, ) self.session.add(chain_step_response) self.session.commit() + + async def get_chain_run_id(self, chain_name): + chain_run = ChainRun( + chain_id=self.get_chain(chain_name=chain_name)["id"], + user_id=self.user_id, + ) + self.session.add(chain_run) + self.session.commit() + return chain_run.id + + async def get_last_chain_run_id(self, chain_name): + chain_data = self.get_chain(chain_name=chain_name) + chain_run = ( + self.session.query(ChainRun) + .filter(ChainRun.chain_id == chain_data["id"]) + .order_by(ChainRun.timestamp.desc()) + .first() + ) + if chain_run: + return chain_run.id + else: + return await self.get_chain_run_id(chain_name=chain_name) + + def get_chain_args(self, chain_name): + skip_args = [ + "command_list", + "context", + "COMMANDS", + "date", + "conversation_history", + "agent_name", + "working_directory", + "helper_agent_name", + ] + chain_data = self.get_chain(chain_name=chain_name) + steps = chain_data["steps"] + prompt_args = [] + args = [] + for step in steps: + try: + prompt = step["prompt"] + if "prompt_name" in prompt: + prompt_text = Prompts(user=self.user).get_prompt( + prompt_name=prompt["prompt_name"] + ) + args = Prompts(user=self.user).get_prompt_args( + prompt_text=prompt_text + ) + elif "command_name" in prompt: + args = Extensions().get_command_args( + command_name=prompt["command_name"] + ) + elif "chain_name" in prompt: + args = self.get_chain_args(chain_name=prompt["chain_name"]) + for arg in args: + if arg not in prompt_args and arg not in skip_args: + prompt_args.append(arg) + except Exception as e: + logging.error(f"Error getting chain args: {e}") + return prompt_args diff --git a/agixt/Chains.py b/agixt/Chains.py deleted file mode 100644 index 6951a5055c31..000000000000 --- a/agixt/Chains.py +++ /dev/null @@ -1,212 +0,0 @@ -import logging -from Globals import getenv -from ApiClient import Chain, Prompts, Conversations -from Extensions import Extensions - -logging.basicConfig( - level=getenv("LOG_LEVEL"), - format=getenv("LOG_FORMAT"), -) - - -class Chains: - def __init__(self, user="USER", ApiClient=None): - self.user = user - self.chain = Chain(user=user) - self.ApiClient = ApiClient - - async def run_chain_step( - self, - step: dict = {}, - chain_name="", - user_input="", - agent_override="", - chain_args={}, - ): - if step: - if "prompt_type" in step: - if agent_override != "": - agent_name = agent_override - else: - agent_name = step["agent_name"] - - prompt_type = step["prompt_type"] - step_number = step["step"] - if "prompt_name" in step["prompt"]: - prompt_name = step["prompt"]["prompt_name"] - else: - prompt_name = "" - args = self.chain.get_step_content( - chain_name=chain_name, - prompt_content=step["prompt"], - user_input=user_input, - agent_name=step["agent_name"], - ) - if chain_args != {}: - for arg, value in chain_args.items(): - args[arg] = value - if "chain_name" in args: - args["chain"] = args["chain_name"] - if "chain" not in args: - args["chain"] = chain_name - if "conversation_name" not in args: - args["conversation_name"] = f"Chain Execution History: {chain_name}" - if "conversation" in args: - args["conversation_name"] = args["conversation"] - if prompt_type == "Command": - return self.ApiClient.execute_command( - agent_name=agent_name, - command_name=step["prompt"]["command_name"], - command_args=args, - conversation_name=args["conversation_name"], - ) - elif prompt_type == "Prompt": - result = self.ApiClient.prompt_agent( - agent_name=agent_name, - prompt_name=prompt_name, - prompt_args={ - "chain_name": chain_name, - "step_number": step_number, - "user_input": user_input, - **args, - }, - ) - elif prompt_type == "Chain": - result = self.ApiClient.run_chain( - chain_name=args["chain"], - user_input=args["input"], - agent_name=agent_name, - all_responses=( - args["all_responses"] if "all_responses" in args else False - ), - from_step=args["from_step"] if "from_step" in args else 1, - chain_args=( - args["chain_args"] - if "chain_args" in args - else {"conversation_name": args["conversation_name"]} - ), - ) - if result: - if isinstance(result, dict) and "response" in result: - result = result["response"] - if result == "Unable to retrieve data.": - result = None - if not isinstance(result, str): - result = str(result) - return result - else: - return None - - async def run_chain( - self, - chain_name, - user_input=None, - all_responses=True, - agent_override="", - from_step=1, - chain_args={}, - ): - chain_data = self.ApiClient.get_chain(chain_name=chain_name) - if chain_data == {}: - return f"Chain `{chain_name}` not found." - c = Conversations( - conversation_name=( - f"Chain Execution History: {chain_name}" - if "conversation_name" not in chain_args - else chain_args["conversation_name"] - ), - user=self.user, - ) - c.log_interaction( - role="USER", - message=user_input, - ) - logging.info(f"Running chain '{chain_name}'") - responses = {} # Create a dictionary to hold responses. - last_response = "" - for step_data in chain_data["steps"]: - if int(step_data["step"]) >= int(from_step): - if "prompt" in step_data and "step" in step_data: - step = {} - step["agent_name"] = ( - agent_override - if agent_override != "" - else step_data["agent_name"] - ) - step["prompt_type"] = step_data["prompt_type"] - step["prompt"] = step_data["prompt"] - step["step"] = step_data["step"] - logging.info( - f"Running step {step_data['step']} with agent {step['agent_name']}." - ) - # try: - step_response = await self.run_chain_step( - step=step, - chain_name=chain_name, - user_input=user_input, - agent_override=agent_override, - chain_args=chain_args, - ) # Get the response of the current step. - # except Exception as e: - # logging.error(f"Error running chain step: {e}") - # step_response = None - if step_response == None: - return f"Chain failed to complete, it failed on step {step_data['step']}. You can resume by starting the chain from the step that failed." - step["response"] = step_response - last_response = step_response - logging.info(f"Last response: {last_response}") - responses[step_data["step"]] = step # Store the response. - logging.info(f"Step {step_data['step']} response: {step_response}") - # Write the response to the chain responses file. - await self.chain.update_step_response( - chain_name=chain_name, - step_number=step_data["step"], - response=step_response, - ) - if all_responses: - return responses - else: - # Return only the last response in the chain. - c.log_interaction( - role=agent_override if agent_override != "" else "AGiXT", - message=last_response, - ) - return last_response - - def get_chain_args(self, chain_name): - skip_args = [ - "command_list", - "context", - "COMMANDS", - "date", - "conversation_history", - "agent_name", - "working_directory", - "helper_agent_name", - ] - chain_data = self.chain.get_chain(chain_name=chain_name) - steps = chain_data["steps"] - prompt_args = [] - args = [] - for step in steps: - try: - prompt = step["prompt"] - if "prompt_name" in prompt: - prompt_text = Prompts(user=self.user).get_prompt( - prompt_name=prompt["prompt_name"] - ) - args = Prompts(user=self.user).get_prompt_args( - prompt_text=prompt_text - ) - elif "command_name" in prompt: - args = Extensions().get_command_args( - command_name=prompt["command_name"] - ) - elif "chain_name" in prompt: - args = self.get_chain_args(chain_name=prompt["chain_name"]) - for arg in args: - if arg not in prompt_args and arg not in skip_args: - prompt_args.append(arg) - except Exception as e: - logging.error(f"Error getting chain args: {e}") - return prompt_args diff --git a/agixt/DB.py b/agixt/DB.py index 6d78c2879516..5dea1e5e0880 100644 --- a/agixt/DB.py +++ b/agixt/DB.py @@ -314,6 +314,7 @@ class Chain(Base): "ChainStep", backref="target_chain", foreign_keys="ChainStep.target_chain_id" ) user = relationship("User", backref="chain") + runs = relationship("ChainRun", backref="chain", cascade="all, delete-orphan") class ChainStep(Base): @@ -379,6 +380,29 @@ class ChainStepArgument(Base): value = Column(Text, nullable=True) +class ChainRun(Base): + __tablename__ = "chain_run" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + chain_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("chain.id", ondelete="CASCADE"), + nullable=False, + ) + user_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("user.id"), + nullable=True, + ) + timestamp = Column(DateTime, server_default=text("now()")) + chain_step_responses = relationship( + "ChainStepResponse", backref="chain_run", cascade="all, delete-orphan" + ) + + class ChainStepResponse(Base): __tablename__ = "chain_step_response" id = Column( @@ -391,6 +415,11 @@ class ChainStepResponse(Base): ForeignKey("chain_step.id", ondelete="CASCADE"), nullable=False, # Add the ondelete option ) + chain_run_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("chain_run.id", ondelete="CASCADE"), + nullable=True, + ) timestamp = Column(DateTime, server_default=text("now()")) content = Column(Text, nullable=False) diff --git a/agixt/Interactions.py b/agixt/Interactions.py index b58eef729934..47e550c175c6 100644 --- a/agixt/Interactions.py +++ b/agixt/Interactions.py @@ -99,8 +99,6 @@ async def format_prompt( user_input: str = "", top_results: int = 5, prompt="", - chain_name="", - step_number=0, conversation_name="", vision_response: str = "", **kwargs, @@ -195,36 +193,6 @@ async def format_prompt( context = f"The user's input causes you remember these things:\n{context}\n" else: context = "" - if chain_name != "": - try: - for arg, value in kwargs.items(): - if "{STEP" in value: - # get the response from the step number - step_response = self.chain.get_step_response( - chain_name=chain_name, step_number=step_number - ) - # replace the {STEPx} with the response - value = value.replace( - f"{{STEP{step_number}}}", - step_response if step_response else "", - ) - kwargs[arg] = value - except: - logging.info("No args to replace.") - if "{STEP" in prompt: - step_response = self.chain.get_step_response( - chain_name=chain_name, step_number=step_number - ) - prompt = prompt.replace( - f"{{STEP{step_number}}}", step_response if step_response else "" - ) - if "{STEP" in user_input: - step_response = self.chain.get_step_response( - chain_name=chain_name, step_number=step_number - ) - user_input = user_input.replace( - f"{{STEP{step_number}}}", step_response if step_response else "" - ) try: working_directory = self.agent.AGENT_CONFIG["settings"]["WORKING_DIRECTORY"] except: @@ -385,8 +353,6 @@ async def run( self, user_input: str = "", context_results: int = 5, - chain_name: str = "", - step_number: int = 0, shots: int = 1, disable_memory: bool = True, conversation_name: str = "", @@ -501,8 +467,6 @@ async def run( top_results=int(context_results), prompt=prompt, prompt_category=prompt_category, - chain_name=chain_name, - step_number=step_number, conversation_name=conversation_name, websearch=websearch, vision_response=vision_response, @@ -541,8 +505,6 @@ async def run( if context_results > 0: context_results = context_results - 1 prompt_args = { - "chain_name": chain_name, - "step_number": step_number, "shots": shots, "disable_memory": disable_memory, "user_input": user_input, @@ -648,8 +610,6 @@ async def run( responses = [self.response] for shot in range(shots - 1): prompt_args = { - "chain_name": chain_name, - "step_number": step_number, "user_input": user_input, "context_results": context_results, "conversation_name": conversation_name, diff --git a/agixt/Models.py b/agixt/Models.py index 2fb75888dfa3..d25b69a0066f 100644 --- a/agixt/Models.py +++ b/agixt/Models.py @@ -117,6 +117,7 @@ class RunChainStep(BaseModel): prompt: str agent_override: Optional[str] = "" chain_args: Optional[dict] = {} + chain_run_id: Optional[str] = "" class StepInfo(BaseModel): @@ -285,15 +286,11 @@ class Register(BaseModel): email: str first_name: str last_name: str - company_name: str - job_title: str class UserInfo(BaseModel): first_name: str last_name: str - company_name: str - job_title: str class Detail(BaseModel): diff --git a/agixt/AGiXT.py b/agixt/XT.py similarity index 77% rename from agixt/AGiXT.py rename to agixt/XT.py index 494ee8b0a126..60bb75b48575 100644 --- a/agixt/AGiXT.py +++ b/agixt/XT.py @@ -2,7 +2,6 @@ from ApiClient import get_api_client, Conversations, Prompts, Chain from readers.file import FileReader from Extensions import Extensions -from Chains import Chains from pydub import AudioSegment from Globals import getenv, get_tokens, DEFAULT_SETTINGS from Models import ChatCompletions @@ -31,6 +30,7 @@ def __init__(self, user: str, agent_name: str, api_key: str): if "settings" in self.agent.AGENT_CONFIG else DEFAULT_SETTINGS ) + self.chain = Chain(user=self.user_email) async def prompts(self, prompt_category: str = "Default"): """ @@ -53,7 +53,7 @@ async def chains(self): Returns: list: List of available chains """ - return Chain(user=self.user_email).get_chains() + return self.chain.get_chains() async def settings(self): """ @@ -264,8 +264,8 @@ async def execute_command( if conversation_name != "" and conversation_name != None: c = Conversations(conversation_name=conversation_name, user=self.user_email) c.log_interaction( - role=self.agent, - message=f"[ACTIVITY_START] Execute command: {command_name} with args: {command_args} [ACTIVITY_END]", + role=self.agent_name, + message=f"[ACTIVITY_START] Executing command: {command_name} with args: {command_args} [ACTIVITY_END]", ) response = await Extensions( agent_name=self.agent_name, @@ -293,41 +293,181 @@ async def execute_command( ) return response - async def execute_chain( + async def run_chain_step( self, - chain_name: str, - user_input: str, - chain_args: dict = {}, - use_current_agent: bool = True, - conversation_name: str = "", - voice_response: bool = False, + chain_run_id=None, + step: dict = {}, + chain_name="", + user_input="", + agent_override="", + chain_args={}, + conversation_name="", ): - """ - Execute a chain with arguments - - Args: - chain_name (str): Name of the chain to execute - user_input (str): Message to add to conversation log pre-execution - chain_args (dict): Arguments for the chain - use_current_agent (bool): Whether to use the current agent - conversation_name (str): Name of the conversation - voice_response (bool): Whether to generate a voice response + if not chain_run_id: + chain_run_id = await self.chain.get_chain_run_id(chain_name=chain_name) + if step: + if "prompt_type" in step: + c = None + if conversation_name != "": + c = Conversations( + conversation_name=conversation_name, + user=self.user_email, + ) + if agent_override != "": + agent_name = agent_override + else: + agent_name = step["agent_name"] + prompt_type = step["prompt_type"] + step_number = step["step"] + if "prompt_name" in step["prompt"]: + prompt_name = step["prompt"]["prompt_name"] + else: + prompt_name = "" + args = self.chain.get_step_content( + chain_run_id=chain_run_id, + chain_name=chain_name, + prompt_content=step["prompt"], + user_input=user_input, + agent_name=agent_name, + ) + if chain_args != {}: + for arg, value in chain_args.items(): + args[arg] = value + if "chain_name" in args: + args["chain"] = args["chain_name"] + if "chain" not in args: + args["chain"] = chain_name + if "conversation_name" not in args: + args["conversation_name"] = f"Chain Execution History: {chain_name}" + if "conversation" in args: + args["conversation_name"] = args["conversation"] + if prompt_type == "Command": + if conversation_name != "" and conversation_name != None: + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY_START] Executing command: {step['prompt']['command_name']} with args: {args} [ACTIVITY_END]", + ) + result = await self.execute_command( + command_name=step["prompt"]["command_name"], + command_args=args, + conversation_name=args["conversation_name"], + voice_response=False, + ) + elif prompt_type == "Prompt": + if conversation_name != "" and conversation_name != None: + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY_START] Running prompt: {prompt_name} with args: {args} [ACTIVITY_END]", + ) + if "prompt_name" not in args: + args["prompt_name"] = prompt_name + result = await self.inference( + agent_name=agent_name, + user_input=user_input, + log_user_input=False, + **args, + ) + elif prompt_type == "Chain": + if conversation_name != "" and conversation_name != None: + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY_START] Running chain: {step['prompt']['chain_name']} with args: {args} [ACTIVITY_END]", + ) + result = await self.execute_chain( + chain_name=args["chain"], + user_input=args["input"], + agent_override=agent_name, + from_step=args["from_step"] if "from_step" in args else 1, + chain_args=( + args["chain_args"] + if "chain_args" in args + else {"conversation_name": args["conversation_name"]} + ), + conversation_name=args["conversation_name"], + log_user_input=False, + voice_response=False, + ) + if result: + if isinstance(result, dict) and "response" in result: + result = result["response"] + if result == "Unable to retrieve data.": + result = None + if isinstance(result, dict): + result = json.dumps(result) + if not isinstance(result, str): + result = str(result) + await self.chain.update_step_response( + chain_run_id=chain_run_id, + chain_name=chain_name, + step_number=step_number, + response=result, + ) + return result + else: + return None - Returns: - str: Response from the chain - """ - c = Conversations(conversation_name=conversation_name, user=self.user_email) - c.log_interaction(role="USER", message=user_input) - response = await Chains( - user=self.user_email, ApiClient=self.ApiClient - ).run_chain( - chain_name=chain_name, - user_input=user_input, - agent_override=self.agent_name if use_current_agent else None, - all_responses=False, - chain_args=chain_args, - from_step=1, + async def execute_chain( + self, + chain_name, + chain_run_id=None, + user_input=None, + agent_override="", + from_step=1, + chain_args={}, + log_user_input=False, + conversation_name="", + voice_response=False, + ): + chain_data = self.chain.get_chain(chain_name=chain_name) + if not chain_run_id: + chain_run_id = await self.chain.get_chain_run_id(chain_name=chain_name) + if chain_data == {}: + return f"Chain `{chain_name}` not found." + c = Conversations( + conversation_name=conversation_name, + user=self.user_email, ) + if log_user_input: + c.log_interaction( + role="USER", + message=user_input, + ) + agent_name = agent_override if agent_override != "" else "AGiXT" + if conversation_name != "": + c.log_interaction( + role=agent_name, + message=f"[ACTIVITY_START] Running chain `{chain_name}`... [ACTIVITY_END]", + ) + response = "" + for step_data in chain_data["steps"]: + if int(step_data["step"]) >= int(from_step): + if "prompt" in step_data and "step" in step_data: + step = {} + step["agent_name"] = ( + agent_override + if agent_override != "" + else step_data["agent_name"] + ) + step["prompt_type"] = step_data["prompt_type"] + step["prompt"] = step_data["prompt"] + step["step"] = step_data["step"] + step_response = await self.run_chain_step( + chain_run_id=chain_run_id, + step=step, + chain_name=chain_name, + user_input=user_input, + agent_override=agent_override, + chain_args=chain_args, + conversation_name=conversation_name, + ) + if step_response == None: + return f"Chain failed to complete, it failed on step {step_data['step']}. You can resume by starting the chain from the step that failed with chain ID {chain_run_id}." + response = step_response + if conversation_name != "": + c.log_interaction( + role=agent_name, + message=response, + ) if "tts_provider" in self.agent_settings and voice_response: if ( self.agent_settings["tts_provider"] != "None" @@ -667,8 +807,9 @@ async def chat_completions(self, prompt: ChatCompletions): response = await self.execute_chain( chain_name=chain_name, user_input=new_prompt, + agent_override=self.agent_name, chain_args=chain_args, - use_current_agent=True, + log_user_input=False, conversation_name=conversation_name, voice_response=tts, ) diff --git a/agixt/endpoints/Chain.py b/agixt/endpoints/Chain.py index 9f245a21489b..b5c70001d77b 100644 --- a/agixt/endpoints/Chain.py +++ b/agixt/endpoints/Chain.py @@ -1,6 +1,6 @@ from fastapi import APIRouter, HTTPException, Depends, Header from ApiClient import Chain, verify_api_key, get_api_client, is_admin -from Chains import Chains +from XT import AGiXT from Models import ( RunChain, RunChainStep, @@ -32,21 +32,6 @@ async def get_chain(chain_name: str, user=Depends(verify_api_key)): return {"chain": chain_data} -@app.get( - "/api/chain/{chain_name}/responses", - tags=["Chain"], - dependencies=[Depends(verify_api_key)], -) -async def get_chain_responses(chain_name: str, user=Depends(verify_api_key)): - try: - chain_data = Chain(user=user).get_step_response( - chain_name=chain_name, step_number="all" - ) - return {"chain": chain_data} - except: - raise HTTPException(status_code=404, detail="Chain not found") - - @app.post( "/api/chain/{chain_name}/run", tags=["Chain", "Admin"], @@ -60,14 +45,18 @@ async def run_chain( ): if is_admin(email=user, api_key=authorization) != True: raise HTTPException(status_code=403, detail="Access Denied") - ApiClient = get_api_client(authorization=authorization) - chain_response = await Chains(user=user, ApiClient=ApiClient).run_chain( + agent_name = user_input.agent_override if user_input.agent_override else "gpt4free" + chain_response = await AGiXT( + user=user, + agent_name=agent_name, + api_key=authorization, + ).execute_chain( chain_name=chain_name, user_input=user_input.prompt, agent_override=user_input.agent_override, - all_responses=user_input.all_responses, from_step=user_input.from_step, chain_args=user_input.chain_args, + log_user_input=False, ) try: if "Chain failed to complete" in chain_response: @@ -99,8 +88,15 @@ async def run_chain_step( raise HTTPException( status_code=404, detail=f"Step {step_number} not found. {e}" ) - ApiClient = get_api_client(authorization=authorization) - chain_step_response = await Chains(user=user, ApiClient=ApiClient).run_chain_step( + agent_name = ( + user_input.agent_override if user_input.agent_override else step["agent"] + ) + chain_step_response = await AGiXT( + user=user, + agent_name=agent_name, + api_key=authorization, + ).run_chain_step( + chain_run_id=user_input.chain_run_id, step=step, chain_name=chain_name, user_input=user_input.prompt, @@ -129,7 +125,7 @@ async def get_chain_args( if is_admin(email=user, api_key=authorization) != True: raise HTTPException(status_code=403, detail="Access Denied") ApiClient = get_api_client(authorization=authorization) - chain_args = Chains(user=user, ApiClient=ApiClient).get_chain_args( + chain_args = Chain(user=user, ApiClient=ApiClient).get_chain_args( chain_name=chain_name ) return {"chain_args": chain_args} diff --git a/agixt/endpoints/Completions.py b/agixt/endpoints/Completions.py index 1362f282dd78..59f0b4d2bbc8 100644 --- a/agixt/endpoints/Completions.py +++ b/agixt/endpoints/Completions.py @@ -13,7 +13,7 @@ TextToSpeech, ImageCreation, ) -from AGiXT import AGiXT +from XT import AGiXT app = APIRouter() diff --git a/tests/tests.ipynb b/tests/tests.ipynb index 44a6057cb6ff..fff013d4d0c3 100644 --- a/tests/tests.ipynb +++ b/tests/tests.ipynb @@ -1534,37 +1534,6 @@ "print(\"Run chain response:\", run_chain_resp)" ] }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Get the responses from the chain running\n" - ] - }, - { - "cell_type": "code", - "execution_count": 41, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Chain: {'1': {'agent_name': 'new_agent', 'prompt_type': 'Prompt', 'prompt': {'prompt_name': 'Write a Poem', 'subject': 'Quantum Computers'}, 'step': 1, 'response': \"In the depths of the quantum realm,\\nWhere possibilities unfurl,\\nLies a marvel of modern science,\\nA computer that defies our world.\\n\\nQuantum entanglement's mysterious dance,\\nHarnessing particles, in a cosmic trance,\\nBits of information, not just zero or one,\\nA quantum computer, where wonders are spun.\\n\\nThe qubits, like tiny dancers on a stage,\\nCan exist in multiple states, they engage,\\nA quantum superposition, a delicate balance,\\nComputing power that leaves us in a trance.\\n\\nThrough quantum gates, these qubits entwine,\\nCreating a web of possibilities, oh so fine,\\nParallel universes, in computation they roam,\\nQuantum computers, bringing the unknown home.\\n\\nComplex algorithms, they can quickly solve,\\nShattering encryption, with problems they evolve,\\nFrom cryptography to simulating the universe,\\nQuantum computers, a scientific traverse.\\n\\nYet, in this realm of infinite potential,\\nErrors and decoherence can be consequential,\\nNoise and disturbances, they threaten the state,\\nA challenge to overcome, for quantum's fate.\\n\\nBut fear not, for scientists persist,\\nAdvancing quantum technology, a fervent twist,\\nWith every breakthrough, a step closer we come,\\nTo a future where quantum computers will hum.\\n\\nIn this world of uncertainty and flux,\\nQuantum computers, the next paradigm, unbox,\\nUnveiling the secrets of our reality's core,\\nA technological marvel, forever to adore.\"}, '2': {'agent_name': 'new_agent', 'prompt_type': 'Command', 'prompt': {'command_name': 'Write to File', 'filename': '{user_input}.txt', 'text': 'Poem:\\n{STEP1}'}, 'step': 2, 'response': 'File written to successfully.'}}\n" - ] - } - ], - "source": [ - "from agixtsdk import AGiXTSDK\n", - "\n", - "base_uri = \"http://localhost:7437\"\n", - "ApiClient = AGiXTSDK(base_uri=base_uri)\n", - "chain_name = \"Poem Writing Chain\"\n", - "chain = ApiClient.get_chain_responses(chain_name=chain_name)\n", - "print(\"Chain:\", chain)" - ] - }, { "attachments": {}, "cell_type": "markdown", @@ -1591,8 +1560,7 @@ "\n", "base_uri = \"http://localhost:7437\"\n", "ApiClient = AGiXTSDK(base_uri=base_uri)\n", - "chain_name = \"Poem Writing Chain\"\n", - "delete_chain_resp = ApiClient.delete_chain(chain_name=chain_name)\n", + "delete_chain_resp = ApiClient.delete_chain(chain_name=\"Poem Writing Chain\")\n", "print(\"Delete chain response:\", delete_chain_resp)" ] }, From 687982567efdc307bad2fb71bff2dd01451f5f8d Mon Sep 17 00:00:00 2001 From: Josh XT Date: Wed, 29 May 2024 21:50:36 -0400 Subject: [PATCH 0029/1256] clean up duplicate --- agixt/chains/Create New Command.json | 31 +++++++++++++++++++++++++++- agixt/extensions/agixt_actions.py | 28 ------------------------- 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/agixt/chains/Create New Command.json b/agixt/chains/Create New Command.json index 7a361911d14b..d86c987599fd 100644 --- a/agixt/chains/Create New Command.json +++ b/agixt/chains/Create New Command.json @@ -1 +1,30 @@ -{"chain_name": "Create New Command", "steps": [{"step": 1, "agent_name": "gpt4free", "prompt_type": "Prompt", "prompt": {"prompt_name": "Create New Command", "NEW_FUNCTION_DESCRIPTION": "{user_input}"}}, {"step": 2, "agent_name": "gpt4free", "prompt_type": "Prompt", "prompt": {"prompt_name": "Title a Chain"}}, {"step": 3, "agent_name": "gpt4free", "prompt_type": "Command", "prompt": {"command_name": "Write to File", "filename": "{STEP2}.py", "text": "{STEP1}"}}]} \ No newline at end of file +{ + "chain_name": "Create New Command", + "steps": [ + { + "step": 1, + "agent_name": "gpt4free", + "prompt_type": "Prompt", + "prompt": { + "prompt_name": "Create New Command", + "NEW_FUNCTION_DESCRIPTION": "{user_input}" + } + }, + { + "step": 2, + "agent_name": "gpt4free", + "prompt_type": "Prompt", + "prompt": { "prompt_name": "Title a Chain" } + }, + { + "step": 3, + "agent_name": "gpt4free", + "prompt_type": "Command", + "prompt": { + "command_name": "Write to File", + "filename": "{STEP2}.py", + "text": "{STEP1}" + } + } + ] +} diff --git a/agixt/extensions/agixt_actions.py b/agixt/extensions/agixt_actions.py index 8bd4e7b26d6b..a972fe9c230a 100644 --- a/agixt/extensions/agixt_actions.py +++ b/agixt/extensions/agixt_actions.py @@ -148,7 +148,6 @@ def __init__(self, **kwargs): "Generate Extension from OpenAPI": self.generate_openapi_chain, "Generate Agent Helper Chain": self.generate_helper_chain, "Ask for Help or Further Clarification to Complete Task": self.ask_for_help, - "Create a new command": self.create_command, "Execute Python Code": self.execute_python_code_internal, "Get Python Code from Response": self.get_python_code_from_response, "Get Mindmap for task to break it down": self.get_mindmap, @@ -645,33 +644,6 @@ async def ask_for_help(self, your_agent_name, your_task): }, ) - async def create_command( - self, function_description: str, agent: str = "AGiXT" - ) -> List[str]: - """ - Create a new command - - Args: - function_description (str): The description of the function - agent (str): The agent to create the command for - - Returns: - str: The response from the chain - """ - try: - return self.ApiClient.run_chain( - chain_name="Create New Command", - user_input=function_description, - agent_name=self.agent_name, - all_responses=False, - from_step=1, - chain_args={ - "conversation_name": self.conversation_name, - }, - ) - except Exception as e: - return f"Unable to create command: {e}" - async def ask(self, user_input: str) -> str: """ Ask a question From 3a0a9a9e66b516fc2fbfd9cc4e152762c4051713 Mon Sep 17 00:00:00 2001 From: Josh XT <102809327+Josh-XT@users.noreply.github.com> Date: Wed, 29 May 2024 22:14:46 -0400 Subject: [PATCH 0030/1256] Update XT.py Signed-off-by: Josh XT <102809327+Josh-XT@users.noreply.github.com> --- agixt/XT.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agixt/XT.py b/agixt/XT.py index 60bb75b48575..dbe5bd7b048e 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -371,7 +371,7 @@ async def run_chain_step( if conversation_name != "" and conversation_name != None: c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY_START] Running chain: {step['prompt']['chain_name']} with args: {args} [ACTIVITY_END]", + message=f"[ACTIVITY_START] Running chain: {args['chain']} with args: {args} [ACTIVITY_END]", ) result = await self.execute_chain( chain_name=args["chain"], From b97cb528b416eff173446795fc8d922cc5e874b8 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Thu, 30 May 2024 07:41:07 -0400 Subject: [PATCH 0031/1256] Fix prompt --- agixt/prompts/Default/SmartTask-Execution.txt | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/agixt/prompts/Default/SmartTask-Execution.txt b/agixt/prompts/Default/SmartTask-Execution.txt index b5d639984044..841fc3b6253c 100644 --- a/agixt/prompts/Default/SmartTask-Execution.txt +++ b/agixt/prompts/Default/SmartTask-Execution.txt @@ -10,4 +10,7 @@ Task: Task Response: ``` {previous_response} -``` \ No newline at end of file +``` + +The assistant is a command execution agent that executes commands where necessary to complete the task related to the primary objective, but nothing else. +**If there are no commands to use for this task, respond only with empty Json like {}.** \ No newline at end of file From e777052723149026fe5487204bbfd3908409aa09 Mon Sep 17 00:00:00 2001 From: Josh XT <102809327+Josh-XT@users.noreply.github.com> Date: Thu, 30 May 2024 09:47:06 -0400 Subject: [PATCH 0032/1256] Move dataset creation functions (#1197) * Move dataset creation functions * rename funcs * Automatically set dataset name --- agixt/Agent.py | 12 +++ agixt/Memories.py | 183 -------------------------------------- agixt/Models.py | 1 - agixt/Tuning.py | 44 +++++---- agixt/XT.py | 143 +++++++++++++++++++++++++++++ agixt/endpoints/Memory.py | 14 +-- 6 files changed, 180 insertions(+), 217 deletions(-) diff --git a/agixt/Agent.py b/agixt/Agent.py index 919011b059e4..6a1194dade08 100644 --- a/agixt/Agent.py +++ b/agixt/Agent.py @@ -515,3 +515,15 @@ def delete_browsed_link(self, url): self.session.delete(browsed_link) self.session.commit() return f"Link {url} deleted from browsed links." + + def get_agent_id(self): + agent = ( + self.session.query(AgentModel) + .filter( + AgentModel.name == self.agent_name, AgentModel.user_id == self.user_id + ) + .first() + ) + if not agent: + return None + return agent.id diff --git a/agixt/Memories.py b/agixt/Memories.py index 6c2040415a97..1311ce9dcabc 100644 --- a/agixt/Memories.py +++ b/agixt/Memories.py @@ -511,186 +511,3 @@ async def chunk_content(self, text: str, chunk_size: int) -> List[str]: # Sort the chunks by their score in descending order before returning them content_chunks.sort(key=lambda x: x[0], reverse=True) return [chunk_text for score, chunk_text in content_chunks] - - async def get_context( - self, - user_input: str, - limit: int = 10, - websearch: bool = False, - additional_collections: List[str] = [], - ) -> str: - self.collection_number = 0 - context = await self.get_memories( - user_input=user_input, - limit=limit, - min_relevance_score=0.2, - ) - self.collection_number = 2 - positive_feedback = await self.get_memories( - user_input=user_input, - limit=3, - min_relevance_score=0.7, - ) - self.collection_number = 3 - negative_feedback = await self.get_memories( - user_input=user_input, - limit=3, - min_relevance_score=0.7, - ) - if positive_feedback or negative_feedback: - context.append( - f"The users input makes you to remember some feedback from previous interactions:\n" - ) - if positive_feedback: - context += f"Positive Feedback:\n{positive_feedback}\n" - if negative_feedback: - context += f"Negative Feedback:\n{negative_feedback}\n" - if websearch: - self.collection_number = 1 - context += await self.get_memories( - user_input=user_input, - limit=limit, - min_relevance_score=0.2, - ) - if additional_collections: - for collection in additional_collections: - self.collection_number = collection - context += await self.get_memories( - user_input=user_input, - limit=limit, - min_relevance_score=0.2, - ) - return context - - async def batch_prompt( - self, - user_inputs: List[str] = [], - prompt_name: str = "Ask Questions", - prompt_category: str = "Default", - batch_size: int = 10, - **kwargs, - ): - i = 0 - tasks = [] - responses = [] - if user_inputs == []: - return [] - for user_input in user_inputs: - i += 1 - logging.info(f"[{i}/{len(user_inputs)}] Running Prompt: {prompt_name}") - if i % batch_size == 0: - responses += await asyncio.gather(**tasks) - tasks = [] - task = asyncio.create_task( - await self.ApiClient.prompt_agent( - agent_name=self.agent_name, - prompt_name=prompt_name, - prompt_args={ - "prompt_category": prompt_category, - "user_input": user_input, - **kwargs, - }, - ) - ) - tasks.append(task) - responses += await asyncio.gather(**tasks) - return responses - - # Answer a question with context injected, return in sharegpt format - async def agent_dpo_qa(self, question: str = "", context_results: int = 10): - context = await self.get_context(user_input=question, limit=context_results) - prompt = f"### Context\n{context}\n### Question\n{question}" - chosen = await self.ApiClient.prompt_agent( - agent_name=self.agent_name, - prompt_name="Answer Question with Memory", - prompt_args={ - "prompt_category": "Default", - "user_input": question, - "context_results": context_results, - }, - ) - # Create a memory with question and answer - self.collection_number = 0 - await self.write_text_to_memory( - user_input=question, - text=chosen, - external_source="Synthetic QA", - ) - rejected = await self.ApiClient.prompt_agent( - agent_name=self.agent_name, - prompt_name="Wrong Answers Only", - prompt_args={ - "prompt_category": "Default", - "user_input": question, - }, - ) - return prompt, chosen, rejected - - # Creates a synthetic dataset from memories in sharegpt format - async def create_dataset_from_memories( - self, dataset_name: str = "", batch_size: int = 10 - ): - self.agent_config["settings"]["training"] = True - self.ApiClient.update_agent_settings( - agent_name=self.agent_name, settings=self.agent_config["settings"] - ) - memories = [] - questions = [] - if dataset_name == "": - dataset_name = f"{datetime.now().isoformat()}-dataset" - collections = await self.get_collections() - for collection in collections: - self.collection_name = collection - memories += await self.export_collection_to_json() - logging.info(f"There are {len(memories)} memories.") - memories = [memory["text"] for memory in memories] - # Get a list of questions about each memory - question_list = self.batch_prompt( - user_inputs=memories, - batch_size=batch_size, - ) - for question in question_list: - # Convert the response to a list of questions - question = question.split("\n") - question = [ - item.lstrip("0123456789.*- ") for item in question if item.lstrip() - ] - question = [item for item in question if item] - question = [item.lstrip("0123456789.*- ") for item in question] - questions += question - prompts = [] - good_answers = [] - bad_answers = [] - for question in questions: - prompt, chosen, rejected = await self.agent_dpo_qa( - question=question, context_results=10 - ) - prompts.append(prompt) - good_answers.append( - [ - {"content": prompt, "role": "user"}, - {"content": chosen, "role": "assistant"}, - ] - ) - bad_answers.append( - [ - {"content": prompt, "role": "user"}, - {"content": rejected, "role": "assistant"}, - ] - ) - dpo_dataset = { - "prompt": questions, - "chosen": good_answers, - "rejected": bad_answers, - } - # Save messages to a json file to be used as a dataset - os.makedirs(f"./WORKSPACE/{self.agent_name}/datasets", exist_ok=True) - with open( - f"./WORKSPACE/{self.agent_name}/datasets/{dataset_name}.json", "w" - ) as f: - f.write(json.dumps(dpo_dataset)) - self.agent_config["settings"]["training"] = False - self.ApiClient.update_agent_settings( - agent_name=self.agent_name, settings=self.agent_config["settings"] - ) - return dpo_dataset diff --git a/agixt/Models.py b/agixt/Models.py index d25b69a0066f..13d48f64647f 100644 --- a/agixt/Models.py +++ b/agixt/Models.py @@ -23,7 +23,6 @@ class AgentMemoryQuery(BaseModel): class Dataset(BaseModel): - dataset_name: str batch_size: int = 5 diff --git a/agixt/Tuning.py b/agixt/Tuning.py index d7236654be16..684ffd7d254d 100644 --- a/agixt/Tuning.py +++ b/agixt/Tuning.py @@ -54,8 +54,7 @@ from peft import PeftModel from bitsandbytes.functional import dequantize_4bit -from agixtsdk import AGiXTSDK -from Memories import Memories +from XT import AGiXT def fine_tune_llm( @@ -65,31 +64,31 @@ def fine_tune_llm( max_seq_length: int = 16384, huggingface_output_path: str = "JoshXT/finetuned-mistral-7b-v0.2", private_repo: bool = True, - ApiClient=AGiXTSDK(), + user: str = "user", + api_key: str = "", ): output_path = "./models" # Step 1: Build AGiXT dataset - agent_config = ApiClient.get_agentconfig(agent_name) - if not agent_config: - agent_config = {} + agixt = AGiXT(user=user, api_key=api_key, agent_name=agent_name) + agent_settings = agixt.agent_settings + if not agent_settings: + agent_settings = {} huggingface_api_key = ( - agent_config["settings"]["HUGGINGFACE_API_KEY"] - if "HUGGINGFACE_API_KEY" in agent_config["settings"] + agent_settings["HUGGINGFACE_API_KEY"] + if "HUGGINGFACE_API_KEY" in agent_settings else None ) - response = Memories( + response = AGiXT( agent_name=agent_name, - agent_config=agent_config, - collection_number=0, - ApiClient=ApiClient, + api_key=api_key, ).create_dataset_from_memories(dataset_name=dataset_name, batch_size=5) dataset_name = ( response["message"].split("Creation of dataset ")[1].split(" for agent")[0] ) dataset_path = f"./WORKSPACE/{agent_name}/datasets/{dataset_name}.json" - agent_config["settings"]["training"] = True - ApiClient.update_agent_settings( - agent_name=agent_name, settings=agent_config["settings"] + agent_settings["training"] = True + agixt.agent_interactions.agent.update_agent_config( + new_config=agent_settings, config_key="settings" ) # Step 2: Create qLora adapter model, tokenizer = FastLanguageModel.from_pretrained( @@ -169,9 +168,9 @@ def fine_tune_llm( tokenizer.push_to_hub( huggingface_output_path, use_temp_dir=False, private=private_repo ) - agent_config["settings"]["training"] = False - ApiClient.update_agent_settings( - agent_name=agent_name, settings=agent_config["settings"] + agent_settings["training"] = False + agixt.agent_interactions.agent.update_agent_config( + new_config=agent_settings, config_key="settings" ) @@ -180,11 +179,10 @@ def fine_tune_llm( fine_tune_llm( agent_name="AGiXT", dataset_name="dataset", - base_uri="http://localhost:7437", - api_key="Your AGiXT API Key", - model_name="unsloth/mistral-7b-v0.2", + model_name="unsloth/llama-3-8b-Instruct-bnb-4bit", max_seq_length=16384, - output_path="./WORKSPACE/merged_model", - huggingface_output_path="JoshXT/finetuned-mistral-7b-v0.2", + huggingface_output_path="JoshXT/finetuned-llama-3-8b", private_repo=True, + user="user", + api_key="Your AGiXT API Key", ) diff --git a/agixt/XT.py b/agixt/XT.py index dbe5bd7b048e..f9c13c041678 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -5,6 +5,10 @@ from pydub import AudioSegment from Globals import getenv, get_tokens, DEFAULT_SETTINGS from Models import ChatCompletions +from datetime import datetime +from typing import List +import logging +import asyncio import os import base64 import uuid @@ -853,3 +857,142 @@ async def chat_completions(self, prompt: ChatCompletions): }, } return res_model + + async def batch_inference( + self, + user_inputs: List[str] = [], + prompt_category: str = "Default", + prompt_name: str = "Ask Questions", + conversation_name: str = "", + images: list = [], + injected_memories: int = 5, + batch_size: int = 10, + browse_links: bool = False, + voice_response: bool = False, + log_user_input: bool = False, + **kwargs, + ): + i = 0 + tasks = [] + responses = [] + if user_inputs == []: + return [] + for user_input in user_inputs: + i += 1 + if i % batch_size == 0: + responses += await asyncio.gather(**tasks) + tasks = [] + task = asyncio.create_task( + await self.inference( + user_input=user_input, + prompt_category=prompt_category, + prompt_name=prompt_name, + conversation_name=conversation_name, + images=images, + injected_memories=injected_memories, + browse_links=browse_links, + voice_response=voice_response, + log_user_input=log_user_input, + **kwargs, + ) + ) + tasks.append(task) + responses += await asyncio.gather(**tasks) + return responses + + async def dpo(self, question: str = "", context_results: int = 10): + context = await self.memories( + user_input=question, + limit_per_collection=context_results, + ) + prompt = f"### Context\n{context}\n### Question\n{question}" + chosen = await self.inference( + user_input=question, + prompt_category="Default", + prompt_name="Answer Question with Memory", + injected_memories=context_results, + log_user_input=False, + ) + # Create a memory with question and answer + self.collection_number = 0 + await self.agent_interactions.agent_memory.write_text_to_memory( + user_input=question, + text=chosen, + external_source="Synthetic QA", + ) + rejected = await self.inference( + user_input=question, + prompt_category="Default", + prompt_name="Wrong Answers Only", + log_user_input=False, + ) + return prompt, chosen, rejected + + # Creates a synthetic dataset from memories in sharegpt format + async def create_dataset_from_memories(self, batch_size: int = 10): + self.agent_settings["training"] = True + self.agent_interactions.agent.update_agent_config( + new_config=self.agent_settings, config_key="settings" + ) + memories = [] + questions = [] + dataset_name = f"{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}-DPO-Dataset" + collections = await self.agent_interactions.agent_memory.get_collections() + for collection in collections: + self.collection_name = collection + memories += ( + await self.agent_interactions.agent_memory.export_collection_to_json() + ) + logging.info(f"There are {len(memories)} memories.") + memories = [memory["text"] for memory in memories] + # Get a list of questions about each memory + question_list = self.batch_inference( + user_inputs=memories, + batch_size=batch_size, + prompt_category="Default", + prompt_name="Ask Questions", + ) + for question in question_list: + # Convert the response to a list of questions + question = question.split("\n") + question = [ + item.lstrip("0123456789.*- ") for item in question if item.lstrip() + ] + question = [item for item in question if item] + question = [item.lstrip("0123456789.*- ") for item in question] + questions += question + prompts = [] + good_answers = [] + bad_answers = [] + for question in questions: + prompt, chosen, rejected = await self.dpo( + question=question, context_results=10 + ) + prompts.append(prompt) + good_answers.append( + [ + {"content": prompt, "role": "user"}, + {"content": chosen, "role": "assistant"}, + ] + ) + bad_answers.append( + [ + {"content": prompt, "role": "user"}, + {"content": rejected, "role": "assistant"}, + ] + ) + dpo_dataset = { + "prompt": questions, + "chosen": good_answers, + "rejected": bad_answers, + } + # Save messages to a json file to be used as a dataset + agent_id = self.agent_interactions.agent.get_agent_id() + os.makedirs(f"./WORKSPACE/{agent_id}/datasets", exist_ok=True) + with open(f"./WORKSPACE/{agent_id}/datasets/{dataset_name}.json", "w") as f: + f.write(json.dumps(dpo_dataset)) + self.agent_settings["training"] = False + self.agent_interactions.agent.update_agent_config( + new_config=self.agent_settings, config_key="settings" + ) + return dpo_dataset diff --git a/agixt/endpoints/Memory.py b/agixt/endpoints/Memory.py index 1c821e9dc281..9b37ac08991c 100644 --- a/agixt/endpoints/Memory.py +++ b/agixt/endpoints/Memory.py @@ -5,6 +5,7 @@ from ApiClient import Agent, verify_api_key, get_api_client, WORKERS, is_admin from typing import Dict, Any, List from Websearch import Websearch +from XT import AGiXT from Memories import Memories from readers.github import GithubReader from readers.file import FileReader @@ -462,20 +463,13 @@ async def create_dataset( ) -> ResponseMessage: if is_admin(email=user, api_key=authorization) != True: raise HTTPException(status_code=403, detail="Access Denied") - ApiClient = get_api_client(authorization=authorization) - agent = Agent(agent_name=agent_name, user=user, ApiClient=ApiClient) batch_size = dataset.batch_size if dataset.batch_size < (int(WORKERS) - 2) else 4 asyncio.create_task( - await Memories( + await AGiXT( agent_name=agent_name, - agent_config=agent.AGENT_CONFIG, - collection_number=0, - ApiClient=ApiClient, user=user, - ).create_dataset_from_memories( - dataset_name=dataset.dataset_name, - batch_size=batch_size, - ) + api_key=authorization, + ).create_dataset_from_memories(batch_size=batch_size) ) return ResponseMessage( message=f"Creation of dataset {dataset.dataset_name} for agent {agent_name} started." From 30e8e811fc5539fb887a5db3f35b25d1e29c899d Mon Sep 17 00:00:00 2001 From: Josh XT <102809327+Josh-XT@users.noreply.github.com> Date: Thu, 30 May 2024 11:22:29 -0400 Subject: [PATCH 0033/1256] Add DPO endpoint (#1198) * Add DPO endpoint * fix error output and handle websearches --- agixt/Interactions.py | 31 ++++++++++++++++++-- agixt/Models.py | 5 ++++ agixt/Websearch.py | 7 +++-- agixt/XT.py | 31 ++++++++++---------- agixt/endpoints/Memory.py | 24 ++++++++++++++++ agixt/extensions/google.py | 43 ---------------------------- agixt/extensions/searxng.py | 56 ------------------------------------- 7 files changed, 77 insertions(+), 120 deletions(-) delete mode 100644 agixt/extensions/searxng.py diff --git a/agixt/Interactions.py b/agixt/Interactions.py index 47e550c175c6..56b780f2b764 100644 --- a/agixt/Interactions.py +++ b/agixt/Interactions.py @@ -393,6 +393,19 @@ async def run( ) del kwargs["browse_links"] websearch = False + if "websearch" in self.agent.AGENT_CONFIG["settings"]: + websearch = ( + str(self.agent.AGENT_CONFIG["settings"]["websearch"]).lower() == "true" + ) + if "websearch_depth" in self.agent.AGENT_CONFIG["settings"]: + websearch_depth = int( + self.agent.AGENT_CONFIG["settings"]["websearch_depth"] + ) + if "browse_links" in self.agent.AGENT_CONFIG["settings"]: + browse_links = ( + str(self.agent.AGENT_CONFIG["settings"]["browse_links"]).lower() + == "true" + ) if "websearch" in kwargs: websearch = True if str(kwargs["websearch"]).lower() == "true" else False del kwargs["websearch"] @@ -407,6 +420,7 @@ async def run( conversation_name = kwargs["conversation_name"] if conversation_name == "": conversation_name = datetime.now().strftime("%Y-%m-%d") + c = Conversations(conversation_name=conversation_name, user=self.user) if "WEBSEARCH_TIMEOUT" in kwargs: try: websearch_timeout = int(kwargs["WEBSEARCH_TIMEOUT"]) @@ -414,7 +428,7 @@ async def run( websearch_timeout = 0 else: websearch_timeout = 0 - if browse_links != False: + if browse_links != False and websearch == False: await self.websearch.scrape_websites( user_input=user_input, search_depth=websearch_depth, @@ -430,6 +444,17 @@ async def run( else: search_string = user_input if search_string != "": + search_string = self.run( + user_input=search_string, + context_results=context_results if context_results > 0 else 5, + log_user_input=False, + browse_links=False, + websearch=False, + ) + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY_START] Searching web for: {search_string} [ACTIVITY_END]", + ) try: await self.websearch.websearch_agent( user_input=search_string, @@ -437,7 +462,7 @@ async def run( websearch_timeout=websearch_timeout, ) except Exception as e: - logging.warning("Failed to websearch. Error: {e}") + logging.warning(f"Failed to websearch. Error: {e}") vision_response = "" if "vision_provider" in self.agent.AGENT_CONFIG["settings"]: vision_provider = self.agent.AGENT_CONFIG["settings"]["vision_provider"] @@ -478,7 +503,7 @@ async def run( if user_input != "" and persist_context_in_history == False else formatted_prompt ) - c = Conversations(conversation_name=conversation_name, user=self.user) + if log_user_input: c.log_interaction( role="USER", diff --git a/agixt/Models.py b/agixt/Models.py index 13d48f64647f..a903de7c39c8 100644 --- a/agixt/Models.py +++ b/agixt/Models.py @@ -22,6 +22,11 @@ class AgentMemoryQuery(BaseModel): min_relevance_score: float = 0.0 +class UserInput(BaseModel): + user_input: str + injected_memories: Optional[int] = 10 + + class Dataset(BaseModel): batch_size: int = 5 diff --git a/agixt/Websearch.py b/agixt/Websearch.py index 74d7b4cace47..7831726ef170 100644 --- a/agixt/Websearch.py +++ b/agixt/Websearch.py @@ -435,7 +435,9 @@ async def scrape_websites( conversation_name: str = "", ): # user_input = "I am browsing {url} and collecting data from it to learn more." - c = Conversations(conversation_name=conversation_name, user=self.user) + c = None + if conversation_name != "" and conversation_name is not None: + c = Conversations(conversation_name=conversation_name, user=self.user) links = re.findall(r"(?Phttps?://[^\s]+)", user_input) if len(links) < 1: return "" @@ -481,8 +483,7 @@ async def scrape_websites( scraped_links.append(sublink[1]) str_links = "\n".join(scraped_links) message = f"I have read all of the content from the following links into my memory:\n{str_links}" - if conversation_name: - c = Conversations(conversation_name=conversation_name, user=self.user) + if conversation_name != "" and conversation_name is not None: c.log_interaction( role=self.agent_name, message=f"[ACTIVITY_START] {message} [ACTIVITY_END]", diff --git a/agixt/XT.py b/agixt/XT.py index f9c13c041678..f4a0c43d263c 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -900,32 +900,32 @@ async def batch_inference( responses += await asyncio.gather(**tasks) return responses - async def dpo(self, question: str = "", context_results: int = 10): - context = await self.memories( + async def dpo( + self, + question: str = "", + injected_memories: int = 10, + ): + context_async = self.memories( user_input=question, - limit_per_collection=context_results, + limit_per_collection=injected_memories, ) - prompt = f"### Context\n{context}\n### Question\n{question}" - chosen = await self.inference( + chosen_async = self.inference( user_input=question, prompt_category="Default", prompt_name="Answer Question with Memory", - injected_memories=context_results, + injected_memories=injected_memories, log_user_input=False, ) - # Create a memory with question and answer - self.collection_number = 0 - await self.agent_interactions.agent_memory.write_text_to_memory( - user_input=question, - text=chosen, - external_source="Synthetic QA", - ) - rejected = await self.inference( + rejected_async = self.inference( user_input=question, prompt_category="Default", prompt_name="Wrong Answers Only", log_user_input=False, ) + chosen = await chosen_async + rejected = await rejected_async + context = await context_async + prompt = f"### Context\n{context}\n### Question\n{question}" return prompt, chosen, rejected # Creates a synthetic dataset from memories in sharegpt format @@ -966,7 +966,8 @@ async def create_dataset_from_memories(self, batch_size: int = 10): bad_answers = [] for question in questions: prompt, chosen, rejected = await self.dpo( - question=question, context_results=10 + question=question, + injected_memories=10, ) prompts.append(prompt) good_answers.append( diff --git a/agixt/endpoints/Memory.py b/agixt/endpoints/Memory.py index 9b37ac08991c..65d2866964c4 100644 --- a/agixt/endpoints/Memory.py +++ b/agixt/endpoints/Memory.py @@ -23,6 +23,7 @@ Dataset, FinetuneAgentModel, ExternalSource, + UserInput, ) app = APIRouter() @@ -476,6 +477,29 @@ async def create_dataset( ) +@app.post( + "/api/agent/{agent_name}/dpo", + tags=["Memory"], + dependencies=[Depends(verify_api_key)], + summary="Gets a DPO response for a question", +) +async def get_dpo_response( + agent_name: str, + user_input: UserInput, + user=Depends(verify_api_key), + authorization: str = Header(None), +) -> Dict[str, Any]: + agixt = AGiXT(user=user, agent_name=agent_name, api_key=authorization) + prompt, chosen, rejected = await agixt.dpo( + question=user_input, injected_memories=int(user_input.injected_memories) + ) + return { + "prompt": prompt, + "chosen": chosen, + "rejected": rejected, + } + + # Train model @app.post( "/api/agent/{agent_name}/memory/dataset/{dataset_name}/finetune", diff --git a/agixt/extensions/google.py b/agixt/extensions/google.py index 94e2dae93629..3d2cd4d32d4e 100644 --- a/agixt/extensions/google.py +++ b/agixt/extensions/google.py @@ -9,8 +9,6 @@ import mimetypes import email from base64 import urlsafe_b64decode -from typing import List, Union -import json import logging from Extensions import Extensions @@ -29,8 +27,6 @@ ) from googleapiclient.discovery import build -from googleapiclient.errors import HttpError - class google(Extensions): def __init__( @@ -38,15 +34,11 @@ def __init__( GOOGLE_CLIENT_ID: str = "", GOOGLE_CLIENT_SECRET: str = "", GOOGLE_REFRESH_TOKEN: str = "", - GOOGLE_API_KEY: str = "", - GOOGLE_SEARCH_ENGINE_ID: str = "", **kwargs, ): self.GOOGLE_CLIENT_ID = GOOGLE_CLIENT_ID self.GOOGLE_CLIENT_SECRET = GOOGLE_CLIENT_SECRET self.GOOGLE_REFRESH_TOKEN = GOOGLE_REFRESH_TOKEN - self.GOOGLE_API_KEY = GOOGLE_API_KEY - self.GOOGLE_SEARCH_ENGINE_ID = GOOGLE_SEARCH_ENGINE_ID self.attachments_dir = "./WORKSPACE/email_attachments/" os.makedirs(self.attachments_dir, exist_ok=True) self.commands = { @@ -61,7 +53,6 @@ def __init__( "Google - Get Calendar Items": self.get_calendar_items, "Google - Add Calendar Item": self.add_calendar_item, "Google - Remove Calendar Item": self.remove_calendar_item, - "Google Search": self.google_official_search, } def authenticate(self): @@ -557,37 +548,3 @@ async def remove_calendar_item(self, item_id): except Exception as e: logging.info(f"Error removing calendar item: {str(e)}") return "Failed to remove calendar item." - - async def google_official_search( - self, query: str, num_results: int = 8 - ) -> Union[str, List[str]]: - """ - Perform a Google search using the official Google API - - Args: - query (str): The search query - num_results (int): The number of search results to retrieve - - Returns: - Union[str, List[str]]: The search results - """ - try: - service = build("customsearch", "v1", developerKey=self.GOOGLE_API_KEY) - result = ( - service.cse() - .list(q=query, cx=self.GOOGLE_SEARCH_ENGINE_ID, num=num_results) - .execute() - ) - search_results = result.get("items", []) - search_results_links = [item["link"] for item in search_results] - except HttpError as e: - error_details = json.loads(e.content.decode()) - if error_details.get("error", {}).get( - "code" - ) == 403 and "invalid API key" in error_details.get("error", {}).get( - "message", "" - ): - return "Error: The provided Google API key is invalid or missing." - else: - return f"Error: {e}" - return search_results_links diff --git a/agixt/extensions/searxng.py b/agixt/extensions/searxng.py deleted file mode 100644 index b227780a1170..000000000000 --- a/agixt/extensions/searxng.py +++ /dev/null @@ -1,56 +0,0 @@ -import json -import random -import requests -from typing import List -from Extensions import Extensions - - -class searxng(Extensions): - def __init__(self, SEARXNG_INSTANCE_URL: str = "", **kwargs): - self.SEARXNG_INSTANCE_URL = SEARXNG_INSTANCE_URL - self.SEARXNG_ENDPOINT = self.get_server() - self.commands = {"Use The SearXNG Search Engine": self.search} - - def get_server(self): - if self.SEARXNG_INSTANCE_URL == "": - try: # SearXNG - List of these at https://searx.space/ - response = requests.get("https://searx.space/data/instances.json") - data = json.loads(response.text) - servers = list(data["instances"].keys()) - random_index = random.randint(0, len(servers) - 1) - self.SEARXNG_INSTANCE_URL = servers[random_index] - except: # Select default remote server that typically works if unable to get list. - self.SEARXNG_INSTANCE_URL = "https://search.us.projectsegfau.lt" - server = self.SEARXNG_INSTANCE_URL.rstrip("/") - endpoint = f"{server}/search" - return endpoint - - async def search(self, query: str) -> List[str]: - """ - Search using the SearXNG search engine - - Args: - query (str): The query to search for - - Returns: - List[str]: A list of search results - """ - try: - response = requests.get( - self.SEARXNG_ENDPOINT, - params={ - "q": query, - "language": "en", - "safesearch": 1, - "format": "json", - }, - ) - results = response.json() - summaries = [ - result["title"] + " - " + result["url"] for result in results["results"] - ] - return summaries - except: - # The SearXNG server is down or refusing connection, so we will use the default one. - self.SEARXNG_ENDPOINT = "https://search.us.projectsegfau.lt/search" - return await self.search(query) From d4b10748dc845e7654b731fa28213e044371286c Mon Sep 17 00:00:00 2001 From: Josh XT <102809327+Josh-XT@users.noreply.github.com> Date: Thu, 30 May 2024 12:13:21 -0400 Subject: [PATCH 0034/1256] Handle activity log in conversation history injection (#1199) * Add DPO endpoint * fix error output and handle websearches * fix websearch * strip activities from conversation injection --------- Signed-off-by: Josh XT <102809327+Josh-XT@users.noreply.github.com> --- agixt/Interactions.py | 46 ++++++++++++++++++++++++++++++++++++++----- agixt/Websearch.py | 10 +--------- 2 files changed, 42 insertions(+), 14 deletions(-) diff --git a/agixt/Interactions.py b/agixt/Interactions.py index 56b780f2b764..299f79fab6d7 100644 --- a/agixt/Interactions.py +++ b/agixt/Interactions.py @@ -219,6 +219,19 @@ async def format_prompt( total_results = len(conversation["interactions"]) # Get the last conversation_results interactions from the conversation new_conversation_history = [] + # Strip out any interactions where the message starts with [ACTIVITY_START] + activity_history = [ + interaction + for interaction in conversation["interactions"] + if interaction["message"].startswith("[ACTIVITY_START]") + ] + if len(activity_history) > 5: + activity_history = activity_history[-5:] + conversation["interactions"] = [ + interaction + for interaction in conversation["interactions"] + if not interaction["message"].startswith("[ACTIVITY_START]") + ] if total_results > conversation_results: new_conversation_history = conversation["interactions"][ total_results - conversation_results : total_results @@ -236,6 +249,12 @@ async def format_prompt( # Strip code blocks out of the message message = regex.sub(r"(```.*?```)", "", message) conversation_history += f"{timestamp} {role}: {message} \n " + conversation_history += "\nThe assistant's recent activities:\n" + for activity in activity_history: + timestamp = activity["timestamp"] + role = activity["role"] + message = activity["message"] + conversation_history += f"{timestamp} {role}: {message} \n " persona = "" if "persona" in prompt_args: if "PERSONA" in self.agent.AGENT_CONFIG["settings"]: @@ -438,14 +457,17 @@ async def run( if websearch: if user_input == "": if "primary_objective" in kwargs and "task" in kwargs: - search_string = f"Primary Objective: {kwargs['primary_objective']}\n\nTask: {kwargs['task']}" + user_input = f"Primary Objective: {kwargs['primary_objective']}\n\nTask: {kwargs['task']}" else: - search_string = "" - else: - search_string = user_input + user_input = "" if search_string != "": + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY_START] Searching the web... [ACTIVITY_END]", + ) search_string = self.run( user_input=search_string, + prompt_name="WebSearch", context_results=context_results if context_results > 0 else 5, log_user_input=False, browse_links=False, @@ -457,7 +479,8 @@ async def run( ) try: await self.websearch.websearch_agent( - user_input=search_string, + user_input=user_input, + search_string=search_string, websearch_depth=websearch_depth, websearch_timeout=websearch_timeout, ) @@ -479,12 +502,25 @@ async def run( ) image_urls.append(image_url) logging.info(f"Getting vision response for images: {image_urls}") + message = ( + "Looking at images..." + if len(image_urls) > 1 + else "Looking at image..." + ) + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY_START] {message} [ACTIVITY_END]", + ) try: vision_response = await self.agent.inference( prompt=user_input, images=image_urls ) logging.info(f"Vision Response: {vision_response}") except Exception as e: + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY_START] Unable to view image. [ACTIVITY_END]", + ) logging.error(f"Error getting vision response: {e}") logging.warning("Failed to get vision response.") formatted_prompt, unformatted_prompt, tokens = await self.format_prompt( diff --git a/agixt/Websearch.py b/agixt/Websearch.py index 7831726ef170..091880ef8e9f 100644 --- a/agixt/Websearch.py +++ b/agixt/Websearch.py @@ -493,6 +493,7 @@ async def scrape_websites( async def websearch_agent( self, user_input: str = "What are the latest breakthroughs in AI?", + search_string: str = "", websearch_depth: int = 0, websearch_timeout: int = 0, ): @@ -506,15 +507,6 @@ async def websearch_agent( except: websearch_timeout = 0 if websearch_depth > 0: - search_string = self.ApiClient.prompt_agent( - agent_name=self.agent_name, - prompt_name="WebSearch", - prompt_args={ - "user_input": user_input, - "disable_memory": True, - "browse_links": False, - }, - ) if len(search_string) > 0: links = [] logging.info(f"Searching for: {search_string}") From 5b42300b5ef1b3e637483703e99c141ff07b6c63 Mon Sep 17 00:00:00 2001 From: Josh XT <102809327+Josh-XT@users.noreply.github.com> Date: Thu, 30 May 2024 12:56:09 -0400 Subject: [PATCH 0035/1256] use activity instead of activity start and end (#1200) --- agixt/Interactions.py | 18 +++++++++--------- agixt/Websearch.py | 6 +++--- agixt/XT.py | 26 +++++++++++++------------- 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/agixt/Interactions.py b/agixt/Interactions.py index 299f79fab6d7..4cef8b0cb06e 100644 --- a/agixt/Interactions.py +++ b/agixt/Interactions.py @@ -219,18 +219,18 @@ async def format_prompt( total_results = len(conversation["interactions"]) # Get the last conversation_results interactions from the conversation new_conversation_history = [] - # Strip out any interactions where the message starts with [ACTIVITY_START] + # Strip out any interactions where the message starts with [ACTIVITY] activity_history = [ interaction for interaction in conversation["interactions"] - if interaction["message"].startswith("[ACTIVITY_START]") + if interaction["message"].startswith("[ACTIVITY]") ] if len(activity_history) > 5: activity_history = activity_history[-5:] conversation["interactions"] = [ interaction for interaction in conversation["interactions"] - if not interaction["message"].startswith("[ACTIVITY_START]") + if not interaction["message"].startswith("[ACTIVITY]") ] if total_results > conversation_results: new_conversation_history = conversation["interactions"][ @@ -463,7 +463,7 @@ async def run( if search_string != "": c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY_START] Searching the web... [ACTIVITY_END]", + message=f"[ACTIVITY] Searching the web...", ) search_string = self.run( user_input=search_string, @@ -475,7 +475,7 @@ async def run( ) c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY_START] Searching web for: {search_string} [ACTIVITY_END]", + message=f"[ACTIVITY] Searching web for: {search_string}", ) try: await self.websearch.websearch_agent( @@ -509,7 +509,7 @@ async def run( ) c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY_START] {message} [ACTIVITY_END]", + message=f"[ACTIVITY] {message}", ) try: vision_response = await self.agent.inference( @@ -519,7 +519,7 @@ async def run( except Exception as e: c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY_START] Unable to view image. [ACTIVITY_END]", + message=f"[ACTIVITY] Unable to view image.", ) logging.error(f"Error getting vision response: {e}") logging.warning("Failed to get vision response.") @@ -769,7 +769,7 @@ async def execution_agent(self, conversation_name): try: c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY_START] Executing command `{command_name}` with args `{command_args}`. [ACTIVITY_END]", + message=f"[ACTIVITY] Executing command `{command_name}` with args `{command_args}`.", ) ext = Extensions( agent_name=self.agent_name, @@ -792,7 +792,7 @@ async def execution_agent(self, conversation_name): if command_output: c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY_START] {command_output} [ACTIVITY_END]", + message=f"[ACTIVITY] {command_output}", ) reformatted_response = reformatted_response.replace( f"#execute({command_name}, {command_args})", diff --git a/agixt/Websearch.py b/agixt/Websearch.py index 091880ef8e9f..b00d9f960830 100644 --- a/agixt/Websearch.py +++ b/agixt/Websearch.py @@ -448,7 +448,7 @@ async def scrape_websites( if conversation_name != "" and conversation_name is not None: c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY_START] Browsing {link}... [ACTIVITY_END]", + message=f"[ACTIVITY] Browsing {link}...", ) text_content, link_list = await self.get_web_content( url=link, summarize_content=summarize_content @@ -470,7 +470,7 @@ async def scrape_websites( ): c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY_START] Browsing {sublink[1]}... [ACTIVITY_END]", + message=f"[ACTIVITY] Browsing {sublink[1]}...", ) ( text_content, @@ -486,7 +486,7 @@ async def scrape_websites( if conversation_name != "" and conversation_name is not None: c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY_START] {message} [ACTIVITY_END]", + message=f"[ACTIVITY] {message}", ) return message diff --git a/agixt/XT.py b/agixt/XT.py index f4a0c43d263c..cd72cb0c4919 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -174,7 +174,7 @@ async def generate_image(self, prompt: str, conversation_name: str = ""): ) c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY_START] Generating image... [ACTIVITY_END]", + message=f"[ACTIVITY] Generating image...", ) return await self.agent.generate_image(prompt=prompt) @@ -192,7 +192,7 @@ async def text_to_speech(self, text: str, conversation_name: str = ""): c = Conversations(conversation_name="Text to Speech", user=self.user_email) c.log_interaction( role="USER", - message=f"[ACTIVITY_START] Generating audio from text: {text} [ACTIVITY_END]", + message=f"[ACTIVITY] Generating audio from text: {text}", ) tts_url = await self.agent.text_to_speech(text=text.text) if not str(tts_url).startswith("http"): @@ -222,7 +222,7 @@ async def audio_to_text(self, audio_path: str, conversation_name: str = ""): ) c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY_START] Transcribed audio to text: {response} [ACTIVITY_END]", + message=f"[ACTIVITY] Transcribed audio to text: {response}", ) async def translate_audio(self, audio_path: str, conversation_name: str = ""): @@ -242,7 +242,7 @@ async def translate_audio(self, audio_path: str, conversation_name: str = ""): ) c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY_START] Translated audio: {response} [ACTIVITY_END]", + message=f"[ACTIVITY] Translated audio: {response}", ) return response @@ -269,7 +269,7 @@ async def execute_command( c = Conversations(conversation_name=conversation_name, user=self.user_email) c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY_START] Executing command: {command_name} with args: {command_args} [ACTIVITY_END]", + message=f"[ACTIVITY] Executing command: {command_name} with args: {command_args}", ) response = await Extensions( agent_name=self.agent_name, @@ -293,7 +293,7 @@ async def execute_command( if conversation_name != "" and conversation_name != None: c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY_START] {response} [ACTIVITY_END]", + message=f"[ACTIVITY] {response}", ) return response @@ -349,7 +349,7 @@ async def run_chain_step( if conversation_name != "" and conversation_name != None: c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY_START] Executing command: {step['prompt']['command_name']} with args: {args} [ACTIVITY_END]", + message=f"[ACTIVITY] Executing command: {step['prompt']['command_name']} with args: {args}", ) result = await self.execute_command( command_name=step["prompt"]["command_name"], @@ -361,7 +361,7 @@ async def run_chain_step( if conversation_name != "" and conversation_name != None: c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY_START] Running prompt: {prompt_name} with args: {args} [ACTIVITY_END]", + message=f"[ACTIVITY] Running prompt: {prompt_name} with args: {args}", ) if "prompt_name" not in args: args["prompt_name"] = prompt_name @@ -375,7 +375,7 @@ async def run_chain_step( if conversation_name != "" and conversation_name != None: c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY_START] Running chain: {args['chain']} with args: {args} [ACTIVITY_END]", + message=f"[ACTIVITY] Running chain: {args['chain']} with args: {args}", ) result = await self.execute_chain( chain_name=args["chain"], @@ -440,7 +440,7 @@ async def execute_chain( if conversation_name != "": c.log_interaction( role=agent_name, - message=f"[ACTIVITY_START] Running chain `{chain_name}`... [ACTIVITY_END]", + message=f"[ACTIVITY] Running chain `{chain_name}`...", ) response = "" for step_data in chain_data["steps"]: @@ -511,7 +511,7 @@ async def learn_from_websites( if conversation_name != "" and conversation_name != None: c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY_START] Browsing the web... [ACTIVITY_END]", + message=f"[ACTIVITY] Browsing the web...", ) response = await self.agent_interactions.websearch.scrape_websites( user_input=user_input, @@ -522,7 +522,7 @@ async def learn_from_websites( if conversation_name != "" and conversation_name != None: c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY_START] {response} [ACTIVITY_END]", + message=f"[ACTIVITY] {response}", ) return "I have read the information from the websites into my memory." @@ -561,7 +561,7 @@ async def learn_from_file( c = Conversations(conversation_name=conversation_name, user=self.user_email) c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY_START] {response} [ACTIVITY_END]", + message=f"[ACTIVITY] {response}", ) return response From 7952f58d3ca4b7f74918a74308a225c675d8e1c1 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Thu, 30 May 2024 13:17:40 -0400 Subject: [PATCH 0036/1256] skip invalid image urls --- agixt/XT.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index cd72cb0c4919..15f8fc343f11 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -699,7 +699,17 @@ async def chat_completions(self, prompt: ChatCompletions): ) image_path = f"./WORKSPACE/{uuid.uuid4().hex}.jpg" if url.startswith("http"): - image = requests.get(url).content + # Validate if url is an image + if ( + url.endswith( + (".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp") + ) + and "localhost" not in url + and "127.0.0.1" not in url + ): + image = requests.get(url).content + else: + image = None else: file_type = url.split(",")[0].split("/")[1].split(";")[0] if file_type == "jpeg": @@ -707,9 +717,10 @@ async def chat_completions(self, prompt: ChatCompletions): file_name = f"{uuid.uuid4().hex}.{file_type}" image_path = f"./WORKSPACE/{file_name}" image = base64.b64decode(url.split(",")[1]) - with open(image_path, "wb") as f: - f.write(image) - images.append(image_path) + if image: + with open(image_path, "wb") as f: + f.write(image) + images.append(image_path) if "audio_url" in msg: audio_url = str( msg["audio_url"]["url"] From 8f18a2f1984a404f76f06731899934f7e6cd0571 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Thu, 30 May 2024 13:29:49 -0400 Subject: [PATCH 0037/1256] Improve path expression security --- agixt/XT.py | 65 ++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 44 insertions(+), 21 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index 15f8fc343f11..786b611e6ea3 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -581,6 +581,7 @@ async def chat_completions(self, prompt: ChatCompletions): browse_links = True tts = False urls = [] + base_path = os.path.join(os.getcwd(), "WORKSPACE") if "mode" in self.agent_settings: mode = self.agent_settings["mode"] else: @@ -697,7 +698,9 @@ async def chat_completions(self, prompt: ChatCompletions): if "url" in msg["image_url"] else msg["image_url"] ) - image_path = f"./WORKSPACE/{uuid.uuid4().hex}.jpg" + image_path = os.path.join( + os.getcwd(), "WORKSPACE", f"{uuid.uuid4().hex}.jpg" + ) if url.startswith("http"): # Validate if url is an image if ( @@ -715,12 +718,15 @@ async def chat_completions(self, prompt: ChatCompletions): if file_type == "jpeg": file_type = "jpg" file_name = f"{uuid.uuid4().hex}.{file_type}" - image_path = f"./WORKSPACE/{file_name}" + image_path = os.path.join( + os.getcwd(), "WORKSPACE", file_name + ) image = base64.b64decode(url.split(",")[1]) if image: - with open(image_path, "wb") as f: - f.write(image) - images.append(image_path) + if image_path.startswith(base_path): + with open(image_path, "wb") as f: + f.write(image) + images.append(image_path) if "audio_url" in msg: audio_url = str( msg["audio_url"]["url"] @@ -733,26 +739,35 @@ async def chat_completions(self, prompt: ChatCompletions): audio_url.split(",")[0].split("/")[1].split(";")[0] ) audio_data = base64.b64decode(audio_url.split(",")[1]) - audio_path = f"./WORKSPACE/{uuid.uuid4().hex}.{file_type}" + audio_path = os.path.join( + os.getcwd(), + "WORKSPACE", + f"{uuid.uuid4().hex}.{file_type}", + ) with open(audio_path, "wb") as f: f.write(audio_data) audio_url = audio_path else: # Download the audio file from the url, get the file type and convert to wav audio_type = audio_url.split(".")[-1] - audio_url = f"./WORKSPACE/{uuid.uuid4().hex}.{audio_type}" + audio_url = os.path.join( + os.getcwd(), + "WORKSPACE", + f"{uuid.uuid4().hex}.{audio_type}", + ) audio_data = requests.get(audio_url).content with open(audio_url, "wb") as f: f.write(audio_data) - wav_file = f"./WORKSPACE/{uuid.uuid4().hex}.wav" - AudioSegment.from_file(audio_url).set_frame_rate(16000).export( - wav_file, format="wav" - ) - transcribed_audio = await self.audio_to_text( - audio_path=wav_file, - conversation_name=conversation_name, - ) - new_prompt += transcribed_audio + if audio_url.startswith(base_path): + wav_file = f"./WORKSPACE/{uuid.uuid4().hex}.wav" + AudioSegment.from_file(audio_url).set_frame_rate( + 16000 + ).export(wav_file, format="wav") + transcribed_audio = await self.audio_to_text( + audio_path=wav_file, + conversation_name=conversation_name, + ) + new_prompt += transcribed_audio if "video_url" in msg: video_url = str( msg["video_url"]["url"] @@ -779,11 +794,19 @@ async def chat_completions(self, prompt: ChatCompletions): file_url.split(",")[0].split("/")[1].split(";")[0] ) file_data = base64.b64decode(file_url.split(",")[1]) - file_path = f"./WORKSPACE/{uuid.uuid4().hex}.{file_type}" - with open(file_path, "wb") as f: - f.write(file_data) - file_url = f"{self.outputs}/{os.path.basename(file_path)}" - urls.append(file_url) + # file_path = f"./WORKSPACE/{uuid.uuid4().hex}.{file_type}" + file_path = os.path.join( + os.getcwd(), + "WORKSPACE", + f"{uuid.uuid4().hex}.{file_type}", + ) + if file_path.startswith(base_path): + with open(file_path, "wb") as f: + f.write(file_data) + file_url = ( + f"{self.outputs}/{os.path.basename(file_path)}" + ) + urls.append(file_url) # Add user input to conversation c = Conversations(conversation_name=conversation_name, user=self.user_email) c.log_interaction(role="USER", message=new_prompt) From 3a83d775197f003f398f463db8f979ad5001c3ab Mon Sep 17 00:00:00 2001 From: Josh XT <102809327+Josh-XT@users.noreply.github.com> Date: Thu, 30 May 2024 17:22:32 -0400 Subject: [PATCH 0038/1256] Parallel chain steps (#1201) * parallel chain steps * remove ref temporarily * lint * add task creation * improve logic on step responses * add sleep --- agixt/Chain.py | 72 ++++++++++++++++++++++++++++++++++++++++++++++++++ agixt/XT.py | 59 +++++++++++++++++++++++++++++++++-------- 2 files changed, 120 insertions(+), 11 deletions(-) diff --git a/agixt/Chain.py b/agixt/Chain.py index a327cf3596ee..5b73d2ffc046 100644 --- a/agixt/Chain.py +++ b/agixt/Chain.py @@ -16,6 +16,7 @@ from Conversations import Conversations from Extensions import Extensions import logging +import asyncio logging.basicConfig( level=getenv("LOG_LEVEL"), @@ -463,6 +464,7 @@ def get_step_response(self, chain_name, chain_run_id=None, step_number="all"): return responses else: + step_number = int(step_number) chain_step = ( self.session.query(ChainStep) .filter( @@ -598,6 +600,76 @@ def import_chain(self, chain_name: str, steps: dict): return f"Imported chain: {chain_name}" + def get_chain_step_dependencies(self, chain_name): + chain_steps = self.get_steps(chain_name=chain_name) + prompts = Prompts(user=self.user) + chain_dependencies = {} + for step in chain_steps: + step_dependencies = [] + prompt = step.prompt + if not isinstance(prompt, dict) and not isinstance(prompt, str): + prompt = str(prompt) + if isinstance(prompt, dict): + for key, value in prompt.items(): + if "{STEP" in value: + step_count = value.count("{STEP") + for i in range(step_count): + new_step_number = int(value.split("{STEP")[1].split("}")[0]) + step_dependencies.append(new_step_number) + if "prompt_name" in prompt: + prompt_text = prompts.get_prompt( + prompt_name=prompt["prompt_name"], + prompt_category=( + prompt["prompt_category"] + if "prompt_category" in prompt + else "Default" + ), + ) + # See if "{context}" is in the prompt + if "{context}" in prompt_text: + # Add all prior steps in the chain as deps + for i in range(step.step_number): + step_dependencies.append(i) + elif isinstance(prompt, str): + if "{STEP" in prompt: + step_count = prompt.count("{STEP") + for i in range(step_count): + new_step_number = int(prompt.split("{STEP")[1].split("}")[0]) + step_dependencies.append(new_step_number) + if "{context}" in prompt: + # Add all prior steps in the chain as deps + for i in range(step.step_number): + step_dependencies.append(i) + chain_dependencies[str(step.step_number)] = step_dependencies + return chain_dependencies + + async def check_if_dependencies_met( + self, chain_run_id, chain_name, step_number, dependencies=[] + ): + if dependencies == []: + chain_dependencies = self.get_chain_step_dependencies(chain_name=chain_name) + dependencies = chain_dependencies[str(step_number)] + + async def check_dependencies_met(dependencies): + for dependency in dependencies: + try: + step_responses = self.get_step_response( + chain_name=chain_name, + chain_run_id=chain_run_id, + step_number=int(dependency), + ) + except: + return False + if not step_responses: + return False + return True + + dependencies_met = await check_dependencies_met(dependencies) + while not dependencies_met: + await asyncio.sleep(1) + dependencies_met = await check_dependencies_met(dependencies) + return True + def get_step_content( self, chain_run_id, chain_name, prompt_content, user_input, agent_name ): diff --git a/agixt/XT.py b/agixt/XT.py index 786b611e6ea3..03fd66841867 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -423,6 +423,24 @@ async def execute_chain( voice_response=False, ): chain_data = self.chain.get_chain(chain_name=chain_name) + chain_dependencies = self.chain.get_chain_step_dependencies( + chain_name=chain_name + ) + + async def check_dependencies_met(dependencies): + for dependency in dependencies: + try: + step_responses = self.chain.get_step_response( + chain_name=chain_name, + chain_run_id=chain_run_id, + step_number=int(dependency), + ) + except: + return False + if not step_responses: + return False + return True + if not chain_run_id: chain_run_id = await self.chain.get_chain_run_id(chain_name=chain_name) if chain_data == {}: @@ -443,6 +461,8 @@ async def execute_chain( message=f"[ACTIVITY] Running chain `{chain_name}`...", ) response = "" + tasks = [] + step_responses = [] for step_data in chain_data["steps"]: if int(step_data["step"]) >= int(from_step): if "prompt" in step_data and "step" in step_data: @@ -455,18 +475,35 @@ async def execute_chain( step["prompt_type"] = step_data["prompt_type"] step["prompt"] = step_data["prompt"] step["step"] = step_data["step"] - step_response = await self.run_chain_step( - chain_run_id=chain_run_id, - step=step, - chain_name=chain_name, - user_input=user_input, - agent_override=agent_override, - chain_args=chain_args, - conversation_name=conversation_name, + # Get the step dependencies from chain_dependencies then check if the dependencies are + # met before running the step + step_dependencies = chain_dependencies[str(step["step"])] + dependencies_met = await check_dependencies_met(step_dependencies) + while not dependencies_met: + await asyncio.sleep(1) + if step_responses == []: + step_responses = await asyncio.gather(*tasks) + else: + step_responses += await asyncio.gather(*tasks) + dependencies_met = await check_dependencies_met( + step_dependencies + ) + task = asyncio.create_task( + self.run_chain_step( + chain_run_id=chain_run_id, + step=step, + chain_name=chain_name, + user_input=user_input, + agent_override=agent_override, + chain_args=chain_args, + conversation_name=conversation_name, + ) ) - if step_response == None: - return f"Chain failed to complete, it failed on step {step_data['step']}. You can resume by starting the chain from the step that failed with chain ID {chain_run_id}." - response = step_response + tasks.append(task) + step_responses = await asyncio.gather(*tasks) + response = step_responses[-1] + if response == None: + return f"Chain failed to complete, it failed on step {step_data['step']}. You can resume by starting the chain from the step that failed with chain ID {chain_run_id}." if conversation_name != "": c.log_interaction( role=agent_name, From 63467c04a9bfa7cf26866a3abbd1e4747ef57d41 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Thu, 30 May 2024 18:02:58 -0400 Subject: [PATCH 0039/1256] update endpoint --- agixt/endpoints/Auth.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/agixt/endpoints/Auth.py b/agixt/endpoints/Auth.py index 44e5108fa340..3f561624d0e8 100644 --- a/agixt/endpoints/Auth.py +++ b/agixt/endpoints/Auth.py @@ -92,10 +92,7 @@ def delete_user( async def createuser( account: WebhookUser, authorization: str = Header(None), - user=Depends(verify_api_key), ): - if is_admin(email=user, api_key=authorization) != True: - raise HTTPException(status_code=403, detail="Access Denied") ApiClient = get_api_client(authorization=authorization) return webhook_create_user( api_key=authorization, @@ -108,25 +105,3 @@ async def createuser( github_repos=account.github_repos, ApiClient=ApiClient, ) - - -@app.post("/api/admin", tags=["User"]) -async def createadmin( - account: WebhookUser, - authorization: str = Header(None), - user=Depends(verify_api_key), -): - if is_admin(email=user, api_key=authorization) != True: - raise HTTPException(status_code=403, detail="Access Denied") - ApiClient = get_api_client(authorization=authorization) - return webhook_create_user( - api_key=authorization, - email=account.email, - role="admin", - agent_name=account.agent_name, - settings=account.settings, - commands=account.commands, - training_urls=account.training_urls, - github_repos=account.github_repos, - ApiClient=ApiClient, - ) From c5a366ba49db534b5f608cc8be99ce8b14006ed3 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Thu, 30 May 2024 18:27:12 -0400 Subject: [PATCH 0040/1256] handle collection number issue --- agixt/endpoints/Memory.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/agixt/endpoints/Memory.py b/agixt/endpoints/Memory.py index 65d2866964c4..c584fe397457 100644 --- a/agixt/endpoints/Memory.py +++ b/agixt/endpoints/Memory.py @@ -118,13 +118,18 @@ async def learn_text( agent_config = Agent( agent_name=agent_name, user=user, ApiClient=ApiClient ).get_agent_config() - await Memories( + try: + collection_number = int(data.collection_number) + except: + collection_number = 0 + memory = Memories( agent_name=agent_name, agent_config=agent_config, - collection_number=data.collection_number, + collection_number=collection_number, ApiClient=ApiClient, user=user, - ).write_text_to_memory( + ) + await memory.write_text_to_memory( user_input=data.user_input, text=data.text, external_source="user input" ) return ResponseMessage( From 69c25d42ed0d8769c27a26cef6325129bc766f6e Mon Sep 17 00:00:00 2001 From: Josh XT Date: Thu, 30 May 2024 18:49:01 -0400 Subject: [PATCH 0041/1256] clean up logging --- agixt/XT.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index 03fd66841867..945295ccd4bb 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -545,18 +545,13 @@ async def learn_from_websites( url_str = {"\n".join(urls)} user_input = f"Learn from the information from these websites:\n {url_str} " c = Conversations(conversation_name=conversation_name, user=self.user_email) - if conversation_name != "" and conversation_name != None: - c.log_interaction( - role=self.agent_name, - message=f"[ACTIVITY] Browsing the web...", - ) response = await self.agent_interactions.websearch.scrape_websites( user_input=user_input, search_depth=scrape_depth, summarize_content=summarize_content, conversation_name=conversation_name, ) - if conversation_name != "" and conversation_name != None: + if conversation_name != "" and conversation_name != None and response != "": c.log_interaction( role=self.agent_name, message=f"[ACTIVITY] {response}", From 709cd0d8ce477d2774efec021b498750f2b094dc Mon Sep 17 00:00:00 2001 From: Josh XT Date: Thu, 30 May 2024 20:52:46 -0400 Subject: [PATCH 0042/1256] fix ref --- agixt/Interactions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/agixt/Interactions.py b/agixt/Interactions.py index 4cef8b0cb06e..cc4ecd917db3 100644 --- a/agixt/Interactions.py +++ b/agixt/Interactions.py @@ -460,13 +460,13 @@ async def run( user_input = f"Primary Objective: {kwargs['primary_objective']}\n\nTask: {kwargs['task']}" else: user_input = "" - if search_string != "": + if user_input != "": c.log_interaction( role=self.agent_name, message=f"[ACTIVITY] Searching the web...", ) search_string = self.run( - user_input=search_string, + user_input=user_input, prompt_name="WebSearch", context_results=context_results if context_results > 0 else 5, log_user_input=False, From 8b5ce788c60503b52c8ae0399d9f5825f9994cb4 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Thu, 30 May 2024 21:04:31 -0400 Subject: [PATCH 0043/1256] await --- agixt/Interactions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agixt/Interactions.py b/agixt/Interactions.py index cc4ecd917db3..2e9874d2f6fb 100644 --- a/agixt/Interactions.py +++ b/agixt/Interactions.py @@ -465,7 +465,7 @@ async def run( role=self.agent_name, message=f"[ACTIVITY] Searching the web...", ) - search_string = self.run( + search_string = await self.run( user_input=user_input, prompt_name="WebSearch", context_results=context_results if context_results > 0 else 5, From 62f0224a91fe2b2649cdce50759b1d05c343136b Mon Sep 17 00:00:00 2001 From: Josh XT Date: Thu, 30 May 2024 21:19:07 -0400 Subject: [PATCH 0044/1256] remove handling to expose error --- agixt/Interactions.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/agixt/Interactions.py b/agixt/Interactions.py index 2e9874d2f6fb..74c8f1303864 100644 --- a/agixt/Interactions.py +++ b/agixt/Interactions.py @@ -477,15 +477,15 @@ async def run( role=self.agent_name, message=f"[ACTIVITY] Searching web for: {search_string}", ) - try: - await self.websearch.websearch_agent( - user_input=user_input, - search_string=search_string, - websearch_depth=websearch_depth, - websearch_timeout=websearch_timeout, - ) - except Exception as e: - logging.warning(f"Failed to websearch. Error: {e}") + # try: + await self.websearch.websearch_agent( + user_input=user_input, + search_string=search_string, + websearch_depth=websearch_depth, + websearch_timeout=websearch_timeout, + ) + # except Exception as e: + # logging.warning(f"Failed to websearch. Error: {e}") vision_response = "" if "vision_provider" in self.agent.AGENT_CONFIG["settings"]: vision_provider = self.agent.AGENT_CONFIG["settings"]["vision_provider"] From 6017f29b55a76d5f06c939bffbcb501a48c1bc51 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Thu, 30 May 2024 21:30:52 -0400 Subject: [PATCH 0045/1256] dont use searxng --- agixt/Websearch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/agixt/Websearch.py b/agixt/Websearch.py index b00d9f960830..9e3e87da9de7 100644 --- a/agixt/Websearch.py +++ b/agixt/Websearch.py @@ -510,10 +510,10 @@ async def websearch_agent( if len(search_string) > 0: links = [] logging.info(f"Searching for: {search_string}") - if self.searx_instance_url != "": - links = await self.search(query=search_string) - else: - links = await self.ddg_search(query=search_string) + # if self.searx_instance_url != "": + # links = await self.search(query=search_string) + # else: + links = await self.ddg_search(query=search_string) logging.info(f"Found {len(links)} results for {search_string}") if len(links) > websearch_depth: links = links[:websearch_depth] From 5bb63a7fff9cfe9460143a45f8f5e18f2b1f6134 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Thu, 30 May 2024 21:45:28 -0400 Subject: [PATCH 0046/1256] force playwright update --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 5dff9659e0f6..a1e0dd742699 100644 --- a/Dockerfile +++ b/Dockerfile @@ -57,7 +57,7 @@ RUN pip install spacy && \ # Install Playwright RUN npm install -g playwright && \ npx playwright install && \ - playwright install + playwright install --with-deps COPY . . From 699b57393ff7ef71fd4a0c4f6def8781b18bd6ed Mon Sep 17 00:00:00 2001 From: Josh XT Date: Thu, 30 May 2024 22:08:20 -0400 Subject: [PATCH 0047/1256] add launch option and allow searxng again --- agixt/Websearch.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/agixt/Websearch.py b/agixt/Websearch.py index 9e3e87da9de7..e3bdb84b1b20 100644 --- a/agixt/Websearch.py +++ b/agixt/Websearch.py @@ -325,12 +325,15 @@ async def resursive_browsing(self, user_input, links): "links": link_list, "visited_links": self.browsed_links, "disable_memory": True, + "websearch": False, "browse_links": False, "user_input": user_input, "context_results": 0, + "tts": False, + "conversation_name": "Link selection", }, ) - if not pick_a_link.startswith("None"): + if not str(pick_a_link).lower().startswith("none"): logging.info( f"AI has decided to click: {pick_a_link}" ) @@ -345,6 +348,7 @@ async def ddg_search(self, query: str, proxy=None) -> List[str]: launch_options = {} if proxy: launch_options["proxy"] = {"server": proxy} + launch_options["headless"] = True browser = await p.chromium.launch(**launch_options) context = await browser.new_context() page = await context.new_page() @@ -384,12 +388,13 @@ async def search(self, query: str) -> List[str]: self.ApiClient.update_agent_settings( agent_name=self.agent_name, settings=self.agent_settings ) - server = self.searx_instance_url.rstrip("/") - self.agent_settings["SEARXNG_INSTANCE_URL"] = server + self.searx_instance_url = str(self.searx_instance_url).rstrip("/") + logging.info(f"Using {self.searx_instance_url}") + self.agent_settings["SEARXNG_INSTANCE_URL"] = self.searx_instance_url self.ApiClient.update_agent_settings( agent_name=self.agent_name, settings=self.agent_settings ) - endpoint = f"{server}/search" + endpoint = f"{self.searx_instance_url}/search" try: logging.info(f"Trying to connect to SearXNG Search at {endpoint}...") response = requests.get( @@ -510,10 +515,10 @@ async def websearch_agent( if len(search_string) > 0: links = [] logging.info(f"Searching for: {search_string}") - # if self.searx_instance_url != "": - # links = await self.search(query=search_string) - # else: - links = await self.ddg_search(query=search_string) + if self.searx_instance_url != "": + links = await self.search(query=search_string) + else: + links = await self.ddg_search(query=search_string) logging.info(f"Found {len(links)} results for {search_string}") if len(links) > websearch_depth: links = links[:websearch_depth] From a028f4084d1e4b0bc01cd60c264dda769e3dd467 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Thu, 30 May 2024 22:24:25 -0400 Subject: [PATCH 0048/1256] add logging --- agixt/Agent.py | 3 --- agixt/Websearch.py | 8 ++++++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/agixt/Agent.py b/agixt/Agent.py index 6a1194dade08..0ea11661186c 100644 --- a/agixt/Agent.py +++ b/agixt/Agent.py @@ -392,9 +392,6 @@ def update_agent_config(self, new_config, config_key): self.session.add(agent_command) else: for setting_name, setting_value in new_config.items(): - logging.info( - f"Updating agent setting: {setting_name} = {setting_value}" - ) agent_setting = ( self.session.query(AgentSettingModel) .filter_by(agent_id=agent.id, name=setting_name) diff --git a/agixt/Websearch.py b/agixt/Websearch.py index e3bdb84b1b20..e2cb2e145093 100644 --- a/agixt/Websearch.py +++ b/agixt/Websearch.py @@ -354,6 +354,10 @@ async def ddg_search(self, query: str, proxy=None) -> List[str]: page = await context.new_page() url = f"https://lite.duckduckgo.com/lite/?q={query}" await page.goto(url) + page_content = await page.content() + soup = BeautifulSoup(page_content, "html.parser") + # print the page content + logging.info(f"Page content from DDG search: {soup.get_text()}") links = await page.query_selector_all("a") results = [] for link in links: @@ -417,8 +421,8 @@ async def search(self, query: str) -> List[str]: return summaries except: self.failures.append(self.searx_instance_url) - if len(self.failures) > 5: - logging.info("Failed 5 times. Trying DDG...") + if len(self.failures) > 25: + logging.info("Failed 25 times. Trying DDG...") self.agent_settings["SEARXNG_INSTANCE_URL"] = "" self.ApiClient.update_agent_settings( agent_name=self.agent_name, settings=self.agent_settings From b3948df788d76072e9df09057a25d6be13c6d7e8 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Thu, 30 May 2024 22:39:08 -0400 Subject: [PATCH 0049/1256] use ddgs --- agixt/Agent.py | 3 --- agixt/Websearch.py | 33 +++++---------------------------- requirements.txt | 3 ++- 3 files changed, 7 insertions(+), 32 deletions(-) diff --git a/agixt/Agent.py b/agixt/Agent.py index 0ea11661186c..e442d2cbbe92 100644 --- a/agixt/Agent.py +++ b/agixt/Agent.py @@ -369,9 +369,6 @@ def update_agent_config(self, new_config, config_key): if not agent: logging.error(f"Agent '{self.agent_name}' not found in the database.") return - logging.info( - f"Updating agent config for '{self.agent_name}'. Config key: {config_key}, New config: {new_config}" - ) if config_key == "commands": for command_name, enabled in new_config.items(): command = ( diff --git a/agixt/Websearch.py b/agixt/Websearch.py index e2cb2e145093..73eda88cfa7f 100644 --- a/agixt/Websearch.py +++ b/agixt/Websearch.py @@ -14,6 +14,7 @@ from Globals import getenv, get_tokens from readers.youtube import YoutubeReader from readers.github import GithubReader +from duckduckgo_search import DDGS logging.basicConfig( level=getenv("LOG_LEVEL"), @@ -344,34 +345,10 @@ async def resursive_browsing(self, user_input, links): logging.info(f"Issues reading {url}. Moving on...") async def ddg_search(self, query: str, proxy=None) -> List[str]: - async with async_playwright() as p: - launch_options = {} - if proxy: - launch_options["proxy"] = {"server": proxy} - launch_options["headless"] = True - browser = await p.chromium.launch(**launch_options) - context = await browser.new_context() - page = await context.new_page() - url = f"https://lite.duckduckgo.com/lite/?q={query}" - await page.goto(url) - page_content = await page.content() - soup = BeautifulSoup(page_content, "html.parser") - # print the page content - logging.info(f"Page content from DDG search: {soup.get_text()}") - links = await page.query_selector_all("a") - results = [] - for link in links: - summary = await page.evaluate("(link) => link.textContent", link) - summary = summary.replace("\n", "").replace("\t", "").replace(" ", "") - href = await page.evaluate("(link) => link.href", link) - parsed_url = urllib.parse.urlparse(href) - query_params = urllib.parse.parse_qs(parsed_url.query) - uddg = query_params.get("uddg", [None])[0] - if uddg: - href = urllib.parse.unquote(uddg) - if summary: - results.append(f"{summary} - {href}") - await browser.close() + search_results = DDGS(proxy=proxy).text(query, max_results=10) + results = [] + for result in search_results: + results.append(f"{result['title']} - {result['href']}") return results async def search(self, query: str) -> List[str]: diff --git a/requirements.txt b/requirements.txt index f65835c88eb4..7a7cc06aa6ba 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,4 +17,5 @@ google-api-python-client==2.125.0 python-multipart==0.0.9 nest_asyncio g4f==0.3.1.9 -pyotp \ No newline at end of file +pyotp +duckduckgo_search \ No newline at end of file From bc1c44ecea238e00fa97f726f2bd85d5c04042d1 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Thu, 30 May 2024 22:59:48 -0400 Subject: [PATCH 0050/1256] force version --- agixt/Websearch.py | 4 ++-- requirements.txt | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/agixt/Websearch.py b/agixt/Websearch.py index 73eda88cfa7f..497d2370e9b0 100644 --- a/agixt/Websearch.py +++ b/agixt/Websearch.py @@ -398,8 +398,8 @@ async def search(self, query: str) -> List[str]: return summaries except: self.failures.append(self.searx_instance_url) - if len(self.failures) > 25: - logging.info("Failed 25 times. Trying DDG...") + if len(self.failures) > 10: + logging.info("Failed 10 times. Trying DDG...") self.agent_settings["SEARXNG_INSTANCE_URL"] = "" self.ApiClient.update_agent_settings( agent_name=self.agent_name, settings=self.agent_settings diff --git a/requirements.txt b/requirements.txt index 7a7cc06aa6ba..3d3f757d331f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,4 +18,4 @@ python-multipart==0.0.9 nest_asyncio g4f==0.3.1.9 pyotp -duckduckgo_search \ No newline at end of file +duckduckgo_search==6.1.4 \ No newline at end of file From 073e661fe2156eb4ea6b21de27d6911738e43b6a Mon Sep 17 00:00:00 2001 From: Josh XT Date: Thu, 30 May 2024 23:21:08 -0400 Subject: [PATCH 0051/1256] add websearch test --- tests/tests.ipynb | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/tests.ipynb b/tests/tests.ipynb index fff013d4d0c3..448b626fc153 100644 --- a/tests/tests.ipynb +++ b/tests/tests.ipynb @@ -1132,6 +1132,41 @@ "print(\"Agent prompt response:\", agent_prompt_resp)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Test Websearch\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from agixtsdk import AGiXTSDK\n", + "\n", + "base_uri = \"http://localhost:7437\"\n", + "ApiClient = AGiXTSDK(base_uri=base_uri)\n", + "agent_name = \"new_agent\"\n", + "prompt_name = \"Chat\"\n", + "prompt_args = {\n", + " \"user_input\": \"What is the latest news in AI today?\",\n", + " \"websearch\": True,\n", + " \"websearch_depth\": 3,\n", + " \"context_results\": 10,\n", + " \"conversation_name\": \"Search for info\",\n", + "}\n", + "\n", + "agent_prompt_resp = ApiClient.prompt_agent(\n", + " agent_name=agent_name,\n", + " prompt_name=prompt_name,\n", + " prompt_args=prompt_args,\n", + ")\n", + "print(\"Agent prompt response:\", agent_prompt_resp)" + ] + }, { "attachments": {}, "cell_type": "markdown", From 6dbdb8dc5bc0724a5a7bd5e532b7a579fd5603e4 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 31 May 2024 07:18:44 -0400 Subject: [PATCH 0052/1256] ddg fix attempt --- agixt/Websearch.py | 62 +++++++++++++++++++++++++++++++++++++--------- 1 file changed, 50 insertions(+), 12 deletions(-) diff --git a/agixt/Websearch.py b/agixt/Websearch.py index 497d2370e9b0..326bf3539e03 100644 --- a/agixt/Websearch.py +++ b/agixt/Websearch.py @@ -14,7 +14,6 @@ from Globals import getenv, get_tokens from readers.youtube import YoutubeReader from readers.github import GithubReader -from duckduckgo_search import DDGS logging.basicConfig( level=getenv("LOG_LEVEL"), @@ -53,11 +52,9 @@ def __init__( user=user, ) self.searx_instance_url = ( - ( - self.agent.AGENT_CONFIG["settings"]["SEARXNG_INSTANCE_URL"] - if "SEARXNG_INSTANCE_URL" in self.agent.AGENT_CONFIG["settings"] - else "" - ), + self.agent.AGENT_CONFIG["settings"]["SEARXNG_INSTANCE_URL"] + if "SEARXNG_INSTANCE_URL" in self.agent.AGENT_CONFIG["settings"] + else "" ) def verify_link(self, link: str = "") -> bool: @@ -345,10 +342,51 @@ async def resursive_browsing(self, user_input, links): logging.info(f"Issues reading {url}. Moving on...") async def ddg_search(self, query: str, proxy=None) -> List[str]: - search_results = DDGS(proxy=proxy).text(query, max_results=10) - results = [] - for result in search_results: - results.append(f"{result['title']} - {result['href']}") + async with async_playwright() as p: + launch_options = {} + if proxy: + launch_options["proxy"] = {"server": proxy} + launch_options["headless"] = True + browser = await p.chromium.launch(**launch_options) + query = urllib.parse.quote(query) + headers = { + "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7", + "Accept-Encoding": "gzip, deflate, br, zstd", + "Accept-Language": "en-US,en;q=0.9", + "Cache-Control": "max-age=0", + "Priority": "u=0, i", + "Sec-Ch-Ua": '"Google Chrome";v="125", "Chromium";v="125", "Not.A/Brand";v="24"', + "Sec-Ch-Ua-Mobile": "?0", + "Sec-Ch-Ua-Platform": '"Windows"', + "Sec-Fetch-Dest": "document", + "Sec-Fetch-Mode": "navigate", + "Sec-Fetch-Site": "none", + "Sec-Fetch-User": "?1", + "Upgrade-Insecure-Requests": "1", + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/125.0.0.0 Safari/537.36", + } + context = await browser.new_context(extra_http_headers=headers) + page = await context.new_page() + url = f"https://lite.duckduckgo.com/lite/?q={query}" + await page.goto(url) + page_content = await page.content() + soup = BeautifulSoup(page_content, "html.parser") + # print the page content + logging.info(f"Page content from DDG search: {soup.get_text()}") + links = await page.query_selector_all("a") + results = [] + for link in links: + summary = await page.evaluate("(link) => link.textContent", link) + summary = summary.replace("\n", "").replace("\t", "").replace(" ", "") + href = await page.evaluate("(link) => link.href", link) + parsed_url = urllib.parse.urlparse(href) + query_params = urllib.parse.parse_qs(parsed_url.query) + uddg = query_params.get("uddg", [None])[0] + if uddg: + href = urllib.parse.unquote(uddg) + if summary: + results.append(f"{summary} - {href}") + await browser.close() return results async def search(self, query: str) -> List[str]: @@ -398,8 +436,8 @@ async def search(self, query: str) -> List[str]: return summaries except: self.failures.append(self.searx_instance_url) - if len(self.failures) > 10: - logging.info("Failed 10 times. Trying DDG...") + if len(self.failures) > 5: + logging.info("Failed 5 times. Trying DDG...") self.agent_settings["SEARXNG_INSTANCE_URL"] = "" self.ApiClient.update_agent_settings( agent_name=self.agent_name, settings=self.agent_settings From 9220481c1b1b68808f7890097071673ab8d14f18 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 31 May 2024 07:26:17 -0400 Subject: [PATCH 0053/1256] remove headless flag --- agixt/Websearch.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/agixt/Websearch.py b/agixt/Websearch.py index 326bf3539e03..f9f461ba28b6 100644 --- a/agixt/Websearch.py +++ b/agixt/Websearch.py @@ -346,7 +346,6 @@ async def ddg_search(self, query: str, proxy=None) -> List[str]: launch_options = {} if proxy: launch_options["proxy"] = {"server": proxy} - launch_options["headless"] = True browser = await p.chromium.launch(**launch_options) query = urllib.parse.quote(query) headers = { @@ -370,9 +369,8 @@ async def ddg_search(self, query: str, proxy=None) -> List[str]: url = f"https://lite.duckduckgo.com/lite/?q={query}" await page.goto(url) page_content = await page.content() - soup = BeautifulSoup(page_content, "html.parser") # print the page content - logging.info(f"Page content from DDG search: {soup.get_text()}") + logging.info(f"Page content from DDG search: {page_content}") links = await page.query_selector_all("a") results = [] for link in links: From b1c07b8d425ddf5e16b2598ed81213c66110ab72 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 31 May 2024 07:29:41 -0400 Subject: [PATCH 0054/1256] use firefox instead --- agixt/Websearch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/agixt/Websearch.py b/agixt/Websearch.py index f9f461ba28b6..a853b3a993d2 100644 --- a/agixt/Websearch.py +++ b/agixt/Websearch.py @@ -346,7 +346,8 @@ async def ddg_search(self, query: str, proxy=None) -> List[str]: launch_options = {} if proxy: launch_options["proxy"] = {"server": proxy} - browser = await p.chromium.launch(**launch_options) + launch_options["headless"] = True + browser = await p.firefox.launch(**launch_options) query = urllib.parse.quote(query) headers = { "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7", From b8cdc28c9fec9ce46949c1427faf0f02d738fafd Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 31 May 2024 07:35:11 -0400 Subject: [PATCH 0055/1256] add user_agent --- agixt/Websearch.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/agixt/Websearch.py b/agixt/Websearch.py index a853b3a993d2..58a11434db61 100644 --- a/agixt/Websearch.py +++ b/agixt/Websearch.py @@ -365,7 +365,11 @@ async def ddg_search(self, query: str, proxy=None) -> List[str]: "Upgrade-Insecure-Requests": "1", "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/125.0.0.0 Safari/537.36", } - context = await browser.new_context(extra_http_headers=headers) + context = await browser.new_context( + extra_http_headers=headers, + user_agent="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/125.0.0.0 Safari/537.36", + ignore_https_errors=True, + ) page = await context.new_page() url = f"https://lite.duckduckgo.com/lite/?q={query}" await page.goto(url) From e21b8b41a7cfcdebf463624b15c1fbb1e1c9f58a Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 31 May 2024 07:37:35 -0400 Subject: [PATCH 0056/1256] add wait for page load --- agixt/Websearch.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/agixt/Websearch.py b/agixt/Websearch.py index 58a11434db61..352a231000df 100644 --- a/agixt/Websearch.py +++ b/agixt/Websearch.py @@ -373,6 +373,8 @@ async def ddg_search(self, query: str, proxy=None) -> List[str]: page = await context.new_page() url = f"https://lite.duckduckgo.com/lite/?q={query}" await page.goto(url) + # wait for page to load + await page.wait_for_load_state("load") page_content = await page.content() # print the page content logging.info(f"Page content from DDG search: {page_content}") From 4c7d448c084498bc504d962dc790d70245dbe39e Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 31 May 2024 08:00:13 -0400 Subject: [PATCH 0057/1256] use undetected chromedriver --- agixt/Websearch.py | 69 ++++++++++++---------------------------------- requirements.txt | 2 +- 2 files changed, 18 insertions(+), 53 deletions(-) diff --git a/agixt/Websearch.py b/agixt/Websearch.py index 352a231000df..88293beea83f 100644 --- a/agixt/Websearch.py +++ b/agixt/Websearch.py @@ -14,6 +14,7 @@ from Globals import getenv, get_tokens from readers.youtube import YoutubeReader from readers.github import GithubReader +import undetected_chromedriver as uc logging.basicConfig( level=getenv("LOG_LEVEL"), @@ -341,58 +342,22 @@ async def resursive_browsing(self, user_input, links): except: logging.info(f"Issues reading {url}. Moving on...") - async def ddg_search(self, query: str, proxy=None) -> List[str]: - async with async_playwright() as p: - launch_options = {} - if proxy: - launch_options["proxy"] = {"server": proxy} - launch_options["headless"] = True - browser = await p.firefox.launch(**launch_options) - query = urllib.parse.quote(query) - headers = { - "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7", - "Accept-Encoding": "gzip, deflate, br, zstd", - "Accept-Language": "en-US,en;q=0.9", - "Cache-Control": "max-age=0", - "Priority": "u=0, i", - "Sec-Ch-Ua": '"Google Chrome";v="125", "Chromium";v="125", "Not.A/Brand";v="24"', - "Sec-Ch-Ua-Mobile": "?0", - "Sec-Ch-Ua-Platform": '"Windows"', - "Sec-Fetch-Dest": "document", - "Sec-Fetch-Mode": "navigate", - "Sec-Fetch-Site": "none", - "Sec-Fetch-User": "?1", - "Upgrade-Insecure-Requests": "1", - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/125.0.0.0 Safari/537.36", - } - context = await browser.new_context( - extra_http_headers=headers, - user_agent="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/125.0.0.0 Safari/537.36", - ignore_https_errors=True, - ) - page = await context.new_page() - url = f"https://lite.duckduckgo.com/lite/?q={query}" - await page.goto(url) - # wait for page to load - await page.wait_for_load_state("load") - page_content = await page.content() - # print the page content - logging.info(f"Page content from DDG search: {page_content}") - links = await page.query_selector_all("a") - results = [] - for link in links: - summary = await page.evaluate("(link) => link.textContent", link) - summary = summary.replace("\n", "").replace("\t", "").replace(" ", "") - href = await page.evaluate("(link) => link.href", link) - parsed_url = urllib.parse.urlparse(href) - query_params = urllib.parse.parse_qs(parsed_url.query) - uddg = query_params.get("uddg", [None])[0] - if uddg: - href = urllib.parse.unquote(uddg) - if summary: - results.append(f"{summary} - {href}") - await browser.close() - return results + async def ddg_search(query: str) -> List[str]: + driver = uc.Chrome(headless=True, use_subpress=False) + driver.get(f"https://lite.duckduckgo.com/lite/?q={query}") + page_content = driver.page_source + logging.info(f"DDG Page Content: {page_content}...") + soup = BeautifulSoup(page_content, "html.parser") + links = soup.find_all("a") + parsed_links = [] + for link in links: + new_link = str(link) + new_link = new_link.split("?uddg=")[1].split("&rut=")[0] + new_link = urllib.parse.unquote(new_link) + summary = str(link).split(">")[1].split("")[0].replace(" List[str]: if self.searx_instance_url == "": diff --git a/requirements.txt b/requirements.txt index 3d3f757d331f..2152316cafa5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,4 +18,4 @@ python-multipart==0.0.9 nest_asyncio g4f==0.3.1.9 pyotp -duckduckgo_search==6.1.4 \ No newline at end of file +undetected-chromedriver==3.5.5 \ No newline at end of file From c6e6fba54ca5a9908d4bf1a9d9f072a16b4f4851 Mon Sep 17 00:00:00 2001 From: Josh XT <102809327+Josh-XT@users.noreply.github.com> Date: Sun, 2 Jun 2024 00:15:28 -0400 Subject: [PATCH 0058/1256] Rework websearch to rotate endpoints on failure (#1202) * add brave search * improve prompt agent handling of input * set as str * Rotate search providers until one works. * fix ref * dont use subprocess * clean up functions and req * improve web summary * fix endpoint * set link_list to empty list if none * use failure list * add logging * fix none len * improve prompt order * move link parsing * add 'be concise' * fix agent config update * set endpoint * extract keywords for search * define version * move textacy install * try ddg first * await ddg search * Disable summarization by default while experimental * Clean up before merge --- Dockerfile | 5 +- agixt/Agent.py | 1 + agixt/Extensions.py | 2 - agixt/Interactions.py | 15 +- agixt/Memories.py | 11 +- agixt/Websearch.py | 224 ++++++++++-------- agixt/XT.py | 70 +++++- agixt/endpoints/Agent.py | 11 +- .../Default/Convert to Pydantic Model.txt | 11 + agixt/prompts/Default/Web Summary.txt | 6 + agixt/prompts/Default/WebSearch.txt | 4 +- agixt/prompts/Default/Website Summary.txt | 4 - requirements.txt | 3 +- static-requirements.txt | 2 +- 14 files changed, 237 insertions(+), 132 deletions(-) create mode 100644 agixt/prompts/Default/Convert to Pydantic Model.txt create mode 100644 agixt/prompts/Default/Web Summary.txt delete mode 100644 agixt/prompts/Default/Website Summary.txt diff --git a/Dockerfile b/Dockerfile index a1e0dd742699..0d8ec1c509a2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -52,12 +52,13 @@ RUN pip install -r requirements.txt # Download spaCy language model RUN pip install spacy && \ - python -m spacy download en_core_web_sm + python -m spacy download en_core_web_sm && \ + pip install textacy==0.13.0 # Install Playwright RUN npm install -g playwright && \ npx playwright install && \ - playwright install --with-deps + playwright install COPY . . diff --git a/agixt/Agent.py b/agixt/Agent.py index e442d2cbbe92..3a5452ce85f1 100644 --- a/agixt/Agent.py +++ b/agixt/Agent.py @@ -389,6 +389,7 @@ def update_agent_config(self, new_config, config_key): self.session.add(agent_command) else: for setting_name, setting_value in new_config.items(): + logging.info(f"Setting {setting_name} to {setting_value}.") agent_setting = ( self.session.query(AgentSettingModel) .filter_by(agent_id=agent.id, name=setting_name) diff --git a/agixt/Extensions.py b/agixt/Extensions.py index 9c8bbe02bcb5..97142859fd2b 100644 --- a/agixt/Extensions.py +++ b/agixt/Extensions.py @@ -107,8 +107,6 @@ def load_commands(self): params, ) ) - # Return the commands list - logging.debug(f"loaded commands: {commands}") return commands def get_extension_settings(self): diff --git a/agixt/Interactions.py b/agixt/Interactions.py index 74c8f1303864..8c3226d4af21 100644 --- a/agixt/Interactions.py +++ b/agixt/Interactions.py @@ -451,7 +451,7 @@ async def run( await self.websearch.scrape_websites( user_input=user_input, search_depth=websearch_depth, - summarize_content=True, + summarize_content=False, conversation_name=conversation_name, ) if websearch: @@ -465,22 +465,9 @@ async def run( role=self.agent_name, message=f"[ACTIVITY] Searching the web...", ) - search_string = await self.run( - user_input=user_input, - prompt_name="WebSearch", - context_results=context_results if context_results > 0 else 5, - log_user_input=False, - browse_links=False, - websearch=False, - ) - c.log_interaction( - role=self.agent_name, - message=f"[ACTIVITY] Searching web for: {search_string}", - ) # try: await self.websearch.websearch_agent( user_input=user_input, - search_string=search_string, websearch_depth=websearch_depth, websearch_timeout=websearch_timeout, ) diff --git a/agixt/Memories.py b/agixt/Memories.py index 1311ce9dcabc..41ef38eca36b 100644 --- a/agixt/Memories.py +++ b/agixt/Memories.py @@ -15,6 +15,7 @@ from collections import Counter from typing import List from Globals import getenv, DEFAULT_USER +from textacy.extract.keyterms import textrank logging.basicConfig( level=getenv("LOG_LEVEL"), @@ -34,6 +35,12 @@ def nlp(text): return sp(text) +def extract_keywords(doc=None, text="", limit=10): + if not doc: + doc = nlp(text) + return [k for k, s in textrank(doc, topn=limit)] + + def snake(old_str: str = ""): if not old_str: return "" @@ -488,9 +495,7 @@ async def chunk_content(self, text: str, chunk_size: int) -> List[str]: content_chunks = [] chunk = [] chunk_len = 0 - keywords = [ - token.text for token in doc if token.pos_ in {"NOUN", "PROPN", "VERB"} - ] + keywords = set(extract_keywords(doc=doc, limit=10)) for sentence in sentences: sentence_tokens = len(sentence) if chunk_len + sentence_tokens > chunk_size and chunk: diff --git a/agixt/Websearch.py b/agixt/Websearch.py index 88293beea83f..6aca7264f7a4 100644 --- a/agixt/Websearch.py +++ b/agixt/Websearch.py @@ -14,7 +14,9 @@ from Globals import getenv, get_tokens from readers.youtube import YoutubeReader from readers.github import GithubReader -import undetected_chromedriver as uc +from datetime import datetime +from Memories import extract_keywords + logging.basicConfig( level=getenv("LOG_LEVEL"), @@ -52,10 +54,10 @@ def __init__( ApiClient=ApiClient, user=user, ) - self.searx_instance_url = ( - self.agent.AGENT_CONFIG["settings"]["SEARXNG_INSTANCE_URL"] - if "SEARXNG_INSTANCE_URL" in self.agent.AGENT_CONFIG["settings"] - else "" + self.websearch_endpoint = ( + self.agent_settings["websearch_endpoint"] + if "websearch_endpoint" in self.agent_settings + else "https://search.brave.com" ) def verify_link(self, link: str = "") -> bool: @@ -89,7 +91,7 @@ async def summarize_web_content(self, url, content): if get_tokens(text=content) < int(max_tokens): return self.ApiClient.prompt_agent( agent_name=self.agent_name, - prompt_name="Website Summary", + prompt_name="Web Summary", prompt_args={ "user_input": content, "url": url, @@ -107,7 +109,7 @@ async def summarize_web_content(self, url, content): new_content.append( self.ApiClient.prompt_agent( agent_name=self.agent_name, - prompt_name="Website Summary", + prompt_name="Web Summary", prompt_args={ "user_input": chunk, "url": url, @@ -273,7 +275,7 @@ async def get_web_content(self, url: str, summarize_content=False): except: return None, None - async def resursive_browsing(self, user_input, links): + async def recursive_browsing(self, user_input, links): try: words = links.split() links = [ @@ -336,91 +338,12 @@ async def resursive_browsing(self, user_input, links): logging.info( f"AI has decided to click: {pick_a_link}" ) - await self.resursive_browsing( + await self.recursive_browsing( user_input=user_input, links=pick_a_link ) except: logging.info(f"Issues reading {url}. Moving on...") - async def ddg_search(query: str) -> List[str]: - driver = uc.Chrome(headless=True, use_subpress=False) - driver.get(f"https://lite.duckduckgo.com/lite/?q={query}") - page_content = driver.page_source - logging.info(f"DDG Page Content: {page_content}...") - soup = BeautifulSoup(page_content, "html.parser") - links = soup.find_all("a") - parsed_links = [] - for link in links: - new_link = str(link) - new_link = new_link.split("?uddg=")[1].split("&rut=")[0] - new_link = urllib.parse.unquote(new_link) - summary = str(link).split(">")[1].split("")[0].replace(" List[str]: - if self.searx_instance_url == "": - try: # SearXNG - List of these at https://searx.space/ - response = requests.get("https://searx.space/data/instances.json") - data = json.loads(response.text) - if self.failures != []: - for failure in self.failures: - if failure in data["instances"]: - del data["instances"][failure] - servers = list(data["instances"].keys()) - random_index = random.randint(0, len(servers) - 1) - self.searx_instance_url = servers[random_index] - except: # Select default remote server that typically works if unable to get list. - self.searx_instance_url = "https://search.us.projectsegfau.lt" - self.agent_settings["SEARXNG_INSTANCE_URL"] = self.searx_instance_url - self.ApiClient.update_agent_settings( - agent_name=self.agent_name, settings=self.agent_settings - ) - self.searx_instance_url = str(self.searx_instance_url).rstrip("/") - logging.info(f"Using {self.searx_instance_url}") - self.agent_settings["SEARXNG_INSTANCE_URL"] = self.searx_instance_url - self.ApiClient.update_agent_settings( - agent_name=self.agent_name, settings=self.agent_settings - ) - endpoint = f"{self.searx_instance_url}/search" - try: - logging.info(f"Trying to connect to SearXNG Search at {endpoint}...") - response = requests.get( - endpoint, - params={ - "q": query, - "language": "en", - "safesearch": 1, - "format": "json", - }, - ) - results = response.json() - summaries = [ - result["title"] + " - " + result["url"] for result in results["results"] - ] - if len(summaries) < 1: - self.failures.append(self.searx_instance_url) - self.searx_instance_url = "" - return await self.search(query=query) - return summaries - except: - self.failures.append(self.searx_instance_url) - if len(self.failures) > 5: - logging.info("Failed 5 times. Trying DDG...") - self.agent_settings["SEARXNG_INSTANCE_URL"] = "" - self.ApiClient.update_agent_settings( - agent_name=self.agent_name, settings=self.agent_settings - ) - return await self.ddg_search(query=query) - times = "times" if len(self.failures) != 1 else "time" - logging.info( - f"Failed to find a working SearXNG server {len(self.failures)} {times}. Trying again..." - ) - # The SearXNG server is down or refusing connection, so we will use the default one. - self.searx_instance_url = "" - return await self.search(query=query) - async def scrape_websites( self, user_input: str = "", @@ -484,10 +407,101 @@ async def scrape_websites( ) return message + async def ddg_search(self, query: str, proxy=None) -> List[str]: + async with async_playwright() as p: + launch_options = {} + if proxy: + launch_options["proxy"] = {"server": proxy} + browser = await p.chromium.launch(**launch_options) + context = await browser.new_context() + page = await context.new_page() + url = f"https://lite.duckduckgo.com/lite/?q={query}" + await page.goto(url) + links = await page.query_selector_all("a") + results = [] + for link in links: + summary = await page.evaluate("(link) => link.textContent", link) + summary = summary.replace("\n", "").replace("\t", "").replace(" ", "") + href = await page.evaluate("(link) => link.href", link) + parsed_url = urllib.parse.urlparse(href) + query_params = urllib.parse.parse_qs(parsed_url.query) + uddg = query_params.get("uddg", [None])[0] + if uddg: + href = urllib.parse.unquote(uddg) + if summary: + results.append(f"{summary} - {href}") + await browser.close() + return results + + async def update_search_provider(self): + # SearXNG - List of these at https://searx.space/ + # Check if the instances-todays date.json file exists + instances_file = ( + f"./WORKSPACE/instances-{datetime.now().strftime('%Y-%m-%d')}.json" + ) + if os.path.exists(instances_file): + with open(instances_file, "r") as f: + data = json.load(f) + else: + response = requests.get("https://searx.space/data/instances.json") + data = json.loads(response.text) + with open(instances_file, "w") as f: + json.dump(data, f) + servers = list(data["instances"].keys()) + servers.append("https://search.brave.com") + servers.append("https://lite.duckduckgo.com/lite") + websearch_endpoint = self.websearch_endpoint + if "websearch_endpoint" not in self.agent_settings: + self.agent_settings["websearch_endpoint"] = websearch_endpoint + self.agent.update_agent_config( + new_config={"websearch_endpoint": websearch_endpoint}, + config_key="settings", + ) + return websearch_endpoint + if ( + self.agent_settings["websearch_endpoint"] == "" + or self.agent_settings["websearch_endpoint"] is None + ): + self.agent_settings["websearch_endpoint"] = websearch_endpoint + self.agent.update_agent_config( + new_config={"websearch_endpoint": websearch_endpoint}, + config_key="settings", + ) + return websearch_endpoint + random_index = random.randint(0, len(servers) - 1) + websearch_endpoint = servers[random_index] + while websearch_endpoint in self.failures: + random_index = random.randint(0, len(servers) - 1) + websearch_endpoint = servers[random_index] + self.agent_settings["websearch_endpoint"] = websearch_endpoint + self.agent.update_agent_config( + new_config={"websearch_endpoint": websearch_endpoint}, + config_key="settings", + ) + self.websearch_endpoint = websearch_endpoint + return websearch_endpoint + + async def web_search(self, query: str) -> List[str]: + endpoint = self.websearch_endpoint + if endpoint.endswith("/"): + endpoint = endpoint[:-1] + if endpoint.endswith("search"): + endpoint = endpoint[:-6] + logging.info(f"Websearching for {query} on {endpoint}") + text_content, link_list = await self.get_web_content( + url=f"{endpoint}/search?q={query}" + ) + if link_list is None: + link_list = [] + if len(link_list) < 5: + self.failures.append(self.websearch_endpoint) + await self.update_search_provider() + return await self.web_search(query=query) + return text_content, link_list + async def websearch_agent( self, user_input: str = "What are the latest breakthroughs in AI?", - search_string: str = "", websearch_depth: int = 0, websearch_timeout: int = 0, ): @@ -501,19 +515,33 @@ async def websearch_agent( except: websearch_timeout = 0 if websearch_depth > 0: - if len(search_string) > 0: - links = [] - logging.info(f"Searching for: {search_string}") - if self.searx_instance_url != "": - links = await self.search(query=search_string) - else: - links = await self.ddg_search(query=search_string) - logging.info(f"Found {len(links)} results for {search_string}") + if len(user_input) > 0: + search_string = self.ApiClient.prompt_agent( + agent_name=self.agent_name, + prompt_name="WebSearch", + prompt_args={ + "user_input": user_input, + "browse_links": "false", + "websearch": "false", + }, + ) + keywords = extract_keywords(text=user_input, limit=5) + if keywords: + search_string = " ".join(keywords) + # add month and year to the end of the search string + search_string += f" {datetime.now().strftime('%B %Y')}" + links = await self.ddg_search(query=search_string) + if links == [] or links is None: + links = [] + content, links = await self.web_search(query=search_string) + logging.info( + f"Found {len(links)} results for {search_string} using DDG." + ) if len(links) > websearch_depth: links = links[:websearch_depth] if links is not None and len(links) > 0: task = asyncio.create_task( - self.resursive_browsing(user_input=user_input, links=links) + self.recursive_browsing(user_input=user_input, links=links) ) self.tasks.append(task) diff --git a/agixt/XT.py b/agixt/XT.py index 945295ccd4bb..ee1ba776302c 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -6,7 +6,9 @@ from Globals import getenv, get_tokens, DEFAULT_SETTINGS from Models import ChatCompletions from datetime import datetime -from typing import List +from typing import Type, get_args, get_origin, Union, List +from enum import Enum +from pydantic import BaseModel import logging import asyncio import os @@ -524,7 +526,7 @@ async def learn_from_websites( self, urls: list = [], scrape_depth: int = 3, - summarize_content: bool = True, + summarize_content: bool = False, conversation_name: str = "", ): """ @@ -845,7 +847,7 @@ async def chat_completions(self, prompt: ChatCompletions): await self.learn_from_websites( urls=urls, scrape_depth=3, - summarize_content=True, + summarize_content=False, conversation_name=conversation_name, ) if mode == "command" and command_name and command_variable: @@ -1063,3 +1065,65 @@ async def create_dataset_from_memories(self, batch_size: int = 10): new_config=self.agent_settings, config_key="settings" ) return dpo_dataset + + def convert_to_pydantic_model( + self, + input_string: str, + model: Type[BaseModel], + max_failures: int = 3, + response_type: str = None, + **kwargs, + ): + input_string = str(input_string) + fields = model.__annotations__ + field_descriptions = [] + for field, field_type in fields.items(): + description = f"{field}: {field_type}" + if get_origin(field_type) == Union: + field_type = get_args(field_type)[0] + if isinstance(field_type, type) and issubclass(field_type, Enum): + enum_values = ", ".join([f"{e.name} = {e.value}" for e in field_type]) + description += f" (Enum values: {enum_values})" + field_descriptions.append(description) + schema = "\n".join(field_descriptions) + response = self.inference( + user_input=input_string, + schema=schema, + prompt_category="Default", + prompt_name="Convert to Pydantic Model", + log_user_input=False, + ) + if "```json" in response: + response = response.split("```json")[1].split("```")[0].strip() + elif "```" in response: + response = response.split("```")[1].strip() + try: + response = json.loads(response) + if response_type == "json": + return response + else: + return model(**response) + except Exception as e: + if "failures" in kwargs: + failures = int(kwargs["failures"]) + 1 + if failures > max_failures: + logging.error( + f"Error: {e} . Failed to convert the response to the model after 3 attempts. Response: {response}" + ) + return ( + response + if response + else "Failed to convert the response to the model." + ) + else: + failures = 1 + logging.warning( + f"Error: {e} . Failed to convert the response to the model, trying again. {failures}/3 failures. Response: {response}" + ) + return self.convert_to_pydantic_model( + input_string=input_string, + model=model, + max_failures=max_failures, + response_type=response_type, + failures=failures, + ) diff --git a/agixt/endpoints/Agent.py b/agixt/endpoints/Agent.py index c63133a03772..d033ed8ba92a 100644 --- a/agixt/endpoints/Agent.py +++ b/agixt/endpoints/Agent.py @@ -192,8 +192,17 @@ async def prompt_agent( ): ApiClient = get_api_client(authorization=authorization) agent = Interactions(agent_name=agent_name, user=user, ApiClient=ApiClient) + if ( + "prompt" in agent_prompt.prompt_args + and "prompt_name" not in agent_prompt.prompt_args + ): + agent_prompt.prompt_args["prompt_name"] = agent_prompt.prompt_args["prompt"] + if "prompt_name" not in agent_prompt.prompt_args: + agent_prompt.prompt_args["prompt_name"] = "Chat" + if "prompt_category" not in agent_prompt.prompt_args: + agent_prompt.prompt_args["prompt_category"] = "Default" + agent_prompt.prompt_args = {k: v for k, v in agent_prompt.prompt_args.items()} response = await agent.run( - prompt=agent_prompt.prompt_name, log_user_input=True, **agent_prompt.prompt_args, ) diff --git a/agixt/prompts/Default/Convert to Pydantic Model.txt b/agixt/prompts/Default/Convert to Pydantic Model.txt new file mode 100644 index 000000000000..ab6d3a4c9252 --- /dev/null +++ b/agixt/prompts/Default/Convert to Pydantic Model.txt @@ -0,0 +1,11 @@ +Act as a JSON converter that converts any text into the desired JSON format based on the schema provided. Respond only with JSON in a properly formatted markdown code block, no explanations. Make your best assumptions based on data to try to fill in information to match the schema provided. +**DO NOT ADD FIELDS TO THE MODEL OR CHANGE TYPES OF FIELDS, FOLLOW THE PYDANTIC SCHEMA!** +**Reformat the following information into a structured format according to the schema provided:** + +## Information: +{user_input} + +## Pydantic Schema: +{schema} + +JSON Structured Output: diff --git a/agixt/prompts/Default/Web Summary.txt b/agixt/prompts/Default/Web Summary.txt new file mode 100644 index 000000000000..3091465b999d --- /dev/null +++ b/agixt/prompts/Default/Web Summary.txt @@ -0,0 +1,6 @@ +Content of {url} to notate for the user: +{user_input} + +**The assistant breaks down the content of the website into bulletpoints of important detailed notes about the content that may be important to know to a reader. It is important to not lose any of the essence of the content.** +**If something in the content does not belong, such as a third party ad, do not include it in the notes. Do not mention the URL of the content in the notes. If the website is related to coding tutorials, return full code blocks in notes. Remember the importance of retaining as much detail as possible!** +**The assistant should respond with well-structured notes that are detailed about the content.** \ No newline at end of file diff --git a/agixt/prompts/Default/WebSearch.txt b/agixt/prompts/Default/WebSearch.txt index c32fbd7660ce..02f3a000d3d9 100644 --- a/agixt/prompts/Default/WebSearch.txt +++ b/agixt/prompts/Default/WebSearch.txt @@ -5,6 +5,6 @@ Recent conversation history for context: Today's date is {date}. -You are a web search suggestion agent. Suggest a good web search string for the user's input. Attempt to make suggestions that will ensure top results are recent information and from reputable information sources and give proper keywords to maximize the chance of finding the information for the user's input. **Respond only with the search string and nothing else.** - User's input: {user_input} + +The assistant is a web search suggestion agent. Suggest a good web search string for the user's input. Attempt to make suggestions that will ensure top results are recent information and from reputable information sources and give proper keywords to maximize the chance of finding the information for the user's input. **Respond only with the search string and nothing else. Be concise and only rephrase the user's input to create an optimal search string that will be used directly in Google search!** diff --git a/agixt/prompts/Default/Website Summary.txt b/agixt/prompts/Default/Website Summary.txt deleted file mode 100644 index 2fc73a0602c9..000000000000 --- a/agixt/prompts/Default/Website Summary.txt +++ /dev/null @@ -1,4 +0,0 @@ -Content of {url} to summarize for the user: -{user_input} - -**Task: Summarize the content in as little text as possible without losing any details, it is important to retain details. If something in the content does not belong, such as a third party ad, do not include it in the summary. Do not summarize anything inside of code blocks, return fully populated code blocks if they exist. Do not mention the URL of the content in the summary.** diff --git a/requirements.txt b/requirements.txt index 2152316cafa5..f65835c88eb4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,5 +17,4 @@ google-api-python-client==2.125.0 python-multipart==0.0.9 nest_asyncio g4f==0.3.1.9 -pyotp -undetected-chromedriver==3.5.5 \ No newline at end of file +pyotp \ No newline at end of file diff --git a/static-requirements.txt b/static-requirements.txt index 8e1fcf42322b..8009022ad887 100644 --- a/static-requirements.txt +++ b/static-requirements.txt @@ -1,4 +1,4 @@ -chromadb==0.4.24 +chromadb==0.5.0 beautifulsoup4==4.12.3 docker==6.1.3 docx2txt==0.8 From 5147f0507bd49ac938916df3f84679cd383bf494 Mon Sep 17 00:00:00 2001 From: Josh XT <102809327+Josh-XT@users.noreply.github.com> Date: Tue, 4 Jun 2024 10:42:21 -0400 Subject: [PATCH 0059/1256] Move google search to Websearch.py (#1203) --- agixt/Websearch.py | 51 +++++++++++++++++++++++++++++++++++++- agixt/extensions/google.py | 4 +++ 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/agixt/Websearch.py b/agixt/Websearch.py index 6aca7264f7a4..d94fff5124c1 100644 --- a/agixt/Websearch.py +++ b/agixt/Websearch.py @@ -16,6 +16,8 @@ from readers.github import GithubReader from datetime import datetime from Memories import extract_keywords +from googleapiclient.discovery import build +from googleapiclient.errors import HttpError logging.basicConfig( @@ -433,6 +435,29 @@ async def ddg_search(self, query: str, proxy=None) -> List[str]: await browser.close() return results + async def google_search( + self, + query: str, + depth: int = 5, + google_api_key: str = "", + google_search_engine_id: str = "", + ) -> List[str]: + if google_api_key == "" or google_search_engine_id == "": + return [] + try: + service = build("customsearch", "v1", developerKey=google_api_key) + result = ( + service.cse() + .list(q=query, cx=google_search_engine_id, num=depth) + .execute() + ) + search_results = result.get("items", []) + search_results_links = [item["link"] for item in search_results] + except Exception as e: + logging.error(f"Google Search Error: {e}") + search_results_links = [] + return search_results_links + async def update_search_provider(self): # SearXNG - List of these at https://searx.space/ # Check if the instances-todays date.json file exists @@ -530,7 +555,31 @@ async def websearch_agent( search_string = " ".join(keywords) # add month and year to the end of the search string search_string += f" {datetime.now().strftime('%B %Y')}" - links = await self.ddg_search(query=search_string) + google_api_key = ( + self.agent_settings["GOOGLE_API_KEY"] + if "GOOGLE_API_KEY" in self.agent_settings + else "" + ) + google_search_engine_id = ( + self.agent_settings["GOOGLE_SEARCH_ENGINE_ID"] + if "GOOGLE_SEARCH_ENGINE_ID" in self.agent_settings + else "" + ) + links = [] + if ( + google_api_key != "" + and google_search_engine_id != "" + and google_api_key is not None + and google_search_engine_id is not None + ): + links = await self.google_search( + query=search_string, + depth=websearch_depth, + google_api_key=google_api_key, + google_search_engine_id=google_search_engine_id, + ) + if links == [] or links is None: + links = await self.ddg_search(query=search_string) if links == [] or links is None: links = [] content, links = await self.web_search(query=search_string) diff --git a/agixt/extensions/google.py b/agixt/extensions/google.py index 3d2cd4d32d4e..bddb2ee2b996 100644 --- a/agixt/extensions/google.py +++ b/agixt/extensions/google.py @@ -34,11 +34,15 @@ def __init__( GOOGLE_CLIENT_ID: str = "", GOOGLE_CLIENT_SECRET: str = "", GOOGLE_REFRESH_TOKEN: str = "", + GOOGLE_API_KEY: str = "", + GOOGLE_SEARCH_ENGINE_ID: str = "", **kwargs, ): self.GOOGLE_CLIENT_ID = GOOGLE_CLIENT_ID self.GOOGLE_CLIENT_SECRET = GOOGLE_CLIENT_SECRET self.GOOGLE_REFRESH_TOKEN = GOOGLE_REFRESH_TOKEN + self.GOOGLE_API_KEY = GOOGLE_API_KEY + self.GOOGLE_SEARCH_ENGINE_ID = GOOGLE_SEARCH_ENGINE_ID self.attachments_dir = "./WORKSPACE/email_attachments/" os.makedirs(self.attachments_dir, exist_ok=True) self.commands = { From b30a70fab14b88bf807db38324b193621084b3f7 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Tue, 4 Jun 2024 10:46:52 -0400 Subject: [PATCH 0060/1256] Allow defining google as search engine in env by adding API keys --- agixt/Websearch.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/agixt/Websearch.py b/agixt/Websearch.py index d94fff5124c1..f87f0f451161 100644 --- a/agixt/Websearch.py +++ b/agixt/Websearch.py @@ -442,6 +442,9 @@ async def google_search( google_api_key: str = "", google_search_engine_id: str = "", ) -> List[str]: + if google_api_key == "" or google_search_engine_id == "": + google_api_key = getenv("GOOGLE_API_KEY") + google_search_engine_id = getenv("GOOGLE_SEARCH_ENGINE_ID") if google_api_key == "" or google_search_engine_id == "": return [] try: From 6f77a63981111987908ed7395bc883b5a8e84c77 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Wed, 5 Jun 2024 19:24:43 -0400 Subject: [PATCH 0061/1256] Fix activity response --- agixt/DB.py | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ agixt/XT.py | 10 +++++----- 2 files changed, 53 insertions(+), 5 deletions(-) diff --git a/agixt/DB.py b/agixt/DB.py index 5dea1e5e0880..8586f859c263 100644 --- a/agixt/DB.py +++ b/agixt/DB.py @@ -474,6 +474,54 @@ class PromptCategory(Base): user = relationship("User", backref="prompt_category") +class TaskCategory(Base): + __tablename__ = "task_category" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + name = Column(String) + description = Column(String) + memory_collection = Column(Integer, default=0) + created_at = Column(DateTime, server_default=func.now()) + updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now()) + category_id = Column( + UUID(as_uuid=True), ForeignKey("task_category.id"), nullable=True + ) + parent_category = relationship("TaskCategory", remote_side=[id]) + + +class TaskItem(Base): + __tablename__ = "task_item" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + user_id = Column(UUID(as_uuid=True), ForeignKey("user.id")) + category_id = Column(UUID(as_uuid=True), ForeignKey("task_category.id")) + category = relationship("TaskCategory") + title = Column(String) + description = Column(String) + memory_collection = Column(Integer, default=0) + # agent_id is the action item owner. If it is null, it is an item for the user + agent_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("agent.id"), + nullable=True, + ) + estimated_hours = Column(Integer) + scheduled = Column(Boolean, default=False) + completed = Column(Boolean, default=False) + due_date = Column(DateTime, nullable=True) + created_at = Column(DateTime, server_default=func.now()) + updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now()) + completed_at = Column(DateTime, nullable=True) + priority = Column(Integer) + user = relationship("User", backref="task_item") + + class Prompt(Base): __tablename__ = "prompt" id = Column( diff --git a/agixt/XT.py b/agixt/XT.py index ee1ba776302c..c923c0153b86 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -547,17 +547,17 @@ async def learn_from_websites( url_str = {"\n".join(urls)} user_input = f"Learn from the information from these websites:\n {url_str} " c = Conversations(conversation_name=conversation_name, user=self.user_email) + if conversation_name != "" and conversation_name != None and response != "": + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY] Learning from websites..", + ) response = await self.agent_interactions.websearch.scrape_websites( user_input=user_input, search_depth=scrape_depth, summarize_content=summarize_content, conversation_name=conversation_name, ) - if conversation_name != "" and conversation_name != None and response != "": - c.log_interaction( - role=self.agent_name, - message=f"[ACTIVITY] {response}", - ) return "I have read the information from the websites into my memory." async def learn_from_file( From da69fd43bb35086d214d5754d6279fbf6c36007e Mon Sep 17 00:00:00 2001 From: Josh XT Date: Wed, 5 Jun 2024 20:03:09 -0400 Subject: [PATCH 0062/1256] fix ref --- agixt/XT.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agixt/XT.py b/agixt/XT.py index c923c0153b86..b8e5a6639c96 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -547,7 +547,7 @@ async def learn_from_websites( url_str = {"\n".join(urls)} user_input = f"Learn from the information from these websites:\n {url_str} " c = Conversations(conversation_name=conversation_name, user=self.user_email) - if conversation_name != "" and conversation_name != None and response != "": + if conversation_name != "" and conversation_name != None: c.log_interaction( role=self.agent_name, message=f"[ACTIVITY] Learning from websites..", From 19a23e833203182723aea91712c822ff07dc6422 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Wed, 5 Jun 2024 23:43:01 -0400 Subject: [PATCH 0063/1256] Update google args --- agixt/extensions/google.py | 37 +++++++++++++++++++++++++------------ requirements.txt | 1 + 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/agixt/extensions/google.py b/agixt/extensions/google.py index bddb2ee2b996..6cbfd658b9cc 100644 --- a/agixt/extensions/google.py +++ b/agixt/extensions/google.py @@ -27,20 +27,28 @@ ) from googleapiclient.discovery import build +try: + from google_auth_oauthlib.flow import InstalledAppFlow +except: + subprocess.check_call( + [sys.executable, "-m", "pip", "install", "google-auth-oauthlib"] + ) + from google_auth_oauthlib.flow import InstalledAppFlow + class google(Extensions): def __init__( self, GOOGLE_CLIENT_ID: str = "", GOOGLE_CLIENT_SECRET: str = "", - GOOGLE_REFRESH_TOKEN: str = "", + GOOGLE_PROJECT_ID: str = "", GOOGLE_API_KEY: str = "", GOOGLE_SEARCH_ENGINE_ID: str = "", **kwargs, ): self.GOOGLE_CLIENT_ID = GOOGLE_CLIENT_ID self.GOOGLE_CLIENT_SECRET = GOOGLE_CLIENT_SECRET - self.GOOGLE_REFRESH_TOKEN = GOOGLE_REFRESH_TOKEN + self.GOOGLE_PROJECT_ID = GOOGLE_PROJECT_ID self.GOOGLE_API_KEY = GOOGLE_API_KEY self.GOOGLE_SEARCH_ENGINE_ID = GOOGLE_SEARCH_ENGINE_ID self.attachments_dir = "./WORKSPACE/email_attachments/" @@ -61,19 +69,24 @@ def __init__( def authenticate(self): try: - creds = Credentials.from_authorized_user_info( - info={ - "client_id": self.GOOGLE_CLIENT_ID, - "client_secret": self.GOOGLE_CLIENT_SECRET, - "refresh_token": self.GOOGLE_REFRESH_TOKEN, - } + flow = InstalledAppFlow.from_client_config( + { + "installed": { + "client_id": self.GOOGLE_CLIENT_ID, + "project_id": self.GOOGLE_PROJECT_ID, + "client_secret": self.GOOGLE_CLIENT_SECRET, + "redirect_uris": ["http://localhost"], + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + } + }, + scopes=["https://www.googleapis.com/auth/gmail.readonly"], ) - - if creds.expired and creds.refresh_token: - creds.refresh(Request()) - + creds = flow.run_local_server(port=0) return creds except Exception as e: + print(f"Error authenticating: {str(e)}") return None async def get_emails(self, query=None, max_emails=10): diff --git a/requirements.txt b/requirements.txt index f65835c88eb4..39d259e1bf87 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,6 +14,7 @@ youtube-transcript-api==0.6.2 O365==2.0.34 google-auth==2.29.0 google-api-python-client==2.125.0 +google-auth-oauthlib python-multipart==0.0.9 nest_asyncio g4f==0.3.1.9 From f7f4190339ee3d55c5d8cc639dd519b8452dc01f Mon Sep 17 00:00:00 2001 From: Josh XT Date: Wed, 5 Jun 2024 23:51:57 -0400 Subject: [PATCH 0064/1256] Update scopes --- agixt/extensions/google.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/agixt/extensions/google.py b/agixt/extensions/google.py index 6cbfd658b9cc..83e5e0ea472a 100644 --- a/agixt/extensions/google.py +++ b/agixt/extensions/google.py @@ -12,13 +12,6 @@ import logging from Extensions import Extensions -try: - from google.oauth2.credentials import Credentials -except: - subprocess.check_call([sys.executable, "-m", "pip", "install", "google-auth"]) - from google.oauth2.credentials import Credentials -from google.auth.transport.requests import Request - try: from googleapiclient.discovery import build except: @@ -70,7 +63,7 @@ def __init__( def authenticate(self): try: flow = InstalledAppFlow.from_client_config( - { + client_config={ "installed": { "client_id": self.GOOGLE_CLIENT_ID, "project_id": self.GOOGLE_PROJECT_ID, @@ -81,7 +74,11 @@ def authenticate(self): "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", } }, - scopes=["https://www.googleapis.com/auth/gmail.readonly"], + scopes=[ + "https://www.googleapis.com/auth/gmail.send", + "https://www.googleapis.com/auth/calendar", + "https://www.googleapis.com/auth/calendar.events", + ], ) creds = flow.run_local_server(port=0) return creds From 2651a561a77b8cc33267a4806760dbfff31672d8 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Thu, 6 Jun 2024 14:15:14 -0400 Subject: [PATCH 0065/1256] update image_path --- agixt/XT.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index b8e5a6639c96..d9ac47116d82 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -736,17 +736,7 @@ async def chat_completions(self, prompt: ChatCompletions): os.getcwd(), "WORKSPACE", f"{uuid.uuid4().hex}.jpg" ) if url.startswith("http"): - # Validate if url is an image - if ( - url.endswith( - (".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp") - ) - and "localhost" not in url - and "127.0.0.1" not in url - ): - image = requests.get(url).content - else: - image = None + image = requests.get(url).content else: file_type = url.split(",")[0].split("/")[1].split(";")[0] if file_type == "jpeg": From d71fc5edc074bf58f99f9fc5c19750ca60e83b7a Mon Sep 17 00:00:00 2001 From: Josh XT Date: Thu, 6 Jun 2024 14:18:37 -0400 Subject: [PATCH 0066/1256] improve activity messages --- agixt/XT.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index d9ac47116d82..e5861eeb5fae 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -550,7 +550,7 @@ async def learn_from_websites( if conversation_name != "" and conversation_name != None: c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY] Learning from websites..", + message=f"[ACTIVITY] Researching online.", ) response = await self.agent_interactions.websearch.scrape_websites( user_input=user_input, @@ -577,7 +577,12 @@ async def learn_from_file( Returns: str: Response from the agent """ - + if conversation_name != "" and conversation_name != None: + c = Conversations(conversation_name=conversation_name, user=self.user_email) + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY] Reading file.", + ) file_name = os.path.basename(file_path) file_reader = FileReader( agent_name=self.agent_name, @@ -592,7 +597,6 @@ async def learn_from_file( else: response = f"I was unable to read the file called {file_name}." if conversation_name != "" and conversation_name != None: - c = Conversations(conversation_name=conversation_name, user=self.user_email) c.log_interaction( role=self.agent_name, message=f"[ACTIVITY] {response}", From 11abf0afab4e6ece73b0e7d880c461552183a826 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Thu, 6 Jun 2024 14:30:13 -0400 Subject: [PATCH 0067/1256] improve activity message --- agixt/XT.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index e5861eeb5fae..d1cd14da14b2 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -577,13 +577,14 @@ async def learn_from_file( Returns: str: Response from the agent """ + + file_name = os.path.basename(file_path) if conversation_name != "" and conversation_name != None: c = Conversations(conversation_name=conversation_name, user=self.user_email) c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY] Reading file.", + message=f"[ACTIVITY] Reading file {file_name} into memory.", ) - file_name = os.path.basename(file_path) file_reader = FileReader( agent_name=self.agent_name, agent_config=self.agent.AGENT_CONFIG, From 19c4cbaf00f8a3d5b5a93d6595d9920af26e3235 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Thu, 6 Jun 2024 15:47:42 -0400 Subject: [PATCH 0068/1256] add return on audio --- agixt/XT.py | 1 + 1 file changed, 1 insertion(+) diff --git a/agixt/XT.py b/agixt/XT.py index d1cd14da14b2..9609cfb78ad7 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -226,6 +226,7 @@ async def audio_to_text(self, audio_path: str, conversation_name: str = ""): role=self.agent_name, message=f"[ACTIVITY] Transcribed audio to text: {response}", ) + return response async def translate_audio(self, audio_path: str, conversation_name: str = ""): """ From 67c1baf2d39ded2fe3fd1e7966f4d0fe0a9aa8e1 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Thu, 6 Jun 2024 15:51:50 -0400 Subject: [PATCH 0069/1256] Move activity logging --- agixt/XT.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index 9609cfb78ad7..0a8b723cf7fc 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -217,15 +217,13 @@ async def audio_to_text(self, audio_path: str, conversation_name: str = ""): Returns str: Transcription of the audio """ - response = await self.agent.transcribe_audio(audio_path=audio_path) if conversation_name != "" and conversation_name != None: - c = Conversations( - conversation_name="Audio Transcription", user=self.user_email - ) + c = Conversations(conversation_name=conversation_name, user=self.user_email) c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY] Transcribed audio to text: {response}", + message=f"[ACTIVITY] Transcribing audio.", ) + response = await self.agent.transcribe_audio(audio_path=audio_path) return response async def translate_audio(self, audio_path: str, conversation_name: str = ""): @@ -238,15 +236,13 @@ async def translate_audio(self, audio_path: str, conversation_name: str = ""): Returns str: Translation of the audio """ - response = await self.agent.translate_audio(audio_path=audio_path) if conversation_name != "" and conversation_name != None: - c = Conversations( - conversation_name="Audio Translation", user=self.user_email - ) + c = Conversations(conversation_name=conversation_name, user=self.user_email) c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY] Translated audio: {response}", + message=f"[ACTIVITY] Translating audio.", ) + response = await self.agent.translate_audio(audio_path=audio_path) return response async def execute_command( From 537acd6cfec5d5c061d151cbcfca43be73b1143c Mon Sep 17 00:00:00 2001 From: Josh XT Date: Thu, 6 Jun 2024 15:53:37 -0400 Subject: [PATCH 0070/1256] fix logs --- agixt/XT.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index 0a8b723cf7fc..a0b065abc70c 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -172,11 +172,12 @@ async def generate_image(self, prompt: str, conversation_name: str = ""): """ if conversation_name != "" and conversation_name != None: c = Conversations( - conversation_name="Image Generation", user=self.user_email + conversation_name=conversation_name, + user=self.user_email, ) c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY] Generating image...", + message=f"[ACTIVITY] Generating image.", ) return await self.agent.generate_image(prompt=prompt) @@ -191,10 +192,10 @@ async def text_to_speech(self, text: str, conversation_name: str = ""): str: URL of the generated audio """ if conversation_name != "" and conversation_name != None: - c = Conversations(conversation_name="Text to Speech", user=self.user_email) + c = Conversations(conversation_name=conversation_name, user=self.user_email) c.log_interaction( - role="USER", - message=f"[ACTIVITY] Generating audio from text: {text}", + role=self.agent_name, + message=f"[ACTIVITY] Generating audio.", ) tts_url = await self.agent.text_to_speech(text=text.text) if not str(tts_url).startswith("http"): @@ -289,11 +290,6 @@ async def execute_command( ): tts_response = await self.text_to_speech(text=response) response = f"{response}\n\n{tts_response}" - if conversation_name != "" and conversation_name != None: - c.log_interaction( - role=self.agent_name, - message=f"[ACTIVITY] {response}", - ) return response async def run_chain_step( From 2e257165c8bad88ee08695a65967dc806079bcb0 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Thu, 6 Jun 2024 16:07:14 -0400 Subject: [PATCH 0071/1256] remove activity from audio to text functions --- agixt/XT.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index a0b065abc70c..a1fc28e8fc14 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -222,7 +222,7 @@ async def audio_to_text(self, audio_path: str, conversation_name: str = ""): c = Conversations(conversation_name=conversation_name, user=self.user_email) c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY] Transcribing audio.", + message=f"Transcribing audio...", ) response = await self.agent.transcribe_audio(audio_path=audio_path) return response @@ -241,7 +241,7 @@ async def translate_audio(self, audio_path: str, conversation_name: str = ""): c = Conversations(conversation_name=conversation_name, user=self.user_email) c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY] Translating audio.", + message=f"Translating audio...", ) response = await self.agent.translate_audio(audio_path=audio_path) return response From cf0e2f734c8c29b87bf7e93ee74159b245106a7b Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 7 Jun 2024 06:02:38 -0400 Subject: [PATCH 0072/1256] clean up dots --- agixt/Interactions.py | 6 ++---- agixt/Websearch.py | 4 ++-- agixt/XT.py | 6 +++--- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/agixt/Interactions.py b/agixt/Interactions.py index 8c3226d4af21..3002e9f0b677 100644 --- a/agixt/Interactions.py +++ b/agixt/Interactions.py @@ -463,7 +463,7 @@ async def run( if user_input != "": c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY] Searching the web...", + message=f"[ACTIVITY] Searching the web.", ) # try: await self.websearch.websearch_agent( @@ -490,9 +490,7 @@ async def run( image_urls.append(image_url) logging.info(f"Getting vision response for images: {image_urls}") message = ( - "Looking at images..." - if len(image_urls) > 1 - else "Looking at image..." + "Looking at images." if len(image_urls) > 1 else "Looking at image." ) c.log_interaction( role=self.agent_name, diff --git a/agixt/Websearch.py b/agixt/Websearch.py index f87f0f451161..ec43b8c950fc 100644 --- a/agixt/Websearch.py +++ b/agixt/Websearch.py @@ -367,7 +367,7 @@ async def scrape_websites( if conversation_name != "" and conversation_name is not None: c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY] Browsing {link}...", + message=f"[ACTIVITY] Browsing {link} ", ) text_content, link_list = await self.get_web_content( url=link, summarize_content=summarize_content @@ -389,7 +389,7 @@ async def scrape_websites( ): c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY] Browsing {sublink[1]}...", + message=f"[ACTIVITY] Browsing {sublink[1]} ", ) ( text_content, diff --git a/agixt/XT.py b/agixt/XT.py index a1fc28e8fc14..875445e34454 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -222,7 +222,7 @@ async def audio_to_text(self, audio_path: str, conversation_name: str = ""): c = Conversations(conversation_name=conversation_name, user=self.user_email) c.log_interaction( role=self.agent_name, - message=f"Transcribing audio...", + message=f"Transcribing audio.", ) response = await self.agent.transcribe_audio(audio_path=audio_path) return response @@ -241,7 +241,7 @@ async def translate_audio(self, audio_path: str, conversation_name: str = ""): c = Conversations(conversation_name=conversation_name, user=self.user_email) c.log_interaction( role=self.agent_name, - message=f"Translating audio...", + message=f"Translating audio.", ) response = await self.agent.translate_audio(audio_path=audio_path) return response @@ -453,7 +453,7 @@ async def check_dependencies_met(dependencies): if conversation_name != "": c.log_interaction( role=agent_name, - message=f"[ACTIVITY] Running chain `{chain_name}`...", + message=f"[ACTIVITY] Running chain `{chain_name}`.", ) response = "" tasks = [] From 562b9857e73d85156e1d3e4037adb9049444ec9d Mon Sep 17 00:00:00 2001 From: Josh XT <102809327+Josh-XT@users.noreply.github.com> Date: Fri, 7 Jun 2024 09:55:49 -0400 Subject: [PATCH 0073/1256] use url for image uploads (#1206) --- agixt/Interactions.py | 12 +++--------- agixt/XT.py | 21 ++++++++++++++------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/agixt/Interactions.py b/agixt/Interactions.py index 3002e9f0b677..b4452ad264c6 100644 --- a/agixt/Interactions.py +++ b/agixt/Interactions.py @@ -482,15 +482,9 @@ async def run( and vision_provider != "" and vision_provider != None ): - image_urls = [] - for image in images: - image_url = str(image).replace( - "./WORKSPACE/", f"{AGIXT_URI}/outputs/" - ) - image_urls.append(image_url) - logging.info(f"Getting vision response for images: {image_urls}") + logging.info(f"Getting vision response for images: {images}") message = ( - "Looking at images." if len(image_urls) > 1 else "Looking at image." + "Looking at images." if len(images) > 1 else "Looking at image." ) c.log_interaction( role=self.agent_name, @@ -498,7 +492,7 @@ async def run( ) try: vision_response = await self.agent.inference( - prompt=user_input, images=image_urls + prompt=user_input, images=images ) logging.info(f"Vision Response: {vision_response}") except Exception as e: diff --git a/agixt/XT.py b/agixt/XT.py index 875445e34454..10dfd330aef9 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -734,21 +734,28 @@ async def chat_completions(self, prompt: ChatCompletions): os.getcwd(), "WORKSPACE", f"{uuid.uuid4().hex}.jpg" ) if url.startswith("http"): - image = requests.get(url).content + # image = requests.get(url).content + images.append(url) else: file_type = url.split(",")[0].split("/")[1].split(";")[0] if file_type == "jpeg": file_type = "jpg" - file_name = f"{uuid.uuid4().hex}.{file_type}" + if "file_name" in msg: + file_name = str(msg["file_name"]) + if file_name == "": + file_name = f"{uuid.uuid4().hex}.{file_type}" + file_name = "".join( + c if c.isalnum() else "_" for c in file_name + ) + else: + file_name = f"{uuid.uuid4().hex}.{file_type}" image_path = os.path.join( os.getcwd(), "WORKSPACE", file_name ) image = base64.b64decode(url.split(",")[1]) - if image: - if image_path.startswith(base_path): - with open(image_path, "wb") as f: - f.write(image) - images.append(image_path) + with open(image_path, "wb") as f: + f.write(image) + images.append(f"{self.outputs}/{file_name}") if "audio_url" in msg: audio_url = str( msg["audio_url"]["url"] From cb93063c312263f1345fe87c7bf785deeac863c2 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 7 Jun 2024 10:15:12 -0400 Subject: [PATCH 0074/1256] learn from files in pipeline --- agixt/XT.py | 60 +++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 44 insertions(+), 16 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index 10dfd330aef9..354c3315eee6 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -555,7 +555,8 @@ async def learn_from_websites( async def learn_from_file( self, - file_path: str, + file_url: str = "", + file_name: str = "", collection_number: int = 1, conversation_name: str = "", ): @@ -563,6 +564,7 @@ async def learn_from_file( Learn from a file Args: + file_url (str): URL of the file file_path (str): Path to the file collection_number (int): Collection number to store the file conversation_name (str): Name of the conversation @@ -570,8 +572,14 @@ async def learn_from_file( Returns: str: Response from the agent """ - - file_name = os.path.basename(file_path) + if file_name == "": + file_name = file_url.split("/")[-1] + if file_url.startswith(self.outputs): + file_path = os.path.join(os.getcwd(), "WORKSPACE", file_name) + else: + file_path = os.path.join(os.getcwd(), "WORKSPACE", file_name) + with open(file_path, "wb") as f: + f.write(requests.get(file_url).content) if conversation_name != "" and conversation_name != None: c = Conversations(conversation_name=conversation_name, user=self.user_email) c.log_interaction( @@ -609,10 +617,11 @@ async def chat_completions(self, prompt: ChatCompletions): """ conversation_name = prompt.user images = [] + urls = [] + files = [] new_prompt = "" browse_links = True tts = False - urls = [] base_path = os.path.join(os.getcwd(), "WORKSPACE") if "mode" in self.agent_settings: mode = self.agent_settings["mode"] @@ -734,7 +743,6 @@ async def chat_completions(self, prompt: ChatCompletions): os.getcwd(), "WORKSPACE", f"{uuid.uuid4().hex}.jpg" ) if url.startswith("http"): - # image = requests.get(url).content images.append(url) else: file_type = url.split(",")[0].split("/")[1].split(";")[0] @@ -816,29 +824,49 @@ async def chat_completions(self, prompt: ChatCompletions): if "url" in msg["file_url"] else msg["file_url"] ) + if file_url.startswith("data:"): + file_type = ( + file_url.split(",")[0].split("/")[1].split(";")[0] + ) + else: + file_type = file_url.split(".")[-1] + file_name = f"{uuid.uuid4().hex}.{file_type}" + if "file_name" in msg: + file_name = str(msg["file_name"]) + if file_name == "": + file_name = f"{uuid.uuid4().hex}.{file_type}" + file_name = "".join( + c if c.isalnum() else "_" for c in file_name + ) + file_path = os.path.join(os.getcwd(), "WORKSPACE", file_name) if file_url.startswith("http"): - urls.append(file_url) + if "." in file_url[-5:]: + file_type = file_url.split(".")[-1] + if file_type in ["jpg", "jpeg", "png"]: + images.append(file_url) + else: + files.append(file_url) + else: + urls.append(file_url) else: file_type = ( file_url.split(",")[0].split("/")[1].split(";")[0] ) file_data = base64.b64decode(file_url.split(",")[1]) - # file_path = f"./WORKSPACE/{uuid.uuid4().hex}.{file_type}" - file_path = os.path.join( - os.getcwd(), - "WORKSPACE", - f"{uuid.uuid4().hex}.{file_type}", - ) if file_path.startswith(base_path): with open(file_path, "wb") as f: f.write(file_data) - file_url = ( - f"{self.outputs}/{os.path.basename(file_path)}" - ) - urls.append(file_url) + file_url = f"{self.outputs}/{file_name}" + files.append(file_url) # Add user input to conversation c = Conversations(conversation_name=conversation_name, user=self.user_email) c.log_interaction(role="USER", message=new_prompt) + for file in files: + await self.learn_from_file( + file_path=file, + collection_number=1, + conversation_name=conversation_name, + ) await self.learn_from_websites( urls=urls, scrape_depth=3, From 25d6d5f913fb9b3d58507aace3213a356cbc1d3c Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 7 Jun 2024 12:02:53 -0400 Subject: [PATCH 0075/1256] improve file upload pipeline --- agixt/Websearch.py | 8 ++- agixt/XT.py | 120 ++++++++++++++++++++++++++++++++++++--------- 2 files changed, 102 insertions(+), 26 deletions(-) diff --git a/agixt/Websearch.py b/agixt/Websearch.py index ec43b8c950fc..b1e2e2ceee89 100644 --- a/agixt/Websearch.py +++ b/agixt/Websearch.py @@ -355,11 +355,15 @@ async def scrape_websites( ): # user_input = "I am browsing {url} and collecting data from it to learn more." c = None - if conversation_name != "" and conversation_name is not None: - c = Conversations(conversation_name=conversation_name, user=self.user) links = re.findall(r"(?Phttps?://[^\s]+)", user_input) if len(links) < 1: return "" + if conversation_name != "" and conversation_name is not None: + c = Conversations(conversation_name=conversation_name, user=self.user) + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY] Researching online.", + ) scraped_links = [] if links is not None and len(links) > 0: for link in links: diff --git a/agixt/XT.py b/agixt/XT.py index 354c3315eee6..22ff9e5dcbf7 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -539,12 +539,6 @@ async def learn_from_websites( else: url_str = {"\n".join(urls)} user_input = f"Learn from the information from these websites:\n {url_str} " - c = Conversations(conversation_name=conversation_name, user=self.user_email) - if conversation_name != "" and conversation_name != None: - c.log_interaction( - role=self.agent_name, - message=f"[ACTIVITY] Researching online.", - ) response = await self.agent_interactions.websearch.scrape_websites( user_input=user_input, search_depth=scrape_depth, @@ -557,6 +551,7 @@ async def learn_from_file( self, file_url: str = "", file_name: str = "", + user_input: str = "", collection_number: int = 1, conversation_name: str = "", ): @@ -586,6 +581,9 @@ async def learn_from_file( role=self.agent_name, message=f"[ACTIVITY] Reading file {file_name} into memory.", ) + if user_input == "": + user_input = "Describe each stage of this image." + file_type = file_name.split(".")[-1] file_reader = FileReader( agent_name=self.agent_name, agent_config=self.agent.AGENT_CONFIG, @@ -593,11 +591,73 @@ async def learn_from_file( ApiClient=self.ApiClient, user=self.user_email, ) - res = await file_reader.write_file_to_memory(file_path=file_path) - if res == True: - response = f"I have read the entire content of the file called {file_name} into my memory." + if ( + file_type == "wav" + or file_type == "mp3" + or file_type == "ogg" + or file_type == "m4a" + or file_type == "flac" + or file_type == "wma" + or file_type == "aac" + ): + audio = AudioSegment.from_file(file_path) + audio.export(file_path, format="wav") + if conversation_name != "" and conversation_name != None: + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY] Transcribing audio file `{file_name}` into memory.", + ) + audio_response = await self.audio_to_text(audio_path=file_path) + await file_reader.write_text_to_memory( + user_input=user_input, + text=f"Transcription from the audio file called `{file_name}`:\n{audio_response}\n", + external_source=f"Audio file called `{file_name}`", + ) + response = ( + f"I have transcribed the audio from `{file_name}` into my memory." + ) + # If it is an image, generate a description then save to memory + elif file_type in ["jpg", "jpeg", "png", "gif"]: + if "vision_provider" in self.agent.AGENT_CONFIG["settings"]: + vision_provider = self.agent.AGENT_CONFIG["settings"]["vision_provider"] + if ( + vision_provider != "None" + and vision_provider != "" + and vision_provider != None + ): + if conversation_name != "" and conversation_name != None: + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY] Viewing image `{file_name}`.", + ) + try: + vision_response = await self.agent.inference( + prompt=user_input, images=[file_url] + ) + await file_reader.write_text_to_memory( + user_input=user_input, + text=f"{self.agent_name}'s visual description from viewing uploaded image called `{file_name}`:\n{vision_response}\n", + external_source=f"Image called `{file_name}`", + ) + response = f"I have generated a description of the image called `{file_name}` into my memory." + except Exception as e: + logging.error(f"Error getting vision response: {e}") + response = ( + f"I was unable to view the image called `{file_name}`." + ) + else: + response = f"I was unable to view the image called `{file_name}`." else: - response = f"I was unable to read the file called {file_name}." + if conversation_name != "" and conversation_name != None: + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY] Reading file `{file_name}` into memory.", + ) + res = await file_reader.write_file_to_memory(file_path=file_path) + if res == True: + response = f"I have read the entire content of the file called {file_name} into my memory." + else: + response = f"I was unable to read the file called {file_name}." if conversation_name != "" and conversation_name != None: c.log_interaction( role=self.agent_name, @@ -616,7 +676,6 @@ async def chat_completions(self, prompt: ChatCompletions): dict: Chat completion response """ conversation_name = prompt.user - images = [] urls = [] files = [] new_prompt = "" @@ -742,8 +801,22 @@ async def chat_completions(self, prompt: ChatCompletions): image_path = os.path.join( os.getcwd(), "WORKSPACE", f"{uuid.uuid4().hex}.jpg" ) + if "file_name" in msg: + file_name = str(msg["file_name"]) + if file_name == "": + file_name = f"{uuid.uuid4().hex}.jpg" + file_name = "".join( + c if c.isalnum() else "_" for c in file_name + ) + else: + file_name = f"{uuid.uuid4().hex}.jpg" if url.startswith("http"): - images.append(url) + files.append( + { + "file_name": file_name, + "file_url": url, + } + ) else: file_type = url.split(",")[0].split("/")[1].split(";")[0] if file_type == "jpeg": @@ -763,7 +836,12 @@ async def chat_completions(self, prompt: ChatCompletions): image = base64.b64decode(url.split(",")[1]) with open(image_path, "wb") as f: f.write(image) - images.append(f"{self.outputs}/{file_name}") + files.append( + { + "file_name": file_name, + "file_url": f"{self.outputs}/{file_name}", + } + ) if "audio_url" in msg: audio_url = str( msg["audio_url"]["url"] @@ -840,14 +918,7 @@ async def chat_completions(self, prompt: ChatCompletions): ) file_path = os.path.join(os.getcwd(), "WORKSPACE", file_name) if file_url.startswith("http"): - if "." in file_url[-5:]: - file_type = file_url.split(".")[-1] - if file_type in ["jpg", "jpeg", "png"]: - images.append(file_url) - else: - files.append(file_url) - else: - urls.append(file_url) + files.append({"file_name": file_name, "file_url": file_url}) else: file_type = ( file_url.split(",")[0].split("/")[1].split(";")[0] @@ -857,13 +928,15 @@ async def chat_completions(self, prompt: ChatCompletions): with open(file_path, "wb") as f: f.write(file_data) file_url = f"{self.outputs}/{file_name}" - files.append(file_url) + files.append({"file_name": file_name, "file_url": file_url}) # Add user input to conversation c = Conversations(conversation_name=conversation_name, user=self.user_email) c.log_interaction(role="USER", message=new_prompt) for file in files: await self.learn_from_file( - file_path=file, + file_url=file["file_url"], + file_name=file["file_name"], + user_input=new_prompt, collection_number=1, conversation_name=conversation_name, ) @@ -918,7 +991,6 @@ async def chat_completions(self, prompt: ChatCompletions): shots=prompt.n, browse_links=browse_links, voice_response=tts, - images=images, log_user_input=False, **prompt_args, ) From 496d3fcd478e7eec70c9a138ad7b0ed8c336dd49 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 7 Jun 2024 12:27:19 -0400 Subject: [PATCH 0076/1256] use image url instead of file name --- agixt/XT.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/agixt/XT.py b/agixt/XT.py index 22ff9e5dcbf7..a547400e2d93 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -628,7 +628,7 @@ async def learn_from_file( if conversation_name != "" and conversation_name != None: c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY] Viewing image `{file_name}`.", + message=f"[ACTIVITY] Viewing image at {file_url}.", ) try: vision_response = await self.agent.inference( @@ -808,6 +808,9 @@ async def chat_completions(self, prompt: ChatCompletions): file_name = "".join( c if c.isalnum() else "_" for c in file_name ) + file_name = file_name.replace("_jpg", ".jpg") + if "." not in file_name: + file_name = f"{file_name}.jpg" else: file_name = f"{uuid.uuid4().hex}.jpg" if url.startswith("http"): From e91de84c2b14e8ac5f485f30178096dc6b9f04e2 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 7 Jun 2024 12:37:51 -0400 Subject: [PATCH 0077/1256] fix file extension on upload --- agixt/XT.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index a547400e2d93..80bf3e1c72a2 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -808,9 +808,10 @@ async def chat_completions(self, prompt: ChatCompletions): file_name = "".join( c if c.isalnum() else "_" for c in file_name ) - file_name = file_name.replace("_jpg", ".jpg") - if "." not in file_name: - file_name = f"{file_name}.jpg" + file_extension = file_name.split("_")[-1] + file_name = file_name.replace( + f"_{file_extension}", f".{file_extension}" + ) else: file_name = f"{uuid.uuid4().hex}.jpg" if url.startswith("http"): @@ -831,6 +832,10 @@ async def chat_completions(self, prompt: ChatCompletions): file_name = "".join( c if c.isalnum() else "_" for c in file_name ) + file_extension = file_name.split("_")[-1] + file_name = file_name.replace( + f"_{file_extension}", f".{file_extension}" + ) else: file_name = f"{uuid.uuid4().hex}.{file_type}" image_path = os.path.join( From 7aef69c5919f388ebe6bb5d4c45a14dffff74f7b Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 7 Jun 2024 12:42:43 -0400 Subject: [PATCH 0078/1256] add error activity --- agixt/XT.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index 80bf3e1c72a2..b89994aab799 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -642,11 +642,11 @@ async def learn_from_file( response = f"I have generated a description of the image called `{file_name}` into my memory." except Exception as e: logging.error(f"Error getting vision response: {e}") - response = ( - f"I was unable to view the image called `{file_name}`." - ) + response = f"[ERROR] I was unable to view the image called `{file_name}`." else: - response = f"I was unable to view the image called `{file_name}`." + response = ( + f"[ERROR] I was unable to view the image called `{file_name}`." + ) else: if conversation_name != "" and conversation_name != None: c.log_interaction( @@ -661,7 +661,11 @@ async def learn_from_file( if conversation_name != "" and conversation_name != None: c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY] {response}", + message=( + f"[ACTIVITY] {response}" + if "[ERROR]" not in response + else f"[ACTIVITY]{response}" + ), ) return response From 35024e0e2821b2b85c157f0be94783bb5656663c Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 7 Jun 2024 13:52:55 -0400 Subject: [PATCH 0079/1256] add file types for image viewing --- agixt/XT.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/agixt/XT.py b/agixt/XT.py index b89994aab799..8428a442b6c4 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -617,7 +617,16 @@ async def learn_from_file( f"I have transcribed the audio from `{file_name}` into my memory." ) # If it is an image, generate a description then save to memory - elif file_type in ["jpg", "jpeg", "png", "gif"]: + elif file_type in [ + "jpg", + "jpeg", + "png", + "gif", + "webp", + "tiff", + "bmp", + "svg", + ]: if "vision_provider" in self.agent.AGENT_CONFIG["settings"]: vision_provider = self.agent.AGENT_CONFIG["settings"]["vision_provider"] if ( From 07ebe3b8d6231e60288419730ab2faae34d67621 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 7 Jun 2024 13:55:53 -0400 Subject: [PATCH 0080/1256] remove unused ref --- agixt/XT.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index 8428a442b6c4..7c0816996620 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -811,9 +811,6 @@ async def chat_completions(self, prompt: ChatCompletions): if "url" in msg["image_url"] else msg["image_url"] ) - image_path = os.path.join( - os.getcwd(), "WORKSPACE", f"{uuid.uuid4().hex}.jpg" - ) if "file_name" in msg: file_name = str(msg["file_name"]) if file_name == "": From 4db7d640adf12d7df4d4000e0efb88194c95a351 Mon Sep 17 00:00:00 2001 From: Josh XT <102809327+Josh-XT@users.noreply.github.com> Date: Fri, 7 Jun 2024 20:03:34 -0400 Subject: [PATCH 0081/1256] Code reduction on chat completions function (#1207) * Code reduction on chat completions function * Add docs about file_name * handle potentially uncontrolled paths * add full path check --- agixt/XT.py | 260 +++++++++++-------------- docs/2-Concepts/04-Chat Completions.md | 3 + 2 files changed, 114 insertions(+), 149 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index 7c0816996620..1b738b22463d 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -25,7 +25,6 @@ def __init__(self, user: str, agent_name: str, api_key: str): self.api_key = api_key self.agent_name = agent_name self.uri = getenv("AGIXT_URI") - self.outputs = f"{self.uri}/outputs/" self.ApiClient = get_api_client(api_key) self.agent_interactions = Interactions( agent_name=self.agent_name, user=self.user_email, ApiClient=self.ApiClient @@ -37,6 +36,10 @@ def __init__(self, user: str, agent_name: str, api_key: str): else DEFAULT_SETTINGS ) self.chain = Chain(user=self.user_email) + self.agent_id = str(self.agent.get_agent_id()) + self.agent_workspace = os.path.join(os.getcwd(), "WORKSPACE", self.agent_id) + os.makedirs(self.agent_workspace, exist_ok=True) + self.outputs = f"{self.uri}/outputs/{self.agent_id}" async def prompts(self, prompt_category: str = "Default"): """ @@ -201,7 +204,10 @@ async def text_to_speech(self, text: str, conversation_name: str = ""): if not str(tts_url).startswith("http"): file_type = "wav" file_name = f"{uuid.uuid4().hex}.{file_type}" - audio_path = f"./WORKSPACE/{file_name}" + audio_path = os.path.join(self.agent_workspace, file_name) + full_path = os.path.normpath(os.path.join(self.agent_workspace, file_name)) + if not full_path.startswith(self.agent_workspace): + raise Exception("Path given not allowed") audio_data = base64.b64decode(tts_url) with open(audio_path, "wb") as f: f.write(audio_data) @@ -570,9 +576,12 @@ async def learn_from_file( if file_name == "": file_name = file_url.split("/")[-1] if file_url.startswith(self.outputs): - file_path = os.path.join(os.getcwd(), "WORKSPACE", file_name) + file_path = os.path.join(self.agent_workspace, file_name) else: - file_path = os.path.join(os.getcwd(), "WORKSPACE", file_name) + file_path = os.path.join(self.agent_workspace, file_name) + full_path = os.path.normpath(os.path.join(self.agent_workspace, file_name)) + if not full_path.startswith(self.agent_workspace): + raise Exception("Path given not allowed") with open(file_path, "wb") as f: f.write(requests.get(file_url).content) if conversation_name != "" and conversation_name != None: @@ -678,6 +687,44 @@ async def learn_from_file( ) return response + async def download_file_to_workspace(self, url: str, file_name: str = ""): + """ + Download a file from a URL to the workspace + + Args: + url (str): URL of the file + file_name (str): Name of the file + + Returns: + str: URL of the downloaded file + """ + if url.startswith("data:"): + file_type = url.split(",")[0].split("/")[1].split(";")[0] + else: + file_type = url.split(".")[-1] + if not file_type: + file_type = "txt" + file_name = f"{uuid.uuid4().hex}.{file_type}" if file_name == "" else file_name + file_name = "".join(c if c.isalnum() else "_" for c in file_name) + file_extension = file_name.split("_")[-1] + file_name = file_name.replace(f"_{file_extension}", f".{file_extension}") + file_path = os.path.join(self.agent_workspace, file_name) + full_path = os.path.normpath(os.path.join(self.agent_workspace, file_name)) + if not full_path.startswith(self.agent_workspace): + raise Exception("Path given not allowed") + if url.startswith("http"): + return {"file_name": file_name, "file_url": url} + else: + file_type = url.split(",")[0].split("/")[1].split(";")[0] + file_data = base64.b64decode(url.split(",")[1]) + full_path = os.path.normpath(os.path.join(self.agent_workspace, file_name)) + if not full_path.startswith(self.agent_workspace): + raise Exception("Path given not allowed") + with open(file_path, "wb") as f: + f.write(file_data) + url = f"{self.outputs}/{file_name}" + return {"file_name": file_name, "file_url": url} + async def chat_completions(self, prompt: ChatCompletions): """ Generate an OpenAI style chat completion response with a ChatCompletion prompt @@ -694,7 +741,6 @@ async def chat_completions(self, prompt: ChatCompletions): new_prompt = "" browse_links = True tts = False - base_path = os.path.join(os.getcwd(), "WORKSPACE") if "mode" in self.agent_settings: mode = self.agent_settings["mode"] else: @@ -805,148 +851,53 @@ async def chat_completions(self, prompt: ChatCompletions): role = message["role"] if "role" in message else "User" if role.lower() == "user": new_prompt += f"{msg['text']}\n\n" - if "image_url" in msg: - url = str( - msg["image_url"]["url"] - if "url" in msg["image_url"] - else msg["image_url"] - ) - if "file_name" in msg: - file_name = str(msg["file_name"]) - if file_name == "": - file_name = f"{uuid.uuid4().hex}.jpg" - file_name = "".join( - c if c.isalnum() else "_" for c in file_name - ) - file_extension = file_name.split("_")[-1] - file_name = file_name.replace( - f"_{file_extension}", f".{file_extension}" - ) - else: - file_name = f"{uuid.uuid4().hex}.jpg" - if url.startswith("http"): - files.append( - { - "file_name": file_name, - "file_url": url, - } - ) - else: - file_type = url.split(",")[0].split("/")[1].split(";")[0] - if file_type == "jpeg": - file_type = "jpg" - if "file_name" in msg: - file_name = str(msg["file_name"]) - if file_name == "": - file_name = f"{uuid.uuid4().hex}.{file_type}" - file_name = "".join( - c if c.isalnum() else "_" for c in file_name - ) - file_extension = file_name.split("_")[-1] - file_name = file_name.replace( - f"_{file_extension}", f".{file_extension}" - ) - else: - file_name = f"{uuid.uuid4().hex}.{file_type}" - image_path = os.path.join( - os.getcwd(), "WORKSPACE", file_name - ) - image = base64.b64decode(url.split(",")[1]) - with open(image_path, "wb") as f: - f.write(image) - files.append( - { - "file_name": file_name, - "file_url": f"{self.outputs}/{file_name}", - } - ) - if "audio_url" in msg: - audio_url = str( - msg["audio_url"]["url"] - if "url" in msg["audio_url"] - else msg["audio_url"] - ) - # If it is not a url, we need to find the file type and convert with pydub - if not audio_url.startswith("http"): - file_type = ( - audio_url.split(",")[0].split("/")[1].split(";")[0] - ) - audio_data = base64.b64decode(audio_url.split(",")[1]) - audio_path = os.path.join( - os.getcwd(), - "WORKSPACE", - f"{uuid.uuid4().hex}.{file_type}", - ) - with open(audio_path, "wb") as f: - f.write(audio_data) - audio_url = audio_path - else: - # Download the audio file from the url, get the file type and convert to wav - audio_type = audio_url.split(".")[-1] - audio_url = os.path.join( - os.getcwd(), - "WORKSPACE", - f"{uuid.uuid4().hex}.{audio_type}", - ) - audio_data = requests.get(audio_url).content - with open(audio_url, "wb") as f: - f.write(audio_data) - if audio_url.startswith(base_path): - wav_file = f"./WORKSPACE/{uuid.uuid4().hex}.wav" - AudioSegment.from_file(audio_url).set_frame_rate( - 16000 - ).export(wav_file, format="wav") - transcribed_audio = await self.audio_to_text( - audio_path=wav_file, - conversation_name=conversation_name, - ) - new_prompt += transcribed_audio - if "video_url" in msg: - video_url = str( - msg["video_url"]["url"] - if "url" in msg["video_url"] - else msg["video_url"] - ) - if video_url.startswith("http"): - urls.append(video_url) - if ( - "file_url" in msg - or "application_url" in msg - or "text_url" in msg - or "url" in msg - ): - file_url = str( - msg["file_url"]["url"] - if "url" in msg["file_url"] - else msg["file_url"] - ) - if file_url.startswith("data:"): - file_type = ( - file_url.split(",")[0].split("/")[1].split(";")[0] - ) - else: - file_type = file_url.split(".")[-1] - file_name = f"{uuid.uuid4().hex}.{file_type}" - if "file_name" in msg: - file_name = str(msg["file_name"]) - if file_name == "": - file_name = f"{uuid.uuid4().hex}.{file_type}" - file_name = "".join( - c if c.isalnum() else "_" for c in file_name - ) - file_path = os.path.join(os.getcwd(), "WORKSPACE", file_name) - if file_url.startswith("http"): - files.append({"file_name": file_name, "file_url": file_url}) - else: - file_type = ( - file_url.split(",")[0].split("/")[1].split(";")[0] - ) - file_data = base64.b64decode(file_url.split(",")[1]) - if file_path.startswith(base_path): - with open(file_path, "wb") as f: - f.write(file_data) - file_url = f"{self.outputs}/{file_name}" - files.append({"file_name": file_name, "file_url": file_url}) + # Iterate over the msg to find _url in one of the keys then use the value of that key unless it has a "url" under it + if isinstance(msg, dict): + for key, value in msg.items(): + if "_url" in key: + url = str(value["url"] if "url" in value else value) + if "file_name" in msg: + file_name = str(msg["file_name"]) + else: + file_name = "" + if key != "audio_url": + files.append( + await self.download_file_to_workspace( + url=url, file_name=file_name + ) + ) + else: + # If there is an audio_url, it is the user's voice input that needs transcribed before running inference + audio_file_info = ( + await self.download_file_to_workspace(url=url) + ) + full_path = os.path.normpath( + os.path.join( + self.agent_workspace, + audio_file_info["file_name"], + ) + ) + if not full_path.startswith(self.agent_workspace): + raise Exception("Path given not allowed") + audio_file_path = os.path.join( + self.agent_workspace, + audio_file_info["file_name"], + ) + if url.startswith(self.agent_workspace): + wav_file = os.path.join( + self.agent_workspace, + f"{uuid.uuid4().hex}.wav", + ) + AudioSegment.from_file( + audio_file_path + ).set_frame_rate(16000).export( + wav_file, format="wav" + ) + transcribed_audio = await self.audio_to_text( + audio_path=wav_file, + conversation_name=conversation_name, + ) + new_prompt += transcribed_audio # Add user input to conversation c = Conversations(conversation_name=conversation_name, user=self.user_email) c.log_interaction(role="USER", message=new_prompt) @@ -1170,8 +1121,19 @@ async def create_dataset_from_memories(self, batch_size: int = 10): } # Save messages to a json file to be used as a dataset agent_id = self.agent_interactions.agent.get_agent_id() - os.makedirs(f"./WORKSPACE/{agent_id}/datasets", exist_ok=True) - with open(f"./WORKSPACE/{agent_id}/datasets/{dataset_name}.json", "w") as f: + dataset_dir = os.path.join(self.agent_workspace, "datasets") + + os.makedirs(dataset_dir, exist_ok=True) + dataset_name = "".join( + [c for c in dataset_name if c.isalpha() or c.isdigit() or c == " "] + ) + dataset_filename = f"{dataset_name}.json" + full_path = os.path.normpath( + os.path.join(self.agent_workspace, dataset_filename) + ) + if not full_path.startswith(self.agent_workspace): + raise Exception("Path given not allowed") + with open(os.path.join(dataset_dir, dataset_filename), "w") as f: f.write(json.dumps(dpo_dataset)) self.agent_settings["training"] = False self.agent_interactions.agent.update_agent_config( diff --git a/docs/2-Concepts/04-Chat Completions.md b/docs/2-Concepts/04-Chat Completions.md index 3f6e1fb6baa9..1fff1a7ea296 100644 --- a/docs/2-Concepts/04-Chat Completions.md +++ b/docs/2-Concepts/04-Chat Completions.md @@ -62,12 +62,14 @@ response = openai.chat.completions.create( {"type": "text", "text": "YOUR USER INPUT TO THE AGENT GOES HERE"}, { "type": "image_url", + "file_name": "funny_cat.jpg", # Optional field, defaults to a random name. "image_url": { # Will download the image and send it to the vision model "url": f"https://www.visualwatermark.com/images/add-text-to-photos/add-text-to-image-3.webp" }, }, { "type": "text_url", # Or just "url" + "file_name": "agixt_com_website_text.txt", # Optional field, defaults to a random name. "text_url": { # Content of the text or URL for it to be scraped "url": "https://agixt.com" }, @@ -75,6 +77,7 @@ response = openai.chat.completions.create( }, { "type": "application_url", + "file_name": "important_document.pdf", # Optional field, defaults to a random name. "application_url": { # Will scrape mime type `application` into the agent's memory "url": "data:application/pdf;base64,base64_encoded_pdf_here" }, From 3aeba3851ed3ce102227fdc258676c47a5b6f709 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 7 Jun 2024 23:52:09 -0400 Subject: [PATCH 0082/1256] add decision maker to websearch --- agixt/Websearch.py | 13 +++++++++++++ agixt/prompts/Default/WebSearch Decision.txt | 11 +++++++++++ agixt/prompts/Default/WebSearch.txt | 3 --- 3 files changed, 24 insertions(+), 3 deletions(-) create mode 100644 agixt/prompts/Default/WebSearch Decision.txt diff --git a/agixt/Websearch.py b/agixt/Websearch.py index b1e2e2ceee89..2d0bac2a7658 100644 --- a/agixt/Websearch.py +++ b/agixt/Websearch.py @@ -548,6 +548,18 @@ async def websearch_agent( websearch_timeout = 0 if websearch_depth > 0: if len(user_input) > 0: + to_search_or_not = self.ApiClient.prompt_agent( + agent_name=self.agent_name, + prompt_name="WebSearch Decision", + prompt_args={ + "user_input": user_input, + "websearch": "false", + "browse_links": "false", + "tts": "false", + }, + ) + if not str(to_search_or_not).lower().startswith("y"): + return search_string = self.ApiClient.prompt_agent( agent_name=self.agent_name, prompt_name="WebSearch", @@ -555,6 +567,7 @@ async def websearch_agent( "user_input": user_input, "browse_links": "false", "websearch": "false", + "tts": "false", }, ) keywords = extract_keywords(text=user_input, limit=5) diff --git a/agixt/prompts/Default/WebSearch Decision.txt b/agixt/prompts/Default/WebSearch Decision.txt new file mode 100644 index 000000000000..836d13365ec5 --- /dev/null +++ b/agixt/prompts/Default/WebSearch Decision.txt @@ -0,0 +1,11 @@ + {context} + +Today's date is {date}. + +User's input: {user_input} + +The assistant needs to decide if the user's input merits searching the web to assist them before responding, or if the assistant can respond directly. + +If the assistant decides to search the web, say `Yes`. If the assistant decides not to search the web, say `No`. + +**The assistant responds only with Yes or No.** \ No newline at end of file diff --git a/agixt/prompts/Default/WebSearch.txt b/agixt/prompts/Default/WebSearch.txt index 02f3a000d3d9..3a10659a5220 100644 --- a/agixt/prompts/Default/WebSearch.txt +++ b/agixt/prompts/Default/WebSearch.txt @@ -1,8 +1,5 @@ {context} -Recent conversation history for context: - {conversation_history} - Today's date is {date}. User's input: {user_input} From d2b94ae589156c944c0a6340f13a3e95d1bd2c71 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 7 Jun 2024 23:59:47 -0400 Subject: [PATCH 0083/1256] Improve timestamps --- agixt/Conversations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agixt/Conversations.py b/agixt/Conversations.py index 420e546ecdb1..677aacba564c 100644 --- a/agixt/Conversations.py +++ b/agixt/Conversations.py @@ -151,7 +151,7 @@ def log_interaction(self, role, message): conversation = self.new_conversation() session.close() session = get_session() - timestamp = datetime.now().strftime("%B %d, %Y %I:%M %p") + timestamp = datetime.now().strftime("%Y-%m-%d-%I:%M:%S%p") try: new_message = Message( role=role, From 153c5255898e5916ddc081ec43dedde1906957c3 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sat, 8 Jun 2024 00:01:27 -0400 Subject: [PATCH 0084/1256] use 24hr format --- agixt/Conversations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agixt/Conversations.py b/agixt/Conversations.py index 677aacba564c..19e211ad87e9 100644 --- a/agixt/Conversations.py +++ b/agixt/Conversations.py @@ -151,7 +151,7 @@ def log_interaction(self, role, message): conversation = self.new_conversation() session.close() session = get_session() - timestamp = datetime.now().strftime("%Y-%m-%d-%I:%M:%S%p") + timestamp = datetime.now().strftime("%Y/%m/%d-%H:%M:%S") try: new_message = Message( role=role, From 977bb8de13bb94b8fd1db9ff85ea2a55da8d0799 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sat, 8 Jun 2024 00:02:17 -0400 Subject: [PATCH 0085/1256] no slashes --- agixt/Conversations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agixt/Conversations.py b/agixt/Conversations.py index 19e211ad87e9..5a90907b3e98 100644 --- a/agixt/Conversations.py +++ b/agixt/Conversations.py @@ -151,7 +151,7 @@ def log_interaction(self, role, message): conversation = self.new_conversation() session.close() session = get_session() - timestamp = datetime.now().strftime("%Y/%m/%d-%H:%M:%S") + timestamp = datetime.now().strftime("%Y-%m-%d-%H:%M:%S") try: new_message = Message( role=role, From c7cb0a6a3bcf9d9f146b21af23473569cfa3f40c Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sat, 8 Jun 2024 00:07:41 -0400 Subject: [PATCH 0086/1256] add space --- agixt/Conversations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agixt/Conversations.py b/agixt/Conversations.py index 5a90907b3e98..cccccce1d711 100644 --- a/agixt/Conversations.py +++ b/agixt/Conversations.py @@ -151,7 +151,7 @@ def log_interaction(self, role, message): conversation = self.new_conversation() session.close() session = get_session() - timestamp = datetime.now().strftime("%Y-%m-%d-%H:%M:%S") + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") try: new_message = Message( role=role, From fc215e5628fac33b43c6e190631ac1ed62af240e Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sat, 8 Jun 2024 00:11:50 -0400 Subject: [PATCH 0087/1256] automatic update timestamp --- agixt/Conversations.py | 4 ---- agixt/DB.py | 8 ++++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/agixt/Conversations.py b/agixt/Conversations.py index cccccce1d711..9fc352365c35 100644 --- a/agixt/Conversations.py +++ b/agixt/Conversations.py @@ -126,7 +126,6 @@ def new_conversation(self, conversation_content=[]): new_message = Message( role=interaction["role"], content=interaction["message"], - timestamp=interaction["timestamp"], conversation_id=conversation.id, ) session.add(new_message) @@ -151,12 +150,10 @@ def log_interaction(self, role, message): conversation = self.new_conversation() session.close() session = get_session() - timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") try: new_message = Message( role=role, content=message, - timestamp=timestamp, conversation_id=conversation.id, ) except Exception as e: @@ -166,7 +163,6 @@ def log_interaction(self, role, message): new_message = Message( role=role, content=message, - timestamp=timestamp, conversation_id=conversation.id, ) session.add(new_message) diff --git a/agixt/DB.py b/agixt/DB.py index 8586f859c263..4720d06a455e 100644 --- a/agixt/DB.py +++ b/agixt/DB.py @@ -162,7 +162,7 @@ class AgentBrowsedLink(Base): nullable=False, ) link = Column(Text, nullable=False) - timestamp = Column(DateTime, server_default=text("now()")) + timestamp = Column(DateTime, server_default=func.now()) class Agent(Base): @@ -250,7 +250,7 @@ class Message(Base): ) role = Column(Text, nullable=False) content = Column(Text, nullable=False) - timestamp = Column(DateTime, server_default=text("now()")) + timestamp = Column(DateTime, server_default=func.now()) conversation_id = Column( UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, ForeignKey("conversation.id"), @@ -397,7 +397,7 @@ class ChainRun(Base): ForeignKey("user.id"), nullable=True, ) - timestamp = Column(DateTime, server_default=text("now()")) + timestamp = Column(DateTime, server_default=func.now()) chain_step_responses = relationship( "ChainStepResponse", backref="chain_run", cascade="all, delete-orphan" ) @@ -420,7 +420,7 @@ class ChainStepResponse(Base): ForeignKey("chain_run.id", ondelete="CASCADE"), nullable=True, ) - timestamp = Column(DateTime, server_default=text("now()")) + timestamp = Column(DateTime, server_default=func.now()) content = Column(Text, nullable=False) From 331d99a012fc3b9ebd3d9a39a3f79c377f99a781 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sat, 8 Jun 2024 00:48:13 -0400 Subject: [PATCH 0088/1256] add activity verbosity --- agixt/Interactions.py | 72 +++++++++++++++++++++++++++++++++++-------- agixt/Websearch.py | 71 +++++++++++++++++++++++++----------------- 2 files changed, 103 insertions(+), 40 deletions(-) diff --git a/agixt/Interactions.py b/agixt/Interactions.py index b4452ad264c6..92356d6f67b0 100644 --- a/agixt/Interactions.py +++ b/agixt/Interactions.py @@ -320,7 +320,12 @@ async def format_prompt( tokens_used = get_tokens( f"{prompt}{user_input}{all_files_content}{context}" ) - if tokens_used > int(self.agent.MAX_TOKENS) or files == []: + agent_max_tokens = int( + self.agent.AGENT_CONFIG["settings"]["MAX_TOKENS"] + if "MAX_TOKENS" in self.agent.AGENT_CONFIG["settings"] + else 8192 + ) + if tokens_used > agent_max_tokens or files == []: fragmented_content = await file_reader.get_memories( user_input=f"{user_input} {file_list}", min_relevance_score=0.3, @@ -379,6 +384,7 @@ async def run( persist_context_in_history: bool = False, images: list = [], log_user_input: bool = True, + log_output: bool = True, **kwargs, ): global AGIXT_URI @@ -455,6 +461,13 @@ async def run( conversation_name=conversation_name, ) if websearch: + if browse_links != False: + await self.websearch.scrape_websites( + user_input=user_input, + search_depth=websearch_depth, + summarize_content=False, + conversation_name=conversation_name, + ) if user_input == "": if "primary_objective" in kwargs and "task" in kwargs: user_input = f"Primary Objective: {kwargs['primary_objective']}\n\nTask: {kwargs['task']}" @@ -463,16 +476,50 @@ async def run( if user_input != "": c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY] Searching the web.", + message=f"[ACTIVITY] Searching for information.", ) - # try: - await self.websearch.websearch_agent( + to_search_or_not_to_search = await self.run( + prompt_name="Websearch Decision", + prompt_category="Default", user_input=user_input, - websearch_depth=websearch_depth, - websearch_timeout=websearch_timeout, + context_results=context_results, + conversation_name=conversation_name, + log_user_input=False, + log_output=False, + browse_links=False, + websearch=False, + tts=False, ) - # except Exception as e: - # logging.warning(f"Failed to websearch. Error: {e}") + logging.info(f"Search Decision: {to_search_or_not_to_search[:10]}") + if str(to_search_or_not_to_search).lower().startswith("y"): + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY] Searching the web.", + ) + search_string = await self.run( + prompt_name="WebSearch", + prompt_category="Default", + user_input=user_input, + context_results=context_results, + conversation_name=conversation_name, + log_user_input=False, + log_output=False, + browse_links=False, + websearch=False, + tts=False, + ) + await self.websearch.websearch_agent( + user_input=user_input, + search_string=search_string, + websearch_depth=websearch_depth, + websearch_timeout=websearch_timeout, + conversation_name=conversation_name, + ) + else: + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY] Decided searching the web is not necessary.", + ) vision_response = "" if "vision_provider" in self.agent.AGENT_CONFIG["settings"]: vision_provider = self.agent.AGENT_CONFIG["settings"]["vision_provider"] @@ -642,10 +689,11 @@ async def run( logging.warning( f"Failed to generate image for prompt: {image_generation_prompt}" ) - c.log_interaction( - role=self.agent_name, - message=self.response, - ) + if log_output: + c.log_interaction( + role=self.agent_name, + message=self.response, + ) if shots > 1: responses = [self.response] for shot in range(shots - 1): diff --git a/agixt/Websearch.py b/agixt/Websearch.py index 2d0bac2a7658..8b6bd23f91d9 100644 --- a/agixt/Websearch.py +++ b/agixt/Websearch.py @@ -277,7 +277,9 @@ async def get_web_content(self, url: str, summarize_content=False): except: return None, None - async def recursive_browsing(self, user_input, links): + async def recursive_browsing(self, user_input, links, conversation_name: str = ""): + if conversation_name != "" and conversation_name is not None: + c = Conversations(conversation_name=conversation_name, user=self.user) try: words = links.split() links = [ @@ -311,6 +313,11 @@ async def recursive_browsing(self, user_input, links): url = link url = re.sub(r"^.*?(http)", r"http", url) if self.verify_link(link=url): + if conversation_name != "" and conversation_name is not None: + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY] Browsing link: {url}", + ) ( collected_data, link_list, @@ -319,6 +326,14 @@ async def recursive_browsing(self, user_input, links): if len(link_list) > 0: if len(link_list) > 5: link_list = link_list[:3] + if ( + conversation_name != "" + and conversation_name is not None + ): + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY] Found {len(link_list)} links on {url} . Choosing one to browse next.", + ) try: pick_a_link = self.ApiClient.prompt_agent( agent_name=self.agent_name, @@ -341,10 +356,20 @@ async def recursive_browsing(self, user_input, links): f"AI has decided to click: {pick_a_link}" ) await self.recursive_browsing( - user_input=user_input, links=pick_a_link + user_input=user_input, + links=pick_a_link, + conversation_name=conversation_name, ) except: logging.info(f"Issues reading {url}. Moving on...") + if ( + conversation_name != "" + and conversation_name is not None + ): + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY] Issues reading {url}. Moving on...", + ) async def scrape_websites( self, @@ -534,10 +559,11 @@ async def web_search(self, query: str) -> List[str]: async def websearch_agent( self, user_input: str = "What are the latest breakthroughs in AI?", + search_string: str = "", websearch_depth: int = 0, websearch_timeout: int = 0, + conversation_name: str = "", ): - await self.scrape_websites(user_input=user_input, search_depth=websearch_depth) try: websearch_depth = int(websearch_depth) except: @@ -548,33 +574,19 @@ async def websearch_agent( websearch_timeout = 0 if websearch_depth > 0: if len(user_input) > 0: - to_search_or_not = self.ApiClient.prompt_agent( - agent_name=self.agent_name, - prompt_name="WebSearch Decision", - prompt_args={ - "user_input": user_input, - "websearch": "false", - "browse_links": "false", - "tts": "false", - }, - ) - if not str(to_search_or_not).lower().startswith("y"): - return - search_string = self.ApiClient.prompt_agent( - agent_name=self.agent_name, - prompt_name="WebSearch", - prompt_args={ - "user_input": user_input, - "browse_links": "false", - "websearch": "false", - "tts": "false", - }, - ) - keywords = extract_keywords(text=user_input, limit=5) + keywords = extract_keywords(text=search_string, limit=5) if keywords: search_string = " ".join(keywords) # add month and year to the end of the search string search_string += f" {datetime.now().strftime('%B %Y')}" + if conversation_name != "" and conversation_name is not None: + c = Conversations( + conversation_name=conversation_name, user=self.user + ) + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY] Searching for `{search_string}`.", + ) google_api_key = ( self.agent_settings["GOOGLE_API_KEY"] if "GOOGLE_API_KEY" in self.agent_settings @@ -610,10 +622,13 @@ async def websearch_agent( links = links[:websearch_depth] if links is not None and len(links) > 0: task = asyncio.create_task( - self.recursive_browsing(user_input=user_input, links=links) + self.recursive_browsing( + user_input=user_input, + links=links, + conversation_name=conversation_name, + ) ) self.tasks.append(task) - if int(websearch_timeout) == 0: await asyncio.gather(*self.tasks) else: From 33a31a8668c84aa8662d3c48980f5e1d9b9d1ac7 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sat, 8 Jun 2024 00:54:48 -0400 Subject: [PATCH 0089/1256] fix typo --- agixt/Interactions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agixt/Interactions.py b/agixt/Interactions.py index 92356d6f67b0..ca880dd2b503 100644 --- a/agixt/Interactions.py +++ b/agixt/Interactions.py @@ -479,7 +479,7 @@ async def run( message=f"[ACTIVITY] Searching for information.", ) to_search_or_not_to_search = await self.run( - prompt_name="Websearch Decision", + prompt_name="WebSearch Decision", prompt_category="Default", user_input=user_input, context_results=context_results, From 16b3c8deccc50e43e6aa693405f24b38341b43c2 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sat, 8 Jun 2024 01:05:23 -0400 Subject: [PATCH 0090/1256] use regex to find a yes --- agixt/Interactions.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/agixt/Interactions.py b/agixt/Interactions.py index ca880dd2b503..c1696c077cb9 100644 --- a/agixt/Interactions.py +++ b/agixt/Interactions.py @@ -490,8 +490,10 @@ async def run( websearch=False, tts=False, ) - logging.info(f"Search Decision: {to_search_or_not_to_search[:10]}") - if str(to_search_or_not_to_search).lower().startswith("y"): + to_search = re.search( + r"\byes\b", str(to_search_or_not_to_search).lower() + ) + if to_search: c.log_interaction( role=self.agent_name, message=f"[ACTIVITY] Searching the web.", @@ -667,7 +669,8 @@ async def run( create_img = await self.agent.inference(prompt=img_gen_prompt) create_img = str(create_img).lower() logging.info(f"Image Generation Decision Response: {create_img}") - if "yes" in create_img or "es," in create_img: + to_create_image = re.search(r"\byes\b", str(create_img).lower()) + if to_create_image: img_prompt = f"**The assistant is acting as a Stable Diffusion Prompt Generator.**\n\nUsers message: {user_input} \nAssistant response: {self.response} \n\nImportant rules to follow:\n- Describe subjects in detail, specify image type (e.g., digital illustration), art style (e.g., steampunk), and background. Include art inspirations (e.g., Art Station, specific artists). Detail lighting, camera (type, lens, view), and render (resolution, style). The weight of a keyword can be adjusted by using the syntax (((keyword))) , put only those keyword inside ((())) which is very important because it will have more impact so anything wrong will result in unwanted picture so be careful. Realistic prompts: exclude artist, specify lens. Separate with double lines. Max 60 words, avoiding 'real' for fantastical.\n- Based on the message from the user and response of the assistant, you will need to generate one detailed stable diffusion image generation prompt based on the context of the conversation to accompany the assistant response.\n- The prompt can only be up to 60 words long, so try to be concise while using enough descriptive words to make a proper prompt.\n- Following all rules will result in a $2000 tip that you can spend on anything!\n- Must be in markdown code block to be parsed out and only provide prompt in the code block, nothing else.\nStable Diffusion Prompt Generator: " image_generation_prompt = await self.agent.inference( prompt=img_prompt From c62723ca920683c511443498bb9cf7a302c3b451 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sat, 8 Jun 2024 02:21:03 -0400 Subject: [PATCH 0091/1256] fix ssrf --- agixt/XT.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index 1b738b22463d..e59159801ccc 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -582,8 +582,9 @@ async def learn_from_file( full_path = os.path.normpath(os.path.join(self.agent_workspace, file_name)) if not full_path.startswith(self.agent_workspace): raise Exception("Path given not allowed") - with open(file_path, "wb") as f: - f.write(requests.get(file_url).content) + if file_url.startswith("http"): + with open(file_path, "wb") as f: + f.write(requests.get(file_url).content) if conversation_name != "" and conversation_name != None: c = Conversations(conversation_name=conversation_name, user=self.user_email) c.log_interaction( From b6140299dbae6ce71a2ff970a89432d58c91fa7e Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sat, 8 Jun 2024 02:22:53 -0400 Subject: [PATCH 0092/1256] use full_path --- agixt/XT.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agixt/XT.py b/agixt/XT.py index e59159801ccc..00028e065194 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -583,7 +583,7 @@ async def learn_from_file( if not full_path.startswith(self.agent_workspace): raise Exception("Path given not allowed") if file_url.startswith("http"): - with open(file_path, "wb") as f: + with open(full_path, "wb") as f: f.write(requests.get(file_url).content) if conversation_name != "" and conversation_name != None: c = Conversations(conversation_name=conversation_name, user=self.user_email) From 5b70dbc309f0bc3264af66bfc15284726da9b28c Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sat, 8 Jun 2024 10:34:49 -0400 Subject: [PATCH 0093/1256] only allow https --- agixt/XT.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agixt/XT.py b/agixt/XT.py index 00028e065194..99278783f0a4 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -582,7 +582,7 @@ async def learn_from_file( full_path = os.path.normpath(os.path.join(self.agent_workspace, file_name)) if not full_path.startswith(self.agent_workspace): raise Exception("Path given not allowed") - if file_url.startswith("http"): + if file_url.startswith("https"): with open(full_path, "wb") as f: f.write(requests.get(file_url).content) if conversation_name != "" and conversation_name != None: From a30a4996e38380e95369b4149dac4519a233ad95 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sat, 8 Jun 2024 10:36:12 -0400 Subject: [PATCH 0094/1256] revert https requirement --- agixt/XT.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agixt/XT.py b/agixt/XT.py index 99278783f0a4..00028e065194 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -582,7 +582,7 @@ async def learn_from_file( full_path = os.path.normpath(os.path.join(self.agent_workspace, file_name)) if not full_path.startswith(self.agent_workspace): raise Exception("Path given not allowed") - if file_url.startswith("https"): + if file_url.startswith("http"): with open(full_path, "wb") as f: f.write(requests.get(file_url).content) if conversation_name != "" and conversation_name != None: From 05c4722191803fea962cf9de20d5d26c9279fd6d Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sat, 8 Jun 2024 19:06:32 -0400 Subject: [PATCH 0095/1256] use download_file_to_workspace in learn_from_file --- agixt/XT.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index 00028e065194..4c5046977e50 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -578,13 +578,11 @@ async def learn_from_file( if file_url.startswith(self.outputs): file_path = os.path.join(self.agent_workspace, file_name) else: + file_data = await self.download_file_to_workspace( + url=file_url, file_name=file_name + ) + file_name = file_data["file_name"] file_path = os.path.join(self.agent_workspace, file_name) - full_path = os.path.normpath(os.path.join(self.agent_workspace, file_name)) - if not full_path.startswith(self.agent_workspace): - raise Exception("Path given not allowed") - if file_url.startswith("http"): - with open(full_path, "wb") as f: - f.write(requests.get(file_url).content) if conversation_name != "" and conversation_name != None: c = Conversations(conversation_name=conversation_name, user=self.user_email) c.log_interaction( From 16b8606c3e23b2546455a878ee2707ee51c77c09 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 08:22:20 -0400 Subject: [PATCH 0096/1256] remove statement --- agixt/Memories.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/agixt/Memories.py b/agixt/Memories.py index 41ef38eca36b..6b2a32f68a95 100644 --- a/agixt/Memories.py +++ b/agixt/Memories.py @@ -422,10 +422,8 @@ async def get_memories( default_collection_name = self.collection_name if self.user != DEFAULT_USER: self.collection_name = snake(f"{snake(DEFAULT_USER)}_{self.agent_name}") - if self.collection_number > 0: - self.collection_name = ( - f"{self.collection_name}_{self.collection_number}" - ) + # if self.collection_number > 0: + self.collection_name = f"{self.collection_name}_{self.collection_number}" try: default_results = await self.get_memories_data( user_input=user_input, From 6c490317f5dbc9ce70583846d138ca9119829314 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 08:24:23 -0400 Subject: [PATCH 0097/1256] use _0 for collection 0 --- agixt/Memories.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/agixt/Memories.py b/agixt/Memories.py index 6b2a32f68a95..005538a2785b 100644 --- a/agixt/Memories.py +++ b/agixt/Memories.py @@ -170,8 +170,7 @@ def __init__( self.collection_name = snake(f"{snake(DEFAULT_USER)}_{agent_name}") self.user = user self.collection_number = collection_number - if collection_number > 0: - self.collection_name = f"{self.collection_name}_{collection_number}" + self.collection_name = f"{self.collection_name}_{collection_number}" if agent_config is None: agent_config = ApiClient.get_agentconfig(agent_name=agent_name) self.agent_config = ( @@ -251,8 +250,7 @@ async def import_collections_from_json(self, json_data: List[dict]): collection_number = 0 self.collection_number = collection_number self.collection_name = snake(self.agent_name) - if collection_number > 0: - self.collection_name = f"{self.collection_name}_{collection_number}" + self.collection_name = f"{self.collection_name}_{collection_number}" for val in value[self.collection_name]: try: await self.write_text_to_memory( @@ -422,7 +420,6 @@ async def get_memories( default_collection_name = self.collection_name if self.user != DEFAULT_USER: self.collection_name = snake(f"{snake(DEFAULT_USER)}_{self.agent_name}") - # if self.collection_number > 0: self.collection_name = f"{self.collection_name}_{self.collection_number}" try: default_results = await self.get_memories_data( From aa29fadfed64856040e962077763081e9bc17e10 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 08:29:51 -0400 Subject: [PATCH 0098/1256] use gpt-4o as openai default --- agixt/providers/openai.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/agixt/providers/openai.py b/agixt/providers/openai.py index 6a095d5b5b0d..6f1755c2d67d 100644 --- a/agixt/providers/openai.py +++ b/agixt/providers/openai.py @@ -21,9 +21,9 @@ class OpenaiProvider: def __init__( self, OPENAI_API_KEY: str = "", - AI_MODEL: str = "gpt-3.5-turbo-16k-0613", + AI_MODEL: str = "gpt-4o", API_URI: str = "https://api.openai.com/v1", - MAX_TOKENS: int = 16000, + MAX_TOKENS: int = 4096, AI_TEMPERATURE: float = 0.7, AI_TOP_P: float = 0.7, WAIT_BETWEEN_REQUESTS: int = 1, @@ -34,10 +34,10 @@ def __init__( **kwargs, ): self.requirements = ["openai"] - self.AI_MODEL = AI_MODEL if AI_MODEL else "gpt-3.5-turbo-16k-0613" + self.AI_MODEL = AI_MODEL if AI_MODEL else "gpt-4o" self.AI_TEMPERATURE = AI_TEMPERATURE if AI_TEMPERATURE else 0.7 self.AI_TOP_P = AI_TOP_P if AI_TOP_P else 0.7 - self.MAX_TOKENS = MAX_TOKENS if MAX_TOKENS else 16000 + self.MAX_TOKENS = MAX_TOKENS if MAX_TOKENS else 4096 self.API_URI = API_URI if API_URI else "https://api.openai.com/v1" self.WAIT_AFTER_FAILURE = WAIT_AFTER_FAILURE if WAIT_AFTER_FAILURE else 3 self.WAIT_BETWEEN_REQUESTS = ( @@ -84,8 +84,8 @@ def rotate_uri(self): async def inference(self, prompt, tokens: int = 0, images: list = []): if images != []: - if "vision" not in self.AI_MODEL: - self.AI_MODEL = "gpt-4-vision-preview" + if "vision" not in self.AI_MODEL and self.AI_MODEL != "gpt-4o": + self.AI_MODEL = "gpt-4o" if not self.API_URI.endswith("/"): self.API_URI += "/" openai.base_url = self.API_URI if self.API_URI else "https://api.openai.com/v1/" From 3208e9d39b99c1d53c0f67504d8d5ed36fc43053 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 08:31:13 -0400 Subject: [PATCH 0099/1256] undo change --- agixt/Memories.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/agixt/Memories.py b/agixt/Memories.py index 005538a2785b..41ef38eca36b 100644 --- a/agixt/Memories.py +++ b/agixt/Memories.py @@ -170,7 +170,8 @@ def __init__( self.collection_name = snake(f"{snake(DEFAULT_USER)}_{agent_name}") self.user = user self.collection_number = collection_number - self.collection_name = f"{self.collection_name}_{collection_number}" + if collection_number > 0: + self.collection_name = f"{self.collection_name}_{collection_number}" if agent_config is None: agent_config = ApiClient.get_agentconfig(agent_name=agent_name) self.agent_config = ( @@ -250,7 +251,8 @@ async def import_collections_from_json(self, json_data: List[dict]): collection_number = 0 self.collection_number = collection_number self.collection_name = snake(self.agent_name) - self.collection_name = f"{self.collection_name}_{collection_number}" + if collection_number > 0: + self.collection_name = f"{self.collection_name}_{collection_number}" for val in value[self.collection_name]: try: await self.write_text_to_memory( @@ -420,7 +422,10 @@ async def get_memories( default_collection_name = self.collection_name if self.user != DEFAULT_USER: self.collection_name = snake(f"{snake(DEFAULT_USER)}_{self.agent_name}") - self.collection_name = f"{self.collection_name}_{self.collection_number}" + if self.collection_number > 0: + self.collection_name = ( + f"{self.collection_name}_{self.collection_number}" + ) try: default_results = await self.get_memories_data( user_input=user_input, From f80bf43a59aeb1277335038ea78a84b4b9030613 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 08:33:12 -0400 Subject: [PATCH 0100/1256] initialize collection 0 in interactions --- agixt/Interactions.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/agixt/Interactions.py b/agixt/Interactions.py index c1696c077cb9..f1d5f3261ea5 100644 --- a/agixt/Interactions.py +++ b/agixt/Interactions.py @@ -45,7 +45,13 @@ def __init__( user=self.user, ApiClient=self.ApiClient, ) - self.agent_memory = self.websearch.agent_memory + self.agent_memory = FileReader( + agent_name=self.agent_name, + agent_config=self.agent.AGENT_CONFIG, + collection_number=0, + ApiClient=self.ApiClient, + user=self.user, + ) self.positive_feedback_memories = FileReader( agent_name=self.agent_name, agent_config=self.agent.AGENT_CONFIG, From 5c019a7655e2f8f669b28d2481ff1604afd12556 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 09:20:21 -0400 Subject: [PATCH 0101/1256] Add expert examples --- examples/AGiXT-Expert-OAI.ipynb | 175 +++++++++++++++++++++++++ examples/AGiXT-Expert-ezLocalai.ipynb | 178 ++++++++++++++++++++++++++ 2 files changed, 353 insertions(+) create mode 100644 examples/AGiXT-Expert-OAI.ipynb create mode 100644 examples/AGiXT-Expert-ezLocalai.ipynb diff --git a/examples/AGiXT-Expert-OAI.ipynb b/examples/AGiXT-Expert-OAI.ipynb new file mode 100644 index 000000000000..15a4389e078a --- /dev/null +++ b/examples/AGiXT-Expert-OAI.ipynb @@ -0,0 +1,175 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Train an Expert Agent in AGiXT with OpenAI Provider\n", + "\n", + "This example assumes that you have your AGiXT server set up and running and that you have an OpenAI API Key to use for setting up this agent. You can use any other provider, but this example specifically uses OpenAI for the agent.\n", + "\n", + "## Create the Agent\n", + "\n", + "For this example, we will create an expert on AGiXT. We will use the OpenAI Provider and `gpt-4o` model in this example. These settings can be easily changed in the streamlit app or over API.\n", + "\n", + "Modify the `agixt_server`, `api_key`, `agent_name`, `OPENAI_API_KEY`, `persona`, and any others as needed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from agixtsdk import AGiXTSDK\n", + "\n", + "agixt_server = \"http://localhost:7437\" # Change this to your AGiXT server URL\n", + "api_key = \"None\" # Change this to your AGiXT API key\n", + "\n", + "agixt = AGiXTSDK(base_uri=agixt_server, api_key=api_key)\n", + "\n", + "agent_name = \"AGiXT\" # Change this if desired\n", + "\n", + "agixt.add_agent(\n", + " agent_name=agent_name,\n", + " settings={\n", + " \"provider\": \"openai\", # LLM Provider\n", + " \"transcription_provider\": \"default\", # Voice transcription provider, default uses the built in transcription in AGiXT.\n", + " \"translation_provider\": \"default\", # Voice translation provider, default uses the built in translation in AGiXT.\n", + " \"embeddings_provider\": \"default\", # Embeddings provider, default uses the built in embeddings in AGiXT.\n", + " \"image_provider\": \"None\", # If set, AGiXT will autonomously create images if it chooses to do so based on the conversation.\n", + " \"vision_provider\": \"openai\", # Vision provider, None means no vision capabilities. We will use OpenAI's since we're using GPT-4o.\n", + " \"tts_provider\": \"None\", # Change this to `default` or whichever TTS provider you want to use. None means no voice response.\n", + " \"AI_MODEL\": \"gpt-4o\", # GPT-4o is OpenAI's most capable model currently, we will use it for best results.\n", + " \"OPENAI_API_KEY\": \"YOUR_OPENAI_API_KEY\", # Get your OpenAI API key from https://platform.openai.com/account/api-keys\n", + " \"MAX_TOKENS\": 4096,\n", + " \"AI_TEMPERATURE\": \"0.7\",\n", + " \"AI_TOP_P\": \"0.95\",\n", + " \"mode\": \"prompt\", # For info about chat completion modes, go to https://josh-xt.github.io/AGiXT/2-Concepts/04-Chat%20Completions.html\n", + " \"prompt_name\": \"Chat\",\n", + " \"prompt_category\": \"Default\",\n", + " \"persona\": \"AGiXT is an expert on the AGiXT AI agent automation platform and supports the users of AGiXT.\", # Use this field to set persona for the AI model\n", + " \"context_results\": 20, # How many memories from training to inject with each interaction.\n", + " },\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Zip your training data\n", + "\n", + "Creates a zip file called `training_data.zip` of the AGiXT `docs` folder. You can change this to any folder that you would like to use as training data, or skip this step and use an existing zip file.\n", + "\n", + "A good example of what to use for training data would be any PDF, word document, text file, or any other kind of file with information in it that you would like the agent to learn from." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from zipfile import ZipFile\n", + "import os\n", + "\n", + "os.chdir(\"../\")\n", + "with ZipFile(\"examples/training_data.zip\", \"w\") as zipObj:\n", + " for foldername, subfolders, filenames in os.walk(\"docs\"):\n", + " for filename in filenames:\n", + " file_path = os.path.join(foldername, filename)\n", + " zipObj.write(file_path)\n", + "os.chdir(\"examples/\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train the Agent on the training data\n", + "\n", + "This will train the agent on the training data that you have provided. This will take some time to complete depending on the size of the training data. A zip file around 70MB in size takes around 3 minutes to complete. The AGiXT docs should complete very quickly since it is all markdown files totaling around 3MB." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import base64\n", + "\n", + "zip_file_name = \"training_data.zip\"\n", + "training_data = base64.b64encode(open(zip_file_name, \"rb\").read()).decode(\"utf-8\")\n", + "\n", + "agixt.learn_file(\n", + " agent_name=agent_name,\n", + " file_name=zip_file_name,\n", + " file_content=training_data,\n", + " collection_number=0,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Chat with your trained expert AGiXT agent\n", + "\n", + "AGiXT has direct support for using the OpenAI API for chat completions. See this link for more information to take advantage of the abilities of this endpoint: https://josh-xt.github.io/AGiXT/2-Concepts/04-Chat%20Completions.html\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import openai\n", + "\n", + "prompt = \"What can you tell me about AGiXT?\"\n", + "\n", + "openai.base_url = agixt_server\n", + "openai.api_key = api_key\n", + "\n", + "response = openai.chat.completions.create(\n", + " model=agent_name,\n", + " messages=[{\"role\": \"user\", \"content\": prompt}],\n", + " user=\"Tell me about AGiXT\", # This field is used for the conversation name, if empty, it will use today's date\n", + ")\n", + "print(response.choices[0].message.content)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# That is all! \n", + "\n", + "You now have a trained expert agent in AGiXT. This agent will be able to support users by answering questions, providing information, and more about AGiXT." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/AGiXT-Expert-ezLocalai.ipynb b/examples/AGiXT-Expert-ezLocalai.ipynb new file mode 100644 index 000000000000..06d96876dd46 --- /dev/null +++ b/examples/AGiXT-Expert-ezLocalai.ipynb @@ -0,0 +1,178 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Train an Expert Agent in AGiXT with ezLocalai Provider\n", + "\n", + "This example assumes that you have your AGiXT and ezLocalai servers set up and running. This example is specifically for running AGiXT with ezLocalai and training a local agent. If you do not wish to run local models, we have the same example set up for using an agent with the OpenAI API.\n", + "\n", + "If you do not have ezLocalai set up and you want to set it up to run local models, go to https://github.com/DevXT-LLC/ezlocalai and follow the instructions there to set up the server, then continue with this example.\n", + "\n", + "## Create the Agent\n", + "\n", + "For this example, we will create an expert on AGiXT. These settings can be easily changed in the streamlit app or over API.\n", + "\n", + "Modify the `agixt_server`, `api_key`, `agent_name`, `EZLOCALAI_API_URL`, `EZLOCALAI_API_KEY`, `persona`, and any others as needed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from agixtsdk import AGiXTSDK\n", + "\n", + "agixt_server = \"http://localhost:7437\" # Change this to your AGiXT server URL\n", + "api_key = \"None\" # Change this to your AGiXT API key\n", + "\n", + "agixt = AGiXTSDK(base_uri=agixt_server, api_key=api_key)\n", + "\n", + "agent_name = \"AGiXT\" # Change this if desired\n", + "\n", + "agixt.add_agent(\n", + " agent_name=agent_name,\n", + " settings={\n", + " \"provider\": \"ezlocalai\", # LLM Provider\n", + " \"transcription_provider\": \"ezlocalai\", # Voice transcription provider, default uses the built in transcription in AGiXT.\n", + " \"translation_provider\": \"ezlocalai\", # Voice translation provider, default uses the built in translation in AGiXT.\n", + " \"embeddings_provider\": \"default\", # Embeddings provider, default uses the built in embeddings in AGiXT.\n", + " \"image_provider\": \"None\", # If set, AGiXT will autonomously create images if it chooses to do so based on the conversation.\n", + " \"vision_provider\": \"ezlocalai\", # Vision provider, None means no vision capabilities. We will use OpenAI's since we're using GPT-4o.\n", + " \"tts_provider\": \"None\", # Change this to `default` or whichever TTS provider you want to use. None means no voice response.\n", + " \"AI_MODEL\": \"ezlocalai\", # It doesn't matter which model you put here, ezlocalai uses the model it was started with.\n", + " \"EZLOCALAI_API_URL\": \"http://ezlocalai:8091/v1/\", # URL for the EZLOCALAI API, change this to your EZLOCALAI API URL\n", + " \"EZLOCALAI_API_KEY\": \"Your EZLOCALAI API key\", # Change this to your EZLOCALAI API key\n", + " \"MAX_TOKENS\": 4096,\n", + " \"AI_TEMPERATURE\": \"0.7\",\n", + " \"AI_TOP_P\": \"0.95\",\n", + " \"mode\": \"prompt\", # For info about chat completion modes, go to https://josh-xt.github.io/AGiXT/2-Concepts/04-Chat%20Completions.html\n", + " \"prompt_name\": \"Chat\",\n", + " \"prompt_category\": \"Default\",\n", + " \"persona\": \"AGiXT is an expert on the AGiXT AI agent automation platform and supports the users of AGiXT.\", # Use this field to set persona for the AI model\n", + " \"context_results\": 20, # How many memories from training to inject with each interaction.\n", + " },\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Zip your training data\n", + "\n", + "Creates a zip file called `training_data.zip` of the AGiXT `docs` folder. You can change this to any folder that you would like to use as training data, or skip this step and use an existing zip file.\n", + "\n", + "A good example of what to use for training data would be any PDF, word document, text file, or any other kind of file with information in it that you would like the agent to learn from." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from zipfile import ZipFile\n", + "import os\n", + "\n", + "os.chdir(\"../\")\n", + "with ZipFile(\"examples/training_data.zip\", \"w\") as zipObj:\n", + " for foldername, subfolders, filenames in os.walk(\"docs\"):\n", + " for filename in filenames:\n", + " file_path = os.path.join(foldername, filename)\n", + " zipObj.write(file_path)\n", + "os.chdir(\"examples/\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train the Agent on the training data\n", + "\n", + "This will train the agent on the training data that you have provided. This will take some time to complete depending on the size of the training data. A zip file around 70MB in size takes around 3 minutes to complete. The AGiXT docs should complete very quickly since it is all markdown files totaling around 3MB." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import base64\n", + "\n", + "zip_file_name = \"training_data.zip\"\n", + "training_data = base64.b64encode(open(zip_file_name, \"rb\").read()).decode(\"utf-8\")\n", + "\n", + "agixt.learn_file(\n", + " agent_name=agent_name,\n", + " file_name=zip_file_name,\n", + " file_content=training_data,\n", + " collection_number=0,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Chat with your trained expert AGiXT agent\n", + "\n", + "AGiXT has direct support for using the OpenAI API for chat completions. See this link for more information to take advantage of the abilities of this endpoint: https://josh-xt.github.io/AGiXT/2-Concepts/04-Chat%20Completions.html\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import openai\n", + "\n", + "prompt = \"What can you tell me about AGiXT?\"\n", + "\n", + "openai.base_url = agixt_server\n", + "openai.api_key = api_key\n", + "\n", + "response = openai.chat.completions.create(\n", + " model=agent_name,\n", + " messages=[{\"role\": \"user\", \"content\": prompt}],\n", + " user=\"Tell me about AGiXT\", # This field is used for the conversation name, if empty, it will use today's date\n", + ")\n", + "print(response.choices[0].message.content)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# That is all! \n", + "\n", + "You now have a trained expert agent in AGiXT. This agent will be able to support users by answering questions, providing information, and more about AGiXT." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 0c0204b32993add3ad6a984bdb354b17c9d78995 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 09:23:25 -0400 Subject: [PATCH 0102/1256] add tts to example --- examples/AGiXT-Expert-ezLocalai.ipynb | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/AGiXT-Expert-ezLocalai.ipynb b/examples/AGiXT-Expert-ezLocalai.ipynb index 06d96876dd46..3a14952d6da2 100644 --- a/examples/AGiXT-Expert-ezLocalai.ipynb +++ b/examples/AGiXT-Expert-ezLocalai.ipynb @@ -41,7 +41,8 @@ " \"embeddings_provider\": \"default\", # Embeddings provider, default uses the built in embeddings in AGiXT.\n", " \"image_provider\": \"None\", # If set, AGiXT will autonomously create images if it chooses to do so based on the conversation.\n", " \"vision_provider\": \"ezlocalai\", # Vision provider, None means no vision capabilities. We will use OpenAI's since we're using GPT-4o.\n", - " \"tts_provider\": \"None\", # Change this to `default` or whichever TTS provider you want to use. None means no voice response.\n", + " \"tts_provider\": \"ezlocalai\", # Change this to `default` or whichever TTS provider you want to use. None means no voice response.\n", + " \"VOICE\": \"Morgan_Freeman\", # Voice for TTS, change this to the voice you want to use from ezlocalai.\n", " \"AI_MODEL\": \"ezlocalai\", # It doesn't matter which model you put here, ezlocalai uses the model it was started with.\n", " \"EZLOCALAI_API_URL\": \"http://ezlocalai:8091/v1/\", # URL for the EZLOCALAI API, change this to your EZLOCALAI API URL\n", " \"EZLOCALAI_API_KEY\": \"Your EZLOCALAI API key\", # Change this to your EZLOCALAI API key\n", From 0f79736edad9580b7c63e8497e6023ebb7e33d94 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 09:49:31 -0400 Subject: [PATCH 0103/1256] add postgres chat example --- examples/Postgres-Chat-ezLocalai.ipynb | 112 +++++++++++++++++++++++++ 1 file changed, 112 insertions(+) create mode 100644 examples/Postgres-Chat-ezLocalai.ipynb diff --git a/examples/Postgres-Chat-ezLocalai.ipynb b/examples/Postgres-Chat-ezLocalai.ipynb new file mode 100644 index 000000000000..caf832b20eac --- /dev/null +++ b/examples/Postgres-Chat-ezLocalai.ipynb @@ -0,0 +1,112 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Create a Postgres Chat Agent in AGiXT with ezLocalai Provider\n", + "\n", + "This example assumes that you have your AGiXT and ezLocalai servers set up and running. This example is specifically for running AGiXT with ezLocalai and training a local agent.\n", + "\n", + "If you do not have ezLocalai set up and you want to set it up to run local models, go to https://github.com/DevXT-LLC/ezlocalai and follow the instructions there to set up the server, then continue with this example.\n", + "\n", + "## Create the Agent\n", + "\n", + "For this example, we will create an agent that will turn your natural language questions into SQL queries that get executed, then a response with either a CSV of data or a string response.\n", + "\n", + "Connect to any Postgres database by updating the agent's settings. Modify the following and any others as needed:\n", + "\n", + "- `agixt_server`\n", + "- `api_key`\n", + "- `agent_name`\n", + "- `EZLOCALAI_API_URL`\n", + "- `EZLOCALAI_API_KEY`\n", + "- `POSTGRES_DATABASE_HOST`\n", + "- `POSTGRES_DATABASE_PORT`\n", + "- `POSTGRES_DATABASE_NAME`\n", + "- `POSTGRES_DATABASE_USERNAME`\n", + "- `POSTGRES_DATABASE_PASSWORD`\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from agixtsdk import AGiXTSDK\n", + "\n", + "agixt_server = \"http://localhost:7437\" # Change this to your AGiXT server URL\n", + "api_key = \"None\" # Change this to your AGiXT API key\n", + "\n", + "agixt = AGiXTSDK(base_uri=agixt_server, api_key=api_key)\n", + "\n", + "agent_name = \"Postgres\" # Change this if desired\n", + "\n", + "agixt.add_agent(\n", + " agent_name=agent_name,\n", + " settings={\n", + " \"provider\": \"ezlocalai\", # LLM Provider\n", + " \"transcription_provider\": \"ezlocalai\", # Voice transcription provider, default uses the built in transcription in AGiXT.\n", + " \"translation_provider\": \"ezlocalai\", # Voice translation provider, default uses the built in translation in AGiXT.\n", + " \"embeddings_provider\": \"default\", # Embeddings provider, default uses the built in embeddings in AGiXT.\n", + " \"image_provider\": \"None\", # If set, AGiXT will autonomously create images if it chooses to do so based on the conversation.\n", + " \"vision_provider\": \"ezlocalai\", # Vision provider, None means no vision capabilities. We will use OpenAI's since we're using GPT-4o.\n", + " \"tts_provider\": \"None\", # The responses in Postgres Chat will often be CSV format, so we don't need TTS.\n", + " \"AI_MODEL\": \"ezlocalai\", # It doesn't matter which model you put here, ezlocalai uses the model it was started with.\n", + " \"EZLOCALAI_API_URL\": \"http://ezlocalai:8091/v1/\", # URL for the EZLOCALAI API, change this to your EZLOCALAI API URL\n", + " \"EZLOCALAI_API_KEY\": \"Your EZLOCALAI API key\", # Change this to your EZLOCALAI API key\n", + " \"MAX_TOKENS\": 4096,\n", + " \"AI_TEMPERATURE\": \"0.7\",\n", + " \"AI_TOP_P\": \"0.95\",\n", + " \"mode\": \"chain\", # For info about chat completion modes, go to https://josh-xt.github.io/AGiXT/2-Concepts/04-Chat%20Completions.html\n", + " \"chain_name\": \"Postgres Chat\",\n", + " \"chain_args\": \"{}\",\n", + " \"POSTGRES_DATABASE_HOST\": \"postgres\", # Change this to your Postgres database host\n", + " \"POSTGRES_DATABASE_NAME\": \"postgres\", # Change this to your Postgres database name\n", + " \"POSTGRES_DATABASE_PORT\": \"5432\", # Change this to your Postgres database port\n", + " \"POSTGRES_DATABASE_USERNAME\": \"postgres\", # Change this to your Postgres database username\n", + " \"POSTGRES_DATABASE_PASSWORD\": \"postgres\", # Change this to your Postgres database password\n", + " },\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Chat with the Agent\n", + "\n", + "Turn your natural language queries into SQL queries that are executed on your Postgres database, then return the results in a CSV format or a string response." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import openai\n", + "\n", + "prompt = \"How many people bought services in 2024?\"\n", + "\n", + "openai.base_url = agixt_server\n", + "openai.api_key = api_key\n", + "\n", + "response = openai.chat.completions.create(\n", + " model=agent_name,\n", + " messages=[{\"role\": \"user\", \"content\": prompt}],\n", + " user=\"Services query\", # This field is used for the conversation name, if empty, it will use today's date\n", + ")\n", + "print(response.choices[0].message.content)" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 237449a4b382d2fb4c206fc4bfb9e499468e0951 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 09:53:01 -0400 Subject: [PATCH 0104/1256] add notes --- examples/Postgres-Chat-ezLocalai.ipynb | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/examples/Postgres-Chat-ezLocalai.ipynb b/examples/Postgres-Chat-ezLocalai.ipynb index caf832b20eac..2812e72936e6 100644 --- a/examples/Postgres-Chat-ezLocalai.ipynb +++ b/examples/Postgres-Chat-ezLocalai.ipynb @@ -100,6 +100,15 @@ ")\n", "print(response.choices[0].message.content)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# That is all!\n", + "\n", + "You now have a Postgres Chat Agent in AGiXT with ezLocalai Provider. You can now chat with your agent and get responses from your Postgres database." + ] } ], "metadata": { From 61034172adc80f7a9eea6170d9dbf2ca6eab0e28 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 09:57:14 -0400 Subject: [PATCH 0105/1256] add notes --- examples/AGiXT-Expert-ezLocalai.ipynb | 2 +- examples/Postgres-Chat-ezLocalai.ipynb | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/AGiXT-Expert-ezLocalai.ipynb b/examples/AGiXT-Expert-ezLocalai.ipynb index 3a14952d6da2..8592580fc55e 100644 --- a/examples/AGiXT-Expert-ezLocalai.ipynb +++ b/examples/AGiXT-Expert-ezLocalai.ipynb @@ -44,7 +44,7 @@ " \"tts_provider\": \"ezlocalai\", # Change this to `default` or whichever TTS provider you want to use. None means no voice response.\n", " \"VOICE\": \"Morgan_Freeman\", # Voice for TTS, change this to the voice you want to use from ezlocalai.\n", " \"AI_MODEL\": \"ezlocalai\", # It doesn't matter which model you put here, ezlocalai uses the model it was started with.\n", - " \"EZLOCALAI_API_URL\": \"http://ezlocalai:8091/v1/\", # URL for the EZLOCALAI API, change this to your EZLOCALAI API URL\n", + " \"EZLOCALAI_API_URL\": \"http://ezlocalai:8091/v1/\", # URL for the EZLOCALAI API, change this to your EZLOCALAI API URL. Never use localhost here, it is a different container.\n", " \"EZLOCALAI_API_KEY\": \"Your EZLOCALAI API key\", # Change this to your EZLOCALAI API key\n", " \"MAX_TOKENS\": 4096,\n", " \"AI_TEMPERATURE\": \"0.7\",\n", diff --git a/examples/Postgres-Chat-ezLocalai.ipynb b/examples/Postgres-Chat-ezLocalai.ipynb index 2812e72936e6..ba274020770b 100644 --- a/examples/Postgres-Chat-ezLocalai.ipynb +++ b/examples/Postgres-Chat-ezLocalai.ipynb @@ -54,7 +54,7 @@ " \"vision_provider\": \"ezlocalai\", # Vision provider, None means no vision capabilities. We will use OpenAI's since we're using GPT-4o.\n", " \"tts_provider\": \"None\", # The responses in Postgres Chat will often be CSV format, so we don't need TTS.\n", " \"AI_MODEL\": \"ezlocalai\", # It doesn't matter which model you put here, ezlocalai uses the model it was started with.\n", - " \"EZLOCALAI_API_URL\": \"http://ezlocalai:8091/v1/\", # URL for the EZLOCALAI API, change this to your EZLOCALAI API URL\n", + " \"EZLOCALAI_API_URL\": \"http://ezlocalai:8091/v1/\", # URL for the EZLOCALAI API, change this to your EZLOCALAI API URL. Never use localhost here, it is a different container.\n", " \"EZLOCALAI_API_KEY\": \"Your EZLOCALAI API key\", # Change this to your EZLOCALAI API key\n", " \"MAX_TOKENS\": 4096,\n", " \"AI_TEMPERATURE\": \"0.7\",\n", From 37b817190d0ef62c65e22e722b41b3be8ae75267 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 10:16:24 -0400 Subject: [PATCH 0106/1256] fix example --- agixt/XT.py | 4 +++- examples/AGiXT-Expert-OAI.ipynb | 2 +- examples/AGiXT-Expert-ezLocalai.ipynb | 2 +- examples/Postgres-Chat-ezLocalai.ipynb | 2 +- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index 4c5046977e50..47b6db0d4e1d 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -502,7 +502,9 @@ async def check_dependencies_met(dependencies): ) tasks.append(task) step_responses = await asyncio.gather(*tasks) - response = step_responses[-1] + logging.info(f"Step responses: {step_responses}") + if step_responses: + response = step_responses[-1] if response == None: return f"Chain failed to complete, it failed on step {step_data['step']}. You can resume by starting the chain from the step that failed with chain ID {chain_run_id}." if conversation_name != "": diff --git a/examples/AGiXT-Expert-OAI.ipynb b/examples/AGiXT-Expert-OAI.ipynb index 15a4389e078a..e04854d3c1b7 100644 --- a/examples/AGiXT-Expert-OAI.ipynb +++ b/examples/AGiXT-Expert-OAI.ipynb @@ -130,7 +130,7 @@ "\n", "prompt = \"What can you tell me about AGiXT?\"\n", "\n", - "openai.base_url = agixt_server\n", + "openai.base_url = f\"{agixt_server}/v1/\"\n", "openai.api_key = api_key\n", "\n", "response = openai.chat.completions.create(\n", diff --git a/examples/AGiXT-Expert-ezLocalai.ipynb b/examples/AGiXT-Expert-ezLocalai.ipynb index 8592580fc55e..559b34faf2f9 100644 --- a/examples/AGiXT-Expert-ezLocalai.ipynb +++ b/examples/AGiXT-Expert-ezLocalai.ipynb @@ -134,7 +134,7 @@ "\n", "prompt = \"What can you tell me about AGiXT?\"\n", "\n", - "openai.base_url = agixt_server\n", + "openai.base_url = f\"{agixt_server}/v1/\"\n", "openai.api_key = api_key\n", "\n", "response = openai.chat.completions.create(\n", diff --git a/examples/Postgres-Chat-ezLocalai.ipynb b/examples/Postgres-Chat-ezLocalai.ipynb index ba274020770b..cfca9935544e 100644 --- a/examples/Postgres-Chat-ezLocalai.ipynb +++ b/examples/Postgres-Chat-ezLocalai.ipynb @@ -90,7 +90,7 @@ "\n", "prompt = \"How many people bought services in 2024?\"\n", "\n", - "openai.base_url = agixt_server\n", + "openai.base_url = f\"{agixt_server}/v1/\"\n", "openai.api_key = api_key\n", "\n", "response = openai.chat.completions.create(\n", From 7863e7a4f5cd0f5105f68b166dffa3636ca7f315 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 10:30:46 -0400 Subject: [PATCH 0107/1256] handle no connection --- agixt/extensions/postgres_database.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/agixt/extensions/postgres_database.py b/agixt/extensions/postgres_database.py index 4c4ec687fa8d..001a966baa98 100644 --- a/agixt/extensions/postgres_database.py +++ b/agixt/extensions/postgres_database.py @@ -66,6 +66,8 @@ async def execute_sql(self, query: str): query = query.strip() logging.info(f"Executing SQL Query: {query}") connection = self.get_connection() + if not connection: + return "Error connecting to Postgres Database" cursor = connection.cursor(cursor_factory=psycopg2.extras.DictCursor) try: cursor.execute(query) @@ -119,6 +121,8 @@ async def get_schema(self): logging.info(f"Getting schema for database '{self.POSTGRES_DATABASE_NAME}'") connection = self.get_connection() + if not connection: + return "Error connecting to Postgres Database" cursor = connection.cursor(cursor_factory=psycopg2.extras.DictCursor) cursor.execute( f"SELECT schema_name FROM information_schema.schemata WHERE schema_name NOT IN ('pg_catalog', 'information_schema');" From aa5154719b52456a96f8e4f75724b8403573eb92 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 10:40:06 -0400 Subject: [PATCH 0108/1256] clean up refs --- agixt/XT.py | 50 +++++-------------------------- agixt/endpoints/Memory.py | 2 +- agixt/extensions/agixt_actions.py | 2 +- 3 files changed, 10 insertions(+), 44 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index 47b6db0d4e1d..163f89e353b4 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -424,24 +424,6 @@ async def execute_chain( voice_response=False, ): chain_data = self.chain.get_chain(chain_name=chain_name) - chain_dependencies = self.chain.get_chain_step_dependencies( - chain_name=chain_name - ) - - async def check_dependencies_met(dependencies): - for dependency in dependencies: - try: - step_responses = self.chain.get_step_response( - chain_name=chain_name, - chain_run_id=chain_run_id, - step_number=int(dependency), - ) - except: - return False - if not step_responses: - return False - return True - if not chain_run_id: chain_run_id = await self.chain.get_chain_run_id(chain_name=chain_name) if chain_data == {}: @@ -476,32 +458,16 @@ async def check_dependencies_met(dependencies): step["prompt_type"] = step_data["prompt_type"] step["prompt"] = step_data["prompt"] step["step"] = step_data["step"] - # Get the step dependencies from chain_dependencies then check if the dependencies are - # met before running the step - step_dependencies = chain_dependencies[str(step["step"])] - dependencies_met = await check_dependencies_met(step_dependencies) - while not dependencies_met: - await asyncio.sleep(1) - if step_responses == []: - step_responses = await asyncio.gather(*tasks) - else: - step_responses += await asyncio.gather(*tasks) - dependencies_met = await check_dependencies_met( - step_dependencies - ) - task = asyncio.create_task( - self.run_chain_step( - chain_run_id=chain_run_id, - step=step, - chain_name=chain_name, - user_input=user_input, - agent_override=agent_override, - chain_args=chain_args, - conversation_name=conversation_name, - ) + task = await self.run_chain_step( + chain_run_id=chain_run_id, + step=step, + chain_name=chain_name, + user_input=user_input, + agent_override=agent_override, + chain_args=chain_args, + conversation_name=conversation_name, ) tasks.append(task) - step_responses = await asyncio.gather(*tasks) logging.info(f"Step responses: {step_responses}") if step_responses: response = step_responses[-1] diff --git a/agixt/endpoints/Memory.py b/agixt/endpoints/Memory.py index c584fe397457..f97efad02749 100644 --- a/agixt/endpoints/Memory.py +++ b/agixt/endpoints/Memory.py @@ -471,7 +471,7 @@ async def create_dataset( raise HTTPException(status_code=403, detail="Access Denied") batch_size = dataset.batch_size if dataset.batch_size < (int(WORKERS) - 2) else 4 asyncio.create_task( - await AGiXT( + AGiXT( agent_name=agent_name, user=user, api_key=authorization, diff --git a/agixt/extensions/agixt_actions.py b/agixt/extensions/agixt_actions.py index a972fe9c230a..d2f6615445fc 100644 --- a/agixt/extensions/agixt_actions.py +++ b/agixt/extensions/agixt_actions.py @@ -837,7 +837,7 @@ async def convert_questions_to_dataset(self, response): await asyncio.gather(*tasks) tasks = [] task = asyncio.create_task( - await self.ApiClient.prompt_agent( + self.ApiClient.prompt_agent( agent_name=self.agent_name, prompt_name="Basic With Memory", prompt_args={ From f3713bcb4e22362316d3f65ba9d72d39756cec94 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 10:45:12 -0400 Subject: [PATCH 0109/1256] add logging --- agixt/XT.py | 1 + 1 file changed, 1 insertion(+) diff --git a/agixt/XT.py b/agixt/XT.py index 163f89e353b4..fd084e4a75ca 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -446,6 +446,7 @@ async def execute_chain( response = "" tasks = [] step_responses = [] + logging.info(f"Chain data: {chain_data}") for step_data in chain_data["steps"]: if int(step_data["step"]) >= int(from_step): if "prompt" in step_data and "step" in step_data: From e41c80bfd51cd18a8add1ec37b3fc2aac05fe9a9 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 10:53:51 -0400 Subject: [PATCH 0110/1256] improve import chain func --- agixt/Chain.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/agixt/Chain.py b/agixt/Chain.py index 5b73d2ffc046..92cdae8ec93e 100644 --- a/agixt/Chain.py +++ b/agixt/Chain.py @@ -520,9 +520,11 @@ def import_chain(self, chain_name: str, steps: dict): # Handle the case where agent not found based on agent_name # You can choose to skip this step or raise an exception continue - prompt = step_data["prompt"] - if "prompt_name" in prompt: + if "prompt_type" not in prompt: + prompt["prompt_type"] = "prompt" + prompt_type = prompt["prompt_type"].lower() + if prompt_type == "prompt": argument_key = "prompt_name" prompt_category = prompt.get("prompt_category", "Default") target_id = ( @@ -536,7 +538,7 @@ def import_chain(self, chain_name: str, steps: dict): .id ) target_type = "prompt" - elif "chain_name" in prompt: + elif prompt_type == "chain": argument_key = "chain_name" target_id = ( self.session.query(Chain) @@ -548,7 +550,7 @@ def import_chain(self, chain_name: str, steps: dict): .id ) target_type = "chain" - elif "command_name" in prompt: + elif prompt_type == "command": argument_key = "command_name" target_id = ( self.session.query(Command) @@ -561,11 +563,9 @@ def import_chain(self, chain_name: str, steps: dict): # Handle the case where the argument key is not found # You can choose to skip this step or raise an exception continue - argument_value = prompt[argument_key] prompt_arguments = prompt.copy() del prompt_arguments[argument_key] - chain_step = ChainStep( chain_id=chain.id, step_number=step_data["step"], @@ -578,7 +578,6 @@ def import_chain(self, chain_name: str, steps: dict): ) self.session.add(chain_step) self.session.commit() - for argument_name, argument_value in prompt_arguments.items(): argument = ( self.session.query(Argument) @@ -597,7 +596,6 @@ def import_chain(self, chain_name: str, steps: dict): ) self.session.add(chain_step_argument) self.session.commit() - return f"Imported chain: {chain_name}" def get_chain_step_dependencies(self, chain_name): From c270f55220dc467422c4c6694365f03159e2d39a Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 10:59:08 -0400 Subject: [PATCH 0111/1256] remove handling to expose error --- agixt/SeedImports.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/agixt/SeedImports.py b/agixt/SeedImports.py index ff854e0a4430..f7b31005d9bc 100644 --- a/agixt/SeedImports.py +++ b/agixt/SeedImports.py @@ -194,16 +194,16 @@ def import_chains(user=DEFAULT_USER): file_path = os.path.join(chain_dir, file) with open(file_path, "r") as f: - try: - chain_data = json.load(f) - result = chain_importer.import_chain(chain_name, chain_data) - logging.info(result) - except json.JSONDecodeError as e: - logging.info( - f"Error importing chain from '{file}': Invalid JSON format." - ) - except Exception as e: - logging.info(f"Error importing chain from '{file}': {str(e)}") + # try: + chain_data = json.load(f) + result = chain_importer.import_chain(chain_name, chain_data) + logging.info(result) + # except json.JSONDecodeError as e: + # logging.info( + # f"Error importing chain from '{file}': Invalid JSON format." + # ) + # except Exception as e: + # logging.info(f"Error importing chain from '{file}': {str(e)}") def import_prompts(user=DEFAULT_USER): From 1ee27f0202294190d78f8f3db9f8a94f2ecb5869 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 11:04:36 -0400 Subject: [PATCH 0112/1256] add logging --- agixt/Chain.py | 1 + 1 file changed, 1 insertion(+) diff --git a/agixt/Chain.py b/agixt/Chain.py index 92cdae8ec93e..a47ad4b2b613 100644 --- a/agixt/Chain.py +++ b/agixt/Chain.py @@ -510,6 +510,7 @@ def import_chain(self, chain_name: str, steps: dict): steps = steps["steps"] if "steps" in steps else steps for step_data in steps: + logging.info(f"chain step: {step_data}") agent_name = step_data["agent_name"] agent = ( self.session.query(Agent) From c2858cce6c27c1a71e570fcfbfab597e4a419740 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 11:09:36 -0400 Subject: [PATCH 0113/1256] fix step data --- agixt/Chain.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/agixt/Chain.py b/agixt/Chain.py index a47ad4b2b613..24a84a2ec6f1 100644 --- a/agixt/Chain.py +++ b/agixt/Chain.py @@ -510,7 +510,6 @@ def import_chain(self, chain_name: str, steps: dict): steps = steps["steps"] if "steps" in steps else steps for step_data in steps: - logging.info(f"chain step: {step_data}") agent_name = step_data["agent_name"] agent = ( self.session.query(Agent) @@ -518,13 +517,16 @@ def import_chain(self, chain_name: str, steps: dict): .first() ) if not agent: - # Handle the case where agent not found based on agent_name - # You can choose to skip this step or raise an exception - continue + # Use the first agent in the database + agent = ( + self.session.query(Agent) + .filter(Agent.user_id == self.user_id) + .first() + ) prompt = step_data["prompt"] - if "prompt_type" not in prompt: - prompt["prompt_type"] = "prompt" - prompt_type = prompt["prompt_type"].lower() + if "prompt_type" not in step_data: + step_data["prompt_type"] = "prompt" + prompt_type = step_data["prompt_type"].lower() if prompt_type == "prompt": argument_key = "prompt_name" prompt_category = prompt.get("prompt_category", "Default") From 5c74e865969a7760f916a79faaec7f370eaafbdd Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 11:15:42 -0400 Subject: [PATCH 0114/1256] readd handling --- agixt/SeedImports.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/agixt/SeedImports.py b/agixt/SeedImports.py index f7b31005d9bc..563987142060 100644 --- a/agixt/SeedImports.py +++ b/agixt/SeedImports.py @@ -192,18 +192,17 @@ def import_chains(user=DEFAULT_USER): for file in chain_files: chain_name = os.path.splitext(file)[0] file_path = os.path.join(chain_dir, file) - with open(file_path, "r") as f: - # try: - chain_data = json.load(f) - result = chain_importer.import_chain(chain_name, chain_data) - logging.info(result) - # except json.JSONDecodeError as e: - # logging.info( - # f"Error importing chain from '{file}': Invalid JSON format." - # ) - # except Exception as e: - # logging.info(f"Error importing chain from '{file}': {str(e)}") + try: + chain_data = json.load(f) + result = chain_importer.import_chain(chain_name, chain_data) + logging.info(result) + except json.JSONDecodeError as e: + logging.info( + f"Error importing chain from '{file}': Invalid JSON format." + ) + except Exception as e: + logging.info(f"Error importing chain from '{file}': {str(e)}") def import_prompts(user=DEFAULT_USER): @@ -419,7 +418,7 @@ def import_all_data(): logging.info("Importing agents...") import_agents() logging.info("Importing chains...") - import_chains() # Partially works + import_chains() logging.info("Importing conversations...") import_conversations() logging.info("Imports complete.") From c8170dd30b4be9cfcc93b3c31d424fff3d1f8ff5 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 11:23:24 -0400 Subject: [PATCH 0115/1256] expose error --- agixt/Chain.py | 1 - agixt/SeedImports.py | 20 ++++++++++---------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/agixt/Chain.py b/agixt/Chain.py index 24a84a2ec6f1..01e62a899763 100644 --- a/agixt/Chain.py +++ b/agixt/Chain.py @@ -507,7 +507,6 @@ def import_chain(self, chain_name: str, steps: dict): chain = ChainDB(name=chain_name, user_id=self.user_id) self.session.add(chain) self.session.commit() - steps = steps["steps"] if "steps" in steps else steps for step_data in steps: agent_name = step_data["agent_name"] diff --git a/agixt/SeedImports.py b/agixt/SeedImports.py index 563987142060..1f867e1aa93e 100644 --- a/agixt/SeedImports.py +++ b/agixt/SeedImports.py @@ -193,16 +193,16 @@ def import_chains(user=DEFAULT_USER): chain_name = os.path.splitext(file)[0] file_path = os.path.join(chain_dir, file) with open(file_path, "r") as f: - try: - chain_data = json.load(f) - result = chain_importer.import_chain(chain_name, chain_data) - logging.info(result) - except json.JSONDecodeError as e: - logging.info( - f"Error importing chain from '{file}': Invalid JSON format." - ) - except Exception as e: - logging.info(f"Error importing chain from '{file}': {str(e)}") + # try: + chain_data = json.load(f) + result = chain_importer.import_chain(chain_name, chain_data) + logging.info(result) + # except json.JSONDecodeError as e: + # logging.info( + # f"Error importing chain from '{file}': Invalid JSON format." + # ) + # except Exception as e: + # logging.info(f"Error importing chain from '{file}': {str(e)}") def import_prompts(user=DEFAULT_USER): From cf65fca1248e8c484136b835fb7723e59249ee89 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 11:27:47 -0400 Subject: [PATCH 0116/1256] use chaindb --- agixt/Chain.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/agixt/Chain.py b/agixt/Chain.py index 01e62a899763..3dda83086ca1 100644 --- a/agixt/Chain.py +++ b/agixt/Chain.py @@ -543,10 +543,10 @@ def import_chain(self, chain_name: str, steps: dict): elif prompt_type == "chain": argument_key = "chain_name" target_id = ( - self.session.query(Chain) + self.session.query(ChainDB) .filter( - Chain.name == prompt["chain_name"], - Chain.user_id == self.user_id, + ChainDB.name == prompt["chain_name"], + ChainDB.user_id == self.user_id, ) .first() .id From a4cf858b65ab8e722cebbd3dc084a22476c75917 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 11:32:21 -0400 Subject: [PATCH 0117/1256] add handling back --- agixt/Chain.py | 7 +++---- agixt/SeedImports.py | 20 ++++++++++---------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/agixt/Chain.py b/agixt/Chain.py index 3dda83086ca1..3034a99cec61 100644 --- a/agixt/Chain.py +++ b/agixt/Chain.py @@ -13,7 +13,6 @@ ) from Globals import getenv, DEFAULT_USER from Prompts import Prompts -from Conversations import Conversations from Extensions import Extensions import logging import asyncio @@ -532,7 +531,7 @@ def import_chain(self, chain_name: str, steps: dict): target_id = ( self.session.query(Prompt) .filter( - Prompt.name == prompt["prompt_name"], + Prompt.name == prompt[argument_key], Prompt.user_id == self.user_id, Prompt.prompt_category.has(name=prompt_category), ) @@ -545,7 +544,7 @@ def import_chain(self, chain_name: str, steps: dict): target_id = ( self.session.query(ChainDB) .filter( - ChainDB.name == prompt["chain_name"], + ChainDB.name == prompt[argument_key], ChainDB.user_id == self.user_id, ) .first() @@ -556,7 +555,7 @@ def import_chain(self, chain_name: str, steps: dict): argument_key = "command_name" target_id = ( self.session.query(Command) - .filter(Command.name == prompt["command_name"]) + .filter(Command.name == prompt[argument_key]) .first() .id ) diff --git a/agixt/SeedImports.py b/agixt/SeedImports.py index 1f867e1aa93e..563987142060 100644 --- a/agixt/SeedImports.py +++ b/agixt/SeedImports.py @@ -193,16 +193,16 @@ def import_chains(user=DEFAULT_USER): chain_name = os.path.splitext(file)[0] file_path = os.path.join(chain_dir, file) with open(file_path, "r") as f: - # try: - chain_data = json.load(f) - result = chain_importer.import_chain(chain_name, chain_data) - logging.info(result) - # except json.JSONDecodeError as e: - # logging.info( - # f"Error importing chain from '{file}': Invalid JSON format." - # ) - # except Exception as e: - # logging.info(f"Error importing chain from '{file}': {str(e)}") + try: + chain_data = json.load(f) + result = chain_importer.import_chain(chain_name, chain_data) + logging.info(result) + except json.JSONDecodeError as e: + logging.info( + f"Error importing chain from '{file}': Invalid JSON format." + ) + except Exception as e: + logging.info(f"Error importing chain from '{file}': {str(e)}") def import_prompts(user=DEFAULT_USER): From c68ac5f6bf2e2525a60f81aec37c6858c9d7170e Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 11:33:43 -0400 Subject: [PATCH 0118/1256] use chain for chain name import --- agixt/Chain.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/agixt/Chain.py b/agixt/Chain.py index 3034a99cec61..2df4e80fe30a 100644 --- a/agixt/Chain.py +++ b/agixt/Chain.py @@ -541,6 +541,8 @@ def import_chain(self, chain_name: str, steps: dict): target_type = "prompt" elif prompt_type == "chain": argument_key = "chain_name" + if "chain" in prompt: + argument_key = "chain" target_id = ( self.session.query(ChainDB) .filter( From 95236b1e81468b8fce27d229f42a9e16ed41a71b Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 11:39:08 -0400 Subject: [PATCH 0119/1256] add retry on failure for chain imports --- agixt/SeedImports.py | 46 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/agixt/SeedImports.py b/agixt/SeedImports.py index 563987142060..e2068926464f 100644 --- a/agixt/SeedImports.py +++ b/agixt/SeedImports.py @@ -189,6 +189,7 @@ def import_chains(user=DEFAULT_USER): from Chain import Chain chain_importer = Chain(user=user) + failures = [] for file in chain_files: chain_name = os.path.splitext(file)[0] file_path = os.path.join(chain_dir, file) @@ -199,10 +200,51 @@ def import_chains(user=DEFAULT_USER): logging.info(result) except json.JSONDecodeError as e: logging.info( - f"Error importing chain from '{file}': Invalid JSON format." + f"(1/3) Error importing chain from '{file}': Invalid JSON format." ) except Exception as e: - logging.info(f"Error importing chain from '{file}': {str(e)}") + logging.info(f"(1/3) Error importing chain from '{file}': {str(e)}") + failures.append(file) + if failures: + # Try each that failed again just in case it had a dependency on another chain + for file in failures: + chain_name = os.path.splitext(file)[0] + file_path = os.path.join(chain_dir, file) + with open(file_path, "r") as f: + try: + chain_data = json.load(f) + result = chain_importer.import_chain(chain_name, chain_data) + logging.info(result) + failures.remove(file) + except json.JSONDecodeError as e: + logging.info( + f"(2/3) Error importing chain from '{file}': Invalid JSON format." + ) + except Exception as e: + logging.info(f"(2/3) Error importing chain from '{file}': {str(e)}") + if failures: + # Try one more time. + for file in failures: + chain_name = os.path.splitext(file)[0] + file_path = os.path.join(chain_dir, file) + with open(file_path, "r") as f: + try: + chain_data = json.load(f) + result = chain_importer.import_chain(chain_name, chain_data) + logging.info(result) + failures.remove(file) + except json.JSONDecodeError as e: + logging.info( + f"(3/3) Error importing chain from '{file}': Invalid JSON format." + ) + except Exception as e: + logging.info( + f"(3/3) Error importing chain from '{file}': {str(e)}" + ) + if failures: + logging.info( + f"Failed to import the following chains: {', '.join([os.path.splitext(file)[0] for file in failures])}" + ) def import_prompts(user=DEFAULT_USER): From 021e06d18949fff82df36c1b47dcfc0b84d6ae9c Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 11:42:34 -0400 Subject: [PATCH 0120/1256] remove chain, functionality is built in now. --- agixt/chains/Generate Image.json | 60 -------------------------------- 1 file changed, 60 deletions(-) delete mode 100644 agixt/chains/Generate Image.json diff --git a/agixt/chains/Generate Image.json b/agixt/chains/Generate Image.json deleted file mode 100644 index c2e345b14c9a..000000000000 --- a/agixt/chains/Generate Image.json +++ /dev/null @@ -1,60 +0,0 @@ -{ - "chain_name": "Generate Image", - "steps": [ - { - "step": 1, - "agent_name": "gpt4free", - "prompt_type": "Prompt", - "prompt": { - "prompt_name": "AGiXT SD Generator_V3", - "prompt_category": "Default", - "user_input": "{user_input}", - "shots": 1, - "context_results": 0, - "browse_links": false, - "websearch": false, - "websearch_depth": 0, - "disable_memory": true, - "inject_memories_from_collection_number": 0, - "conversation_results": 0, - "conversation": "AGiXT Terminal" - } - }, - { - "step": 2, - "agent_name": "gpt4free", - "prompt_type": "Command", - "prompt": { - "command_name": "Generate Image with Stable Diffusion", - "prompt": "{STEP1}", - "filename": "", - "negative_prompt": "", - "batch_size": "", - "cfg_scale": "", - "denoising_strength": "", - "enable_hr": "", - "eta": "", - "firstphase_height": "", - "firstphase_width": "", - "height": "", - "n_iter": "", - "restore_faces": "", - "s_churn": "", - "s_noise": "", - "s_tmax": "", - "s_tmin": "", - "sampler_index": "", - "seed": "", - "seed_resize_from_h": "", - "seed_resize_from_w": "", - "steps": "", - "styles": "", - "subseed": "", - "subseed_strength": "", - "tiling": "", - "width": "", - "conversation": "AGiXT Terminal" - } - } - ] -} \ No newline at end of file From 7979bf2c6197dec57e76899f4a690dad745a1da3 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 11:59:20 -0400 Subject: [PATCH 0121/1256] improve imports --- agixt/Chain.py | 7 +++++++ agixt/SeedImports.py | 44 ++------------------------------------------ 2 files changed, 9 insertions(+), 42 deletions(-) diff --git a/agixt/Chain.py b/agixt/Chain.py index 2df4e80fe30a..c6b3abe15675 100644 --- a/agixt/Chain.py +++ b/agixt/Chain.py @@ -503,6 +503,13 @@ def get_chain_responses(self, chain_name): return responses def import_chain(self, chain_name: str, steps: dict): + chain = ( + self.session.query(ChainDB) + .filter(ChainDB.name == chain_name, ChainDB.user_id == self.user_id) + .first() + ) + if chain: + return None chain = ChainDB(name=chain_name, user_id=self.user_id) self.session.add(chain) self.session.commit() diff --git a/agixt/SeedImports.py b/agixt/SeedImports.py index e2068926464f..c5757ee3e737 100644 --- a/agixt/SeedImports.py +++ b/agixt/SeedImports.py @@ -54,20 +54,6 @@ def import_extensions(): # Get the existing extensions and commands from the database existing_extensions = session.query(Extension).all() existing_commands = session.query(Command).all() - - # Delete commands that don't exist in the extensions data - for command in existing_commands: - command_exists = any( - extension_data["extension_name"] == command.extension.name - and any( - cmd["friendly_name"] == command.name - for cmd in extension_data["commands"] - ) - for extension_data in extensions_data - ) - if not command_exists: - session.delete(command) - # Add new extensions and commands, and update existing commands for extension_data in extensions_data: extension_name = extension_data["extension_name"] @@ -85,15 +71,11 @@ def import_extensions(): session.add(extension) session.flush() existing_extensions.append(extension) - commands = extension_data["commands"] - for command_data in commands: if "friendly_name" not in command_data: continue - command_name = command_data["friendly_name"] - # Find the existing command or create a new one command = next( ( @@ -112,7 +94,6 @@ def import_extensions(): session.flush() existing_commands.append(command) logging.info(f"Imported command: {command_name}") - # Add command arguments if "command_args" in command_data: command_args = command_data["command_args"] @@ -129,7 +110,6 @@ def import_extensions(): ) session.add(command_arg) logging.info(f"Imported argument: {arg} to command: {command_name}") - session.commit() # Add extensions to the database if they don't exist @@ -197,11 +177,8 @@ def import_chains(user=DEFAULT_USER): try: chain_data = json.load(f) result = chain_importer.import_chain(chain_name, chain_data) - logging.info(result) - except json.JSONDecodeError as e: - logging.info( - f"(1/3) Error importing chain from '{file}': Invalid JSON format." - ) + if result: + logging.info(result) except Exception as e: logging.info(f"(1/3) Error importing chain from '{file}': {str(e)}") failures.append(file) @@ -216,10 +193,6 @@ def import_chains(user=DEFAULT_USER): result = chain_importer.import_chain(chain_name, chain_data) logging.info(result) failures.remove(file) - except json.JSONDecodeError as e: - logging.info( - f"(2/3) Error importing chain from '{file}': Invalid JSON format." - ) except Exception as e: logging.info(f"(2/3) Error importing chain from '{file}': {str(e)}") if failures: @@ -233,10 +206,6 @@ def import_chains(user=DEFAULT_USER): result = chain_importer.import_chain(chain_name, chain_data) logging.info(result) failures.remove(file) - except json.JSONDecodeError as e: - logging.info( - f"(3/3) Error importing chain from '{file}': Invalid JSON format." - ) except Exception as e: logging.info( f"(3/3) Error importing chain from '{file}': {str(e)}" @@ -397,23 +366,14 @@ def import_conversations(user=DEFAULT_USER): def import_providers(): session = get_session() providers = get_providers() - existing_providers = session.query(Provider).all() - existing_provider_names = [provider.name for provider in existing_providers] - for provider in existing_providers: - if provider.name not in providers: - session.delete(provider) - for provider_name in providers: provider_options = get_provider_options(provider_name) - provider = session.query(Provider).filter_by(name=provider_name).one_or_none() - if provider: logging.info(f"Updating provider: {provider_name}") else: provider = Provider(name=provider_name) session.add(provider) - existing_provider_names.append(provider_name) logging.info(f"Imported provider: {provider_name}") session.commit() From 97a7abb5347c43aba9d7ce700a4ede8cb05a32d1 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 12:21:52 -0400 Subject: [PATCH 0122/1256] fix ref --- examples/AGiXT-Expert-ezLocalai.ipynb | 2 +- examples/Postgres-Chat-ezLocalai.ipynb | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/AGiXT-Expert-ezLocalai.ipynb b/examples/AGiXT-Expert-ezLocalai.ipynb index 559b34faf2f9..46566943f430 100644 --- a/examples/AGiXT-Expert-ezLocalai.ipynb +++ b/examples/AGiXT-Expert-ezLocalai.ipynb @@ -44,7 +44,7 @@ " \"tts_provider\": \"ezlocalai\", # Change this to `default` or whichever TTS provider you want to use. None means no voice response.\n", " \"VOICE\": \"Morgan_Freeman\", # Voice for TTS, change this to the voice you want to use from ezlocalai.\n", " \"AI_MODEL\": \"ezlocalai\", # It doesn't matter which model you put here, ezlocalai uses the model it was started with.\n", - " \"EZLOCALAI_API_URL\": \"http://ezlocalai:8091/v1/\", # URL for the EZLOCALAI API, change this to your EZLOCALAI API URL. Never use localhost here, it is a different container.\n", + " \"EZLOCALAI_API_URI\": \"http://ezlocalai:8091/v1/\", # URL for the EZLOCALAI API, change this to your EZLOCALAI API URL. Never use localhost here, it is a different container.\n", " \"EZLOCALAI_API_KEY\": \"Your EZLOCALAI API key\", # Change this to your EZLOCALAI API key\n", " \"MAX_TOKENS\": 4096,\n", " \"AI_TEMPERATURE\": \"0.7\",\n", diff --git a/examples/Postgres-Chat-ezLocalai.ipynb b/examples/Postgres-Chat-ezLocalai.ipynb index cfca9935544e..e7175e98125d 100644 --- a/examples/Postgres-Chat-ezLocalai.ipynb +++ b/examples/Postgres-Chat-ezLocalai.ipynb @@ -54,7 +54,7 @@ " \"vision_provider\": \"ezlocalai\", # Vision provider, None means no vision capabilities. We will use OpenAI's since we're using GPT-4o.\n", " \"tts_provider\": \"None\", # The responses in Postgres Chat will often be CSV format, so we don't need TTS.\n", " \"AI_MODEL\": \"ezlocalai\", # It doesn't matter which model you put here, ezlocalai uses the model it was started with.\n", - " \"EZLOCALAI_API_URL\": \"http://ezlocalai:8091/v1/\", # URL for the EZLOCALAI API, change this to your EZLOCALAI API URL. Never use localhost here, it is a different container.\n", + " \"EZLOCALAI_API_URI\": \"http://ezlocalai:8091/v1/\", # URL for the EZLOCALAI API, change this to your EZLOCALAI API URL. Never use localhost here, it is a different container.\n", " \"EZLOCALAI_API_KEY\": \"Your EZLOCALAI API key\", # Change this to your EZLOCALAI API key\n", " \"MAX_TOKENS\": 4096,\n", " \"AI_TEMPERATURE\": \"0.7\",\n", From f909798a485deff85ac6ef1828b00f85c9754988 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 13:02:16 -0400 Subject: [PATCH 0123/1256] add handling --- agixt/XT.py | 29 ++++++++++++++++++--------- agixt/extensions/postgres_database.py | 1 + 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index fd084e4a75ca..131611283114 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -322,7 +322,7 @@ async def run_chain_step( agent_name = agent_override else: agent_name = step["agent_name"] - prompt_type = step["prompt_type"] + prompt_type = str(step["prompt_type"]).lower() step_number = step["step"] if "prompt_name" in step["prompt"]: prompt_name = step["prompt"]["prompt_name"] @@ -346,7 +346,7 @@ async def run_chain_step( args["conversation_name"] = f"Chain Execution History: {chain_name}" if "conversation" in args: args["conversation_name"] = args["conversation"] - if prompt_type == "Command": + if prompt_type == "command": if conversation_name != "" and conversation_name != None: c.log_interaction( role=self.agent_name, @@ -358,7 +358,7 @@ async def run_chain_step( conversation_name=args["conversation_name"], voice_response=False, ) - elif prompt_type == "Prompt": + elif prompt_type == "prompt": if conversation_name != "" and conversation_name != None: c.log_interaction( role=self.agent_name, @@ -366,18 +366,23 @@ async def run_chain_step( ) if "prompt_name" not in args: args["prompt_name"] = prompt_name - result = await self.inference( - agent_name=agent_name, - user_input=user_input, - log_user_input=False, - **args, - ) - elif prompt_type == "Chain": + if prompt_name != "": + result = await self.inference( + agent_name=agent_name, + user_input=user_input, + log_user_input=False, + **args, + ) + elif prompt_type == "chain": if conversation_name != "" and conversation_name != None: c.log_interaction( role=self.agent_name, message=f"[ACTIVITY] Running chain: {args['chain']} with args: {args}", ) + if "chain_name" in args: + args["chain"] = args["chain_name"] + if "user_input" in args: + args["input"] = args["user_input"] result = await self.execute_chain( chain_name=args["chain"], user_input=args["input"], @@ -447,6 +452,10 @@ async def execute_chain( tasks = [] step_responses = [] logging.info(f"Chain data: {chain_data}") + if "steps" not in chain_data: + return f"Chain `{chain_name}` has no steps." + if len(chain_data["steps"]) == 0: + return f"Chain `{chain_name}` has no steps." for step_data in chain_data["steps"]: if int(step_data["step"]) >= int(from_step): if "prompt" in step_data and "step" in step_data: diff --git a/agixt/extensions/postgres_database.py b/agixt/extensions/postgres_database.py index 001a966baa98..0aef2a899c13 100644 --- a/agixt/extensions/postgres_database.py +++ b/agixt/extensions/postgres_database.py @@ -64,6 +64,7 @@ async def execute_sql(self, query: str): query = query.split("```sql")[1].split("```")[0] query = query.replace("\n", " ") query = query.strip() + query = query.replace("```", "") logging.info(f"Executing SQL Query: {query}") connection = self.get_connection() if not connection: From 68740629e981cf8f765fd5f7c806313bb459c8cd Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 13:10:34 -0400 Subject: [PATCH 0124/1256] fix step responses --- agixt/XT.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index 131611283114..c86857468d6b 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -449,7 +449,6 @@ async def execute_chain( message=f"[ACTIVITY] Running chain `{chain_name}`.", ) response = "" - tasks = [] step_responses = [] logging.info(f"Chain data: {chain_data}") if "steps" not in chain_data: @@ -477,7 +476,7 @@ async def execute_chain( chain_args=chain_args, conversation_name=conversation_name, ) - tasks.append(task) + step_responses.append(task) logging.info(f"Step responses: {step_responses}") if step_responses: response = step_responses[-1] From 323d01ad3c7167828fd87eac28c454cfeab15216 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 13:16:41 -0400 Subject: [PATCH 0125/1256] improve csv code block function --- agixt/extensions/agixt_actions.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/agixt/extensions/agixt_actions.py b/agixt/extensions/agixt_actions.py index d2f6615445fc..460f610f07f9 100644 --- a/agixt/extensions/agixt_actions.py +++ b/agixt/extensions/agixt_actions.py @@ -756,7 +756,9 @@ async def make_csv_code_block(self, data: str) -> str: Returns: str: The CSV code block """ - return f"```csv\n{data}\n```" + if "," in data or "\n" in data: + return f"```csv\n{data}\n```" + return data async def get_csv_preview(self, filename: str): """ From 60ada247835e4e28d2579dea50a5eb41b1b12320 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 13:21:13 -0400 Subject: [PATCH 0126/1256] remove recursion --- agixt/Interactions.py | 24 +----------------------- 1 file changed, 1 insertion(+), 23 deletions(-) diff --git a/agixt/Interactions.py b/agixt/Interactions.py index f1d5f3261ea5..e780b18421fa 100644 --- a/agixt/Interactions.py +++ b/agixt/Interactions.py @@ -590,29 +590,7 @@ async def run( error += f"{err.args}\n{err.name}\n{err.msg}\n" logging.error(f"{self.agent.PROVIDER} Error: {error}") logging.info(f"TOKENS: {tokens} PROMPT CONTENT: {formatted_prompt}") - self.failures += 1 - if self.failures == 5: - self.failures == 0 - logging.warning("Failed to get a response 5 times in a row.") - return None - logging.warning(f"Retrying in 10 seconds...") - time.sleep(10) - if context_results > 0: - context_results = context_results - 1 - prompt_args = { - "shots": shots, - "disable_memory": disable_memory, - "user_input": user_input, - "context_results": context_results, - "conversation_name": conversation_name, - **kwargs, - } - return await self.run( - prompt_name=prompt, - prompt_category=prompt_category, - log_user_input=log_user_input, - **prompt_args, - ) + return f"Unable to retrieve response." # Handle commands if the prompt contains the {COMMANDS} placeholder # We handle command injection that DOESN'T allow command execution by using {command_list} in the prompt if "{COMMANDS}" in unformatted_prompt: From d70cc58486f605ba496d98543a4588f8affd1990 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 13:39:08 -0400 Subject: [PATCH 0127/1256] fix conversation name for activity on command execution --- agixt/XT.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index c86857468d6b..eee387666487 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -280,7 +280,7 @@ async def execute_command( response = await Extensions( agent_name=self.agent_name, agent_config=self.agent.AGENT_CONFIG, - conversation_name=conversation_name, + conversation_name=f"{command_name} Execution History", ApiClient=self.ApiClient, api_key=self.api_key, user=self.user_email, @@ -355,7 +355,7 @@ async def run_chain_step( result = await self.execute_command( command_name=step["prompt"]["command_name"], command_args=args, - conversation_name=args["conversation_name"], + conversation_name=conversation_name, voice_response=False, ) elif prompt_type == "prompt": From cf215c62987af96551710c2b7ac5dd5cdf34743e Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 13:41:51 -0400 Subject: [PATCH 0128/1256] improve activity log for command execution --- agixt/XT.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index eee387666487..ed21565cdb95 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -275,7 +275,7 @@ async def execute_command( c = Conversations(conversation_name=conversation_name, user=self.user_email) c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY] Executing command: {command_name} with args: {command_args}", + message=f"[ACTIVITY] Executing command `{command_name}` with args:\n```json\n{json.dumps(command_args, indent=2)}```", ) response = await Extensions( agent_name=self.agent_name, @@ -350,7 +350,7 @@ async def run_chain_step( if conversation_name != "" and conversation_name != None: c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY] Executing command: {step['prompt']['command_name']} with args: {args}", + message=f"[ACTIVITY] Executing command `{step['prompt']['command_name']}` with args:\n```json\n{json.dumps(args, indent=2)}```", ) result = await self.execute_command( command_name=step["prompt"]["command_name"], From 8ac73c5ebc706701bceb7780484592b87f8021c5 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 14:11:18 -0400 Subject: [PATCH 0129/1256] reduce duplicate user inputs. --- agixt/XT.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/agixt/XT.py b/agixt/XT.py index ed21565cdb95..148c22b7fd98 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -366,6 +366,9 @@ async def run_chain_step( ) if "prompt_name" not in args: args["prompt_name"] = prompt_name + if "user_input" in args: + user_input = args["user_input"] + del args["user_input"] if prompt_name != "": result = await self.inference( agent_name=agent_name, From 3c6eea887bb6494c2a951c0ed7015dec3ea15538 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 14:22:04 -0400 Subject: [PATCH 0130/1256] use requests 2.31.0 --- static-requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/static-requirements.txt b/static-requirements.txt index 8009022ad887..cda4bd35e26b 100644 --- a/static-requirements.txt +++ b/static-requirements.txt @@ -7,7 +7,7 @@ pdfplumber==0.11.0 playwright==1.44.0 pandas==2.1.4 PyYAML==6.0.1 -requests==2.32.0 +requests==2.31.0 python-dotenv==1.0.0 ffmpeg-python==0.2.0 cryptography==42.0.5 From 6ea7b4055e8bad39623cfb66dcc4ea5eb83f71fe Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 14:40:31 -0400 Subject: [PATCH 0131/1256] add feedback endpoint --- agixt/Models.py | 7 +++++++ agixt/endpoints/Memory.py | 43 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/agixt/Models.py b/agixt/Models.py index a903de7c39c8..1bcc93b6d16c 100644 --- a/agixt/Models.py +++ b/agixt/Models.py @@ -171,6 +171,13 @@ class TextMemoryInput(BaseModel): collection_number: int = 0 +class FeedbackInput(BaseModel): + user_input: str + feedback: str + positive: Optional[bool] = True + conversation_name: Optional[str] = "" + + class TaskOutput(BaseModel): output: str message: Optional[str] = None diff --git a/agixt/endpoints/Memory.py b/agixt/endpoints/Memory.py index f97efad02749..702ecb5ece57 100644 --- a/agixt/endpoints/Memory.py +++ b/agixt/endpoints/Memory.py @@ -7,6 +7,7 @@ from Websearch import Websearch from XT import AGiXT from Memories import Memories +from Conversations import Conversations from readers.github import GithubReader from readers.file import FileReader from readers.arxiv import ArxivReader @@ -24,6 +25,7 @@ FinetuneAgentModel, ExternalSource, UserInput, + FeedbackInput, ) app = APIRouter() @@ -589,3 +591,44 @@ async def get_unique_external_sources( user=user, ).get_external_data_sources() return {"external_sources": external_sources} + + +# RLHF endpoint +@app.post( + "/api/agent/{agent_name}/feedback", + tags=["Memory"], + dependencies=[Depends(verify_api_key)], +) +async def rlhf( + agent_name: str, + data: FeedbackInput, + user=Depends(verify_api_key), + authorization: str = Header(None), +) -> ResponseMessage: + ApiClient = get_api_client(authorization=authorization) + agent_config = Agent( + agent_name=agent_name, user=user, ApiClient=ApiClient + ).get_agent_config() + if data.positive == True: + collection_number = 2 + else: + collection_number = 3 + if data.conversation_name != "" and data.conversation_name != None: + c = Conversations(conversation_name=data.conversation_name, user=user) + c.log_interaction( + role=agent_name, + message=f"[ACTIVITY] Added {'positive' if data.positive == True else 'negative'} feedback to memory.", + ) + + memory = Memories( + agent_name=agent_name, + agent_config=agent_config, + collection_number=collection_number, + ApiClient=ApiClient, + user=user, + ) + await memory.write_text_to_memory( + user_input=data.user_input, text=data.feedback, external_source="user feedback" + ) + response_message = f"{'Positive' if data.positive == True else 'Negative'} feedback received for agent {agent_name}." + return ResponseMessage(message=response_message) From 22f3aa35f9353c27931d83cf9c25c4446697afc1 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 14:53:57 -0400 Subject: [PATCH 0132/1256] fix log --- agixt/endpoints/Memory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agixt/endpoints/Memory.py b/agixt/endpoints/Memory.py index 702ecb5ece57..569bc2f8a839 100644 --- a/agixt/endpoints/Memory.py +++ b/agixt/endpoints/Memory.py @@ -617,7 +617,7 @@ async def rlhf( c = Conversations(conversation_name=data.conversation_name, user=user) c.log_interaction( role=agent_name, - message=f"[ACTIVITY] Added {'positive' if data.positive == True else 'negative'} feedback to memory.", + message=f"Added {'positive' if data.positive == True else 'negative'} feedback to memory.", ) memory = Memories( From d7649f356a8ce4951fc45cc04904af09eb82635a Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 15:54:22 -0400 Subject: [PATCH 0133/1256] improve rlhf --- agixt/Interactions.py | 3 ++ agixt/Models.py | 1 + agixt/XT.py | 2 + agixt/endpoints/Memory.py | 47 ++++++++++++-------- agixt/prompts/Default/Summarize Feedback.txt | 23 ++++++++++ 5 files changed, 57 insertions(+), 19 deletions(-) create mode 100644 agixt/prompts/Default/Summarize Feedback.txt diff --git a/agixt/Interactions.py b/agixt/Interactions.py index e780b18421fa..dbd6c63bf1ea 100644 --- a/agixt/Interactions.py +++ b/agixt/Interactions.py @@ -79,6 +79,9 @@ def __init__( self.agent_commands = "" self.websearch = None self.agent_memory = None + self.positive_feedback_memories = None + self.negative_feedback_memories = None + self.github_memories = None self.response = "" self.failures = 0 self.chain = Chain(user=user) diff --git a/agixt/Models.py b/agixt/Models.py index 1bcc93b6d16c..6a65acf3e527 100644 --- a/agixt/Models.py +++ b/agixt/Models.py @@ -173,6 +173,7 @@ class TextMemoryInput(BaseModel): class FeedbackInput(BaseModel): user_input: str + message: str feedback: str positive: Optional[bool] = True conversation_name: Optional[str] = "" diff --git a/agixt/XT.py b/agixt/XT.py index 148c22b7fd98..d4f14a426796 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -130,6 +130,7 @@ async def inference( browse_links: bool = False, voice_response: bool = False, log_user_input: bool = True, + log_output: bool = True, **kwargs, ): """ @@ -160,6 +161,7 @@ async def inference( images=images, tts=voice_response, log_user_input=log_user_input, + log_output=log_output, **kwargs, ) diff --git a/agixt/endpoints/Memory.py b/agixt/endpoints/Memory.py index 569bc2f8a839..6812235c094a 100644 --- a/agixt/endpoints/Memory.py +++ b/agixt/endpoints/Memory.py @@ -605,30 +605,39 @@ async def rlhf( user=Depends(verify_api_key), authorization: str = Header(None), ) -> ResponseMessage: - ApiClient = get_api_client(authorization=authorization) - agent_config = Agent( - agent_name=agent_name, user=user, ApiClient=ApiClient - ).get_agent_config() if data.positive == True: - collection_number = 2 + memory = agixt.agent_interactions.positive_feedback_memories else: - collection_number = 3 + memory = agixt.agent_interactions.negative_feedback_memories + agixt = AGiXT(user=user, agent_name=agent_name, api_key=authorization) + reflection = await agixt.inference( + user_input=data.user_input, + input_kind="positive" if data.positive == True else "negative", + assistant_response=data.message, + feedback=data.feedback, + conversation_name=data.conversation_name, + log_user_input=False, + log_output=False, + ) + memory_message = f""" +# Feedback received from a similar interaction in the past: +User Input: {data.user_input} +Assistant Response: {data.message} +Feedback: {data.feedback} +Reflection on the feedback: {reflection} +""" + await memory.write_text_to_memory( + user_input=data.user_input, + text=memory_message, + external_source="reflection from user feedback", + ) + response_message = ( + f"{'Positive' if data.positive == True else 'Negative'} feedback received." + ) if data.conversation_name != "" and data.conversation_name != None: c = Conversations(conversation_name=data.conversation_name, user=user) c.log_interaction( role=agent_name, - message=f"Added {'positive' if data.positive == True else 'negative'} feedback to memory.", + message=response_message, ) - - memory = Memories( - agent_name=agent_name, - agent_config=agent_config, - collection_number=collection_number, - ApiClient=ApiClient, - user=user, - ) - await memory.write_text_to_memory( - user_input=data.user_input, text=data.feedback, external_source="user feedback" - ) - response_message = f"{'Positive' if data.positive == True else 'Negative'} feedback received for agent {agent_name}." return ResponseMessage(message=response_message) diff --git a/agixt/prompts/Default/Summarize Feedback.txt b/agixt/prompts/Default/Summarize Feedback.txt new file mode 100644 index 000000000000..d2eef12787e6 --- /dev/null +++ b/agixt/prompts/Default/Summarize Feedback.txt @@ -0,0 +1,23 @@ +## Context + {context} + +Recent conversation history for context: + {conversation_history} + +Today's date is {date} . + +## System +The assistant is receiving feedback from the user. As an exercise in reflection, summarize the error that the assistant made to receive this feedback. + +The user's message prior to the one receiving {input_kind} was: + {user_input} + +The assistant's response was: + {assistant_response} + +The user felt this was {input_kind} and provided the following feedback: + {feedback} + +This message is not to the user, it is to store into the assitants long term training data and reflect back on for future interactions. + + From 856a911d0d68d85da85701576d362467bcdca495 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 16:00:43 -0400 Subject: [PATCH 0134/1256] improve format --- agixt/endpoints/Memory.py | 18 ++++++++++++------ agixt/prompts/Default/Summarize Feedback.txt | 2 +- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/agixt/endpoints/Memory.py b/agixt/endpoints/Memory.py index 6812235c094a..487ba8f25b7d 100644 --- a/agixt/endpoints/Memory.py +++ b/agixt/endpoints/Memory.py @@ -619,12 +619,18 @@ async def rlhf( log_user_input=False, log_output=False, ) - memory_message = f""" -# Feedback received from a similar interaction in the past: -User Input: {data.user_input} -Assistant Response: {data.message} -Feedback: {data.feedback} -Reflection on the feedback: {reflection} + memory_message = f"""## Feedback received from a similar interaction in the past: +### User +{data.user_input} + +### Assistant +{data.message} + +### Feedback from User +{data.feedback} + +### Reflection on the feedback +{reflection} """ await memory.write_text_to_memory( user_input=data.user_input, diff --git a/agixt/prompts/Default/Summarize Feedback.txt b/agixt/prompts/Default/Summarize Feedback.txt index d2eef12787e6..9d1a6ab6f46c 100644 --- a/agixt/prompts/Default/Summarize Feedback.txt +++ b/agixt/prompts/Default/Summarize Feedback.txt @@ -20,4 +20,4 @@ The user felt this was {input_kind} and provided the following feedback: This message is not to the user, it is to store into the assitants long term training data and reflect back on for future interactions. - +Based on the interaction, what should the assistant do differently to improve the user experience? From dac2f77576611b5c669e89a6d90bdfe6006744c3 Mon Sep 17 00:00:00 2001 From: Jameson Grieve <37882431+JamesonRGrieve@users.noreply.github.com> Date: Sun, 9 Jun 2024 14:06:24 -0600 Subject: [PATCH 0135/1256] Update README.md Signed-off-by: Jameson Grieve <37882431+JamesonRGrieve@users.noreply.github.com> --- docs/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/README.md b/docs/README.md index 68a3a311c68b..69e1411519e6 100644 --- a/docs/README.md +++ b/docs/README.md @@ -2,7 +2,7 @@ [![GitHub](https://img.shields.io/badge/GitHub-Sponsor%20Josh%20XT-blue?logo=github&style=plastic)](https://github.com/sponsors/Josh-XT) [![PayPal](https://img.shields.io/badge/PayPal-Sponsor%20Josh%20XT-blue.svg?logo=paypal&style=plastic)](https://paypal.me/joshxt) [![Ko-Fi](https://img.shields.io/badge/Kofi-Sponsor%20Josh%20XT-blue.svg?logo=kofi&style=plastic)](https://ko-fi.com/joshxt) -[![GitHub](https://img.shields.io/badge/GitHub-AGiXT%20Core-blue?logo=github&style=plastic)](https://github.com/Josh-XT/AGiXT) [![GitHub](https://img.shields.io/badge/GitHub-AGiXT%20Web%20UI-blue?logo=github&style=plastic)](https://github.com/AGiXT/streamlit) +[![GitHub](https://img.shields.io/badge/GitHub-AGiXT%20Core-blue?logo=github&style=plastic)](https://github.com/Josh-XT/AGiXT) [![GitHub](https://img.shields.io/badge/GitHub-AGiXT%20Interactive%20UI-blue?logo=github&style=plastic)](https://github.com/JamesonRGrieve/AGiXT-Interactive) [![GitHub](https://img.shields.io/badge/GitHub-AGiXT%20StreamLit%20UI-blue?logo=github&style=plastic)](https://github.com/AGiXT/streamlit) [![GitHub](https://img.shields.io/badge/GitHub-AGiXT%20Python%20SDK-blue?logo=github&style=plastic)](https://github.com/AGiXT/python-sdk) [![pypi](https://img.shields.io/badge/pypi-AGiXT%20Python%20SDK-blue?logo=pypi&style=plastic)](https://pypi.org/project/agixtsdk/) From 9d881b6d128c21e2997a62a3a939837dd2d61ccd Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 16:11:11 -0400 Subject: [PATCH 0136/1256] fix tab --- agixt/XT.py | 122 ++++++++++++++++++++++++++-------------------------- 1 file changed, 61 insertions(+), 61 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index d4f14a426796..599874d860b6 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -879,71 +879,71 @@ async def chat_completions(self, prompt: ChatCompletions): conversation_name=conversation_name, ) new_prompt += transcribed_audio - # Add user input to conversation - c = Conversations(conversation_name=conversation_name, user=self.user_email) - c.log_interaction(role="USER", message=new_prompt) - for file in files: - await self.learn_from_file( - file_url=file["file_url"], - file_name=file["file_name"], - user_input=new_prompt, - collection_number=1, - conversation_name=conversation_name, - ) - await self.learn_from_websites( - urls=urls, - scrape_depth=3, - summarize_content=False, + # Add user input to conversation + c = Conversations(conversation_name=conversation_name, user=self.user_email) + c.log_interaction(role="USER", message=new_prompt) + for file in files: + await self.learn_from_file( + file_url=file["file_url"], + file_name=file["file_name"], + user_input=new_prompt, + collection_number=1, conversation_name=conversation_name, ) - if mode == "command" and command_name and command_variable: - try: - command_args = ( - json.loads(self.agent_settings["command_args"]) - if isinstance(self.agent_settings["command_args"], str) - else self.agent_settings["command_args"] - ) - except Exception as e: - command_args = {} - command_args[self.agent_settings["command_variable"]] = new_prompt - response = await self.execute_command( - command_name=self.agent_settings["command_name"], - command_args=command_args, - conversation_name=conversation_name, - voice_response=tts, - ) - elif mode == "chain" and chain_name: - chain_name = self.agent_settings["chain_name"] - try: - chain_args = ( - json.loads(self.agent_settings["chain_args"]) - if isinstance(self.agent_settings["chain_args"], str) - else self.agent_settings["chain_args"] - ) - except Exception as e: - chain_args = {} - response = await self.execute_chain( - chain_name=chain_name, - user_input=new_prompt, - agent_override=self.agent_name, - chain_args=chain_args, - log_user_input=False, - conversation_name=conversation_name, - voice_response=tts, + await self.learn_from_websites( + urls=urls, + scrape_depth=3, + summarize_content=False, + conversation_name=conversation_name, + ) + if mode == "command" and command_name and command_variable: + try: + command_args = ( + json.loads(self.agent_settings["command_args"]) + if isinstance(self.agent_settings["command_args"], str) + else self.agent_settings["command_args"] ) - elif mode == "prompt": - response = await self.inference( - user_input=new_prompt, - prompt_name=prompt_name, - prompt_category=prompt_category, - conversation_name=conversation_name, - injected_memories=context_results, - shots=prompt.n, - browse_links=browse_links, - voice_response=tts, - log_user_input=False, - **prompt_args, + except Exception as e: + command_args = {} + command_args[self.agent_settings["command_variable"]] = new_prompt + response = await self.execute_command( + command_name=self.agent_settings["command_name"], + command_args=command_args, + conversation_name=conversation_name, + voice_response=tts, + ) + elif mode == "chain" and chain_name: + chain_name = self.agent_settings["chain_name"] + try: + chain_args = ( + json.loads(self.agent_settings["chain_args"]) + if isinstance(self.agent_settings["chain_args"], str) + else self.agent_settings["chain_args"] ) + except Exception as e: + chain_args = {} + response = await self.execute_chain( + chain_name=chain_name, + user_input=new_prompt, + agent_override=self.agent_name, + chain_args=chain_args, + log_user_input=False, + conversation_name=conversation_name, + voice_response=tts, + ) + elif mode == "prompt": + response = await self.inference( + user_input=new_prompt, + prompt_name=prompt_name, + prompt_category=prompt_category, + conversation_name=conversation_name, + injected_memories=context_results, + shots=prompt.n, + browse_links=browse_links, + voice_response=tts, + log_user_input=False, + **prompt_args, + ) prompt_tokens = get_tokens(new_prompt) completion_tokens = get_tokens(response) total_tokens = int(prompt_tokens) + int(completion_tokens) From 64d5fdb7fe3cabd0794bccd58f74956df7ff425b Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 16:12:07 -0400 Subject: [PATCH 0137/1256] fix ref --- agixt/endpoints/Memory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agixt/endpoints/Memory.py b/agixt/endpoints/Memory.py index 487ba8f25b7d..ef685f4c9102 100644 --- a/agixt/endpoints/Memory.py +++ b/agixt/endpoints/Memory.py @@ -605,11 +605,11 @@ async def rlhf( user=Depends(verify_api_key), authorization: str = Header(None), ) -> ResponseMessage: + agixt = AGiXT(user=user, agent_name=agent_name, api_key=authorization) if data.positive == True: memory = agixt.agent_interactions.positive_feedback_memories else: memory = agixt.agent_interactions.negative_feedback_memories - agixt = AGiXT(user=user, agent_name=agent_name, api_key=authorization) reflection = await agixt.inference( user_input=data.user_input, input_kind="positive" if data.positive == True else "negative", From a9fc09bdbee71002d998890c1c8d4b8490c21430 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 16:15:36 -0400 Subject: [PATCH 0138/1256] try workspace --- agixt/extensions/agixt_actions.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/agixt/extensions/agixt_actions.py b/agixt/extensions/agixt_actions.py index 460f610f07f9..a10d8e81ee3a 100644 --- a/agixt/extensions/agixt_actions.py +++ b/agixt/extensions/agixt_actions.py @@ -25,15 +25,9 @@ def install_docker_image(): return client -def execute_python_code(code: str, working_directory: str = None) -> str: +def execute_python_code(code: str) -> str: docker_image = "joshxt/safeexecute:latest" - if working_directory is None: - working_directory = os.path.join(os.getcwd(), "WORKSPACE") - docker_working_dir = working_directory - if os.environ.get("DOCKER_CONTAINER", False): - docker_working_dir = os.environ.get("WORKING_DIRECTORY", working_directory) - if not os.path.exists(working_directory): - os.makedirs(working_directory) + docker_working_dir = "WORKSPACE" # Check if there are any package requirements in the code to install package_requirements = re.findall(r"pip install (.*)", code) # Strip out python code blocks if they exist in the code @@ -724,7 +718,7 @@ async def execute_python_code_internal(self, code: str, text: str = "") -> str: filepath = os.path.join(self.WORKING_DIRECTORY, filename) with open(filepath, "w") as f: f.write(text) - return execute_python_code(code=code, working_directory=working_dir) + return execute_python_code(code=code) async def get_mindmap(self, task: str): """ From af226490f1793a6b0d54322dc8cbc600db758770 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 16:24:23 -0400 Subject: [PATCH 0139/1256] add logging --- agixt/extensions/agixt_actions.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/agixt/extensions/agixt_actions.py b/agixt/extensions/agixt_actions.py index a10d8e81ee3a..41d41f29e57d 100644 --- a/agixt/extensions/agixt_actions.py +++ b/agixt/extensions/agixt_actions.py @@ -27,13 +27,14 @@ def install_docker_image(): def execute_python_code(code: str) -> str: docker_image = "joshxt/safeexecute:latest" - docker_working_dir = "WORKSPACE" + docker_working_dir = "/agixt/WORKSPACE" # Check if there are any package requirements in the code to install package_requirements = re.findall(r"pip install (.*)", code) # Strip out python code blocks if they exist in the code if "```python" in code: code = code.split("```python")[1].split("```")[0] - temp_file = os.path.join(os.getcwd(), "WORKSPACE", "temp.py") + temp_file = "/agixt/WORKSPACE/temp.py" + logging.info(f"Writing Python code to temporary file: {temp_file}") with open(temp_file, "w") as f: f.write(code) os.chmod(temp_file, 0o755) # Set executable permissions @@ -49,11 +50,11 @@ def execute_python_code(code: str) -> str: f"pip install {package}", volumes={ os.path.abspath(docker_working_dir): { - "bind": "/workspace", + "bind": "/WORKSPACE", "mode": "rw", } }, - working_dir="/workspace", + working_dir="/WORKSPACE", stderr=True, stdout=True, detach=True, @@ -64,14 +65,14 @@ def execute_python_code(code: str) -> str: # Run the Python code in the container container = client.containers.run( docker_image, - f"python /workspace/temp.py", + f"python /WORKSPACE/temp.py", volumes={ os.path.abspath(docker_working_dir): { - "bind": "/workspace", + "bind": "/WORKSPACE", "mode": "rw", } }, - working_dir="/workspace", + working_dir="/WORKSPACE", stderr=True, stdout=True, detach=True, From 4533baeb10ea2dd0ab87d7c87a77954e11d5ba5f Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 16:54:00 -0400 Subject: [PATCH 0140/1256] improve prompt --- agixt/Interactions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agixt/Interactions.py b/agixt/Interactions.py index dbd6c63bf1ea..e3a43b4f4ec8 100644 --- a/agixt/Interactions.py +++ b/agixt/Interactions.py @@ -652,7 +652,7 @@ async def run( and agent_settings["image_provider"] != None and agent_settings["image_provider"] != "default" ): - img_gen_prompt = f"Users message: {user_input} \n\n{'The user uploaded an image, one does not need generated unless the user is specifically asking.' if images else ''} **The assistant is acting as sentiment analysis expert and only responds with a concise YES or NO answer on if the user would like an image as visual or a picture generated. No other explanation is needed!**\nWould the user potentially like an image generated based on their message?\nAssistant: " + img_gen_prompt = f"Users message: {user_input} \n\n{'The user uploaded an image, one does not need generated unless the user is specifically asking.' if images else ''} **The assistant is acting as sentiment analysis expert and only responds with a concise YES or NO answer on if the user would like a creative generated image to be generated by AI in their request. No other explanation is needed!**\nWould the user potentially like an image generated based on their message?\nAssistant: " create_img = await self.agent.inference(prompt=img_gen_prompt) create_img = str(create_img).lower() logging.info(f"Image Generation Decision Response: {create_img}") From e5d8843fea16e9a46e9c67a7a037ccfe56f5ac99 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 17:01:36 -0400 Subject: [PATCH 0141/1256] fix attempt --- agixt/extensions/agixt_actions.py | 57 +++++++++++++------------------ 1 file changed, 24 insertions(+), 33 deletions(-) diff --git a/agixt/extensions/agixt_actions.py b/agixt/extensions/agixt_actions.py index 41d41f29e57d..0aa3982908c7 100644 --- a/agixt/extensions/agixt_actions.py +++ b/agixt/extensions/agixt_actions.py @@ -28,51 +28,42 @@ def install_docker_image(): def execute_python_code(code: str) -> str: docker_image = "joshxt/safeexecute:latest" docker_working_dir = "/agixt/WORKSPACE" + host_working_dir = os.path.abspath("WORKSPACE") + # Ensure the host working directory exists + os.makedirs(host_working_dir, exist_ok=True) # Check if there are any package requirements in the code to install package_requirements = re.findall(r"pip install (.*)", code) # Strip out python code blocks if they exist in the code if "```python" in code: code = code.split("```python")[1].split("```")[0] - temp_file = "/agixt/WORKSPACE/temp.py" + temp_file = os.path.join(host_working_dir, "temp.py") logging.info(f"Writing Python code to temporary file: {temp_file}") with open(temp_file, "w") as f: f.write(code) - os.chmod(temp_file, 0o755) # Set executable permissions try: client = install_docker_image() - if package_requirements: - # Install the required packages in the container - for package in package_requirements: - try: - logging.info(f"Installing package '{package}' in container") - client.containers.run( - docker_image, - f"pip install {package}", - volumes={ - os.path.abspath(docker_working_dir): { - "bind": "/WORKSPACE", - "mode": "rw", - } - }, - working_dir="/WORKSPACE", - stderr=True, - stdout=True, - detach=True, - ) - except Exception as e: - logging.error(f"Error installing package '{package}': {str(e)}") - return f"Error: {str(e)}" - # Run the Python code in the container + for package in package_requirements: + try: + logging.info(f"Installing package '{package}' in container") + client.containers.run( + docker_image, + f"pip install {package}", + volumes={ + host_working_dir: {"bind": docker_working_dir, "mode": "rw"} + }, + working_dir=docker_working_dir, + stderr=True, + stdout=True, + remove=True, + ) + except Exception as e: + logging.error(f"Error installing package '{package}': {str(e)}") + return f"Error: {str(e)}" container = client.containers.run( docker_image, - f"python /WORKSPACE/temp.py", - volumes={ - os.path.abspath(docker_working_dir): { - "bind": "/WORKSPACE", - "mode": "rw", - } - }, - working_dir="/WORKSPACE", + f"python {os.path.join(docker_working_dir, 'temp.py')}", + volumes={host_working_dir: {"bind": docker_working_dir, "mode": "rw"}}, + working_dir=docker_working_dir, stderr=True, stdout=True, detach=True, From 87f01dcb0aaf096fc41547cea5e6e0b920b12023 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 17:27:36 -0400 Subject: [PATCH 0142/1256] add logging --- agixt/extensions/agixt_actions.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/agixt/extensions/agixt_actions.py b/agixt/extensions/agixt_actions.py index 0aa3982908c7..0d3333e411ff 100644 --- a/agixt/extensions/agixt_actions.py +++ b/agixt/extensions/agixt_actions.py @@ -29,19 +29,27 @@ def execute_python_code(code: str) -> str: docker_image = "joshxt/safeexecute:latest" docker_working_dir = "/agixt/WORKSPACE" host_working_dir = os.path.abspath("WORKSPACE") + # Ensure the host working directory exists os.makedirs(host_working_dir, exist_ok=True) + # Check if there are any package requirements in the code to install package_requirements = re.findall(r"pip install (.*)", code) + # Strip out python code blocks if they exist in the code if "```python" in code: code = code.split("```python")[1].split("```")[0] + temp_file = os.path.join(host_working_dir, "temp.py") logging.info(f"Writing Python code to temporary file: {temp_file}") + with open(temp_file, "w") as f: f.write(code) + try: client = install_docker_image() + + # Install the required packages in the container for package in package_requirements: try: logging.info(f"Installing package '{package}' in container") @@ -59,6 +67,8 @@ def execute_python_code(code: str) -> str: except Exception as e: logging.error(f"Error installing package '{package}': {str(e)}") return f"Error: {str(e)}" + + # Run the Python code in the container container = client.containers.run( docker_image, f"python {os.path.join(docker_working_dir, 'temp.py')}", @@ -68,10 +78,19 @@ def execute_python_code(code: str) -> str: stdout=True, detach=True, ) - container.wait() + + # Wait for the container to finish and capture the logs + result = container.wait() logs = container.logs().decode("utf-8") container.remove() + + # Clean up the temporary file os.remove(temp_file) + + if result["StatusCode"] != 0: + logging.error(f"Error executing Python code: {logs}") + return f"Error: {logs}" + logging.info(f"Python code executed successfully. Logs: {logs}") return logs except Exception as e: From 43de7208a9d37444e645c383b82b5ed6903e8853 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 17:31:59 -0400 Subject: [PATCH 0143/1256] add logging --- agixt/extensions/agixt_actions.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/agixt/extensions/agixt_actions.py b/agixt/extensions/agixt_actions.py index 0d3333e411ff..a75869fd77d1 100644 --- a/agixt/extensions/agixt_actions.py +++ b/agixt/extensions/agixt_actions.py @@ -32,6 +32,7 @@ def execute_python_code(code: str) -> str: # Ensure the host working directory exists os.makedirs(host_working_dir, exist_ok=True) + logging.info(f"Host working directory: {host_working_dir}") # Check if there are any package requirements in the code to install package_requirements = re.findall(r"pip install (.*)", code) @@ -46,6 +47,10 @@ def execute_python_code(code: str) -> str: with open(temp_file, "w") as f: f.write(code) + logging.info( + f"Temporary file written. Checking if the file exists: {os.path.exists(temp_file)}" + ) + try: client = install_docker_image() @@ -68,6 +73,7 @@ def execute_python_code(code: str) -> str: logging.error(f"Error installing package '{package}': {str(e)}") return f"Error: {str(e)}" + logging.info(f"Running the Python code in the container") # Run the Python code in the container container = client.containers.run( docker_image, @@ -86,6 +92,7 @@ def execute_python_code(code: str) -> str: # Clean up the temporary file os.remove(temp_file) + logging.info(f"Temporary file removed") if result["StatusCode"] != 0: logging.error(f"Error executing Python code: {logs}") @@ -718,7 +725,6 @@ async def execute_python_code_internal(self, code: str, text: str = "") -> str: Returns: str: The result of the Python code """ - working_dir = os.environ.get("WORKING_DIRECTORY", self.WORKING_DIRECTORY) if text: csv_content_header = text.split("\n")[0] # Remove any trailing spaces from any headers From c7f4020618d7569732cc8f04a60a018f6112d9d0 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 17:38:29 -0400 Subject: [PATCH 0144/1256] add logging --- agixt/extensions/agixt_actions.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/agixt/extensions/agixt_actions.py b/agixt/extensions/agixt_actions.py index a75869fd77d1..7c56c947c4f7 100644 --- a/agixt/extensions/agixt_actions.py +++ b/agixt/extensions/agixt_actions.py @@ -73,6 +73,24 @@ def execute_python_code(code: str) -> str: logging.error(f"Error installing package '{package}': {str(e)}") return f"Error: {str(e)}" + # Debugging: List files in the container's working directory + logging.info( + "Listing files in the container's working directory before executing code" + ) + list_files_cmd = f"ls -la {docker_working_dir}" + output = client.containers.run( + docker_image, + list_files_cmd, + volumes={host_working_dir: {"bind": docker_working_dir, "mode": "rw"}}, + working_dir=docker_working_dir, + stderr=True, + stdout=True, + remove=True, + ) + logging.info( + f"Files in container's working directory:\n{output.decode('utf-8')}" + ) + logging.info(f"Running the Python code in the container") # Run the Python code in the container container = client.containers.run( From 77d19857f8b10eff6e876f6aae060672169fe0a5 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 17:46:36 -0400 Subject: [PATCH 0145/1256] get host working dir --- agixt/extensions/agixt_actions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agixt/extensions/agixt_actions.py b/agixt/extensions/agixt_actions.py index 7c56c947c4f7..a6d7da783c72 100644 --- a/agixt/extensions/agixt_actions.py +++ b/agixt/extensions/agixt_actions.py @@ -28,7 +28,7 @@ def install_docker_image(): def execute_python_code(code: str) -> str: docker_image = "joshxt/safeexecute:latest" docker_working_dir = "/agixt/WORKSPACE" - host_working_dir = os.path.abspath("WORKSPACE") + host_working_dir = os.getenv("WORKING_DIRECTORY", "/agixt/WORKSPACE") # Ensure the host working directory exists os.makedirs(host_working_dir, exist_ok=True) From b615315cfbfaeacbcae3f9a0d9d54f94592b1958 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 17:48:08 -0400 Subject: [PATCH 0146/1256] flip dir --- agixt/extensions/agixt_actions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agixt/extensions/agixt_actions.py b/agixt/extensions/agixt_actions.py index a6d7da783c72..47f601d8b621 100644 --- a/agixt/extensions/agixt_actions.py +++ b/agixt/extensions/agixt_actions.py @@ -41,7 +41,7 @@ def execute_python_code(code: str) -> str: if "```python" in code: code = code.split("```python")[1].split("```")[0] - temp_file = os.path.join(host_working_dir, "temp.py") + temp_file = os.path.join(docker_working_dir, "temp.py") logging.info(f"Writing Python code to temporary file: {temp_file}") with open(temp_file, "w") as f: From f9184201884bff178244e2228aab5971a2a224c9 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 17:49:17 -0400 Subject: [PATCH 0147/1256] fix path --- agixt/extensions/agixt_actions.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/agixt/extensions/agixt_actions.py b/agixt/extensions/agixt_actions.py index 47f601d8b621..633aadccdd71 100644 --- a/agixt/extensions/agixt_actions.py +++ b/agixt/extensions/agixt_actions.py @@ -29,11 +29,6 @@ def execute_python_code(code: str) -> str: docker_image = "joshxt/safeexecute:latest" docker_working_dir = "/agixt/WORKSPACE" host_working_dir = os.getenv("WORKING_DIRECTORY", "/agixt/WORKSPACE") - - # Ensure the host working directory exists - os.makedirs(host_working_dir, exist_ok=True) - logging.info(f"Host working directory: {host_working_dir}") - # Check if there are any package requirements in the code to install package_requirements = re.findall(r"pip install (.*)", code) From fde11e86ed27487fa06107d6f5301684d7b0c890 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 18:03:49 -0400 Subject: [PATCH 0148/1256] add agent id --- agixt/Agent.py | 2 +- agixt/extensions/agixt_actions.py | 47 +++++++++---------------------- 2 files changed, 14 insertions(+), 35 deletions(-) diff --git a/agixt/Agent.py b/agixt/Agent.py index 3a5452ce85f1..29764ff9ed81 100644 --- a/agixt/Agent.py +++ b/agixt/Agent.py @@ -161,7 +161,7 @@ def get_agents(user=DEFAULT_USER): # Check if the agent is in the output already if agent.name in [a["name"] for a in output]: continue - output.append({"name": agent.name, "status": False}) + output.append({"name": agent.name, "id": agent.id, "status": False}) return output diff --git a/agixt/extensions/agixt_actions.py b/agixt/extensions/agixt_actions.py index 633aadccdd71..24c421d38930 100644 --- a/agixt/extensions/agixt_actions.py +++ b/agixt/extensions/agixt_actions.py @@ -1,5 +1,6 @@ import datetime import json +import uuid import requests import os import re @@ -25,30 +26,26 @@ def install_docker_image(): return client -def execute_python_code(code: str) -> str: +def execute_python_code(code: str, agent_id: str = "") -> str: docker_image = "joshxt/safeexecute:latest" - docker_working_dir = "/agixt/WORKSPACE" + docker_working_dir = f"/agixt/WORKSPACE/{agent_id}" host_working_dir = os.getenv("WORKING_DIRECTORY", "/agixt/WORKSPACE") + host_working_dir = os.path.join(host_working_dir, agent_id) # Check if there are any package requirements in the code to install package_requirements = re.findall(r"pip install (.*)", code) - # Strip out python code blocks if they exist in the code if "```python" in code: code = code.split("```python")[1].split("```")[0] - - temp_file = os.path.join(docker_working_dir, "temp.py") + temp_file_name = f"{str(uuid.uuid4())}.py" + temp_file = os.path.join(docker_working_dir, temp_file_name) logging.info(f"Writing Python code to temporary file: {temp_file}") - with open(temp_file, "w") as f: f.write(code) - logging.info( f"Temporary file written. Checking if the file exists: {os.path.exists(temp_file)}" ) - try: client = install_docker_image() - # Install the required packages in the container for package in package_requirements: try: @@ -67,46 +64,23 @@ def execute_python_code(code: str) -> str: except Exception as e: logging.error(f"Error installing package '{package}': {str(e)}") return f"Error: {str(e)}" - - # Debugging: List files in the container's working directory - logging.info( - "Listing files in the container's working directory before executing code" - ) - list_files_cmd = f"ls -la {docker_working_dir}" - output = client.containers.run( - docker_image, - list_files_cmd, - volumes={host_working_dir: {"bind": docker_working_dir, "mode": "rw"}}, - working_dir=docker_working_dir, - stderr=True, - stdout=True, - remove=True, - ) - logging.info( - f"Files in container's working directory:\n{output.decode('utf-8')}" - ) - - logging.info(f"Running the Python code in the container") # Run the Python code in the container container = client.containers.run( docker_image, - f"python {os.path.join(docker_working_dir, 'temp.py')}", + f"python {os.path.join(docker_working_dir, temp_file_name)}", volumes={host_working_dir: {"bind": docker_working_dir, "mode": "rw"}}, working_dir=docker_working_dir, stderr=True, stdout=True, detach=True, ) - # Wait for the container to finish and capture the logs result = container.wait() logs = container.logs().decode("utf-8") container.remove() - # Clean up the temporary file os.remove(temp_file) logging.info(f"Temporary file removed") - if result["StatusCode"] != 0: logging.error(f"Error executing Python code: {logs}") return f"Error: {logs}" @@ -748,7 +722,12 @@ async def execute_python_code_internal(self, code: str, text: str = "") -> str: filepath = os.path.join(self.WORKING_DIRECTORY, filename) with open(filepath, "w") as f: f.write(text) - return execute_python_code(code=code) + agents = self.ApiClient.get_agents() + agent_id = "" + for agent in agents: + if agent["name"] == self.agent_name: + agent_id = agent["id"] + return execute_python_code(code=code, agent_id=agent_id) async def get_mindmap(self, task: str): """ From 852ea87ebe27bf481d2cb19a3b1bf209296c060e Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 18:16:16 -0400 Subject: [PATCH 0149/1256] add agent id to workspaces --- agixt/Agent.py | 3 +++ agixt/Interactions.py | 13 ++++++------- agixt/XT.py | 6 ++---- agixt/endpoints/Agent.py | 5 +++-- agixt/extensions/agixt_actions.py | 2 +- 5 files changed, 15 insertions(+), 14 deletions(-) diff --git a/agixt/Agent.py b/agixt/Agent.py index 29764ff9ed81..890ac69f1630 100644 --- a/agixt/Agent.py +++ b/agixt/Agent.py @@ -251,6 +251,9 @@ def __init__(self, agent_name=None, user=DEFAULT_USER, ApiClient=None): ApiClient=ApiClient, user=self.user, ).get_available_commands() + self.agent_id = str(self.get_agent_id()) + self.working_directory = os.path.join(os.getcwd(), "WORKSPACE", self.agent_id) + os.makedirs(self.working_directory, exist_ok=True) def load_config_keys(self): config_keys = [ diff --git a/agixt/Interactions.py b/agixt/Interactions.py index e3a43b4f4ec8..2261ac97a9c0 100644 --- a/agixt/Interactions.py +++ b/agixt/Interactions.py @@ -202,10 +202,7 @@ async def format_prompt( context = f"The user's input causes you remember these things:\n{context}\n" else: context = "" - try: - working_directory = self.agent.AGENT_CONFIG["settings"]["WORKING_DIRECTORY"] - except: - working_directory = "./WORKSPACE" + working_directory = self.agent.working_directory helper_agent_name = self.agent_name if "helper_agent_name" not in kwargs: if "helper_agent_name" in self.agent.AGENT_CONFIG["settings"]: @@ -294,7 +291,7 @@ async def format_prompt( file_name = file["file_name"] file_list.append(file_name) file_name = regex.sub(r"(\[.*?\])", "", file_name) - file_path = os.path.normpath(os.getcwd(), working_directory, file_name) + file_path = os.path.normpath(working_directory, file_name) if not file_path.startswith(os.getcwd()): pass if not os.path.exists(file_path): @@ -628,11 +625,13 @@ async def run( if not str(tts_response).startswith("http"): file_type = "wav" file_name = f"{uuid.uuid4().hex}.{file_type}" - audio_path = f"./WORKSPACE/{file_name}" + audio_path = os.path.join( + self.agent.working_directory, file_name + ) audio_data = base64.b64decode(tts_response) with open(audio_path, "wb") as f: f.write(audio_data) - tts_response = f'' + tts_response = f'' self.response = f"{self.response}\n\n{tts_response}" except Exception as e: logging.warning(f"Failed to get TTS response: {e}") diff --git a/agixt/XT.py b/agixt/XT.py index 599874d860b6..fe4d08b08214 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -36,10 +36,9 @@ def __init__(self, user: str, agent_name: str, api_key: str): else DEFAULT_SETTINGS ) self.chain = Chain(user=self.user_email) - self.agent_id = str(self.agent.get_agent_id()) - self.agent_workspace = os.path.join(os.getcwd(), "WORKSPACE", self.agent_id) + self.agent_workspace = self.agent.working_directory os.makedirs(self.agent_workspace, exist_ok=True) - self.outputs = f"{self.uri}/outputs/{self.agent_id}" + self.outputs = f"{self.uri}/outputs/{self.agent.agent_id}" async def prompts(self, prompt_category: str = "Default"): """ @@ -1101,7 +1100,6 @@ async def create_dataset_from_memories(self, batch_size: int = 10): "rejected": bad_answers, } # Save messages to a json file to be used as a dataset - agent_id = self.agent_interactions.agent.get_agent_id() dataset_dir = os.path.join(self.agent_workspace, "datasets") os.makedirs(dataset_dir, exist_ok=True) diff --git a/agixt/endpoints/Agent.py b/agixt/endpoints/Agent.py index d033ed8ba92a..6fcc0d9aa465 100644 --- a/agixt/endpoints/Agent.py +++ b/agixt/endpoints/Agent.py @@ -26,6 +26,7 @@ ) import base64 import uuid +import os app = APIRouter() @@ -326,9 +327,9 @@ async def text_to_speech( if not str(tts_response).startswith("http"): file_type = "wav" file_name = f"{uuid.uuid4().hex}.{file_type}" - audio_path = f"./WORKSPACE/{file_name}" + audio_path = os.path.join(agent.working_directory, file_name) audio_data = base64.b64decode(tts_response) with open(audio_path, "wb") as f: f.write(audio_data) - tts_response = f"{AGIXT_URI}/outputs/{file_name}" + tts_response = f"{AGIXT_URI}/outputs/{agent.agent_id}/{file_name}" return {"url": tts_response} diff --git a/agixt/extensions/agixt_actions.py b/agixt/extensions/agixt_actions.py index 24c421d38930..2d832f89e05b 100644 --- a/agixt/extensions/agixt_actions.py +++ b/agixt/extensions/agixt_actions.py @@ -29,6 +29,7 @@ def install_docker_image(): def execute_python_code(code: str, agent_id: str = "") -> str: docker_image = "joshxt/safeexecute:latest" docker_working_dir = f"/agixt/WORKSPACE/{agent_id}" + os.makedirs(docker_working_dir, exist_ok=True) host_working_dir = os.getenv("WORKING_DIRECTORY", "/agixt/WORKSPACE") host_working_dir = os.path.join(host_working_dir, agent_id) # Check if there are any package requirements in the code to install @@ -84,7 +85,6 @@ def execute_python_code(code: str, agent_id: str = "") -> str: if result["StatusCode"] != 0: logging.error(f"Error executing Python code: {logs}") return f"Error: {logs}" - logging.info(f"Python code executed successfully. Logs: {logs}") return logs except Exception as e: From 9c9639c246ddc72370e6577a085f3177e87c002a Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 18:17:43 -0400 Subject: [PATCH 0150/1256] fix agent id --- agixt/Agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agixt/Agent.py b/agixt/Agent.py index 890ac69f1630..7a5c011022ee 100644 --- a/agixt/Agent.py +++ b/agixt/Agent.py @@ -152,7 +152,7 @@ def get_agents(user=DEFAULT_USER): agents = session.query(AgentModel).filter(AgentModel.user.has(email=user)).all() output = [] for agent in agents: - output.append({"name": agent.name, "status": False}) + output.append({"name": agent.name, "id": agent.id, "status": False}) # Get global agents that belong to DEFAULT_USER global_agents = ( session.query(AgentModel).filter(AgentModel.user.has(email=DEFAULT_USER)).all() From 2493cae18b81f5f2a11c485af731f7d997553c09 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 18:22:11 -0400 Subject: [PATCH 0151/1256] persist scripts --- agixt/extensions/agixt_actions.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/agixt/extensions/agixt_actions.py b/agixt/extensions/agixt_actions.py index 2d832f89e05b..e4d306d130c4 100644 --- a/agixt/extensions/agixt_actions.py +++ b/agixt/extensions/agixt_actions.py @@ -79,9 +79,6 @@ def execute_python_code(code: str, agent_id: str = "") -> str: result = container.wait() logs = container.logs().decode("utf-8") container.remove() - # Clean up the temporary file - os.remove(temp_file) - logging.info(f"Temporary file removed") if result["StatusCode"] != 0: logging.error(f"Error executing Python code: {logs}") return f"Error: {logs}" From 8fcae16229651c041ac563e5a485f97236bb76a2 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 18:33:33 -0400 Subject: [PATCH 0152/1256] fix path --- agixt/extensions/agixt_actions.py | 12 +++++++++++- agixt/prompts/Default/Code Interpreter.txt | 2 +- agixt/prompts/Default/Verify Code Interpreter.txt | 2 +- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/agixt/extensions/agixt_actions.py b/agixt/extensions/agixt_actions.py index e4d306d130c4..66d59f81ee41 100644 --- a/agixt/extensions/agixt_actions.py +++ b/agixt/extensions/agixt_actions.py @@ -724,7 +724,17 @@ async def execute_python_code_internal(self, code: str, text: str = "") -> str: for agent in agents: if agent["name"] == self.agent_name: agent_id = agent["id"] - return execute_python_code(code=code, agent_id=agent_id) + execution_response = execute_python_code(code=code, agent_id=agent_id) + if "Error:" in execution_response: + # Ask the LLM for clarification on the code it wrote, show it the error. + clarification = await self.ApiClient.prompt_agent( + agent_name=self.agent_name, + prompt_name="Clarify Code", + prompt_args={ + "user_input": execution_response, + "conversation_name": self.conversation_name, + }, + ) async def get_mindmap(self, task: str): """ diff --git a/agixt/prompts/Default/Code Interpreter.txt b/agixt/prompts/Default/Code Interpreter.txt index f6fdd9341382..a0ebeae3eae4 100644 --- a/agixt/prompts/Default/Code Interpreter.txt +++ b/agixt/prompts/Default/Code Interpreter.txt @@ -24,7 +24,7 @@ If the user's input doesn't request any specific analysis or asks to surprise th **Make sure the final output of the code is a visualization. The functions final return should be a print of base64 image markdown string that can be displayed on a website parsing markdown code. Example `print('![Generated Image](data:image/png;base64,IMAGE_CONTENT)')`** -You are working the with file at `{import_file}`, use this exact file path in any code that will analyze it. CSV file preview: +You are working the with file at {import_file} . Use this exact file path in any code that will analyze it. CSV file preview: ```csv {file_preview} ``` diff --git a/agixt/prompts/Default/Verify Code Interpreter.txt b/agixt/prompts/Default/Verify Code Interpreter.txt index df310d7e1792..55c536e3af70 100644 --- a/agixt/prompts/Default/Verify Code Interpreter.txt +++ b/agixt/prompts/Default/Verify Code Interpreter.txt @@ -1,6 +1,6 @@ The date today is {date} . -You are working the with file at `{import_file}`, use this exact file path in any code that will analyze it. CSV file preview: +You are working the with file at {import_file} . Use this exact file path in any code that will analyze it. CSV file preview: ```csv {file_preview} ``` From 0284c05bed90672af567abdcd7305ca54d91fb2b Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 18:35:25 -0400 Subject: [PATCH 0153/1256] remove trailing new line --- agixt/extensions/agixt_actions.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/agixt/extensions/agixt_actions.py b/agixt/extensions/agixt_actions.py index 66d59f81ee41..d33ece9d6ee6 100644 --- a/agixt/extensions/agixt_actions.py +++ b/agixt/extensions/agixt_actions.py @@ -725,16 +725,9 @@ async def execute_python_code_internal(self, code: str, text: str = "") -> str: if agent["name"] == self.agent_name: agent_id = agent["id"] execution_response = execute_python_code(code=code, agent_id=agent_id) - if "Error:" in execution_response: - # Ask the LLM for clarification on the code it wrote, show it the error. - clarification = await self.ApiClient.prompt_agent( - agent_name=self.agent_name, - prompt_name="Clarify Code", - prompt_args={ - "user_input": execution_response, - "conversation_name": self.conversation_name, - }, - ) + if str(execution_response).endswith("\n"): + execution_response = execution_response[:-1] + return execution_response async def get_mindmap(self, task: str): """ From 8da0c7cefd137cfe66f13a8f197050bde81bca9d Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 18:43:56 -0400 Subject: [PATCH 0154/1256] fix output --- agixt/endpoints/Extension.py | 9 +++++++-- agixt/extensions/agixt_actions.py | 2 +- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/agixt/endpoints/Extension.py b/agixt/endpoints/Extension.py index 7acd633dee31..7e227ed0bee8 100644 --- a/agixt/endpoints/Extension.py +++ b/agixt/endpoints/Extension.py @@ -61,8 +61,13 @@ async def run_command( ).execute_command( command_name=command.command_name, command_args=command.command_args ) - c = Conversations(conversation_name=command.conversation_name, user=user) - c.log_interaction(role=agent_name, message=command_output) + if ( + command.conversation_name != "" + and command.conversation_name != None + and command_output != None + ): + c = Conversations(conversation_name=command.conversation_name, user=user) + c.log_interaction(role=agent_name, message=command_output) return { "response": command_output, } diff --git a/agixt/extensions/agixt_actions.py b/agixt/extensions/agixt_actions.py index d33ece9d6ee6..3d1ab3d1567d 100644 --- a/agixt/extensions/agixt_actions.py +++ b/agixt/extensions/agixt_actions.py @@ -723,7 +723,7 @@ async def execute_python_code_internal(self, code: str, text: str = "") -> str: agent_id = "" for agent in agents: if agent["name"] == self.agent_name: - agent_id = agent["id"] + agent_id = str(agent["id"]) execution_response = execute_python_code(code=code, agent_id=agent_id) if str(execution_response).endswith("\n"): execution_response = execution_response[:-1] From e771f16dee5963aa95af1cfcf5932b9dfd3e8027 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 18:46:00 -0400 Subject: [PATCH 0155/1256] change ref --- agixt/extensions/agixt_actions.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/agixt/extensions/agixt_actions.py b/agixt/extensions/agixt_actions.py index 3d1ab3d1567d..38395b5e03da 100644 --- a/agixt/extensions/agixt_actions.py +++ b/agixt/extensions/agixt_actions.py @@ -725,8 +725,6 @@ async def execute_python_code_internal(self, code: str, text: str = "") -> str: if agent["name"] == self.agent_name: agent_id = str(agent["id"]) execution_response = execute_python_code(code=code, agent_id=agent_id) - if str(execution_response).endswith("\n"): - execution_response = execution_response[:-1] return execution_response async def get_mindmap(self, task: str): From 00ef4898e9e6b8b9e0e50324af20ef90f3467d1e Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 9 Jun 2024 18:47:43 -0400 Subject: [PATCH 0156/1256] use string for logs return --- agixt/extensions/agixt_actions.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/agixt/extensions/agixt_actions.py b/agixt/extensions/agixt_actions.py index 38395b5e03da..992160c15101 100644 --- a/agixt/extensions/agixt_actions.py +++ b/agixt/extensions/agixt_actions.py @@ -83,6 +83,9 @@ def execute_python_code(code: str, agent_id: str = "") -> str: logging.error(f"Error executing Python code: {logs}") return f"Error: {logs}" logging.info(f"Python code executed successfully. Logs: {logs}") + logs = str(logs) + if logs.endswith("\n"): + logs = logs[:-1] return logs except Exception as e: logging.error(f"Error executing Python code: {str(e)}") From cf4389b1af27798f34fc4af8d81522848e336fe2 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 10 Jun 2024 08:06:26 -0400 Subject: [PATCH 0157/1256] remove warning --- agixt/Memories.py | 1 - 1 file changed, 1 deletion(-) diff --git a/agixt/Memories.py b/agixt/Memories.py index 41ef38eca36b..5b7568dd2475 100644 --- a/agixt/Memories.py +++ b/agixt/Memories.py @@ -389,7 +389,6 @@ async def get_memories_data( ) embedding_array = array(results["embeddings"][0]) if len(embedding_array) == 0: - logging.warning("Embedding collection is empty.") return [] embedding_array = embedding_array.reshape(embedding_array.shape[0], -1) if len(embedding.shape) == 2: From c096874a83217f555385237a1befe3ffc2ced58f Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 10 Jun 2024 12:08:22 -0400 Subject: [PATCH 0158/1256] add to message --- agixt/Conversations.py | 3 +++ agixt/DB.py | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/agixt/Conversations.py b/agixt/Conversations.py index 9fc352365c35..dbf574a90ad1 100644 --- a/agixt/Conversations.py +++ b/agixt/Conversations.py @@ -96,9 +96,12 @@ def get_conversation(self, limit=100, page=1): return_messages = [] for message in messages: msg = { + "id": message.id, "role": message.role, "message": message.content, "timestamp": message.timestamp, + "updated_at": message.updated_at, + "updated_by": message.updated_by, } return_messages.append(msg) return {"interactions": return_messages} diff --git a/agixt/DB.py b/agixt/DB.py index 4720d06a455e..6eac65ffc947 100644 --- a/agixt/DB.py +++ b/agixt/DB.py @@ -256,6 +256,12 @@ class Message(Base): ForeignKey("conversation.id"), nullable=False, ) + updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now()) + updated_by = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("user.id"), + nullable=True, + ) class Setting(Base): From f3927e5f5a5a9498f9dfa05abb0ba8cb61a47cd8 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 10 Jun 2024 12:18:58 -0400 Subject: [PATCH 0159/1256] add feedback toggle --- agixt/Conversations.py | 88 +++++++++++++++++++++++++++++++++++++++ agixt/DB.py | 1 + agixt/endpoints/Memory.py | 16 ++++--- 3 files changed, 99 insertions(+), 6 deletions(-) diff --git a/agixt/Conversations.py b/agixt/Conversations.py index dbf574a90ad1..62d53efe1372 100644 --- a/agixt/Conversations.py +++ b/agixt/Conversations.py @@ -102,6 +102,7 @@ def get_conversation(self, limit=100, page=1): "timestamp": message.timestamp, "updated_at": message.updated_at, "updated_by": message.updated_by, + "feedback_received": message.feedback_received, } return_messages.append(msg) return {"interactions": return_messages} @@ -245,6 +246,93 @@ def delete_message(self, message): session.delete(message) session.commit() + def toggle_feedback_received(self, message): + session = get_session() + user_data = session.query(User).filter(User.email == self.user).first() + user_id = user_data.id + + conversation = ( + session.query(Conversation) + .filter( + Conversation.name == self.conversation_name, + Conversation.user_id == user_id, + ) + .first() + ) + + if not conversation: + logging.info(f"No conversation found.") + return + message_id = ( + session.query(Message) + .filter( + Message.conversation_id == conversation.id, + Message.content == message, + ) + .first() + ).id + + message = ( + session.query(Message) + .filter( + Message.conversation_id == conversation.id, + Message.id == message_id, + ) + .first() + ) + + if not message: + logging.info( + f"No message found with ID '{message_id}' in conversation '{self.conversation_name}'." + ) + return + + message.feedback_received = not message.feedback_received + session.commit() + + def has_received_feedback(self, message): + session = get_session() + user_data = session.query(User).filter(User.email == self.user).first() + user_id = user_data.id + + conversation = ( + session.query(Conversation) + .filter( + Conversation.name == self.conversation_name, + Conversation.user_id == user_id, + ) + .first() + ) + + if not conversation: + logging.info(f"No conversation found.") + return + message_id = ( + session.query(Message) + .filter( + Message.conversation_id == conversation.id, + Message.content == message, + ) + .first() + ).id + + message = ( + session.query(Message) + .filter( + Message.conversation_id == conversation.id, + Message.id == message_id, + ) + .first() + ) + + if not message: + logging.info( + f"No message found with ID '{message_id}' in conversation '{self.conversation_name}'." + ) + return + + return message.feedback_received + def update_message(self, message, new_message): session = get_session() user_data = session.query(User).filter(User.email == self.user).first() diff --git a/agixt/DB.py b/agixt/DB.py index 6eac65ffc947..1a5046af5dfd 100644 --- a/agixt/DB.py +++ b/agixt/DB.py @@ -262,6 +262,7 @@ class Message(Base): ForeignKey("user.id"), nullable=True, ) + feedback_received = Column(Boolean, default=False) class Setting(Base): diff --git a/agixt/endpoints/Memory.py b/agixt/endpoints/Memory.py index ef685f4c9102..1f3c6c342de1 100644 --- a/agixt/endpoints/Memory.py +++ b/agixt/endpoints/Memory.py @@ -605,6 +605,11 @@ async def rlhf( user=Depends(verify_api_key), authorization: str = Header(None), ) -> ResponseMessage: + c = Conversations(conversation_name=data.conversation_name, user=user) + if c.has_received_feedback(message=data.message): + return ResponseMessage( + message="Feedback already received for this interaction." + ) agixt = AGiXT(user=user, agent_name=agent_name, api_key=authorization) if data.positive == True: memory = agixt.agent_interactions.positive_feedback_memories @@ -640,10 +645,9 @@ async def rlhf( response_message = ( f"{'Positive' if data.positive == True else 'Negative'} feedback received." ) - if data.conversation_name != "" and data.conversation_name != None: - c = Conversations(conversation_name=data.conversation_name, user=user) - c.log_interaction( - role=agent_name, - message=response_message, - ) + c.log_interaction( + role=agent_name, + message=response_message, + ) + c.toggle_feedback_received(message=data.message) return ResponseMessage(message=response_message) From 5a416fc00f6c255b5bed2f10c5a61f9ec5a05279 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 10 Jun 2024 12:26:09 -0400 Subject: [PATCH 0160/1256] add activity feedback note --- agixt/endpoints/Memory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agixt/endpoints/Memory.py b/agixt/endpoints/Memory.py index 1f3c6c342de1..daeb77cb4c8e 100644 --- a/agixt/endpoints/Memory.py +++ b/agixt/endpoints/Memory.py @@ -647,7 +647,7 @@ async def rlhf( ) c.log_interaction( role=agent_name, - message=response_message, + message=f"[ACTIVITY][FEEDBACK] {response_message}", ) c.toggle_feedback_received(message=data.message) return ResponseMessage(message=response_message) From 54b3d4541cd4492c91fe28275314c17353bd132e Mon Sep 17 00:00:00 2001 From: Josh XT Date: Tue, 11 Jun 2024 20:41:26 -0400 Subject: [PATCH 0161/1256] handle missing type --- agixt/Chain.py | 7 +++++++ agixt/endpoints/Chain.py | 3 ++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/agixt/Chain.py b/agixt/Chain.py index c6b3abe15675..7d1293a0cf6d 100644 --- a/agixt/Chain.py +++ b/agixt/Chain.py @@ -191,6 +191,13 @@ def add_chain_step(self, chain_name, step_number, agent_name, prompt_type, promp target_type = "command" else: prompt["prompt_name"] = "User Input" + prompt["prompt_category"] = "Default" + prompt["prompt_type"] = "Prompt" + prompt["user_input"] = ( + prompt["input"] + if "input" in prompt + else prompt["user_input"] if "user_input" in prompt else "" + ) argument_key = "prompt_name" target_id = ( self.session.query(Prompt) diff --git a/agixt/endpoints/Chain.py b/agixt/endpoints/Chain.py index b5c70001d77b..86929639eaa1 100644 --- a/agixt/endpoints/Chain.py +++ b/agixt/endpoints/Chain.py @@ -203,7 +203,8 @@ async def add_step( ) -> ResponseMessage: if is_admin(email=user, api_key=authorization) != True: raise HTTPException(status_code=403, detail="Access Denied") - Chain(user=user).add_chain_step( + ApiClient = get_api_client(authorization=authorization) + Chain(user=user, ApiClient=ApiClient).add_chain_step( chain_name=chain_name, step_number=step_info.step_number, prompt_type=step_info.prompt_type, From 0de4cc68313d15a0ebc1c65cf8abcf4917d9f7e2 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Tue, 11 Jun 2024 20:57:30 -0400 Subject: [PATCH 0162/1256] del duplicate key --- agixt/Chain.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/agixt/Chain.py b/agixt/Chain.py index 7d1293a0cf6d..e6a72f990d33 100644 --- a/agixt/Chain.py +++ b/agixt/Chain.py @@ -198,6 +198,8 @@ def add_chain_step(self, chain_name, step_number, agent_name, prompt_type, promp if "input" in prompt else prompt["user_input"] if "user_input" in prompt else "" ) + if "input" in prompt: + del prompt["input"] argument_key = "prompt_name" target_id = ( self.session.query(Prompt) From 5148a8d5fcf14824af6fcc8c5cada51510c3a98e Mon Sep 17 00:00:00 2001 From: Josh XT Date: Tue, 11 Jun 2024 20:58:59 -0400 Subject: [PATCH 0163/1256] add warning --- agixt/Chain.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/agixt/Chain.py b/agixt/Chain.py index e6a72f990d33..569c9f52ff2d 100644 --- a/agixt/Chain.py +++ b/agixt/Chain.py @@ -190,9 +190,12 @@ def add_chain_step(self, chain_name, step_number, agent_name, prompt_type, promp ) target_type = "command" else: + logging.warning( + f"Invalid prompt {prompt} with prompt type {prompt_type}. Using default prompt." + ) prompt["prompt_name"] = "User Input" prompt["prompt_category"] = "Default" - prompt["prompt_type"] = "Prompt" + prompt_type = "Prompt" prompt["user_input"] = ( prompt["input"] if "input" in prompt From 0226a4937240bc805fb4b14a843cd98a91aab3da Mon Sep 17 00:00:00 2001 From: Josh XT Date: Tue, 11 Jun 2024 21:09:28 -0400 Subject: [PATCH 0164/1256] use prompt type --- agixt/Chain.py | 17 +++++++++++++---- agixt/extensions/agixt_actions.py | 2 +- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/agixt/Chain.py b/agixt/Chain.py index 569c9f52ff2d..ad661c1f35d4 100644 --- a/agixt/Chain.py +++ b/agixt/Chain.py @@ -140,7 +140,14 @@ def rename_chain(self, chain_name, new_name): chain.name = new_name self.session.commit() - def add_chain_step(self, chain_name, step_number, agent_name, prompt_type, prompt): + def add_chain_step( + self, + chain_name: str, + step_number: int, + agent_name: str, + prompt_type: str, + prompt: dict, + ): chain = ( self.session.query(ChainDB) .filter(ChainDB.name == chain_name, ChainDB.user_id == self.user_id) @@ -156,7 +163,7 @@ def add_chain_step(self, chain_name, step_number, agent_name, prompt_type, promp else: prompt_category = "Default" argument_key = None - if "prompt_name" in prompt: + if prompt_type.lower() == "prompt": argument_key = "prompt_name" target_id = ( self.session.query(Prompt) @@ -169,8 +176,10 @@ def add_chain_step(self, chain_name, step_number, agent_name, prompt_type, promp .id ) target_type = "prompt" - elif "chain_name" in prompt: + elif prompt_type.lower() == "chain": argument_key = "chain_name" + if argument_key not in prompt: + argument_key = "chain" target_id = ( self.session.query(Chain) .filter( @@ -180,7 +189,7 @@ def add_chain_step(self, chain_name, step_number, agent_name, prompt_type, promp .id ) target_type = "chain" - elif "command_name" in prompt: + elif prompt_type.lower() == "command": argument_key = "command_name" target_id = ( self.session.query(Command) diff --git a/agixt/extensions/agixt_actions.py b/agixt/extensions/agixt_actions.py index 992160c15101..10677c8bc293 100644 --- a/agixt/extensions/agixt_actions.py +++ b/agixt/extensions/agixt_actions.py @@ -317,7 +317,7 @@ async def create_task_chain( step_number=i, prompt_type="Chain", prompt={ - "chain": "Smart Prompt", + "chain_name": "Smart Prompt", "input": f"Primary Objective to keep in mind while working on the task: {primary_objective} \nThe only task to complete to move towards the objective: {task}", }, ) From 36aa067eec6dff74082b97a8c9897c201298f604 Mon Sep 17 00:00:00 2001 From: Josh XT <102809327+Josh-XT@users.noreply.github.com> Date: Fri, 14 Jun 2024 15:54:06 -0400 Subject: [PATCH 0165/1256] Add Task Planning functions, Conversational Memories, and Single Sign-On (#1208) * add task planning functions * update sdk, add endpoint * use collection IDs * fix ref * add collections and improve browsed links * inject conversational memories into context * add get conversations with ids * remove tab * sso sso much code * add auth env vars * fix compose * get or create collection * set default * fix ref * add not provided option * add ALLOW_EMAIL_SIGN_IN --- agixt/Agent.py | 21 +- agixt/Chain.py | 33 ++ agixt/Conversations.py | 31 ++ agixt/DB.py | 30 +- agixt/Interactions.py | 43 +- agixt/MagicalAuth.py | 89 +++- agixt/Memories.py | 49 +- agixt/Models.py | 32 +- agixt/OAuth2Providers.py | 442 ++++++++++++++++++ agixt/Websearch.py | 58 ++- agixt/XT.py | 217 ++++++++- agixt/endpoints/Agent.py | 44 +- agixt/endpoints/Auth.py | 17 + agixt/endpoints/Conversation.py | 21 +- agixt/endpoints/Memory.py | 53 +-- agixt/extensions/agixt_actions.py | 47 +- agixt/readers/arxiv.py | 4 +- agixt/readers/file.py | 4 +- agixt/readers/github.py | 11 +- agixt/readers/youtube.py | 4 +- agixt/sso/amazon.py | 153 ++++++ agixt/sso/aol.py | 151 ++++++ agixt/sso/apple.py | 102 ++++ agixt/sso/autodesk.py | 143 ++++++ agixt/sso/battlenet.py | 101 ++++ agixt/sso/bitbucket.py | 114 +++++ agixt/sso/bitly.py | 94 ++++ agixt/sso/clearscore.py | 166 +++++++ agixt/sso/cloud_foundry.py | 105 +++++ agixt/sso/deutsche_telekom.py | 154 ++++++ agixt/sso/deviantart.py | 139 ++++++ agixt/sso/discord.py | 106 +++++ agixt/sso/dropbox.py | 141 ++++++ agixt/sso/facebook.py | 130 ++++++ agixt/sso/fatsecret.py | 100 ++++ agixt/sso/fitbit.py | 147 ++++++ agixt/sso/formstack.py | 126 +++++ agixt/sso/foursquare.py | 96 ++++ agixt/sso/github.py | 110 +++++ agixt/sso/gitlab.py | 111 +++++ agixt/sso/google.py | 141 ++++++ agixt/sso/huddle.py | 155 ++++++ agixt/sso/imgur.py | 140 ++++++ agixt/sso/instagram.py | 112 +++++ agixt/sso/intel_cloud_services.py | 155 ++++++ agixt/sso/jive.py | 147 ++++++ agixt/sso/keycloak.py | 102 ++++ agixt/sso/linkedin.py | 124 +++++ agixt/sso/microsoft.py | 155 ++++++ agixt/sso/netiq.py | 156 +++++++ agixt/sso/okta.py | 105 +++++ agixt/sso/openam.py | 105 +++++ agixt/sso/openstreetmap.py | 105 +++++ agixt/sso/orcid.py | 108 +++++ agixt/sso/paypal.py | 154 ++++++ agixt/sso/ping_identity.py | 155 ++++++ agixt/sso/pixiv.py | 110 +++++ agixt/sso/reddit.py | 136 ++++++ agixt/sso/salesforce.py | 118 +++++ agixt/sso/sina_weibo.py | 135 ++++++ agixt/sso/spotify.py | 117 +++++ agixt/sso/stack_exchange.py | 114 +++++ agixt/sso/strava.py | 148 ++++++ agixt/sso/stripe.py | 98 ++++ agixt/sso/twitch.py | 109 +++++ agixt/sso/viadeo.py | 146 ++++++ agixt/sso/vimeo.py | 163 +++++++ agixt/sso/vk.py | 86 ++++ agixt/sso/wechat.py | 96 ++++ agixt/sso/withings.py | 114 +++++ agixt/sso/xero.py | 110 +++++ agixt/sso/xing.py | 156 +++++++ agixt/sso/yahoo.py | 140 ++++++ agixt/sso/yammer.py | 132 ++++++ agixt/sso/yandex.py | 142 ++++++ agixt/sso/yelp.py | 109 +++++ agixt/sso/zendesk.py | 153 ++++++ docker-compose-dev.yml | 321 ++++++++++++- docs/2-Concepts/09-Agent Training.md | 4 +- docs/4-Authentication/amazon.md | 39 ++ docs/4-Authentication/aol.md | 33 ++ docs/4-Authentication/apple.md | 86 ++++ docs/4-Authentication/autodesk.md | 45 ++ docs/4-Authentication/battlenet.md | 46 ++ docs/4-Authentication/bitbucket.md | 43 ++ docs/4-Authentication/bitly.md | 38 ++ docs/4-Authentication/clearscore.md | 33 ++ docs/4-Authentication/cloud_foundry.md | 51 ++ docs/4-Authentication/deutsche_telekom.md | 63 +++ docs/4-Authentication/deviantart.md | 39 ++ docs/4-Authentication/discord.md | 44 ++ docs/4-Authentication/dropbox.md | 44 ++ docs/4-Authentication/facebook.md | 37 ++ docs/4-Authentication/fatsecret.md | 25 + docs/4-Authentication/fitbit.md | 45 ++ docs/4-Authentication/formstack.md | 40 ++ docs/4-Authentication/foursquare.md | 41 ++ docs/4-Authentication/github.md | 43 ++ docs/4-Authentication/gitlab.md | 42 ++ docs/4-Authentication/google.md | 63 +++ docs/4-Authentication/huddle.md | 41 ++ docs/4-Authentication/imgur.md | 36 ++ docs/4-Authentication/instagram.md | 70 +++ docs/4-Authentication/intel_cloud_services.md | 48 ++ docs/4-Authentication/jive.md | 64 +++ docs/4-Authentication/keycloak.md | 20 + docs/4-Authentication/linkedin.md | 54 +++ docs/4-Authentication/microsoft.md | 61 +++ docs/4-Authentication/netiq.md | 41 ++ docs/4-Authentication/okta.md | 52 +++ docs/4-Authentication/openam.md | 47 ++ docs/4-Authentication/openstreetmap.md | 42 ++ docs/4-Authentication/orcid.md | 28 ++ docs/4-Authentication/paypal.md | 55 +++ docs/4-Authentication/ping_identity.md | 47 ++ docs/4-Authentication/pixiv.md | 37 ++ docs/4-Authentication/reddit.md | 29 ++ docs/4-Authentication/salesforce.md | 54 +++ docs/4-Authentication/sina_weibo.md | 44 ++ docs/4-Authentication/spotify.md | 46 ++ docs/4-Authentication/stack_exchange.md | 37 ++ docs/4-Authentication/strava.md | 36 ++ docs/4-Authentication/stripe.md | 31 ++ docs/4-Authentication/twitch.md | 46 ++ docs/4-Authentication/viadeo.md | 31 ++ docs/4-Authentication/vimeo.md | 29 ++ docs/4-Authentication/vk.md | 35 ++ docs/4-Authentication/wechat.md | 36 ++ docs/4-Authentication/withings.md | 35 ++ docs/4-Authentication/xero.md | 41 ++ docs/4-Authentication/xing.md | 39 ++ docs/4-Authentication/yahoo.md | 48 ++ docs/4-Authentication/yammer.md | 40 ++ docs/4-Authentication/yandex.md | 20 + docs/4-Authentication/yelp.md | 38 ++ docs/4-Authentication/zendesk.md | 51 ++ examples/AGiXT-Expert-OAI.ipynb | 2 +- examples/AGiXT-Expert-ezLocalai.ipynb | 2 +- examples/Chatbot.ipynb | 2 +- requirements.txt | 4 +- tests/tests.ipynb | 14 +- 141 files changed, 11098 insertions(+), 190 deletions(-) create mode 100644 agixt/OAuth2Providers.py create mode 100644 agixt/sso/amazon.py create mode 100644 agixt/sso/aol.py create mode 100644 agixt/sso/apple.py create mode 100644 agixt/sso/autodesk.py create mode 100644 agixt/sso/battlenet.py create mode 100644 agixt/sso/bitbucket.py create mode 100644 agixt/sso/bitly.py create mode 100644 agixt/sso/clearscore.py create mode 100644 agixt/sso/cloud_foundry.py create mode 100644 agixt/sso/deutsche_telekom.py create mode 100644 agixt/sso/deviantart.py create mode 100644 agixt/sso/discord.py create mode 100644 agixt/sso/dropbox.py create mode 100644 agixt/sso/facebook.py create mode 100644 agixt/sso/fatsecret.py create mode 100644 agixt/sso/fitbit.py create mode 100644 agixt/sso/formstack.py create mode 100644 agixt/sso/foursquare.py create mode 100644 agixt/sso/github.py create mode 100644 agixt/sso/gitlab.py create mode 100644 agixt/sso/google.py create mode 100644 agixt/sso/huddle.py create mode 100644 agixt/sso/imgur.py create mode 100644 agixt/sso/instagram.py create mode 100644 agixt/sso/intel_cloud_services.py create mode 100644 agixt/sso/jive.py create mode 100644 agixt/sso/keycloak.py create mode 100644 agixt/sso/linkedin.py create mode 100644 agixt/sso/microsoft.py create mode 100644 agixt/sso/netiq.py create mode 100644 agixt/sso/okta.py create mode 100644 agixt/sso/openam.py create mode 100644 agixt/sso/openstreetmap.py create mode 100644 agixt/sso/orcid.py create mode 100644 agixt/sso/paypal.py create mode 100644 agixt/sso/ping_identity.py create mode 100644 agixt/sso/pixiv.py create mode 100644 agixt/sso/reddit.py create mode 100644 agixt/sso/salesforce.py create mode 100644 agixt/sso/sina_weibo.py create mode 100644 agixt/sso/spotify.py create mode 100644 agixt/sso/stack_exchange.py create mode 100644 agixt/sso/strava.py create mode 100644 agixt/sso/stripe.py create mode 100644 agixt/sso/twitch.py create mode 100644 agixt/sso/viadeo.py create mode 100644 agixt/sso/vimeo.py create mode 100644 agixt/sso/vk.py create mode 100644 agixt/sso/wechat.py create mode 100644 agixt/sso/withings.py create mode 100644 agixt/sso/xero.py create mode 100644 agixt/sso/xing.py create mode 100644 agixt/sso/yahoo.py create mode 100644 agixt/sso/yammer.py create mode 100644 agixt/sso/yandex.py create mode 100644 agixt/sso/yelp.py create mode 100644 agixt/sso/zendesk.py create mode 100644 docs/4-Authentication/amazon.md create mode 100644 docs/4-Authentication/aol.md create mode 100644 docs/4-Authentication/apple.md create mode 100644 docs/4-Authentication/autodesk.md create mode 100644 docs/4-Authentication/battlenet.md create mode 100644 docs/4-Authentication/bitbucket.md create mode 100644 docs/4-Authentication/bitly.md create mode 100644 docs/4-Authentication/clearscore.md create mode 100644 docs/4-Authentication/cloud_foundry.md create mode 100644 docs/4-Authentication/deutsche_telekom.md create mode 100644 docs/4-Authentication/deviantart.md create mode 100644 docs/4-Authentication/discord.md create mode 100644 docs/4-Authentication/dropbox.md create mode 100644 docs/4-Authentication/facebook.md create mode 100644 docs/4-Authentication/fatsecret.md create mode 100644 docs/4-Authentication/fitbit.md create mode 100644 docs/4-Authentication/formstack.md create mode 100644 docs/4-Authentication/foursquare.md create mode 100644 docs/4-Authentication/github.md create mode 100644 docs/4-Authentication/gitlab.md create mode 100644 docs/4-Authentication/google.md create mode 100644 docs/4-Authentication/huddle.md create mode 100644 docs/4-Authentication/imgur.md create mode 100644 docs/4-Authentication/instagram.md create mode 100644 docs/4-Authentication/intel_cloud_services.md create mode 100644 docs/4-Authentication/jive.md create mode 100644 docs/4-Authentication/keycloak.md create mode 100644 docs/4-Authentication/linkedin.md create mode 100644 docs/4-Authentication/microsoft.md create mode 100644 docs/4-Authentication/netiq.md create mode 100644 docs/4-Authentication/okta.md create mode 100644 docs/4-Authentication/openam.md create mode 100644 docs/4-Authentication/openstreetmap.md create mode 100644 docs/4-Authentication/orcid.md create mode 100644 docs/4-Authentication/paypal.md create mode 100644 docs/4-Authentication/ping_identity.md create mode 100644 docs/4-Authentication/pixiv.md create mode 100644 docs/4-Authentication/reddit.md create mode 100644 docs/4-Authentication/salesforce.md create mode 100644 docs/4-Authentication/sina_weibo.md create mode 100644 docs/4-Authentication/spotify.md create mode 100644 docs/4-Authentication/stack_exchange.md create mode 100644 docs/4-Authentication/strava.md create mode 100644 docs/4-Authentication/stripe.md create mode 100644 docs/4-Authentication/twitch.md create mode 100644 docs/4-Authentication/viadeo.md create mode 100644 docs/4-Authentication/vimeo.md create mode 100644 docs/4-Authentication/vk.md create mode 100644 docs/4-Authentication/wechat.md create mode 100644 docs/4-Authentication/withings.md create mode 100644 docs/4-Authentication/xero.md create mode 100644 docs/4-Authentication/xing.md create mode 100644 docs/4-Authentication/yahoo.md create mode 100644 docs/4-Authentication/yammer.md create mode 100644 docs/4-Authentication/yandex.md create mode 100644 docs/4-Authentication/yelp.md create mode 100644 docs/4-Authentication/zendesk.md diff --git a/agixt/Agent.py b/agixt/Agent.py index 7a5c011022ee..7d7368632e5a 100644 --- a/agixt/Agent.py +++ b/agixt/Agent.py @@ -415,7 +415,7 @@ def update_agent_config(self, new_config, config_key): return f"Agent {self.agent_name} configuration updated." - def get_browsed_links(self): + def get_browsed_links(self, conversation_id=None): """ Get the list of URLs that have been browsed by the agent. @@ -433,7 +433,7 @@ def get_browsed_links(self): return [] browsed_links = ( self.session.query(AgentBrowsedLink) - .filter_by(agent_id=agent.id) + .filter_by(agent_id=agent.id, conversation_id=conversation_id) .order_by(AgentBrowsedLink.id.desc()) .all() ) @@ -441,7 +441,7 @@ def get_browsed_links(self): return [] return browsed_links - def browsed_recently(self, url) -> bool: + def browsed_recently(self, url, conversation_id=None) -> bool: """ Check if the given URL has been browsed by the agent within the last 24 hours. @@ -451,7 +451,7 @@ def browsed_recently(self, url) -> bool: Returns: bool: True if the URL has been browsed within the last 24 hours, False otherwise. """ - browsed_links = self.get_browsed_links() + browsed_links = self.get_browsed_links(conversation_id=conversation_id) if not browsed_links: return False for link in browsed_links: @@ -460,7 +460,7 @@ def browsed_recently(self, url) -> bool: return True return False - def add_browsed_link(self, url): + def add_browsed_link(self, url, conversation_id=None): """ Add a URL to the list of browsed links for the agent. @@ -479,12 +479,14 @@ def add_browsed_link(self, url): ) if not agent: return f"Agent {self.agent_name} not found." - browsed_link = AgentBrowsedLink(agent_id=agent.id, url=url) + browsed_link = AgentBrowsedLink( + agent_id=agent.id, url=url, conversation_id=conversation_id + ) self.session.add(browsed_link) self.session.commit() return f"Link {url} added to browsed links." - def delete_browsed_link(self, url): + def delete_browsed_link(self, url, conversation_id=None): """ Delete a URL from the list of browsed links for the agent. @@ -497,7 +499,8 @@ def delete_browsed_link(self, url): agent = ( self.session.query(AgentModel) .filter( - AgentModel.name == self.agent_name, AgentModel.user_id == self.user_id + AgentModel.name == self.agent_name, + AgentModel.user_id == self.user_id, ) .first() ) @@ -505,7 +508,7 @@ def delete_browsed_link(self, url): return f"Agent {self.agent_name} not found." browsed_link = ( self.session.query(AgentBrowsedLink) - .filter_by(agent_id=agent.id, url=url) + .filter_by(agent_id=agent.id, url=url, conversation_id=conversation_id) .first() ) if not browsed_link: diff --git a/agixt/Chain.py b/agixt/Chain.py index ad661c1f35d4..174571c8bffe 100644 --- a/agixt/Chain.py +++ b/agixt/Chain.py @@ -10,6 +10,8 @@ Prompt, Command, User, + TaskCategory, + TaskItem, ) from Globals import getenv, DEFAULT_USER from Prompts import Prompts @@ -865,3 +867,34 @@ def get_chain_args(self, chain_name): except Exception as e: logging.error(f"Error getting chain args: {e}") return prompt_args + + def new_task( + self, + conversation_id, + chain_name, + task_category, + task_description, + estimated_hours, + ): + task_category = ( + self.session.query(TaskCategory) + .filter( + TaskCategory.name == task_category, TaskCategory.user_id == self.user_id + ) + .first() + ) + if not task_category: + task_category = TaskCategory(name=task_category, user_id=self.user_id) + self.session.add(task_category) + self.session.commit() + task = TaskItem( + user_id=self.user_id, + category_id=task_category.id, + title=chain_name, + description=task_description, + estimated_hours=estimated_hours, + memory_collection=str(conversation_id), + ) + self.session.add(task) + self.session.commit() + return task.id diff --git a/agixt/Conversations.py b/agixt/Conversations.py index 62d53efe1372..520b3876fc58 100644 --- a/agixt/Conversations.py +++ b/agixt/Conversations.py @@ -63,6 +63,21 @@ def get_conversations(self): ) return [conversation.name for conversation in conversations] + def get_conversations_with_ids(self): + session = get_session() + user_data = session.query(User).filter(User.email == self.user).first() + user_id = user_data.id + conversations = ( + session.query(Conversation) + .filter( + Conversation.user_id == user_id, + ) + .all() + ) + return { + str(conversation.id): conversation.name for conversation in conversations + } + def get_conversation(self, limit=100, page=1): session = get_session() user_data = session.query(User).filter(User.email == self.user).first() @@ -376,3 +391,19 @@ def update_message(self, message, new_message): message.content = new_message session.commit() + + def get_conversation_id(self): + session = get_session() + user_data = session.query(User).filter(User.email == self.user).first() + user_id = user_data.id + conversation = ( + session.query(Conversation) + .filter( + Conversation.name == self.conversation_name, + Conversation.user_id == user_id, + ) + .first() + ) + if not conversation: + return None + return str(conversation.id) diff --git a/agixt/DB.py b/agixt/DB.py index 1a5046af5dfd..c9fcb984051e 100644 --- a/agixt/DB.py +++ b/agixt/DB.py @@ -66,6 +66,25 @@ class User(Base): is_active = Column(Boolean, default=True) +class UserOAuth(Base): + __tablename__ = "user_oauth" + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + user_id = Column(UUID(as_uuid=True), ForeignKey("user.id")) + user = relationship("User") + provider_id = Column(UUID(as_uuid=True), ForeignKey("oauth_provider.id")) + provider = relationship("OAuthProvider") + access_token = Column(String, default="", nullable=False) + refresh_token = Column(String, default="", nullable=False) + created_at = Column(DateTime, server_default=func.now()) + updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now()) + + +class OAuthProvider(Base): + __tablename__ = "oauth_provider" + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + name = Column(String, default="", nullable=False) + + class FailedLogins(Base): __tablename__ = "failed_logins" id = Column( @@ -161,6 +180,11 @@ class AgentBrowsedLink(Base): ForeignKey("agent.id"), nullable=False, ) + conversation_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("conversation.id"), + nullable=True, + ) link = Column(Text, nullable=False) timestamp = Column(DateTime, server_default=func.now()) @@ -488,15 +512,17 @@ class TaskCategory(Base): primary_key=True, default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), ) + user_id = Column(UUID(as_uuid=True), ForeignKey("user.id")) name = Column(String) description = Column(String) - memory_collection = Column(Integer, default=0) + memory_collection = Column(String, default="0") created_at = Column(DateTime, server_default=func.now()) updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now()) category_id = Column( UUID(as_uuid=True), ForeignKey("task_category.id"), nullable=True ) parent_category = relationship("TaskCategory", remote_side=[id]) + user = relationship("User", backref="task_category") class TaskItem(Base): @@ -511,7 +537,7 @@ class TaskItem(Base): category = relationship("TaskCategory") title = Column(String) description = Column(String) - memory_collection = Column(Integer, default=0) + memory_collection = Column(String, default="0") # agent_id is the action item owner. If it is null, it is an item for the user agent_id = Column( UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, diff --git a/agixt/Interactions.py b/agixt/Interactions.py index 2261ac97a9c0..b117cb043028 100644 --- a/agixt/Interactions.py +++ b/agixt/Interactions.py @@ -40,7 +40,7 @@ def __init__( self.agent = Agent(self.agent_name, user=user, ApiClient=self.ApiClient) self.agent_commands = self.agent.get_commands_string() self.websearch = Websearch( - collection_number=1, + collection_number="1", agent=self.agent, user=self.user, ApiClient=self.ApiClient, @@ -48,28 +48,28 @@ def __init__( self.agent_memory = FileReader( agent_name=self.agent_name, agent_config=self.agent.AGENT_CONFIG, - collection_number=0, + collection_number="0", ApiClient=self.ApiClient, user=self.user, ) self.positive_feedback_memories = FileReader( agent_name=self.agent_name, agent_config=self.agent.AGENT_CONFIG, - collection_number=2, + collection_number="2", ApiClient=self.ApiClient, user=self.user, ) self.negative_feedback_memories = FileReader( agent_name=self.agent_name, agent_config=self.agent.AGENT_CONFIG, - collection_number=3, + collection_number="3", ApiClient=self.ApiClient, user=self.user, ) self.github_memories = GithubReader( agent_name=self.agent_name, agent_config=self.agent.AGENT_CONFIG, - collection_number=7, + collection_number="7", user=self.user, ApiClient=self.ApiClient, ) @@ -129,6 +129,12 @@ async def format_prompt( ) prompt = prompt_name prompt_args = [] + if "conversation_name" in kwargs: + conversation_name = kwargs["conversation_name"] + if conversation_name == "": + conversation_name = f"{str(datetime.now())} Conversation" + c = Conversations(conversation_name=conversation_name, user=self.user) + conversation = c.get_conversation() if top_results == 0: context = [] else: @@ -175,13 +181,12 @@ async def format_prompt( joined_feedback = "\n".join(negative_feedback) context.append(f"Negative Feedback:\n{joined_feedback}\n") if "inject_memories_from_collection_number" in kwargs: - if int(kwargs["inject_memories_from_collection_number"]) > 5: + collection_id = kwargs["inject_memories_from_collection_number"] + if collection_id not in ["0", "1", "2", "3", "4", "5", "6", "7"]: context += await FileReader( agent_name=self.agent_name, agent_config=self.agent.AGENT_CONFIG, - collection_number=int( - kwargs["inject_memories_from_collection_number"] - ), + collection_number=collection_id, ApiClient=self.ApiClient, user=self.user, ).get_memories( @@ -189,6 +194,17 @@ async def format_prompt( limit=top_results, min_relevance_score=min_relevance_score, ) + context += await FileReader( + agent_name=self.agent_name, + agent_config=self.agent.AGENT_CONFIG, + collection_number=c.get_conversation_id(), + ApiClient=self.ApiClient, + user=self.user, + ).get_memories( + user_input=user_input, + limit=top_results, + min_relevance_score=min_relevance_score, + ) else: context = [] if "context" in kwargs: @@ -209,12 +225,7 @@ async def format_prompt( helper_agent_name = self.agent.AGENT_CONFIG["settings"][ "helper_agent_name" ] - if "conversation_name" in kwargs: - conversation_name = kwargs["conversation_name"] - if conversation_name == "": - conversation_name = f"{str(datetime.now())} Conversation" - c = Conversations(conversation_name=conversation_name, user=self.user) - conversation = c.get_conversation() + if "conversation_results" in kwargs: conversation_results = int(kwargs["conversation_results"]) else: @@ -697,7 +708,7 @@ async def run( agent_name=self.agent_name, prompt_name=prompt, prompt_category=prompt_category, - log_user_interaction=False, + log_user_input=False, **prompt_args, ) time.sleep(1) diff --git a/agixt/MagicalAuth.py b/agixt/MagicalAuth.py index 471580bbf525..d789c3cb5bda 100644 --- a/agixt/MagicalAuth.py +++ b/agixt/MagicalAuth.py @@ -1,4 +1,5 @@ -from DB import User, FailedLogins, get_session +from DB import User, FailedLogins, UserOAuth, OAuthProvider, get_session +from OAuth2Providers import get_sso_provider from Models import UserInfo, Register, Login from fastapi import Header, HTTPException from Globals import getenv @@ -243,7 +244,13 @@ def count_failed_logins(self): session.close() return failed_logins - def send_magic_link(self, ip_address, login: Login, referrer=None): + def send_magic_link( + self, + ip_address, + login: Login, + referrer=None, + send_link: bool = True, + ): self.email = login.email.lower() session = get_session() user = session.query(User).filter(User.email == self.email).first() @@ -302,6 +309,7 @@ def send_magic_link(self, ip_address, login: Login, referrer=None): and str(getenv("SENDGRID_API_KEY")).lower() != "none" and getenv("SENDGRID_FROM_EMAIL") != "" and str(getenv("SENDGRID_FROM_EMAIL")).lower() != "none" + and send_link ): send_email( email=self.email, @@ -423,3 +431,80 @@ def delete_user(self): session.commit() session.close() return "User deleted successfully" + + def sso( + self, + provider, + code, + ip_address, + referrer=None, + ): + if not referrer: + referrer = getenv("MAGIC_LINK_URL") + provider = str(provider).lower() + sso_data = None + sso_data = get_sso_provider(provider=provider, code=code, redirect_uri=referrer) + if not sso_data: + raise HTTPException( + status_code=400, + detail=f"Failed to get user data from {provider.capitalize()}.", + ) + if not sso_data.access_token: + raise HTTPException( + status_code=400, + detail=f"Failed to get access token from {provider.capitalize()}.", + ) + user_data = sso_data.user_info + access_token = sso_data.access_token + refresh_token = sso_data.refresh_token + self.email = str(user_data["email"]).lower() + if not user_data: + logging.warning(f"Error on {provider.capitalize()}: {user_data}") + raise HTTPException( + status_code=400, + detail=f"Failed to get user data from {provider.capitalize()}.", + ) + session = get_session() + user = session.query(User).filter(User.email == self.email).first() + if not user: + register = Register( + email=self.email, + first_name=user_data["first_name"] if "first_name" in user_data else "", + last_name=user_data["last_name"] if "last_name" in user_data else "", + ) + mfa_token = self.register(new_user=register) + # Create the UserOAuth record + user = session.query(User).filter(User.email == self.email).first() + provider = ( + session.query(OAuthProvider) + .filter(OAuthProvider.name == provider) + .first() + ) + if not provider: + provider = OAuthProvider(name=provider) + session.add(provider) + user_oauth = UserOAuth( + user_id=user.id, + provider_id=provider.id, + access_token=access_token, + refresh_token=refresh_token, + ) + session.add(user_oauth) + else: + mfa_token = user.mfa_token + user_oauth = ( + session.query(UserOAuth).filter(UserOAuth.user_id == user.id).first() + ) + if user_oauth: + user_oauth.access_token = access_token + user_oauth.refresh_token = refresh_token + session.commit() + session.close() + totp = pyotp.TOTP(mfa_token) + login = Login(email=self.email, token=totp.now()) + return self.send_magic_link( + ip_address=ip_address, + login=login, + referrer=referrer, + send_link=False, + ) diff --git a/agixt/Memories.py b/agixt/Memories.py index 5b7568dd2475..3a725f6586e8 100644 --- a/agixt/Memories.py +++ b/agixt/Memories.py @@ -153,7 +153,7 @@ def __init__( self, agent_name: str = "AGiXT", agent_config=None, - collection_number: int = 0, + collection_number: str = "0", # Is now actually a collection ID and a string to allow conversational memories. ApiClient=None, summarize_content: bool = False, user=DEFAULT_USER, @@ -170,8 +170,9 @@ def __init__( self.collection_name = snake(f"{snake(DEFAULT_USER)}_{agent_name}") self.user = user self.collection_number = collection_number - if collection_number > 0: - self.collection_name = f"{self.collection_name}_{collection_number}" + # Check if collection_number is a number, it might be a string + if collection_number != "0": + self.collection_name = snake(f"{self.collection_name}_{collection_number}") if agent_config is None: agent_config = ApiClient.get_agentconfig(agent_name=agent_name) self.agent_config = ( @@ -245,14 +246,12 @@ async def export_collections_to_json(self): async def import_collections_from_json(self, json_data: List[dict]): for data in json_data: for key, value in data.items(): - try: - collection_number = int(key) - except: - collection_number = 0 - self.collection_number = collection_number - self.collection_name = snake(self.agent_name) - if collection_number > 0: - self.collection_name = f"{self.collection_name}_{collection_number}" + self.collection_number = key if key else "0" + self.collection_name = snake(f"{self.user}_{self.agent_name}") + if str(self.collection_number) != "0": + self.collection_name = ( + f"{self.collection_name}_{self.collection_number}" + ) for val in value[self.collection_name]: try: await self.write_text_to_memory( @@ -266,9 +265,8 @@ async def import_collections_from_json(self, json_data: List[dict]): # get collections that start with the collection name async def get_collections(self): collections = self.chroma_client.list_collections() - if int(self.collection_number) > 0: - collection_name = snake(self.agent_name) - collection_name = f"{collection_name}_{self.collection_number}" + if str(self.collection_number) != "0": + collection_name = snake(f"{self.user}_{self.agent_name}") else: collection_name = self.collection_name return [ @@ -279,22 +277,17 @@ async def get_collections(self): async def get_collection(self): try: - return self.chroma_client.get_collection( + return self.chroma_client.get_or_create_collection( name=self.collection_name, embedding_function=self.embedder ) except: try: - return self.chroma_client.create_collection( - name=self.collection_name, - embedding_function=self.embedder, - get_or_create=True, + return self.chroma_client.get_or_create_collection( + name=self.collection_name, embedding_function=self.embedder ) except: - # Collection already exists - pass - return self.chroma_client.get_collection( - name=self.collection_name, embedding_function=self.embedder - ) + logging.warning(f"Error getting collection: {self.collection_name}") + return None async def delete_memory(self, key: str): collection = await self.get_collection() @@ -419,9 +412,11 @@ async def get_memories( ) -> List[str]: global DEFAULT_USER default_collection_name = self.collection_name + default_results = [] if self.user != DEFAULT_USER: + # Get global memories for the agent first self.collection_name = snake(f"{snake(DEFAULT_USER)}_{self.agent_name}") - if self.collection_number > 0: + if str(self.collection_number) != "0": self.collection_name = ( f"{self.collection_name}_{self.collection_number}" ) @@ -439,6 +434,10 @@ async def get_memories( limit=limit, min_relevance_score=min_relevance_score, ) + if isinstance(user_results, str): + user_results = [user_results] + if isinstance(default_results, str): + default_results = [default_results] results = user_results + default_results response = [] if results: diff --git a/agixt/Models.py b/agixt/Models.py index 6a65acf3e527..e30d0f4fd20d 100644 --- a/agixt/Models.py +++ b/agixt/Models.py @@ -156,19 +156,19 @@ class ResponseMessage(BaseModel): class UrlInput(BaseModel): url: str - collection_number: int = 0 + collection_number: Optional[str] = "0" class FileInput(BaseModel): file_name: str file_content: str - collection_number: int = 0 + collection_number: Optional[str] = "0" class TextMemoryInput(BaseModel): user_input: str text: str - collection_number: int = 0 + collection_number: Optional[str] = "0" class FeedbackInput(BaseModel): @@ -226,7 +226,7 @@ class HistoryModel(BaseModel): class ExternalSource(BaseModel): external_source: str - collection_number: int = 0 + collection_number: Optional[str] = "0" class ConversationHistoryModel(BaseModel): @@ -252,6 +252,24 @@ class UpdateConversationHistoryMessageModel(BaseModel): new_message: str +class TaskPlanInput(BaseModel): + user_input: str + websearch: Optional[bool] = False + websearch_depth: Optional[int] = 3 + conversation_name: Optional[str] = "AGiXT Task Planning" + log_user_input: Optional[bool] = True + log_output: Optional[bool] = True + enable_new_command: Optional[bool] = True + + +class TasksToDo(BaseModel): + tasks: List[str] + + +class ChainCommandName(BaseModel): + command_name: str + + class GitHubInput(BaseModel): github_repo: str github_user: Optional[str] = None @@ -270,7 +288,7 @@ class ArxivInput(BaseModel): class YoutubeInput(BaseModel): video_id: str - collection_number: int = 0 + collection_number: Optional[str] = "0" class CommandExecution(BaseModel): @@ -296,8 +314,8 @@ class Login(BaseModel): class Register(BaseModel): email: str - first_name: str - last_name: str + first_name: Optional[str] = "" + last_name: Optional[str] = "" class UserInfo(BaseModel): diff --git a/agixt/OAuth2Providers.py b/agixt/OAuth2Providers.py new file mode 100644 index 000000000000..50253a17fbe0 --- /dev/null +++ b/agixt/OAuth2Providers.py @@ -0,0 +1,442 @@ +from sso.amazon import amazon_sso +from sso.aol import aol_sso +from sso.apple import apple_sso +from sso.autodesk import autodesk_sso +from sso.battlenet import battlenet_sso +from sso.bitbucket import bitbucket_sso +from sso.bitly import bitly_sso +from sso.clearscore import clearscore_sso +from sso.cloud_foundry import cloud_foundry_sso +from sso.deutsche_telekom import deutsche_telekom_sso +from sso.deviantart import deviantart_sso +from sso.discord import discord_sso +from sso.dropbox import dropbox_sso +from sso.facebook import facebook_sso +from sso.fatsecret import fatsecret_sso +from sso.fitbit import fitbit_sso +from sso.formstack import formstack_sso +from sso.foursquare import foursquare_sso +from sso.github import github_sso +from sso.gitlab import gitlab_sso +from sso.google import google_sso +from sso.huddle import huddle_sso +from sso.imgur import imgur_sso +from sso.instagram import instagram_sso +from sso.intel_cloud_services import intel_cloud_services_sso +from sso.jive import jive_sso +from sso.keycloak import keycloak_sso +from sso.linkedin import linkedin_sso +from sso.microsoft import microsoft_sso +from sso.netiq import netiq_sso +from sso.okta import okta_sso +from sso.openam import openam_sso +from sso.openstreetmap import openstreetmap_sso +from sso.orcid import orcid_sso +from sso.paypal import paypal_sso +from sso.ping_identity import ping_identity_sso +from sso.pixiv import pixiv_sso +from sso.reddit import reddit_sso +from sso.salesforce import salesforce_sso +from sso.sina_weibo import sina_weibo_sso +from sso.spotify import spotify_sso +from sso.stack_exchange import stack_exchange_sso +from sso.strava import strava_sso +from sso.stripe import stripe_sso +from sso.twitch import twitch_sso +from sso.viadeo import viadeo_sso +from sso.vimeo import vimeo_sso +from sso.vk import vk_sso +from sso.wechat import wechat_sso +from sso.withings import withings_sso +from sso.xero import xero_sso +from sso.xing import xing_sso +from sso.yahoo import yahoo_sso +from sso.yammer import yammer_sso +from sso.yandex import yandex_sso +from sso.yelp import yelp_sso +from sso.zendesk import zendesk_sso +from Globals import getenv + + +def get_provider_info(provider): + providers = { + "amazon": { + "scopes": ["openid", "email", "profile"], + "authorization_url": f"https://{getenv('AWS_USER_POOL_ID')}.auth.{getenv('AWS_REGION')}.amazoncognito.com/oauth2/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/a/a9/Amazon_logo.svg", + "function": amazon_sso, + }, + "aol": { + "scopes": [ + "https://api.aol.com/userinfo.profile", + "https://api.aol.com/userinfo.email", + "https://api.aol.com/mail.send", + ], + "authorization_url": "https://api.login.aol.com/oauth2/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/5/51/AOL.svg", + "function": aol_sso, + }, + "apple": { + "scopes": ["name", "email"], + "authorization_url": "https://appleid.apple.com/auth/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/f/fa/Apple_logo_black.svg", + "function": apple_sso, + }, + "autodesk": { + "scopes": ["data:read", "data:write", "bucket:read", "bucket:create"], + "authorization_url": "https://developer.api.autodesk.com/authentication/v1/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/d/d7/Autodesk_logo_2019.svg", + "function": autodesk_sso, + }, + "battlenet": { + "scopes": ["openid", "email"], + "authorization_url": "https://oauth.battle.net/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/en/1/1b/Battle.net_Icon.svg", + "function": battlenet_sso, + }, + "bitbucket": { + "scopes": ["account", "email"], + "authorization_url": "https://bitbucket.org/site/oauth2/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/0/0e/Bitbucket-blue-logomark-only.svg", + "function": bitbucket_sso, + }, + "bitly": { + "scopes": ["bitly:read", "bitly:write"], + "authorization_url": "https://bitly.com/oauth/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/5/56/Bitly_logo.svg", + "function": bitly_sso, + }, + "clearscore": { + "scopes": ["user.info.read", "email.send"], + "authorization_url": "https://auth.clearscore.com/oauth2/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/en/5/57/ClearScore_logo.png", + "function": clearscore_sso, + }, + "cloud_foundry": { + "scopes": ["cloud_controller.read", "openid", "email"], + "authorization_url": "https://login.system.example.com/oauth/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/thumb/7/75/Cloud_Foundry_Logo.svg/512px-Cloud_Foundry_Logo.svg.png", + "function": cloud_foundry_sso, + }, + "deutsche_telekom": { + "scopes": ["t-online-profile", "t-online-email"], + "authorization_url": "https://www.telekom.com/ssoservice/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/d/d2/Logo_telekom_2013.svg", + "function": deutsche_telekom_sso, + }, + "deviantart": { + "scopes": ["user", "browse", "stash", "send_message"], + "authorization_url": "https://www.deviantart.com/oauth2/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/b/b5/DeviantArt_Logo.svg", + "function": deviantart_sso, + }, + "discord": { + "scopes": ["identify", "email"], + "authorization_url": "https://discord.com/api/oauth2/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/9/98/Discord_logo.svg", + "function": discord_sso, + }, + "dropbox": { + "scopes": ["account_info.read", "files.metadata.read"], + "authorization_url": "https://www.dropbox.com/oauth2/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/7/7e/Dropbox_Icon.svg", + "function": dropbox_sso, + }, + "facebook": { + "scopes": ["public_profile", "email", "pages_messaging"], + "authorization_url": "https://www.facebook.com/v10.0/dialog/oauth", + "icon": "https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg", + "function": facebook_sso, + }, + "fatsecret": { + "scopes": ["profile.get"], + "authorization_url": "https://oauth.fatsecret.com/connect/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/en/2/20/FatSecret.png", + "function": fatsecret_sso, + }, + "fitbit": { + "scopes": [ + "activity", + "heartrate", + "location", + "nutrition", + "profile", + "settings", + "sleep", + "social", + "weight", + ], + "authorization_url": "https://www.fitbit.com/oauth2/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/6/60/Fitbit_logo_2016.svg", + "function": fitbit_sso, + }, + "formstack": { + "scopes": ["formstack:read", "formstack:write"], + "authorization_url": "https://www.formstack.com/api/v2/oauth2/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/en/0/09/Formstack_logo.png", + "function": formstack_sso, + }, + "foursquare": { + "scopes": [], + "authorization_url": "https://foursquare.com/oauth2/authenticate", + "icon": "https://upload.wikimedia.org/wikipedia/en/1/12/Foursquare_logo.svg", + "function": foursquare_sso, + }, + "github": { + "scopes": ["user:email", "read:user"], + "authorization_url": "https://github.com/login/oauth/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/9/91/Octicons-mark-github.svg", + "function": github_sso, + }, + "gitlab": { + "scopes": ["read_user", "api", "email"], + "authorization_url": "https://gitlab.com/oauth/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/1/18/GitLab_Logo.svg", + "function": gitlab_sso, + }, + "google": { + "scopes": [ + "https://www.googleapis.com/auth/gmail.send", + "https://www.googleapis.com/auth/userinfo.profile", + "https://www.googleapis.com/auth/userinfo.email", + ], + "authorization_url": "https://accounts.google.com/o/oauth2/auth", + "icon": "https://upload.wikimedia.org/wikipedia/commons/5/53/Google_%22G%22_Logo.svg", + "function": google_sso, + }, + "huddle": { + "scopes": ["user_info", "send_email"], + "authorization_url": "https://login.huddle.com/oauth2/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/1/1c/Huddle_logo.png", + "function": huddle_sso, + }, + "imgur": { + "scopes": ["read", "write"], + "authorization_url": "https://api.imgur.com/oauth2/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/1/1e/Imgur_logo.svg", + "function": imgur_sso, + }, + "instagram": { + "scopes": ["user_profile", "user_media"], + "authorization_url": "https://api.instagram.com/oauth/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/a/a5/Instagram_icon.png", + "function": instagram_sso, + }, + "intel_cloud_services": { + "scopes": [ + "https://api.intel.com/userinfo.read", + "https://api.intel.com/mail.send", + ], + "authorization_url": "https://auth.intel.com/oauth2/v2.0/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/1/1f/Intel_logo_%282006-2020%29.svg", + "function": intel_cloud_services_sso, + }, + "jive": { + "scopes": ["user", "email"], + "authorization_url": "https://example.jive.com/oauth2/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/0/0e/Jive_Software_logo.svg", + "function": jive_sso, + }, + "keycloak": { + "scopes": ["openid", "email", "profile"], + "authorization_url": "https://your-keycloak-server/auth/realms/your-realm/protocol/openid-connect/auth", + "icon": "https://upload.wikimedia.org/wikipedia/commons/0/05/Keycloak_Logo.png", + "function": keycloak_sso, + }, + "linkedin": { + "scopes": ["r_liteprofile", "r_emailaddress", "w_member_social"], + "authorization_url": "https://www.linkedin.com/oauth/v2/authorization", + "icon": "https://upload.wikimedia.org/wikipedia/commons/c/ca/LinkedIn_logo_initials.png", + "function": linkedin_sso, + }, + "microsoft": { + "scopes": [ + "https://graph.microsoft.com/User.Read", + "https://graph.microsoft.com/Mail.Send", + "https://graph.microsoft.com/Calendars.ReadWrite.Shared", + ], + "authorization_url": "https://login.microsoftonline.com/common/oauth2/v2.0/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/4/44/Microsoft_logo.svg", + "function": microsoft_sso, + }, + "netiq": { + "scopes": ["profile", "email", "openid", "user.info"], + "authorization_url": "https://your-netiq-domain.com/oauth2/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/4/4d/NetIQ_logo.png", + "function": netiq_sso, + }, + "okta": { + "scopes": ["openid", "profile", "email"], + "authorization_url": "https://your-okta-domain/oauth2/v1/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/6/6b/Okta_logo.png", + "function": okta_sso, + }, + "openam": { + "scopes": ["profile", "email"], + "authorization_url": "https://your-openam-base-url/oauth2/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/7/7a/OpenAM_logo.png", + "function": openam_sso, + }, + "openstreetmap": { + "scopes": ["read_prefs"], + "authorization_url": "https://www.openstreetmap.org/oauth/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/7/7e/OpenStreetMap_logo.svg", + "function": openstreetmap_sso, + }, + "orcid": { + "scopes": ["/authenticate", "/activities/update"], + "authorization_url": "https://orcid.org/oauth/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/0/0e/ORCID_logo.png", + "function": orcid_sso, + }, + "paypal": { + "scopes": ["email openid"], + "authorization_url": "https://www.paypal.com/oauth2/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/b/b5/PayPal.svg", + "function": paypal_sso, + }, + "ping_identity": { + "scopes": ["profile", "email", "openid"], + "authorization_url": "https://your-ping-identity-domain/oauth2/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/8/8e/Ping_Identity_logo.png", + "function": ping_identity_sso, + }, + "pixiv": { + "scopes": ["pixiv.scope.profile.read"], + "authorization_url": "https://oauth.secure.pixiv.net/auth/token", + "icon": "https://upload.wikimedia.org/wikipedia/commons/6/6a/Pixiv_logo.svg", + "function": pixiv_sso, + }, + "reddit": { + "scopes": ["identity", "submit", "read"], + "authorization_url": "https://www.reddit.com/api/v1/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/en/8/82/Reddit_logo_and_wordmark.svg", + "function": reddit_sso, + }, + "salesforce": { + "scopes": ["refresh_token full email"], + "authorization_url": "https://login.salesforce.com/services/oauth2/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/5/51/Salesforce_logo.svg", + "function": salesforce_sso, + }, + "sina_weibo": { + "scopes": ["email", "statuses_update"], + "authorization_url": "https://api.weibo.com/oauth2/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/8/86/Sina_Weibo_logo.svg", + "function": sina_weibo_sso, + }, + "spotify": { + "scopes": ["user-read-email", "user-read-private", "playlist-read-private"], + "authorization_url": "https://accounts.spotify.com/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/2/26/Spotify_logo_with_text.png", + "function": spotify_sso, + }, + "stack_exchange": { + "scopes": ["read_inbox no_expiry private_info write_access"], + "authorization_url": "https://stackexchange.com/oauth", + "icon": "https://upload.wikimedia.org/wikipedia/commons/6/6f/Stack_Exchange_logo.png", + "function": stack_exchange_sso, + }, + "strava": { + "scopes": ["read", "activity:write"], + "authorization_url": "https://www.strava.com/oauth/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/2/29/Strava_Logo.png", + "function": strava_sso, + }, + "stripe": { + "scopes": ["read_write"], + "authorization_url": "https://connect.stripe.com/oauth/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/3/30/Stripe_Logo%2C_revised_2016.png", + "function": stripe_sso, + }, + "twitch": { + "scopes": ["user:read:email"], + "authorization_url": "https://id.twitch.tv/oauth2/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/9/98/Twitch_logo.svg", + "function": twitch_sso, + }, + "viadeo": { + "scopes": ["basic", "email"], + "authorization_url": "https://secure.viadeo.com/oauth-provider/authorize2", + "icon": "https://upload.wikimedia.org/wikipedia/commons/7/7a/Viadeo_logo.png", + "function": viadeo_sso, + }, + "vimeo": { + "scopes": ["public", "private", "video_files"], + "authorization_url": "https://api.vimeo.com/oauth/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/9/9b/Vimeo_logo.png", + "function": vimeo_sso, + }, + "vk": { + "scopes": ["email"], + "authorization_url": "https://oauth.vk.com/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/2/21/VK.com-logo.svg", + "function": vk_sso, + }, + "wechat": { + "scopes": ["snsapi_userinfo"], + "authorization_url": "https://open.weixin.qq.com/connect/oauth2/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/8/8e/WeChat_logo.svg", + "function": wechat_sso, + }, + "withings": { + "scopes": ["user.info", "user.metrics", "user.activity"], + "authorization_url": "https://account.withings.com/oauth2_user/authorize2", + "icon": "https://upload.wikimedia.org/wikipedia/commons/9/9c/Withings_logo.png", + "function": withings_sso, + }, + "xero": { + "scopes": ["openid", "profile", "email", "offline_access"], + "authorization_url": "https://login.xero.com/identity/connect/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/5/5d/Xero_logo.svg", + "function": xero_sso, + }, + "xing": { + "scopes": [ + "https://api.xing.com/v1/users/me", + "https://api.xing.com/v1/authorize", + ], + "authorization_url": "https://api.xing.com/v1/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/2/2f/Xing_logo.svg", + "function": xing_sso, + }, + "yahoo": { + "scopes": ["profile", "email", "mail-w"], + "authorization_url": "https://api.login.yahoo.com/oauth2/request_auth", + "icon": "https://upload.wikimedia.org/wikipedia/commons/6/63/Yahoo%21_logo.svg", + "function": yahoo_sso, + }, + "yammer": { + "scopes": ["messages:email", "messages:post"], + "authorization_url": "https://www.yammer.com/oauth2/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/4/4b/Yammer_logo.png", + "function": yammer_sso, + }, + "yandex": { + "scopes": ["login:info login:email", "mail.send"], + "authorization_url": "https://oauth.yandex.com/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/9/96/Yandex_logo.png", + "function": yandex_sso, + }, + "yelp": { + "scopes": ["business"], + "authorization_url": "https://www.yelp.com/oauth2/authorize", + "icon": "https://upload.wikimedia.org/wikipedia/commons/6/62/Yelp_Logo.svg", + "function": yelp_sso, + }, + "zendesk": { + "scopes": ["read", "write"], + "authorization_url": "https://your-zendesk-domain/oauth/authorizations/new", + "icon": "https://upload.wikimedia.org/wikipedia/commons/2/2e/Zendesk_logo.png", + "function": zendesk_sso, + }, + } + return providers[provider] if provider in providers else None + + +def get_sso_provider(provider: str, code, redirect_uri=None): + provider_info = get_provider_info(provider) + if provider_info: + return provider_info["function"](code=code, redirect_uri=redirect_uri) + else: + return None diff --git a/agixt/Websearch.py b/agixt/Websearch.py index 8b6bd23f91d9..a0909dcc13d5 100644 --- a/agixt/Websearch.py +++ b/agixt/Websearch.py @@ -29,7 +29,7 @@ class Websearch: def __init__( self, - collection_number: int = 1, + collection_number: str = "1", agent: Agent = None, user: str = None, ApiClient=None, @@ -52,7 +52,7 @@ def __init__( self.agent_memory = YoutubeReader( agent_name=self.agent_name, agent_config=self.agent.AGENT_CONFIG, - collection_number=int(collection_number), + collection_number=str(collection_number), ApiClient=ApiClient, user=user, ) @@ -130,7 +130,9 @@ async def summarize_web_content(self, url, content): # If the content isn't too long, we will ask AI to resummarize the combined chunks. return await self.summarize_web_content(url=url, content=new_content) - async def get_web_content(self, url: str, summarize_content=False): + async def get_web_content( + self, url: str, summarize_content=False, conversation_id="1" + ): if url.startswith("https://arxiv.org/") or url.startswith( "https://www.arxiv.org/" ): @@ -149,7 +151,9 @@ async def get_web_content(self, url: str, summarize_content=False): video_id = video_id.split("&")[0] content = await self.agent_memory.get_transcription(video_id=video_id) self.browsed_links.append(url) - self.agent.add_browsed_link(url=url) + self.agent.add_browsed_link( + url=url, conversation_id=conversation_id + ) # add conversation ID if summarize_content: content = await self.summarize_web_content(url=url, content=content) await self.agent_memory.write_text_to_memory( @@ -187,7 +191,7 @@ async def get_web_content(self, url: str, summarize_content=False): res = await GithubReader( agent_name=self.agent_name, agent_config=self.agent.AGENT_CONFIG, - collection_number=7, + collection_number="7", user=self.user, ApiClient=self.ApiClient, ).write_github_repository_to_memory( @@ -206,7 +210,7 @@ async def get_web_content(self, url: str, summarize_content=False): ) if res: self.browsed_links.append(url) - self.agent.add_browsed_link(url=url) + self.agent.add_browsed_link(url=url, conversation_id=conversation_id) return ( f"Content from GitHub repository at {url} has been added to memory.", None, @@ -272,14 +276,15 @@ async def get_web_content(self, url: str, summarize_content=False): external_source=url, ) self.browsed_links.append(url) - self.agent.add_browsed_link(url=url) + self.agent.add_browsed_link(url=url, conversation_id=conversation_id) return text_content, link_list except: return None, None - async def recursive_browsing(self, user_input, links, conversation_name: str = ""): - if conversation_name != "" and conversation_name is not None: - c = Conversations(conversation_name=conversation_name, user=self.user) + async def recursive_browsing( + self, user_input, links, conversation_name: str = "", conversation_id="1" + ): + c = Conversations(conversation_name=conversation_name, user=self.user) try: words = links.split() links = [ @@ -301,7 +306,9 @@ async def recursive_browsing(self, user_input, links, conversation_name: str = " ( collected_data, link_list, - ) = await self.get_web_content(url=url) + ) = await self.get_web_content( + url=url, conversation_id=conversation_id + ) if links is not None: for link in links: if "href" in link: @@ -321,7 +328,9 @@ async def recursive_browsing(self, user_input, links, conversation_name: str = " ( collected_data, link_list, - ) = await self.get_web_content(url=url) + ) = await self.get_web_content( + url=url, conversation_id=conversation_id + ) if link_list is not None: if len(link_list) > 0: if len(link_list) > 5: @@ -359,6 +368,7 @@ async def recursive_browsing(self, user_input, links, conversation_name: str = " user_input=user_input, links=pick_a_link, conversation_name=conversation_name, + conversation_id=conversation_id, ) except: logging.info(f"Issues reading {url}. Moving on...") @@ -385,6 +395,14 @@ async def scrape_websites( return "" if conversation_name != "" and conversation_name is not None: c = Conversations(conversation_name=conversation_name, user=self.user) + conversation_id = c.get_conversation_id() + self.agent_memory = YoutubeReader( + agent_name=self.agent_name, + agent_config=self.agent.AGENT_CONFIG, + collection_number=conversation_id, + ApiClient=self.ApiClient, + user=self.user, + ) c.log_interaction( role=self.agent_name, message=f"[ACTIVITY] Researching online.", @@ -538,7 +556,7 @@ async def update_search_provider(self): self.websearch_endpoint = websearch_endpoint return websearch_endpoint - async def web_search(self, query: str) -> List[str]: + async def web_search(self, query: str, conversation_id: str = "1") -> List[str]: endpoint = self.websearch_endpoint if endpoint.endswith("/"): endpoint = endpoint[:-1] @@ -546,14 +564,14 @@ async def web_search(self, query: str) -> List[str]: endpoint = endpoint[:-6] logging.info(f"Websearching for {query} on {endpoint}") text_content, link_list = await self.get_web_content( - url=f"{endpoint}/search?q={query}" + url=f"{endpoint}/search?q={query}", conversation_id=conversation_id ) if link_list is None: link_list = [] if len(link_list) < 5: self.failures.append(self.websearch_endpoint) await self.update_search_provider() - return await self.web_search(query=query) + return await self.web_search(query=query, conversation_id=conversation_id) return text_content, link_list async def websearch_agent( @@ -579,10 +597,9 @@ async def websearch_agent( search_string = " ".join(keywords) # add month and year to the end of the search string search_string += f" {datetime.now().strftime('%B %Y')}" + c = Conversations(conversation_name=conversation_name, user=self.user) + conversation_id = c.get_conversation_id() if conversation_name != "" and conversation_name is not None: - c = Conversations( - conversation_name=conversation_name, user=self.user - ) c.log_interaction( role=self.agent_name, message=f"[ACTIVITY] Searching for `{search_string}`.", @@ -614,7 +631,9 @@ async def websearch_agent( links = await self.ddg_search(query=search_string) if links == [] or links is None: links = [] - content, links = await self.web_search(query=search_string) + content, links = await self.web_search( + query=search_string, conversation_id=conversation_id + ) logging.info( f"Found {len(links)} results for {search_string} using DDG." ) @@ -626,6 +645,7 @@ async def websearch_agent( user_input=user_input, links=links, conversation_name=conversation_name, + conversation_id=conversation_id, ) ) self.tasks.append(task) diff --git a/agixt/XT.py b/agixt/XT.py index fe4d08b08214..51fe1172f484 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -4,7 +4,7 @@ from Extensions import Extensions from pydub import AudioSegment from Globals import getenv, get_tokens, DEFAULT_SETTINGS -from Models import ChatCompletions +from Models import ChatCompletions, TasksToDo, ChainCommandName from datetime import datetime from typing import Type, get_args, get_origin, Union, List from enum import Enum @@ -95,7 +95,7 @@ async def memories( user_input: str = "", limit_per_collection: int = 5, minimum_relevance_score: float = 0.3, - additional_collection_number: int = 0, + additional_collection: str = "0", ): """ Get a list of memories @@ -104,7 +104,7 @@ async def memories( user_input (str): User input to the agent limit_per_collection (int): Number of memories to return per collection minimum_relevance_score (float): Minimum relevance score for memories - additional_collection_number (int): Additional collection number to pull memories from. Collections 0-5 are injected automatically. + additional_collection (int): Additional collection number to pull memories from. Collections 0-5 are injected automatically. Returns: str: Agents relevant memories from the user input from collections 0-5 and the additional collection number if provided @@ -113,7 +113,7 @@ async def memories( user_input=user_input if user_input else "*", top_results=limit_per_collection, min_relevance_score=minimum_relevance_score, - inject_memories_from_collection_number=int(additional_collection_number), + inject_memories_from_collection_number=additional_collection, ) return formatted_prompt @@ -539,7 +539,7 @@ async def learn_from_file( file_url: str = "", file_name: str = "", user_input: str = "", - collection_number: int = 1, + collection_id: str = "1", conversation_name: str = "", ): """ @@ -548,7 +548,7 @@ async def learn_from_file( Args: file_url (str): URL of the file file_path (str): Path to the file - collection_number (int): Collection number to store the file + collection_id (str): Collection ID to save the file to conversation_name (str): Name of the conversation Returns: @@ -576,7 +576,7 @@ async def learn_from_file( file_reader = FileReader( agent_name=self.agent_name, agent_config=self.agent.AGENT_CONFIG, - collection_number=collection_number, + collection_number=collection_id, ApiClient=self.ApiClient, user=self.user_email, ) @@ -705,6 +705,199 @@ async def download_file_to_workspace(self, url: str, file_name: str = ""): url = f"{self.outputs}/{file_name}" return {"file_name": file_name, "file_url": url} + async def plan_task( + self, + user_input: str, + websearch: bool = False, + websearch_depth: int = 3, + conversation_name: str = "", + log_user_input: bool = True, + log_output: bool = True, + enable_new_command: bool = True, + ): + """ + Plan a task from a user input, create and enable a new command to execute the plan + + Args: + user_input (str): User input to the agent + websearch (bool): Whether to include web research in the chain + websearch_depth (int): Depth of web research to include + conversation_name (str): Name of the conversation to log activity to + log_user_input (bool): Whether to log the user input + log_output (bool): Whether to log the output + enable_new_command (bool): Whether to enable the new command for the agent + + Returns: + str: The name of the created chain + """ + c = Conversations(conversation_name=conversation_name, user=self.user_email) + if log_user_input: + c.log_interaction( + role="USER", + message=user_input, + ) + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY] Determining primary objective.", + ) + # primary_objective = Step 1, execute chain "Smart Prompt" with the user input to get Primary Objective + primary_objective = await self.execute_chain( + chain_name="Smart Prompt", + user_input=user_input, + agent_override=self.agent_name, + log_user_input=False, + conversation_name=conversation_name, + ) + chain_name = await self.inference( + user_input=user_input, + introduction=primary_objective, + prompt_category="Default", + prompt_name="Title a Chain", + log_output=False, + log_user_input=False, + conversation_name=conversation_name, + ) + chain_title = await self.convert_to_pydantic_model( + input_string=chain_name, + model=ChainCommandName, + ) + chain_name = chain_title.command_name + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY] Breaking objective into a list of tasks.", + ) + # numbered_list_of_tasks = Step 2, Execute prompt "Break into Steps" with `introduction` being step 1 response, websearch true if researching + # Note - Should do this more than once to get a better list of tasks + numbered_list_of_tasks = await self.inference( + user_input=user_input, + introduction=primary_objective, + prompt_category="Default", + prompt_name="Break into Steps", + websearch=websearch, + websearch_depth=websearch_depth, + injected_memories=10, + log_output=False, + log_user_input=False, + conversation_name=conversation_name, + ) + task_list = await self.convert_to_pydantic_model( + input_string=numbered_list_of_tasks, + model=TasksToDo, + ) + self.chain.add_chain(chain_name=chain_name) + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY] Creating new command `{chain_name}`.", + ) + i = 1 + total_tasks = len(task_list.tasks) + x = 1 + # First step in the chain should be to disable the command so that the agent doesn't try to execute it while executing it + self.chain.add_chain_step( + chain_name=chain_name, + agent_name=self.agent_name, + step_number=i, + prompt_type="Command", + prompt={ + "command_name": "Disable Command", + "command_args": { + "agent_name": self.agent_name, + "command_name": chain_name, + }, + }, + ) + i += 1 + for task in task_list.tasks: + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY] Planning task `{x}` of `{total_tasks}`.", + ) + x += 1 + # Create a smart prompt with the objective and task in context + self.chain.add_chain_step( + chain_name=chain_name, + agent_name=self.agent_name, + step_number=i, + prompt_type="Chain", + prompt={ + "chain_name": "Smart Prompt", + "input": f"Primary Objective to keep in mind while working on the task: {primary_objective} \nThe only task to complete to move towards the objective: {task}", + }, + ) + i += 1 + self.chain.add_chain_step( + chain_name=chain_name, + agent_name=self.agent_name, + step_number=i, + prompt_type="Chain", + prompt={ + "chain": ( + "Smart Instruct" + if websearch + else "Smart Instruct - No Research" + ), + "input": "{STEP" + str(i - 1) + "}", + }, + ) + i += 1 + list_of_tasks = "\n".join( + [f"{i}. {task}" for i, task in enumerate(task_list.tasks, 1)] + ) + # Enable the command of the chain name + if enable_new_command: + self.agent.update_agent_config( + new_config={chain_name: True}, config_key="commands" + ) + message = f"I have created a new command called `{chain_name}`. The tasks will be executed in the following order:\n{list_of_tasks}\n\nWould you like me to execute `{chain_name}` now?" + else: + message = f"I have created a new command called `{chain_name}`. The tasks will be executed in the following order:\n{list_of_tasks}\n\nIf you are able to enable the command, I can execute it for you. Alternatively, you can execute the command manually." + if log_output: + c.log_interaction( + role=self.agent_name, + message=message, + ) + return { + "chain_name": chain_name, + "message": message, + "tasks": list_of_tasks, + } + + async def update_planned_task( + self, + chain_name: str, + user_input: str, + conversation_name: str = "", + log_user_input: bool = True, + log_output: bool = True, + enable_new_command: bool = True, + ): + """ + Modify the chain based on user input + + Args: + chain_name (str): Name of the chain to update + user_input (str): User input to the agent + conversation_name (str): Name of the conversation + log_user_input (bool): Whether to log the user input + log_output (bool): Whether to log the output + + Returns: + str: Response from the agent + """ + # Basically just delete the old chain after we extract the tasks and then run the plan_task function with more input from the user. + current_chain = self.chain.get_chain(chain_name=chain_name) + # This function is still a work in progress + # Need to + + self.chain.delete_chain(chain_name=chain_name) + return await self.plan_task( + user_input=user_input, + conversation_name=conversation_name, + log_user_input=log_user_input, + log_output=log_output, + enable_new_command=enable_new_command, + ) + async def chat_completions(self, prompt: ChatCompletions): """ Generate an OpenAI style chat completion response with a ChatCompletion prompt @@ -881,12 +1074,13 @@ async def chat_completions(self, prompt: ChatCompletions): # Add user input to conversation c = Conversations(conversation_name=conversation_name, user=self.user_email) c.log_interaction(role="USER", message=new_prompt) + conversation_id = c.get_conversation_id() for file in files: await self.learn_from_file( file_url=file["file_url"], file_name=file["file_name"], user_input=new_prompt, - collection_number=1, + collection_id=conversation_id, conversation_name=conversation_name, ) await self.learn_from_websites( @@ -1120,7 +1314,7 @@ async def create_dataset_from_memories(self, batch_size: int = 10): ) return dpo_dataset - def convert_to_pydantic_model( + async def convert_to_pydantic_model( self, input_string: str, model: Type[BaseModel], @@ -1140,12 +1334,13 @@ def convert_to_pydantic_model( description += f" (Enum values: {enum_values})" field_descriptions.append(description) schema = "\n".join(field_descriptions) - response = self.inference( + response = await self.inference( user_input=input_string, schema=schema, prompt_category="Default", prompt_name="Convert to Pydantic Model", log_user_input=False, + log_output=False, ) if "```json" in response: response = response.split("```json")[1].split("```")[0].strip() @@ -1174,7 +1369,7 @@ def convert_to_pydantic_model( logging.warning( f"Error: {e} . Failed to convert the response to the model, trying again. {failures}/3 failures. Response: {response}" ) - return self.convert_to_pydantic_model( + return await self.convert_to_pydantic_model( input_string=input_string, model=model, max_failures=max_failures, diff --git a/agixt/endpoints/Agent.py b/agixt/endpoints/Agent.py index 6fcc0d9aa465..e9cdc271f8d8 100644 --- a/agixt/endpoints/Agent.py +++ b/agixt/endpoints/Agent.py @@ -1,6 +1,7 @@ from typing import Dict from fastapi import APIRouter, HTTPException, Depends, Header from Interactions import Interactions +from XT import AGiXT from Websearch import Websearch from Globals import getenv from ApiClient import ( @@ -23,6 +24,7 @@ ResponseMessage, UrlInput, TTSInput, + TaskPlanInput, ) import base64 import uuid @@ -51,7 +53,7 @@ async def addagent( ApiClient = get_api_client(authorization=authorization) _agent = Agent(agent_name=agent.agent_name, user=user, ApiClient=ApiClient) reader = Websearch( - collection_number=0, + collection_number="0", agent=_agent, user=user, ApiClient=ApiClient, @@ -148,7 +150,7 @@ async def deleteagent( ApiClient = get_api_client(authorization=authorization) agent = Agent(agent_name=agent_name, user=user, ApiClient=ApiClient) await Websearch( - collection_number=0, + collection_number="0", agent=agent, user=user, ApiClient=ApiClient, @@ -268,18 +270,21 @@ async def toggle_command( # Get agent browsed links @app.get( - "/api/agent/{agent_name}/browsed_links", + "/api/agent/{agent_name}/browsed_links/{collection_number}", tags=["Agent", "Admin"], dependencies=[Depends(verify_api_key)], ) async def get_agent_browsed_links( - agent_name: str, user=Depends(verify_api_key), authorization: str = Header(None) + agent_name: str, + collection_number: str = "0", + user=Depends(verify_api_key), + authorization: str = Header(None), ): if is_admin(email=user, api_key=authorization) != True: raise HTTPException(status_code=403, detail="Access Denied") ApiClient = get_api_client(authorization=authorization) agent = Agent(agent_name=agent_name, user=user, ApiClient=ApiClient) - return {"links": agent.get_browsed_links()} + return {"links": agent.get_browsed_links(conversation_id=collection_number)} # Delete browsed link from memory @@ -299,13 +304,13 @@ async def delete_browsed_link( ApiClient = get_api_client(authorization=authorization) agent = Agent(agent_name=agent_name, user=user, ApiClient=ApiClient) websearch = Websearch( - collection_number=url.collection_number, + collection_number=str(url.collection_number), agent=agent, user=user, ApiClient=ApiClient, ) websearch.agent_memory.delete_memories_from_external_source(url=url.url) - agent.delete_browsed_link(url=url.url) + agent.delete_browsed_link(url=url.url, conversation_id=url.collection_number) return {"message": "Browsed links deleted."} @@ -333,3 +338,28 @@ async def text_to_speech( f.write(audio_data) tts_response = f"{AGIXT_URI}/outputs/{agent.agent_id}/{file_name}" return {"url": tts_response} + + +# Plan task +@app.post( + "/api/agent/{agent_name}/plan/task", + tags=["Agent"], + dependencies=[Depends(verify_api_key)], +) +async def plan_task( + agent_name: str, + task: TaskPlanInput, + user=Depends(verify_api_key), + authorization: str = Header(None), +) -> ResponseMessage: + agent = AGiXT(user=user, agent_name=agent_name, api_key=authorization) + planned_task = await agent.plan_task( + user_input=task.user_input, + websearch=task.websearch, + websearch_depth=task.websearch_depth, + conversation_name=task.conversation_name, + log_user_input=task.log_user_input, + log_output=task.log_output, + enable_new_command=task.enable_new_command, + ) + return {"response": planned_task} diff --git a/agixt/endpoints/Auth.py b/agixt/endpoints/Auth.py index 3f561624d0e8..10a8005c01f8 100644 --- a/agixt/endpoints/Auth.py +++ b/agixt/endpoints/Auth.py @@ -105,3 +105,20 @@ async def createuser( github_repos=account.github_repos, ApiClient=ApiClient, ) + + +@app.post( + "/v1/oauth2/{provider}", + response_model=Detail, + summary="Login using OAuth2 provider", +) +async def oauth_login(request: Request, provider: str): + data = await request.json() + auth = MagicalAuth() + magic_link = auth.sso( + provider=provider.lower(), + code=data["code"], + ip_address=request.client.host, + referrer=data["referrer"] if "referrer" in data else getenv("MAGIC_LINK_URL"), + ) + return {"detail": magic_link, "email": auth.email, "token": auth.token} diff --git a/agixt/endpoints/Conversation.py b/agixt/endpoints/Conversation.py index 81c4455c3406..f8a951e4e992 100644 --- a/agixt/endpoints/Conversation.py +++ b/agixt/endpoints/Conversation.py @@ -12,28 +12,21 @@ app = APIRouter() -@app.get( - "/api/{agent_name}/conversations", - tags=["Conversation"], - dependencies=[Depends(verify_api_key)], -) -async def get_conversations_list(user=Depends(verify_api_key)): - conversations = Conversations(user=user).get_conversations() - if conversations is None: - conversations = [] - return {"conversations": conversations} - - @app.get( "/api/conversations", tags=["Conversation"], dependencies=[Depends(verify_api_key)], ) async def get_conversations_list(user=Depends(verify_api_key)): - conversations = Conversations(user=user).get_conversations() + c = Conversations(user=user) + conversations = c.get_conversations() if conversations is None: conversations = [] - return {"conversations": conversations} + conversations_with_ids = c.get_conversations_with_ids() + return { + "conversations": conversations, + "conversations_with_ids": conversations_with_ids, + } @app.get( diff --git a/agixt/endpoints/Memory.py b/agixt/endpoints/Memory.py index daeb77cb4c8e..4c8383b2eb78 100644 --- a/agixt/endpoints/Memory.py +++ b/agixt/endpoints/Memory.py @@ -39,22 +39,18 @@ async def query_memories( agent_name: str, memory: AgentMemoryQuery, - collection_number=0, + collection_number="0", user=Depends(verify_api_key), authorization: str = Header(None), ) -> Dict[str, Any]: ApiClient = get_api_client(authorization=authorization) - try: - collection_number = int(collection_number) - except: - collection_number = 0 agent_config = Agent( agent_name=agent_name, user=user, ApiClient=ApiClient ).get_agent_config() memories = await Memories( agent_name=agent_name, agent_config=agent_config, - collection_number=collection_number, + collection_number=str(collection_number), ApiClient=ApiClient, user=user, ).get_memories_data( @@ -120,14 +116,10 @@ async def learn_text( agent_config = Agent( agent_name=agent_name, user=user, ApiClient=ApiClient ).get_agent_config() - try: - collection_number = int(data.collection_number) - except: - collection_number = 0 memory = Memories( agent_name=agent_name, agent_config=agent_config, - collection_number=collection_number, + collection_number=str(data.collection_number), ApiClient=ApiClient, user=user, ) @@ -170,7 +162,7 @@ async def learn_file( await FileReader( agent_name=agent_name, agent_config=agent_config, - collection_number=file.collection_number, + collection_number=str(file.collection_number), ApiClient=ApiClient, user=user, ).write_file_to_memory(file_path=file_path) @@ -202,7 +194,7 @@ async def learn_url( agent = Agent(agent_name=agent_name, user=user, ApiClient=ApiClient) url.url = url.url.replace(" ", "%20") response = await Websearch( - collection_number=url.collection_number, + collection_number=str(url.collection_number), agent=agent, user=user, ApiClient=ApiClient, @@ -231,7 +223,7 @@ async def learn_github_repo( await GithubReader( agent_name=agent_name, agent_config=agent_config, - collection_number=git.collection_number, + collection_number=str(git.collection_number), use_agent_settings=git.use_agent_settings, ApiClient=ApiClient, user=user, @@ -264,7 +256,7 @@ async def learn_arxiv( await ArxivReader( agent_name=agent_name, agent_config=agent_config, - collection_number=arxiv_input.collection_number, + collection_number=str(arxiv_input.collection_number), ApiClient=ApiClient, ).write_arxiv_articles_to_memory( query=arxiv_input.query, @@ -292,7 +284,7 @@ async def learn_youtube( await YoutubeReader( agent_name=agent_name, agent_config=agent_config, - collection_number=youtube_input.collection_number, + collection_number=str(youtube_input.collection_number), ApiClient=ApiClient, ).write_youtube_captions_to_memory(video_id=youtube_input.video_id) return ResponseMessage(message="Agent learned the content from the YouTube video.") @@ -313,7 +305,9 @@ async def agent_reader( ApiClient = get_api_client(authorization=authorization) agent = Agent(agent_name=agent_name, user=user, ApiClient=ApiClient) agent_config = agent.AGENT_CONFIG - collection_number = data["collection_number"] if "collection_number" in data else 0 + collection_number = ( + str(data["collection_number"]) if "collection_number" in data else "0" + ) if reader_name == "file": response = await FileReader( agent_name=agent_name, @@ -390,7 +384,7 @@ async def wipe_agent_memories( await Memories( agent_name=agent_name, agent_config=agent.AGENT_CONFIG, - collection_number=0, + collection_number="0", ApiClient=ApiClient, user=user, ).wipe_memory() @@ -404,17 +398,13 @@ async def wipe_agent_memories( ) async def wipe_agent_memories( agent_name: str, - collection_number=0, + collection_number: str = "0", user=Depends(verify_api_key), authorization: str = Header(None), ) -> ResponseMessage: if is_admin(email=user, api_key=authorization) != True: raise HTTPException(status_code=403, detail="Access Denied") ApiClient = get_api_client(authorization=authorization) - try: - collection_number = int(collection_number) - except: - collection_number = 0 agent = Agent(agent_name=agent_name, user=user, ApiClient=ApiClient) await Memories( agent_name=agent_name, @@ -433,16 +423,12 @@ async def wipe_agent_memories( ) async def delete_agent_memory( agent_name: str, - collection_number=0, + collection_number: str = "0", memory_id="", user=Depends(verify_api_key), authorization: str = Header(None), ) -> ResponseMessage: ApiClient = get_api_client(authorization=authorization) - try: - collection_number = int(collection_number) - except: - collection_number = 0 agent = Agent(agent_name=agent_name, user=user, ApiClient=ApiClient) await Memories( agent_name=agent_name, @@ -561,7 +547,7 @@ async def delete_memories_from_external_source( await Memories( agent_name=agent_name, agent_config=agent.AGENT_CONFIG, - collection_number=external_source.collection_number, + collection_number=str(external_source.collection_number), ApiClient=ApiClient, user=user, ).delete_memories_from_external_source( @@ -574,19 +560,22 @@ async def delete_memories_from_external_source( # Get unique external sources @app.get( - "/api/agent/{agent_name}/memory/external_sources", + "/api/agent/{agent_name}/memory/external_sources/{collection_number}", tags=["Memory"], dependencies=[Depends(verify_api_key)], ) async def get_unique_external_sources( - agent_name: str, user=Depends(verify_api_key), authorization: str = Header(None) + agent_name: str, + collection_number: str = "0", + user=Depends(verify_api_key), + authorization: str = Header(None), ) -> Dict[str, Any]: ApiClient = get_api_client(authorization=authorization) agent = Agent(agent_name=agent_name, user=user, ApiClient=ApiClient) external_sources = await Memories( agent_name=agent_name, agent_config=agent.AGENT_CONFIG, - collection_number=0, + collection_number=collection_number, ApiClient=ApiClient, user=user, ).get_external_data_sources() diff --git a/agixt/extensions/agixt_actions.py b/agixt/extensions/agixt_actions.py index 10677c8bc293..20421ec1f63b 100644 --- a/agixt/extensions/agixt_actions.py +++ b/agixt/extensions/agixt_actions.py @@ -160,6 +160,8 @@ def __init__(self, **kwargs): "Get CSV Preview Text": self.get_csv_preview_text, "Strip CSV Data from Code Block": self.get_csv_from_response, "Convert a string to a Pydantic model": self.convert_string_to_pydantic_model, + "Disable Command": self.disable_command, + "Plan Multistep Task": self.plan_multistep_task, } user = kwargs["user"] if "user" in kwargs else "user" for chain in Chain(user=user).get_chains(): @@ -193,7 +195,7 @@ async def read_file_content(self, file_path: str): agent_name=self.agent_name, file_name=filename, file_content=file_content, - collection_number=0, + collection_number="0", ) async def write_website_to_memory(self, url: str): @@ -209,7 +211,7 @@ async def write_website_to_memory(self, url: str): return self.ApiClient.learn_url( agent_name=self.agent_name, url=url, - collection_number=0, + collection_number="0", ) async def store_long_term_memory( @@ -246,7 +248,7 @@ async def search_arxiv(self, query: str, max_articles: int = 5): query=query, article_ids=None, max_articles=max_articles, - collection_number=0, + collection_number="0", ) async def read_github_repository(self, repository_url: str): @@ -263,9 +265,46 @@ async def read_github_repository(self, repository_url: str): agent_name=self.agent_name, github_repo=repository_url, use_agent_settings=True, - collection_number=0, + collection_number="0", ) + async def disable_command(self, command_name: str): + """ + Disable a command + + Args: + command_name (str): The name of the command to disable + + Returns: + str: Success message + """ + return self.ApiClient.toggle_command( + agent_name=self.agent_name, commands_name=command_name, enable=False + ) + + async def plan_multistep_task(self, assumed_scope_of_work: str): + """ + Plan a multi-step task + + Args: + assumed_scope_of_work (str): The assumed scope of work + + Returns: + str: The name of the new chain + """ + user_input = assumed_scope_of_work + new_chain = self.ApiClient.plan_task( + agent_name=self.agent_name, + user_input=user_input, + websearch=True, + websearch_depth=3, + conversation_name=self.conversation_name, + log_user_input=False, + log_output=False, + enable_new_command=True, + ) + return new_chain["message"] + async def create_task_chain( self, agent: str, diff --git a/agixt/readers/arxiv.py b/agixt/readers/arxiv.py index 4803e7f164fd..91438623bf63 100644 --- a/agixt/readers/arxiv.py +++ b/agixt/readers/arxiv.py @@ -10,7 +10,7 @@ def __init__( self, agent_name: str = "AGiXT", agent_config=None, - collection_number: int = 0, + collection_number: str = "0", ApiClient=None, user=None, **kwargs, @@ -18,7 +18,7 @@ def __init__( super().__init__( agent_name=agent_name, agent_config=agent_config, - collection_number=collection_number, + collection_number=str(collection_number), ApiClient=ApiClient, user=user, ) diff --git a/agixt/readers/file.py b/agixt/readers/file.py index 8080e4379ef4..8a60468b6ed8 100644 --- a/agixt/readers/file.py +++ b/agixt/readers/file.py @@ -13,7 +13,7 @@ def __init__( self, agent_name: str = "AGiXT", agent_config=None, - collection_number: int = 0, + collection_number: str = "0", ApiClient=None, user=None, **kwargs, @@ -21,7 +21,7 @@ def __init__( super().__init__( agent_name=agent_name, agent_config=agent_config, - collection_number=collection_number, + collection_number=str(collection_number), ApiClient=ApiClient, user=user, ) diff --git a/agixt/readers/github.py b/agixt/readers/github.py index e1d2b3952358..0e34f0d3e778 100644 --- a/agixt/readers/github.py +++ b/agixt/readers/github.py @@ -9,7 +9,7 @@ def __init__( self, agent_name: str = "AGiXT", agent_config=None, - collection_number: int = 0, + collection_number: str = "0", use_agent_settings: bool = False, ApiClient=None, user=None, @@ -18,12 +18,17 @@ def __init__( super().__init__( agent_name=agent_name, agent_config=agent_config, - collection_number=collection_number, + collection_number=str(collection_number), ApiClient=ApiClient, user=user, ) self.file_reader = FileReader( - agent_name=self.agent_name, agent_config=self.agent_config, user=user + agent_name=self.agent_name, + agent_config=self.agent_config, + collection_number=str(collection_number), + ApiClient=ApiClient, + user=user, + **kwargs, ) self.use_agent_settings = use_agent_settings if ( diff --git a/agixt/readers/youtube.py b/agixt/readers/youtube.py index 93c8a8f07e02..c062ea88fcf9 100644 --- a/agixt/readers/youtube.py +++ b/agixt/readers/youtube.py @@ -7,7 +7,7 @@ def __init__( self, agent_name: str = "AGiXT", agent_config=None, - collection_number: int = 0, + collection_number: str = 0, ApiClient=None, user=None, **kwargs, @@ -15,7 +15,7 @@ def __init__( super().__init__( agent_name=agent_name, agent_config=agent_config, - collection_number=collection_number, + collection_number=str(collection_number), ApiClient=ApiClient, user=user, ) diff --git a/agixt/sso/amazon.py b/agixt/sso/amazon.py new file mode 100644 index 000000000000..650507cc9364 --- /dev/null +++ b/agixt/sso/amazon.py @@ -0,0 +1,153 @@ +import base64 +import json +import requests +import logging +from email.mime.text import MIMEText +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- AWS_CLIENT_ID: AWS Cognito OAuth client ID +- AWS_CLIENT_SECRET: AWS Cognito OAuth client secret +- AWS_USER_POOL_ID: AWS Cognito User Pool ID +- AWS_REGION: AWS Cognito Region + +Required scopes for AWS OAuth + +- openid +- email +- profile +""" + + +class AmazonSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("AWS_CLIENT_ID") + self.client_secret = getenv("AWS_CLIENT_SECRET") + self.user_pool_id = getenv("AWS_USER_POOL_ID") + self.region = getenv("AWS_REGION") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + f"https://{self.user_pool_id}.auth.{self.region}.amazoncognito.com/oauth2/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + "scope": "openid email profile", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = f"https://{self.user_pool_id}.auth.{self.region}.amazoncognito.com/oauth2/userInfo" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + first_name = data.get("given_name", "") + last_name = data.get("family_name", "") + email = data["email"] + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from AWS", + ) + + def send_email(self, to, subject, message_text): + if not self.user_info.get("email"): + user_info = self.get_user_info() + self.email_address = user_info["email"] + message = MIMEText(message_text) + message["to"] = to + message["from"] = self.email_address + message["subject"] = subject + raw = base64.urlsafe_b64encode(message.as_bytes()) + raw = raw.decode() + email_data = { + "Source": self.email_address, + "Destination": { + "ToAddresses": [to], + }, + "Message": { + "Subject": { + "Data": subject, + }, + "Body": { + "Text": { + "Data": message_text, + }, + }, + }, + } + response = requests.post( + f"https://email.{self.region}.amazonaws.com/v2/email/outbound-emails", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(email_data), + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.post( + f"https://email.{self.region}.amazonaws.com/v2/email/outbound-emails", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(email_data), + ) + return response.json() + + +def amazon_sso(code, redirect_uri=None) -> AmazonSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + f"https://{getenv('AWS_USER_POOL_ID')}.auth.{getenv('AWS_REGION')}.amazoncognito.com/oauth2/token", + data={ + "client_id": getenv("AWS_CLIENT_ID"), + "client_secret": getenv("AWS_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + }, + ) + if response.status_code != 200: + logging.error(f"Error getting AWS access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] + return AmazonSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/aol.py b/agixt/sso/aol.py new file mode 100644 index 000000000000..dfca8f0cfa19 --- /dev/null +++ b/agixt/sso/aol.py @@ -0,0 +1,151 @@ +import base64 +import json +import requests +import logging +from email.mime.text import MIMEText +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- AOL_CLIENT_ID: AOL OAuth client ID +- AOL_CLIENT_SECRET: AOL OAuth client secret + +Note: This example assumes hypothetical OAuth and API endpoints for AOL since AOL does not provide OAuth for individual users publicly like Google or Microsoft. Replace the endpoints and scopes with the actual values if available. + +Required scopes for AOL OAuth + +- https://api.aol.com/userinfo.profile +- https://api.aol.com/userinfo.email +- https://api.aol.com/mail.send +""" + + +class AOLSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("AOL_CLIENT_ID") + self.client_secret = getenv("AOL_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://api.login.aol.com/oauth2/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://api.aol.com/userinfo/v1/me" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + first_name = data["names"][0]["givenName"] + last_name = data["names"][0]["familyName"] + email = data["emailAddresses"][0]["value"] + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from AOL", + ) + + def send_email(self, to, subject, message_text): + if not self.user_info.get("email"): + user_info = self.get_user_info() + self.email_address = user_info["email"] + message = MIMEText(message_text) + message["to"] = to + message["from"] = self.email_address + message["subject"] = subject + raw = base64.urlsafe_b64encode(message.as_bytes()) + raw = raw.decode() + email_data = { + "message": { + "subject": subject, + "body": { + "contentType": "Text", + "content": message_text, + }, + "toRecipients": [ + { + "emailAddress": { + "address": to, + } + } + ], + }, + } + response = requests.post( + "https://api.aol.com/mail/v1/messages/send", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(email_data), + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.post( + "https://api.aol.com/mail/v1/messages/send", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(email_data), + ) + return response.json() + + +def aol_sso(code, redirect_uri=None) -> AOLSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + f"https://api.login.aol.com/oauth2/token", + data={ + "client_id": getenv("AOL_CLIENT_ID"), + "client_secret": getenv("AOL_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + "scope": "https://api.aol.com/userinfo.profile https://api.aol.com/userinfo.email https://api.aol.com/mail.send", + }, + ) + if response.status_code != 200: + logging.error(f"Error getting AOL access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] if "refresh_token" in data else "Not provided" + return AOLSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/apple.py b/agixt/sso/apple.py new file mode 100644 index 000000000000..32bfd87f5b5c --- /dev/null +++ b/agixt/sso/apple.py @@ -0,0 +1,102 @@ +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- APPLE_CLIENT_ID: Apple OAuth client ID +- APPLE_CLIENT_SECRET: Apple OAuth client secret + +Required APIs + +Follow the links to confirm that you have the APIs enabled, +then add the `APPLE_CLIENT_ID` and `APPLE_CLIENT_SECRET` environment variables to your `.env` file. + +Required scopes for Apple SSO + +- name +- email +""" + + +class AppleSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("APPLE_CLIENT_ID") + self.client_secret = getenv("APPLE_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://appleid.apple.com/auth/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + # Apple SSO does not have a straightforward URI to get user info post authentication like Google/Microsoft. + # User info should be captured during the initial token exchange. + try: + # Placeholder: Requires custom logic to handle user info retrieval from the initial login response. + # Capture name and email from initial response or authenticate/authorize endpoint. + first_name = "First" # replace with actual logic + last_name = "Last" # replace with actual logic + email = "email@example.com" # replace with actual logic + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from Apple", + ) + + def send_email(self, to, subject, message_text): + # Note: Apple does not provide an email sending service in their APIs. + # Placeholder: Functionality should be implemented using SMTP or another email service. + raise NotImplementedError( + "Apple OAuth does not support sending emails directly via API" + ) + + +def apple_sso(code, redirect_uri=None) -> AppleSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + "https://appleid.apple.com/auth/token", + data={ + "client_id": getenv("APPLE_CLIENT_ID"), + "client_secret": getenv("APPLE_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + }, + ) + if response.status_code != 200: + logging.error(f"Error getting Apple access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] if "refresh_token" in data else "Not provided" + return AppleSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/autodesk.py b/agixt/sso/autodesk.py new file mode 100644 index 000000000000..230190f2b79b --- /dev/null +++ b/agixt/sso/autodesk.py @@ -0,0 +1,143 @@ +import base64 +import json +import requests +import logging +from email.mime.text import MIMEText +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- AUTODESK_CLIENT_ID: Autodesk OAuth client ID +- AUTODESK_CLIENT_SECRET: Autodesk OAuth client secret + +Required APIs + +Follow the links to confirm that you have the APIs enabled, +then add the `AUTODESK_CLIENT_ID` and `AUTODESK_CLIENT_SECRET` environment variables to your `.env` file. + +Required scopes for Autodesk OAuth + +- data:read +- data:write +- bucket:read +- bucket:create +""" + + +class AutodeskSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("AUTODESK_CLIENT_ID") + self.client_secret = getenv("AUTODESK_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://developer.api.autodesk.com/authentication/v1/refreshtoken", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://developer.api.autodesk.com/userprofile/v1/users/@me" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + first_name = data["firstName"] + last_name = data["lastName"] + email = data["emailId"] + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from Autodesk", + ) + + def send_email(self, to, subject, message_text): + raise NotImplementedError( + "Autodesk API does not support sending emails via OAuth tokens" + ) + + if not self.email_address: + user_info = self.get_user_info() + self.email_address = user_info["email"] + message = MIMEText(message_text) + message["to"] = to + message["from"] = self.email_address + message["subject"] = subject + raw = base64.urlsafe_b64encode(message.as_bytes()) + raw = raw.decode() + message = {"raw": raw} + response = requests.post( + "https://developer.api.autodesk.com/email/v1/send", # Placeholder URL + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(message), + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.post( + "https://developer.api.autodesk.com/email/v1/send", # Placeholder URL + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(message), + ) + return response.json() + + +def autodesk_sso(code, redirect_uri=None) -> AutodeskSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + f"https://developer.api.autodesk.com/authentication/v1/gettoken", + data={ + "client_id": getenv("AUTODESK_CLIENT_ID"), + "client_secret": getenv("AUTODESK_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + }, + ) + if response.status_code != 200: + logging.error(f"Error getting Autodesk access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] + return AutodeskSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/battlenet.py b/agixt/sso/battlenet.py new file mode 100644 index 000000000000..084cf8e18b49 --- /dev/null +++ b/agixt/sso/battlenet.py @@ -0,0 +1,101 @@ +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- BATTLENET_CLIENT_ID: Battle.net OAuth client ID +- BATTLENET_CLIENT_SECRET: Battle.net OAuth client secret + +Required APIs + +Follow the links to confirm that you have the APIs enabled, +then add the `BATTLENET_CLIENT_ID` and `BATTLENET_CLIENT_SECRET` environment variables to your `.env` file. + +Required scopes for Battle.net OAuth + +- openid +- email +""" + + +class BattleNetSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("BATTLENET_CLIENT_ID") + self.client_secret = getenv("BATTLENET_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://oauth.battle.net/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = f"https://oauth.battle.net/oauth/userinfo" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + battletag = data["battletag"] + email = data["email"] + return { + "email": email, + "battletag": battletag, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from Battle.net", + ) + + +def battlenet_sso(code, redirect_uri=None) -> BattleNetSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + f"https://oauth.battle.net/token", + data={ + "client_id": getenv("BATTLENET_CLIENT_ID"), + "client_secret": getenv("BATTLENET_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + }, + ) + if response.status_code != 200: + logging.error(f"Error getting Battle.net access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] if "refresh_token" in data else "Not provided" + return BattleNetSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/bitbucket.py b/agixt/sso/bitbucket.py new file mode 100644 index 000000000000..1d07b126411f --- /dev/null +++ b/agixt/sso/bitbucket.py @@ -0,0 +1,114 @@ +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- BITBUCKET_CLIENT_ID: Bitbucket OAuth client ID +- BITBUCKET_CLIENT_SECRET: Bitbucket OAuth client secret + +Required APIs + +Follow the links to confirm that you have the APIs enabled, +then add the `BITBUCKET_CLIENT_ID` and `BITBUCKET_CLIENT_SECRET` environment variables to your `.env` file. + +Required scopes for Bitbucket SSO + +- account +- email +""" + + +class BitbucketSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("BITBUCKET_CLIENT_ID") + self.client_secret = getenv("BITBUCKET_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://bitbucket.org/site/oauth2/access_token", + data={ + "grant_type": "refresh_token", + "refresh_token": self.refresh_token, + }, + auth=(self.client_id, self.client_secret), + ) + if response.status_code != 200: + raise HTTPException( + status_code=response.status_code, + detail="Error refreshing Bitbucket access token", + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://api.bitbucket.org/2.0/user" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + user_data = response.json() + email_response = requests.get( + "https://api.bitbucket.org/2.0/user/emails", + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + email_data = email_response.json() + email = next( + email["email"] for email in email_data["values"] if email["is_primary"] + ) + first_name = user_data.get("display_name", "").split()[0] + last_name = " ".join(user_data.get("display_name", "").split()[1:]) + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except Exception as e: + logging.error(f"Error getting Bitbucket user info: {str(e)}") + raise HTTPException( + status_code=400, + detail="Error getting user info from Bitbucket", + ) + + +def bitbucket_sso(code, redirect_uri=None) -> BitbucketSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + "https://bitbucket.org/site/oauth2/access_token", + data={ + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirect_uri, + }, + auth=(getenv("BITBUCKET_CLIENT_ID"), getenv("BITBUCKET_CLIENT_SECRET")), + ) + if response.status_code != 200: + logging.error(f"Error getting Bitbucket access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] + return BitbucketSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/bitly.py b/agixt/sso/bitly.py new file mode 100644 index 000000000000..e595d07f490c --- /dev/null +++ b/agixt/sso/bitly.py @@ -0,0 +1,94 @@ +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- BITLY_CLIENT_ID: Bitly OAuth client ID +- BITLY_CLIENT_SECRET: Bitly OAuth client secret +- BITLY_ACCESS_TOKEN: Bitly access token (you can obtain it via OAuth or from the Bitly account settings) + +Required scopes for Bitly OAuth + +- `bitly:read`, `bitly:write` +""" + + +class Bitly: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token or getenv("BITLY_ACCESS_TOKEN") + self.client_id = getenv("BITLY_CLIENT_ID") + self.client_secret = getenv("BITLY_CLIENT_SECRET") + + def get_new_token(self): + response = requests.post( + "https://api-ssl.bitly.com/oauth/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + ) + if response.status_code != 200: + logging.error(f"Error refreshing Bitly token: {response.text}") + raise HTTPException(status_code=400, detail="Error refreshing Bitly token") + return response.json()["access_token"] + + def shorten_url(self, long_url): + uri = "https://api-ssl.bitly.com/v4/shorten" + headers = { + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + } + data = { + "long_url": long_url, + } + response = requests.post(uri, headers=headers, json=data) + if response.status_code == 401: + self.access_token = self.get_new_token() + headers["Authorization"] = f"Bearer {self.access_token}" + response = requests.post(uri, headers=headers, json=data) + + if response.status_code != 200: + logging.error(f"Error shortening URL with Bitly: {response.text}") + raise HTTPException( + status_code=response.status_code, + detail="Error shortening URL with Bitly", + ) + return response.json()["link"] + + +def bitly_sso(code, redirect_uri=None) -> Bitly: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + "https://bit.ly/oauth/access_token", + data={ + "client_id": getenv("BITLY_CLIENT_ID"), + "client_secret": getenv("BITLY_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + }, + ) + if response.status_code != 200: + logging.error(f"Error getting Bitly access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] if "refresh_token" in data else None + return Bitly(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/clearscore.py b/agixt/sso/clearscore.py new file mode 100644 index 000000000000..e3b3b504464d --- /dev/null +++ b/agixt/sso/clearscore.py @@ -0,0 +1,166 @@ +import base64 +import json +import requests +import logging +from email.mime.text import MIMEText +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- CLEAR_SCORE_CLIENT_ID: ClearScore OAuth client ID +- CLEAR_SCORE_CLIENT_SECRET: ClearScore OAuth client secret + +Required APIs + +Add the `CLEAR_SCORE_CLIENT_ID` and `CLEAR_SCORE_CLIENT_SECRET` environment variables to your `.env` file. + +Assumed Required scopes for ClearScore OAuth and email capabilities: + +- user.info.read +- email.send +""" + + +class ClearScoreSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("CLEAR_SCORE_CLIENT_ID") + self.client_secret = getenv("CLEAR_SCORE_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://auth.clearscore.com/oauth2/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + "scope": "user.info.read email.send", + }, + ) + if response.status_code != 200: + logging.error(f"Error refreshing ClearScore access token: {response.text}") + raise HTTPException( + status_code=400, + detail="Error refreshing ClearScore access token", + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://api.clearscore.com/v1/me" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + first_name = data["first_name"] + last_name = data["last_name"] + email = data["email"] + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from ClearScore", + ) + + def send_email(self, to, subject, message_text): + if not self.user_info.get("email"): + self.user_info = self.get_user_info() + email_address = self.user_info["email"] + message = MIMEText(message_text) + message["to"] = to + message["from"] = email_address + message["subject"] = subject + raw = base64.urlsafe_b64encode(message.as_bytes()) + raw = raw.decode() + email_data = { + "message": { + "subject": subject, + "body": { + "contentType": "Text", + "content": message_text, + }, + "toRecipients": [ + { + "emailAddress": { + "address": to, + } + } + ], + }, + "saveToSentItems": "true", + } + response = requests.post( + "https://api.clearscore.com/v1/me/sendMail", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(email_data), + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.post( + "https://api.clearscore.com/v1/me/sendMail", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(email_data), + ) + if response.status_code != 202: + logging.error(f"Error sending email: {response.text}") + raise HTTPException( + status_code=response.status_code, + detail="Error sending email", + ) + return response.json() + + +def clearscore_sso(code, redirect_uri=None) -> ClearScoreSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + "https://auth.clearscore.com/oauth2/token", + data={ + "code": code, + "client_id": getenv("CLEAR_SCORE_CLIENT_ID"), + "client_secret": getenv("CLEAR_SCORE_CLIENT_SECRET"), + "redirect_uri": redirect_uri, + "grant_type": "authorization_code", + "scope": "user.info.read email.send", + }, + ) + if response.status_code != 200: + logging.error(f"Error getting ClearScore access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] + return ClearScoreSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/cloud_foundry.py b/agixt/sso/cloud_foundry.py new file mode 100644 index 000000000000..b191da3ba6e7 --- /dev/null +++ b/agixt/sso/cloud_foundry.py @@ -0,0 +1,105 @@ +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- CF_CLIENT_ID: Cloud Foundry OAuth client ID +- CF_CLIENT_SECRET: Cloud Foundry OAuth client secret + +Required APIs and Scopes: + +- Cloud Foundry API (CF API) +- User Info API +""" + + +class CloudFoundrySSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("CF_CLIENT_ID") + self.client_secret = getenv("CF_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://login.system.example.com/oauth/token", # Update with your CF OAuth token URL + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://uaa.system.example.com/userinfo" # Update with your CF User Info URL + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + first_name = data["given_name"] + last_name = data["family_name"] + email = data["email"] + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from Cloud Foundry", + ) + + def send_email(self, to, subject, message_text): + # Assuming you have a CF service for sending emails; use relevant API + # This part is highly dependent on what services you use in Cloud Foundry + raise NotImplementedError( + "Email sending not supported for Cloud Foundry SSO yet." + ) + + +def cloud_foundry_sso(code, redirect_uri=None) -> CloudFoundrySSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3A", ":") + .replace("%3F", "?") + ) + response = requests.post( + "https://login.system.example.com/oauth/token", # Update with your CF OAuth token URL + data={ + "client_id": getenv("CF_CLIENT_ID"), + "client_secret": getenv("CF_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + }, + ) + if response.status_code != 200: + logging.error(f"Error getting Cloud Foundry access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] + return CloudFoundrySSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/deutsche_telekom.py b/agixt/sso/deutsche_telekom.py new file mode 100644 index 000000000000..67eb9a39253e --- /dev/null +++ b/agixt/sso/deutsche_telekom.py @@ -0,0 +1,154 @@ +import base64 +import json +import requests +import logging +from email.mime.text import MIMEText +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- DEUTSCHE_TELKOM_CLIENT_ID: Deutsche Telekom OAuth client ID +- DEUTSCHE_TELKOM_CLIENT_SECRET: Deutsche Telekom OAuth client secret + +Required APIs: + +- https://www.deutschetelekom.com/ldap-sso + +Required scopes for Deutsche Telekom SSO: + +- t-online-profile --> Access to profile data +- t-online-email --> Access to email services +""" + + +class DeutscheTelekomSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("DEUTSCHE_TELKOM_CLIENT_ID") + self.client_secret = getenv("DEUTSCHE_TELKOM_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://www.telekom.com/ssoservice/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + "scope": "t-online-profile t-online-email", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://www.telekom.com/ssoservice/me" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + first_name = data["firstName"] + last_name = data["lastName"] + email = data["email"] + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from Deutsche Telekom", + ) + + def send_email(self, to, subject, message_text): + if not self.user_info.get("email"): + user_info = self.get_user_info() + self.email_address = user_info["email"] + message = MIMEText(message_text) + message["to"] = to + message["from"] = self.email_address + message["subject"] = subject + raw = base64.urlsafe_b64encode(message.as_bytes()) + raw = raw.decode() + email_data = { + "message": { + "subject": subject, + "body": { + "contentType": "Text", + "content": message_text, + }, + "toRecipients": [ + { + "emailAddress": { + "address": to, + } + } + ], + }, + "saveToSentItems": "true", + } + response = requests.post( + "https://www.telekom.com/ssoservice/sendMail", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(email_data), + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.post( + "https://www.telekom.com/ssoservice/sendMail", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(email_data), + ) + return response.json() + + +def deutsche_telekom_sso(code, redirect_uri=None) -> DeutscheTelekomSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + f"https://www.telekom.com/ssoservice/token", + data={ + "client_id": getenv("DEUTSCHE_TELKOM_CLIENT_ID"), + "client_secret": getenv("DEUTSCHE_TELKOM_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + "scope": "t-online-profile t-online-email", + }, + ) + if response.status_code != 200: + logging.error(f"Error getting Deutsche Telekom access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] if "refresh_token" in data else "Not provided" + return DeutscheTelekomSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/deviantart.py b/agixt/sso/deviantart.py new file mode 100644 index 000000000000..142787c72cd2 --- /dev/null +++ b/agixt/sso/deviantart.py @@ -0,0 +1,139 @@ +import base64 +import json +import requests +import logging +from email.mime.text import MIMEText +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- DEVIANTART_CLIENT_ID: deviantART OAuth client ID +- DEVIANTART_CLIENT_SECRET: deviantART OAuth client secret + +Required OAuth scopes for deviantART + +- user +- browse +- stash +- send_message +""" + + +class DeviantArtSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("DEVIANTART_CLIENT_ID") + self.client_secret = getenv("DEVIANTART_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://www.deviantart.com/oauth2/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://www.deviantart.com/api/v1/oauth2/user/whoami" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + first_name = data.get("username", "Unknown") + last_name = "" + email = data.get( + "usericon", "Unknown" + ) # deviantART doesn't provide email, using user icon as unique identifier. + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from deviantART", + ) + + def send_message(self, to, subject, message_text): + if not self.user_info.get("email"): + user_info = self.get_user_info() + self.email_address = user_info["email"] + message = MIMEText(message_text) + message["to"] = to + message["from"] = self.email_address + message["subject"] = subject + raw = base64.urlsafe_b64encode(message.as_bytes()) + raw = raw.decode() + message_data = { + "subject": subject, + "body": message_text, + } + response = requests.post( + "https://www.deviantart.com/api/v1/oauth2/user/notes/send", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(message_data), + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.post( + "https://www.deviantart.com/api/v1/oauth2/user/notes/send", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(message_data), + ) + return response.json() + + +def deviantart_sso(code, redirect_uri=None) -> DeviantArtSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + "https://www.deviantart.com/oauth2/token", + data={ + "client_id": getenv("DEVIANTART_CLIENT_ID"), + "client_secret": getenv("DEVIANTART_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + }, + ) + if response.status_code != 200: + logging.error(f"Error getting deviantART access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] if "refresh_token" in data else "Not provided" + return DeviantArtSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/discord.py b/agixt/sso/discord.py new file mode 100644 index 000000000000..580ec8a42725 --- /dev/null +++ b/agixt/sso/discord.py @@ -0,0 +1,106 @@ +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- DISCORD_CLIENT_ID: Discord OAuth client ID +- DISCORD_CLIENT_SECRET: Discord OAuth client secret + +Required APIs and Scopes + +Follow the links to confirm that you have the APIs enabled, +then add the `DISCORD_CLIENT_ID` and `DISCORD_CLIENT_SECRET` environment variables to your `.env` file. + +- OAuth2 API +- Email scope https://discord.com/developers/docs/topics/oauth2#shared-resources-oauth2-scopes +""" + + +class DiscordSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("DISCORD_CLIENT_ID") + self.client_secret = getenv("DISCORD_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://discord.com/api/oauth2/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://discord.com/api/users/@me" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + username = data["username"] + discriminator = data["discriminator"] + email = data["email"] + return { + "username": username, + "discriminator": discriminator, + "email": email, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from Discord", + ) + + # Discord doesn't have a direct send email API, but you could implement similar functionality + # using a bot or webhook if necessary. + + +def discord_sso(code, redirect_uri=None) -> DiscordSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + f"https://discord.com/api/oauth2/token", + data={ + "client_id": getenv("DISCORD_CLIENT_ID"), + "client_secret": getenv("DISCORD_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + }, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + if response.status_code != 200: + logging.error(f"Error getting Discord access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] + return DiscordSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/dropbox.py b/agixt/sso/dropbox.py new file mode 100644 index 000000000000..80e5036e3e39 --- /dev/null +++ b/agixt/sso/dropbox.py @@ -0,0 +1,141 @@ +import json +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- DROPBOX_CLIENT_ID: Dropbox OAuth client ID +- DROPBOX_CLIENT_SECRET: Dropbox OAuth client secret + +Required APIs + +Follow the links to confirm that you have the APIs enabled, +then add the `DROPBOX_CLIENT_ID` and `DROPBOX_CLIENT_SECRET` environment variables to your `.env` file. + +Required scopes for Dropbox OAuth + +- account_info.read +- files.metadata.read +""" + + +class DropboxSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("DROPBOX_CLIENT_ID") + self.client_secret = getenv("DROPBOX_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://api.dropboxapi.com/oauth2/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://api.dropboxapi.com/2/users/get_current_account" + response = requests.post( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.post( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + first_name = data["name"]["given_name"] + last_name = data["name"]["surname"] + email = data["email"] + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except Exception as e: + raise HTTPException( + status_code=400, + detail=f"Error getting user info from Dropbox: {e}", + ) + + def list_files(self): + uri = "https://api.dropboxapi.com/2/files/list_folder" + data = { + "path": "", + "recursive": False, + "include_media_info": False, + "include_deleted": False, + "include_has_explicit_shared_members": False, + "include_mounted_folders": True, + "include_non_downloadable_files": True, + } + response = requests.post( + uri, + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(data), + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.post( + uri, + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(data), + ) + if response.status_code != 200: + logging.error(f"Error listing files from Dropbox: {response.text}") + raise HTTPException( + status_code=response.status_code, + detail="Error listing files from Dropbox", + ) + return response.json() + + +def dropbox_sso(code, redirect_uri=None) -> DropboxSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + f"https://api.dropboxapi.com/oauth2/token", + data={ + "code": code, + "client_id": getenv("DROPBOX_CLIENT_ID"), + "client_secret": getenv("DROPBOX_CLIENT_SECRET"), + "redirect_uri": redirect_uri, + "grant_type": "authorization_code", + }, + ) + if response.status_code != 200: + logging.error(f"Error getting Dropbox access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data.get("refresh_token", "") + return DropboxSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/facebook.py b/agixt/sso/facebook.py new file mode 100644 index 000000000000..e30f0bd3bfb9 --- /dev/null +++ b/agixt/sso/facebook.py @@ -0,0 +1,130 @@ +import json +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- FACEBOOK_CLIENT_ID: Facebook OAuth client ID +- FACEBOOK_CLIENT_SECRET: Facebook OAuth client secret + +Required APIs + +Follow the links to confirm that you have the APIs enabled, +then add the `FACEBOOK_CLIENT_ID` and `FACEBOOK_CLIENT_SECRET` environment variables to your `.env` file. + +Required scopes for Facebook OAuth + +- public_profile +- email +- pages_messaging (for sending messages, if applicable) +""" + + +class FacebookSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("FACEBOOK_CLIENT_ID") + self.client_secret = getenv("FACEBOOK_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.get( + "https://graph.facebook.com/v10.0/oauth/access_token", + params={ + "grant_type": "fb_exchange_token", + "client_id": self.client_id, + "client_secret": self.client_secret, + "fb_exchange_token": self.refresh_token, + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://graph.facebook.com/v10.0/me?fields=id,first_name,last_name,email" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + first_name = data["first_name"] + last_name = data["last_name"] + email = data["email"] + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from Facebook", + ) + + def send_message(self, to, message_text): + """ + Sending messages through Facebook may require the pages_messaging permission + and the user to be an admin of the page through which the message is sent. + This example assumes those permissions and settings are in place. + """ + uri = f"https://graph.facebook.com/v10.0/me/messages" + message_data = { + "recipient": {"id": to}, + "message": {"text": message_text}, + } + response = requests.post( + uri, + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(message_data), + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.post( + uri, + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(message_data), + ) + return response.json() + + +def facebook_sso(code, redirect_uri=None) -> FacebookSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + response = requests.get( + f"https://graph.facebook.com/v10.0/oauth/access_token", + params={ + "client_id": getenv("FACEBOOK_CLIENT_ID"), + "client_secret": getenv("FACEBOOK_CLIENT_SECRET"), + "redirect_uri": redirect_uri, + "code": code, + }, + ) + if response.status_code != 200: + logging.error(f"Error getting Facebook access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = ( + access_token # For simplicity, assigning access_token to refresh_token + ) + return FacebookSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/fatsecret.py b/agixt/sso/fatsecret.py new file mode 100644 index 000000000000..8b256728c0c6 --- /dev/null +++ b/agixt/sso/fatsecret.py @@ -0,0 +1,100 @@ +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- FATSECRET_CLIENT_ID: FatSecret OAuth client ID +- FATSECRET_CLIENT_SECRET: FatSecret OAuth client secret + +Required APIs + +Follow the API documentation: https://platform.fatsecret.com/api/ +Add the `FATSECRET_CLIENT_ID` and `FATSECRET_CLIENT_SECRET` environment variables to your `.env` file. +""" + + +class FatSecretSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("FATSECRET_CLIENT_ID") + self.client_secret = getenv("FATSECRET_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://oauth.fatsecret.com/connect/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://platform.fatsecret.com/rest/server.api?method=profile.get" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + first_name = data.get("profile", {}).get("firstName", "") + last_name = data.get("profile", {}).get("lastName", "") + email = data.get("profile", {}).get( + "email", "" + ) # Assuming FatSecret provides email + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from FatSecret", + ) + + +def fatsecret_sso(code, redirect_uri=None) -> FatSecretSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + f"https://oauth.fatsecret.com/connect/token", + data={ + "client_id": getenv("FATSECRET_CLIENT_ID"), + "client_secret": getenv("FATSECRET_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + }, + ) + if response.status_code != 200: + logging.error(f"Error getting FatSecret access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] if "refresh_token" in data else "Not provided" + return FatSecretSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/fitbit.py b/agixt/sso/fitbit.py new file mode 100644 index 000000000000..4c535e7f4e08 --- /dev/null +++ b/agixt/sso/fitbit.py @@ -0,0 +1,147 @@ +import base64 +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- FITBIT_CLIENT_ID: Fitbit OAuth client ID +- FITBIT_CLIENT_SECRET: Fitbit OAuth client secret + +Required APIs + +Follow the links to confirm that you have the APIs enabled, +then add the `FITBIT_CLIENT_ID` and `FITBIT_CLIENT_SECRET` environment variables to your `.env` file. + +Required scopes for Fitbit OAuth + +- activity +- heartrate +- location +- nutrition +- profile +- settings +- sleep +- social +- weight +""" + + +class FitbitSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("FITBIT_CLIENT_ID") + self.client_secret = getenv("FITBIT_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + decoded_token = base64.b64encode( + f"{self.client_id}:{self.client_secret}".encode() + ).decode() + response = requests.post( + "https://api.fitbit.com/oauth2/token", + headers={ + "Authorization": f"Basic {decoded_token}", + "Content-Type": "application/x-www-form-urlencoded", + }, + data={ + "grant_type": "refresh_token", + "refresh_token": self.refresh_token, + }, + ) + if response.status_code != 200: + logging.error(f"Error refreshing Fitbit access token: {response.text}") + raise HTTPException( + status_code=403, + detail="Error refreshing Fitbit access token", + ) + tokens = response.json() + self.access_token = tokens["access_token"] + self.refresh_token = tokens["refresh_token"] + return self.access_token + + def get_user_info(self): + uri = "https://api.fitbit.com/1/user/-/profile.json" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + first_name = data["user"]["firstName"] + last_name = data["user"]["lastName"] + email = data["user"][ + "fullName" + ] # Note: Fitbit may not provide email directly + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except Exception as e: + logging.error(f"Error fetching user info: {str(e)}") + raise HTTPException( + status_code=400, + detail="Error getting user info from Fitbit", + ) + + def get_activities(self): + uri = "https://api.fitbit.com/1/user/-/activities/date/today.json" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code != 200: + logging.error(f"Error fetching user's activities: {response.text}") + raise HTTPException( + status_code=response.status_code, + detail="Error fetching user's activities from Fitbit", + ) + return response.json() + + +def fitbit_sso(code, redirect_uri=None) -> FitbitSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + token = base64.b64encode( + f"{getenv('FITBIT_CLIENT_ID')}:{getenv('FITBIT_CLIENT_SECRET')}".encode() + ).decode() + response = requests.post( + "https://api.fitbit.com/oauth2/token", + headers={ + "Authorization": f"Basic {token}", + "Content-Type": "application/x-www-form-urlencoded", + }, + data={ + "client_id": getenv("FITBIT_CLIENT_ID"), + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + "code": code, + }, + ) + if response.status_code != 200: + logging.error(f"Error getting Fitbit access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] + return FitbitSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/formstack.py b/agixt/sso/formstack.py new file mode 100644 index 000000000000..4268de536a47 --- /dev/null +++ b/agixt/sso/formstack.py @@ -0,0 +1,126 @@ +import json +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- FORMSTACK_CLIENT_ID: Formstack OAuth client ID +- FORMSTACK_CLIENT_SECRET: Formstack OAuth client secret + +Required APIs + +Ensure that you have the necessary APIs enabled on your Formstack account, +then add the `FORMSTACK_CLIENT_ID` and `FORMSTACK_CLIENT_SECRET` environment variables to your `.env` file. + +Required scopes for Formstack OAuth + +- formstack:read +- formstack:write +""" + + +class FormstackSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("FORMSTACK_CLIENT_ID") + self.client_secret = getenv("FORMSTACK_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://www.formstack.com/api/v2/oauth2/token", + data={ + "grant_type": "refresh_token", + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://www.formstack.com/api/v2/user.json" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + first_name = data["first_name"] + last_name = data["last_name"] + email = data["email"] + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from Formstack", + ) + + def send_form_submission(self, form_id, submission_data): + form_submission_url = ( + f"https://www.formstack.com/api/v2/form/{form_id}/submission.json" + ) + headers = { + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + } + response = requests.post( + form_submission_url, + headers=headers, + data=json.dumps(submission_data), + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.post( + form_submission_url, + headers=headers, + data=json.dumps(submission_data), + ) + return response.json() + + +def formstack_sso(code, redirect_uri=None) -> FormstackSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + f"https://www.formstack.com/api/v2/oauth2/token", + data={ + "grant_type": "authorization_code", + "client_id": getenv("FORMSTACK_CLIENT_ID"), + "client_secret": getenv("FORMSTACK_CLIENT_SECRET"), + "code": code, + "redirect_uri": redirect_uri, + }, + ) + if response.status_code != 200: + logging.error(f"Error getting Formstack access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] + return FormstackSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/foursquare.py b/agixt/sso/foursquare.py new file mode 100644 index 000000000000..c26426f92044 --- /dev/null +++ b/agixt/sso/foursquare.py @@ -0,0 +1,96 @@ +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- FOURSQUARE_CLIENT_ID: Foursquare OAuth client ID +- FOURSQUARE_CLIENT_SECRET: Foursquare OAuth client secret + +Required APIs: + +- Follow the links to confirm that you have the APIs enabled, + then add the `FOURSQUARE_CLIENT_ID` and `FOURSQUARE_CLIENT_SECRET` environment variables to your `.env` file. + +Required scopes for Foursquare OAuth: + +- No specific scope is needed for basic user info, as Foursquare uses a userless access approach for its APIs. +""" + + +class FoursquareSSO: + def __init__( + self, + access_token=None, + ): + self.access_token = access_token + self.client_id = getenv("FOURSQUARE_CLIENT_ID") + self.client_secret = getenv("FOURSQUARE_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self, existing_refresh_token): + # Foursquare does not use a refresh token mechanism, re-authentication might be needed + raise NotImplementedError( + "Foursquare does not implement a refresh token mechanism" + ) + + def get_user_info(self): + uri = "https://api.foursquare.com/v2/users/self" + response = requests.get( + uri, + params={ + "oauth_token": self.access_token, + "v": "20230410", # Versioning date that Foursquare expects, can be current date + }, + ) + if response.status_code == 401: + raise HTTPException( + status_code=401, + detail="Unauthorized, please re-authenticate", + ) + try: + data = response.json() + user = data["response"]["user"] + first_name = user["firstName"] + last_name = user.get("lastName", "") # lastName may be optional + email = user["contact"]["email"] + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from Foursquare", + ) + + +def foursquare_sso(code, redirect_uri=None) -> FoursquareSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + f"https://foursquare.com/oauth2/access_token", + params={ + "client_id": getenv("FOURSQUARE_CLIENT_ID"), + "client_secret": getenv("FOURSQUARE_CLIENT_SECRET"), + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + "code": code, + }, + ) + if response.status_code != 200: + logging.error(f"Error getting Foursquare access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + return FoursquareSSO(access_token=access_token) diff --git a/agixt/sso/github.py b/agixt/sso/github.py new file mode 100644 index 000000000000..8b21711ae299 --- /dev/null +++ b/agixt/sso/github.py @@ -0,0 +1,110 @@ +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- GITHUB_CLIENT_ID: GitHub OAuth client ID +- GITHUB_CLIENT_SECRET: GitHub OAuth client secret + +Required scopes for GitHub OAuth + +- user:email +- read:user +""" + + +class GitHubSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("GITHUB_CLIENT_ID") + self.client_secret = getenv("GITHUB_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + # GitHub tokens do not support refresh tokens directly, we need to re-authorize. + response = requests.post( + "https://github.com/login/oauth/access_token", + headers={"Accept": "application/json"}, + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://api.github.com/user" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + email_response = requests.get( + "https://api.github.com/user/emails", + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + email_data = email_response.json() + primary_email = next( + email["email"] for email in email_data if email["primary"] + ) + return { + "email": primary_email, + "first_name": ( + data.get("name", "").split()[0] if data.get("name") else "" + ), + "last_name": ( + data.get("name", "").split()[-1] if data.get("name") else "" + ), + } + except Exception as e: + raise HTTPException( + status_code=400, + detail="Error getting user info from GitHub", + ) + + +def github_sso(code, redirect_uri=None) -> GitHubSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + f"https://github.com/login/oauth/access_token", + headers={"Accept": "application/json"}, + data={ + "client_id": getenv("GITHUB_CLIENT_ID"), + "client_secret": getenv("GITHUB_CLIENT_SECRET"), + "code": code, + "redirect_uri": redirect_uri, + "grant_type": "authorization_code", + }, + ) + if response.status_code != 200: + logging.error(f"Error getting GitHub access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data.get("refresh_token", "Not provided") + return GitHubSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/gitlab.py b/agixt/sso/gitlab.py new file mode 100644 index 000000000000..57fdc121aa35 --- /dev/null +++ b/agixt/sso/gitlab.py @@ -0,0 +1,111 @@ +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- GITLAB_CLIENT_ID: GitLab OAuth client ID +- GITLAB_CLIENT_SECRET: GitLab OAuth client secret + +Required APIs + +Follow the links to confirm that you have the APIs enabled, +then add the `GITLAB_CLIENT_ID` and `GITLAB_CLIENT_SECRET` environment variables to your `.env` file. + +Required scopes for GitLab SSO + +- read_user +- api +- email +""" + + +class GitLabSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("GITLAB_CLIENT_ID") + self.client_secret = getenv("GITLAB_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://gitlab.com/oauth/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://gitlab.com/api/v4/user" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + first_name = data["name"].split()[0] + last_name = data["name"].split()[-1] + email = data["email"] + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from GitLab", + ) + + def send_email(self, to, subject, message_text): + # Assuming that GitLab does not provide an email send capability directly. + # One could use another email service here if required. + raise NotImplementedError( + "GitLab SSO does not support sending emails directly." + ) + + +def gitlab_sso(code, redirect_uri=None) -> GitLabSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + f"https://gitlab.com/oauth/token", + data={ + "client_id": getenv("GITLAB_CLIENT_ID"), + "client_secret": getenv("GITLAB_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + }, + ) + if response.status_code != 200: + logging.error(f"Error getting GitLab access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] if "refresh_token" in data else "Not provided" + return GitLabSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/google.py b/agixt/sso/google.py new file mode 100644 index 000000000000..94dae7f4c9f5 --- /dev/null +++ b/agixt/sso/google.py @@ -0,0 +1,141 @@ +import base64 +import json +import requests +import logging +from email.mime.text import MIMEText +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- GOOGLE_CLIENT_ID: Google OAuth client ID +- GOOGLE_CLIENT_SECRET: Google OAuth client secret + +Required APIs + +Follow the links to confirm that you have the APIs enabled, +then add the `GOOGLE_CLIENT_ID` and `GOOGLE_CLIENT_SECRET` environment variables to your `.env` file. + +- People API https://console.cloud.google.com/marketplace/product/google/people.googleapis.com +- Gmail API https://console.cloud.google.com/marketplace/product/google/gmail.googleapis.com + +Required scopes for Google SSO + +- https://www.googleapis.com/auth/userinfo.profile +- https://www.googleapis.com/auth/userinfo.email +- https://www.googleapis.com/auth/gmail.send +""" + + +class GoogleSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("GOOGLE_CLIENT_ID") + self.client_secret = getenv("GOOGLE_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://oauth2.googleapis.com/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://people.googleapis.com/v1/people/me?personFields=names,emailAddresses" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + first_name = data["names"][0]["givenName"] + last_name = data["names"][0]["familyName"] + email = data["emailAddresses"][0]["value"] + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from Google", + ) + + def send_email(self, to, subject, message_text): + if not self.email_address: + user_info = self.get_user_info() + self.email_address = user_info["email"] + message = MIMEText(message_text) + message["to"] = to + message["from"] = self.email_address + message["subject"] = subject + raw = base64.urlsafe_b64encode(message.as_bytes()) + raw = raw.decode() + message = {"raw": raw} + response = requests.post( + "https://gmail.googleapis.com/gmail/v1/users/me/messages/send", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(message), + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.post( + "https://gmail.googleapis.com/gmail/v1/users/me/messages/send", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(message), + ) + return response.json() + + +def google_sso(code, redirect_uri=None) -> GoogleSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + f"https://accounts.google.com/o/oauth2/token", + params={ + "code": code, + "client_id": getenv("GOOGLE_CLIENT_ID"), + "client_secret": getenv("GOOGLE_CLIENT_SECRET"), + "redirect_uri": redirect_uri, + "grant_type": "authorization_code", + }, + ) + if response.status_code != 200: + logging.error(f"Error getting Google access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] if "refresh_token" in data else "Not provided" + return GoogleSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/huddle.py b/agixt/sso/huddle.py new file mode 100644 index 000000000000..8d9c16833820 --- /dev/null +++ b/agixt/sso/huddle.py @@ -0,0 +1,155 @@ +import base64 +import json +import requests +import logging +from email.mime.text import MIMEText +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- HUDDLE_CLIENT_ID: Huddle OAuth client ID +- HUDDLE_CLIENT_SECRET: Huddle OAuth client secret + +Required APIs + +Ensure the necessary Huddle APIs are enabled, +then add the `HUDDLE_CLIENT_ID` and `HUDDLE_CLIENT_SECRET` environment variables to your `.env` file. + +Required scopes for Huddle OAuth + +- user_info +- send_email +""" + + +class HuddleSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("HUDDLE_CLIENT_ID") + self.client_secret = getenv("HUDDLE_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://login.huddle.com/oauth2/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + "scope": "user_info send_email", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://api.huddle.com/1.0/user_info" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + first_name = data["firstName"] + last_name = data["lastName"] + email = data["email"] + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from Huddle", + ) + + def send_email(self, to, subject, message_text): + if not self.user_info.get("email"): + user_info = self.get_user_info() + self.email_address = user_info["email"] + message = MIMEText(message_text) + message["to"] = to + message["from"] = self.email_address + message["subject"] = subject + raw = base64.urlsafe_b64encode(message.as_bytes()) + raw = raw.decode() + email_data = { + "message": { + "subject": subject, + "body": { + "contentType": "Text", + "content": message_text, + }, + "toRecipients": [ + { + "emailAddress": { + "address": to, + } + } + ], + }, + "saveToSentItems": "true", + } + response = requests.post( + "https://api.huddle.com/1.0/send_email", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(email_data), + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.post( + "https://api.huddle.com/1.0/send_email", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(email_data), + ) + return response.json() + + +def huddle_sso(code, redirect_uri=None) -> HuddleSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + f"https://login.huddle.com/oauth2/token", + data={ + "client_id": getenv("HUDDLE_CLIENT_ID"), + "client_secret": getenv("HUDDLE_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + "scope": "user_info send_email", + }, + ) + if response.status_code != 200: + logging.error(f"Error getting Huddle access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] if "refresh_token" in data else "Not provided" + return HuddleSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/imgur.py b/agixt/sso/imgur.py new file mode 100644 index 000000000000..dcd24a92caff --- /dev/null +++ b/agixt/sso/imgur.py @@ -0,0 +1,140 @@ +import base64 +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- IMGUR_CLIENT_ID: Imgur OAuth client ID +- IMGUR_CLIENT_SECRET: Imgur OAuth client secret + +Required APIs + +Follow the links to confirm that you have the APIs enabled, +then add the `IMGUR_CLIENT_ID` and `IMGUR_CLIENT_SECRET` environment variables to your `.env` file. + +Required scopes for Imgur SSO + +- read +- write +""" + + +class ImgurSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("IMGUR_CLIENT_ID") + self.client_secret = getenv("IMGUR_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://api.imgur.com/oauth2/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + ) + data = response.json() + if response.status_code != 200: + logging.error(f"Error refreshing Imgur access token: {response.text}") + raise HTTPException( + status_code=response.status_code, + detail="Error refreshing Imgur access token", + ) + self.access_token = data["access_token"] + self.refresh_token = data["refresh_token"] + return self.access_token + + def get_user_info(self): + uri = "https://api.imgur.com/3/account/me" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json()["data"] + username = data["url"] + email = data["email"] if "email" in data else None + return { + "username": username, + "email": email, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from Imgur", + ) + + def upload_image(self, image_path, title=None, description=None): + with open(image_path, "rb") as image_file: + image_data = base64.b64encode(image_file.read()).decode() + payload = { + "image": image_data, + "type": "base64", + } + if title: + payload["title"] = title + if description: + payload["description"] = description + response = requests.post( + "https://api.imgur.com/3/image", + headers={ + "Authorization": f"Bearer {self.access_token}", + }, + data=payload, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.post( + "https://api.imgur.com/3/image", + headers={ + "Authorization": f"Bearer {self.access_token}", + }, + data=payload, + ) + return response.json() + + +def imgur_sso(code, redirect_uri=None) -> ImgurSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + "https://api.imgur.com/oauth2/token", + data={ + "client_id": getenv("IMGUR_CLIENT_ID"), + "client_secret": getenv("IMGUR_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + }, + ) + if response.status_code != 200: + logging.error(f"Error getting Imgur access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] + return ImgurSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/instagram.py b/agixt/sso/instagram.py new file mode 100644 index 000000000000..0a4df00db77d --- /dev/null +++ b/agixt/sso/instagram.py @@ -0,0 +1,112 @@ +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- INSTAGRAM_CLIENT_ID: Instagram OAuth client ID +- INSTAGRAM_CLIENT_SECRET: Instagram OAuth client secret + +Required APIs + +Make sure you have the Instagram Basic Display API enabled. +Add the `INSTAGRAM_CLIENT_ID` and `INSTAGRAM_CLIENT_SECRET` environment variables to your `.env` file. + +Required scopes for Instagram OAuth + +- user_profile +- user_media +""" + + +class InstagramSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("INSTAGRAM_CLIENT_ID") + self.client_secret = getenv("INSTAGRAM_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://graph.instagram.com/refresh_access_token", + params={ + "grant_type": "ig_refresh_token", + "access_token": self.access_token, + }, + ) + if response.status_code != 200: + logging.error(f"Error refreshing Instagram access token: {response.text}") + raise HTTPException( + status_code=response.status_code, + detail="Error refreshing Instagram access token", + ) + + return response.json()["access_token"] + + def get_user_info(self): + uri = f"https://graph.instagram.com/me?fields=id,username,media_count&access_token={self.access_token}" + response = requests.get(uri) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get(uri) + try: + data = response.json() + username = data["username"] + return { + "username": username, + "media_count": data.get("media_count", 0), + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from Instagram", + ) + + def send_media_post(self, image_url, caption): + uri = f"https://graph.instagram.com/me/media?image_url={image_url}&caption={caption}&access_token={self.access_token}" + response = requests.post(uri) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.post(uri) + return response.json() + + +def instagram_sso(code, redirect_uri=None) -> InstagramSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + f"https://api.instagram.com/oauth/access_token", + data={ + "client_id": getenv("INSTAGRAM_CLIENT_ID"), + "client_secret": getenv("INSTAGRAM_CLIENT_SECRET"), + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + "code": code, + }, + ) + if response.status_code != 200: + logging.error(f"Error getting Instagram access token: {response.text}") + raise HTTPException( + status_code=response.status_code, + detail="Error getting Instagram access token", + ) + + data = response.json() + access_token = data["access_token"] + refresh_token = "Not applicable for Instagram" # Instagram tokens last 60 days and refresh automatically every time a user interacts with the app + + return InstagramSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/intel_cloud_services.py b/agixt/sso/intel_cloud_services.py new file mode 100644 index 000000000000..8165340a279a --- /dev/null +++ b/agixt/sso/intel_cloud_services.py @@ -0,0 +1,155 @@ +import base64 +import json +import requests +import logging +from email.mime.text import MIMEText +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- INTEL_CLIENT_ID: Intel OAuth client ID +- INTEL_CLIENT_SECRET: Intel OAuth client secret + +Required APIs + +Follow the links to confirm that you have the APIs enabled, +then add the `INTEL_CLIENT_ID` and `INTEL_CLIENT_SECRET` environment variables to your `.env` file. + +Required scopes for Intel SSO + +- https://api.intel.com/userinfo.read +- https://api.intel.com/mail.send +""" + + +class IntelSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("INTEL_CLIENT_ID") + self.client_secret = getenv("INTEL_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://auth.intel.com/oauth2/v2.0/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + "scope": "https://api.intel.com/userinfo.read https://api.intel.com/mail.send", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://api.intel.com/v1.0/me" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + first_name = data["givenName"] + last_name = data["surname"] + email = data["mail"] + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from Intel", + ) + + def send_email(self, to, subject, message_text): + if not self.user_info.get("email"): + user_info = self.get_user_info() + self.email_address = user_info["email"] + message = MIMEText(message_text) + message["to"] = to + message["from"] = self.email_address + message["subject"] = subject + raw = base64.urlsafe_b64encode(message.as_bytes()) + raw = raw.decode() + email_data = { + "message": { + "subject": subject, + "body": { + "contentType": "Text", + "content": message_text, + }, + "toRecipients": [ + { + "emailAddress": { + "address": to, + } + } + ], + }, + "saveToSentItems": "true", + } + response = requests.post( + "https://api.intel.com/v1.0/me/sendMail", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(email_data), + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.post( + "https://api.intel.com/v1.0/me/sendMail", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(email_data), + ) + return response.json() + + +def intel_cloud_services_sso(code, redirect_uri=None) -> IntelSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + f"https://auth.intel.com/oauth2/v2.0/token", + data={ + "client_id": getenv("INTEL_CLIENT_ID"), + "client_secret": getenv("INTEL_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + "scope": "https://api.intel.com/userinfo.read https://api.intel.com/mail.send", + }, + ) + if response.status_code != 200: + logging.error(f"Error getting Intel access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] if "refresh_token" in data else "Not provided" + return IntelSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/jive.py b/agixt/sso/jive.py new file mode 100644 index 000000000000..235ce20acc13 --- /dev/null +++ b/agixt/sso/jive.py @@ -0,0 +1,147 @@ +import base64 +import json +import requests +import logging +from email.mime.text import MIMEText +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- JIVE_CLIENT_ID: Jive OAuth client ID +- JIVE_CLIENT_SECRET: Jive OAuth client secret + +Required APIs: + +Ensure you have the necessary Jive API enabled. + +Required scopes for Jive OAuth: +These scopes will need to be accurate according to Jive�s API documentation. + +""" + + +class JiveSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("JIVE_CLIENT_ID") + self.client_secret = getenv("JIVE_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://example.jive.com/oauth2/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + "scope": "your_required_scopes_here", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://example.jive.com/api/core/v3/people/@me" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + first_name = data["name"]["givenName"] + last_name = data["name"]["familyName"] + email = data["emails"][0]["value"] + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from Jive", + ) + + def send_email(self, to, subject, message_text): + if not self.user_info.get("email"): + user_info = self.get_user_info() + self.email_address = user_info["email"] + message = MIMEText(message_text) + message["to"] = to + message["from"] = self.email_address + message["subject"] = subject + raw = base64.urlsafe_b64encode(message.as_bytes()) + raw = raw.decode() + email_data = { + "message": { + "subject": subject, + "body": { + "contentType": "Text", + "content": message_text, + }, + "toRecipients": [ + { + "emailAddress": { + "address": to, + } + } + ], + }, + "saveToSentItems": "true", + } + response = requests.post( + "https://example.jive.com/api/core/v3/messages", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(email_data), + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.post( + "https://example.jive.com/api/core/v3/messages", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(email_data), + ) + return response.json() + + +def jive_sso(code, redirect_uri=None) -> JiveSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = str(code).replace("%2F", "/").replace("%3D", "=").replace("%3F", "?") + response = requests.post( + "https://example.jive.com/oauth2/token", + data={ + "client_id": getenv("JIVE_CLIENT_ID"), + "client_secret": getenv("JIVE_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + "scope": "your_required_scopes_here", + }, + ) + if response.status_code != 200: + logging.error(f"Error getting Jive access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] + return JiveSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/keycloak.py b/agixt/sso/keycloak.py new file mode 100644 index 000000000000..136ee353aa44 --- /dev/null +++ b/agixt/sso/keycloak.py @@ -0,0 +1,102 @@ +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- KEYCLOAK_CLIENT_ID: Keycloak OAuth client ID +- KEYCLOAK_CLIENT_SECRET: Keycloak OAuth client secret +- KEYCLOAK_REALM: Keycloak realm name +- KEYCLOAK_SERVER_URL: Keycloak server URL + +Required APIs + +Follow the links to confirm that you have the APIs enabled, +then add the KEYCLOAK_CLIENT_ID, KEYCLOAK_CLIENT_SECRET, KEYCLOAK_REALM, and KEYCLOAK_SERVER_URL environment variables to your `.env` file. + +Required scopes for Keycloak SSO: + +- openid +- email +- profile +""" + + +class KeycloakSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("KEYCLOAK_CLIENT_ID") + self.client_secret = getenv("KEYCLOAK_CLIENT_SECRET") + self.realm = getenv("KEYCLOAK_REALM") + self.server_url = getenv("KEYCLOAK_SERVER_URL") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + f"{self.server_url}/realms/{self.realm}/protocol/openid-connect/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = f"{self.server_url}/realms/{self.realm}/protocol/openid-connect/userinfo" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + first_name = data.get("given_name") + last_name = data.get("family_name") + email = data.get("email") + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from Keycloak", + ) + + +def keycloak_sso(code, redirect_uri=None) -> KeycloakSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + response = requests.post( + f"{getenv('KEYCLOAK_SERVER_URL')}/realms/{getenv('KEYCLOAK_REALM')}/protocol/openid-connect/token", + data={ + "client_id": getenv("KEYCLOAK_CLIENT_ID"), + "client_secret": getenv("KEYCLOAK_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + "scope": "openid email profile", + }, + ) + if response.status_code != 200: + logging.error(f"Error getting Keycloak access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] + return KeycloakSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/linkedin.py b/agixt/sso/linkedin.py new file mode 100644 index 000000000000..88e01f94aeed --- /dev/null +++ b/agixt/sso/linkedin.py @@ -0,0 +1,124 @@ +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- LINKEDIN_CLIENT_ID: LinkedIn OAuth client ID +- LINKEDIN_CLIENT_SECRET: LinkedIn OAuth client secret + +Required APIs + +Follow the links to confirm that you have the APIs enabled, +then add the `LINKEDIN_CLIENT_ID` and `LINKEDIN_CLIENT_SECRET` environment variables to your `.env` file. + +Required scopes for LinkedIn OAuth + +- r_liteprofile +- r_emailaddress +- w_member_social +""" + + +class LinkedInSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("LINKEDIN_CLIENT_ID") + self.client_secret = getenv("LINKEDIN_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://www.linkedin.com/oauth/v2/accessToken", + data={ + "grant_type": "refresh_token", + "refresh_token": self.refresh_token, + "client_id": self.client_id, + "client_secret": self.client_secret, + }, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + return response.json()["access_token"] + + def get_user_info(self): + profile_url = "https://api.linkedin.com/v2/me" + email_url = "https://api.linkedin.com/v2/emailAddress?q=members&projection=(elements*(handle~))" + + profile_response = requests.get( + profile_url, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + + email_response = requests.get( + email_url, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + + if profile_response.status_code == 401 or email_response.status_code == 401: + self.access_token = self.get_new_token() + profile_response = requests.get( + profile_url, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + email_response = requests.get( + email_url, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + + try: + profile_data = profile_response.json() + email_data = email_response.json() + first_name = profile_data["localizedFirstName"] + last_name = profile_data["localizedLastName"] + email = email_data["elements"][0]["handle~"]["emailAddress"] + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from LinkedIn", + ) + + def send_email(self, to, subject, message_text): + # LinkedIn API does not support sending emails directly + raise NotImplementedError("LinkedIn API does not support sending emails") + + +def linkedin_sso(code, redirect_uri=None) -> LinkedInSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + "https://www.linkedin.com/oauth/v2/accessToken", + data={ + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirect_uri, + "client_id": getenv("LINKEDIN_CLIENT_ID"), + "client_secret": getenv("LINKEDIN_CLIENT_SECRET"), + }, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + if response.status_code != 200: + logging.error(f"Error getting LinkedIn access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] if "refresh_token" in data else "Not provided" + return LinkedInSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/microsoft.py b/agixt/sso/microsoft.py new file mode 100644 index 000000000000..bb1639f544bc --- /dev/null +++ b/agixt/sso/microsoft.py @@ -0,0 +1,155 @@ +import base64 +import json +import requests +import logging +from email.mime.text import MIMEText +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- MICROSOFT_CLIENT_ID: Microsoft OAuth client ID +- MICROSOFT_CLIENT_SECRET: Microsoft OAuth client secret + +Required APIs + +Follow the links to confirm that you have the APIs enabled, +then add the `MICROSOFT_CLIENT_ID` and `MICROSOFT_CLIENT_SECRET` environment variables to your `.env` file. + +Required scopes for Microsoft OAuth + +- https://graph.microsoft.com/User.Read +- https://graph.microsoft.com/Mail.Send +""" + + +class MicrosoftSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("MICROSOFT_CLIENT_ID") + self.client_secret = getenv("MICROSOFT_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://login.microsoftonline.com/common/oauth2/v2.0/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + "scope": "https://graph.microsoft.com/User.Read https://graph.microsoft.com/Mail.Send", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://graph.microsoft.com/v1.0/me" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + first_name = data["givenName"] + last_name = data["surname"] + email = data["mail"] + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from Microsoft", + ) + + def send_email(self, to, subject, message_text): + if not self.user_info.get("email"): + user_info = self.get_user_info() + self.email_address = user_info["email"] + message = MIMEText(message_text) + message["to"] = to + message["from"] = self.email_address + message["subject"] = subject + raw = base64.urlsafe_b64encode(message.as_bytes()) + raw = raw.decode() + email_data = { + "message": { + "subject": subject, + "body": { + "contentType": "Text", + "content": message_text, + }, + "toRecipients": [ + { + "emailAddress": { + "address": to, + } + } + ], + }, + "saveToSentItems": "true", + } + response = requests.post( + "https://graph.microsoft.com/v1.0/me/sendMail", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(email_data), + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.post( + "https://graph.microsoft.com/v1.0/me/sendMail", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(email_data), + ) + return response.json() + + +def microsoft_sso(code, redirect_uri=None) -> MicrosoftSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + f"https://login.microsoftonline.com/common/oauth2/v2.0/token", + data={ + "client_id": getenv("MICROSOFT_CLIENT_ID"), + "client_secret": getenv("MICROSOFT_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + "scope": "https://graph.microsoft.com/User.Read https://graph.microsoft.com/Mail.Send https://graph.microsoft.com/Calendars.ReadWrite.Shared", + }, + ) + if response.status_code != 200: + logging.error(f"Error getting Microsoft access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] if "refresh_token" in data else "Not provided" + return MicrosoftSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/netiq.py b/agixt/sso/netiq.py new file mode 100644 index 000000000000..306c65c95a6f --- /dev/null +++ b/agixt/sso/netiq.py @@ -0,0 +1,156 @@ +import base64 +import json +import requests +import logging +from email.mime.text import MIMEText +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- NETIQ_CLIENT_ID: NetIQ OAuth client ID +- NETIQ_CLIENT_SECRET: NetIQ OAuth client secret + +Required APIs + +Ensure that the required APIs are enabled in your NetIQ settings, +then add the `NETIQ_CLIENT_ID` and `NETIQ_CLIENT_SECRET` environment variables to your `.env` file. + +Required scopes for NetIQ OAuth + +- profile +- email +- openid +- user.info +""" + + +class NetIQSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("NETIQ_CLIENT_ID") + self.client_secret = getenv("NETIQ_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://your-netiq-domain.com/oauth2/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://your-netiq-domain.com/oauth2/userinfo" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + first_name = data["given_name"] + last_name = data["family_name"] + email = data["email"] + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from NetIQ", + ) + + def send_email(self, to, subject, message_text): + if not self.user_info.get("email"): + user_info = self.get_user_info() + self.email_address = user_info["email"] + message = MIMEText(message_text) + message["to"] = to + message["from"] = self.email_address + message["subject"] = subject + raw = base64.urlsafe_b64encode(message.as_bytes()) + raw = raw.decode() + email_data = { + "message": { + "subject": subject, + "body": { + "contentType": "Text", + "content": message_text, + }, + "toRecipients": [ + { + "emailAddress": { + "address": to, + } + } + ], + }, + "saveToSentItems": "true", + } + response = requests.post( + "https://your-netiq-domain.com/api/sendMail", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(email_data), + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.post( + "https://your-netiq-domain.com/api/sendMail", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(email_data), + ) + return response.json() + + +def netiq_sso(code, redirect_uri=None) -> NetIQSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + "https://your-netiq-domain.com/oauth2/token", + data={ + "client_id": getenv("NETIQ_CLIENT_ID"), + "client_secret": getenv("NETIQ_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + "scope": "profile email openid user.info", + }, + ) + if response.status_code != 200: + logging.error(f"Error getting NetIQ access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data.get("refresh_token", "Not provided") + return NetIQSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/okta.py b/agixt/sso/okta.py new file mode 100644 index 000000000000..e24f8dcedd5c --- /dev/null +++ b/agixt/sso/okta.py @@ -0,0 +1,105 @@ +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- OKTA_CLIENT_ID: Okta OAuth client ID +- OKTA_CLIENT_SECRET: Okta OAuth client secret +- OKTA_DOMAIN: Okta domain (e.g., dev-123456.okta.com) + +Required scopes for Okta OAuth + +- openid +- profile +- email +""" + + +class OktaSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("OKTA_CLIENT_ID") + self.client_secret = getenv("OKTA_CLIENT_SECRET") + self.domain = getenv("OKTA_DOMAIN") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + f"https://{self.domain}/oauth2/v1/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = f"https://{self.domain}/oauth2/v1/userinfo" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + first_name = data["given_name"] + last_name = data["family_name"] + email = data["email"] + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from Okta", + ) + + def send_email(self, to, subject, message_text): + # Placeholder: Replace this with any specific email sending logic for Okta if available + raise NotImplementedError("send_email is not supported for Okta") + + +def okta_sso(code, redirect_uri=None) -> OktaSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + f"https://{getenv('OKTA_DOMAIN')}/oauth2/v1/token", + data={ + "client_id": getenv("OKTA_CLIENT_ID"), + "client_secret": getenv("OKTA_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + }, + ) + if response.status_code != 200: + logging.error(f"Error getting Okta access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] if "refresh_token" in data else "Not provided" + return OktaSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/openam.py b/agixt/sso/openam.py new file mode 100644 index 000000000000..8881595f11ee --- /dev/null +++ b/agixt/sso/openam.py @@ -0,0 +1,105 @@ +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- OPENAM_CLIENT_ID: OpenAM OAuth client ID +- OPENAM_CLIENT_SECRET: OpenAM OAuth client secret +- OPENAM_BASE_URL: Base URL for OpenAM server + +Required scopes for OpenAM OAuth + +- profile +- email +""" + + +class OpenAMSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("OPENAM_CLIENT_ID") + self.client_secret = getenv("OPENAM_CLIENT_SECRET") + self.base_url = getenv("OPENAM_BASE_URL") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + f"{self.base_url}/oauth2/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = f"{self.base_url}/oauth2/userinfo" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + first_name = data["given_name"] + last_name = data["family_name"] + email = data["email"] + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except Exception: + raise HTTPException( + status_code=400, + detail="Error getting user info from OpenAM", + ) + + def send_email(self, to, subject, message_text): + raise NotImplementedError( + "OpenAM SSO does not support sending emails by default" + ) + + +def openam_sso(code, redirect_uri=None) -> OpenAMSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + f"{getenv('OPENAM_BASE_URL')}/oauth2/token", + data={ + "client_id": getenv("OPENAM_CLIENT_ID"), + "client_secret": getenv("OPENAM_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + }, + ) + if response.status_code != 200: + logging.error(f"Error getting OpenAM access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] + return OpenAMSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/openstreetmap.py b/agixt/sso/openstreetmap.py new file mode 100644 index 000000000000..eb2841ce4657 --- /dev/null +++ b/agixt/sso/openstreetmap.py @@ -0,0 +1,105 @@ +import requests +import logging +from fastapi import HTTPException +from Globals import getenv +import xml.etree.ElementTree as ET + +""" +Required environment variables: + +- OSM_CLIENT_ID: OpenStreetMap OAuth client ID +- OSM_CLIENT_SECRET: OpenStreetMap OAuth client secret + +Required APIs + +Make sure you have appropriate OAuth configuration in OpenStreetMap and add the `OSM_CLIENT_ID` and `OSM_CLIENT_SECRET` environment variables to your `.env` file. + +Required scopes for OpenStreetMap OAuth: + +- read_prefs +""" + + +class OpenStreetMapSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("OSM_CLIENT_ID") + self.client_secret = getenv("OSM_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://www.openstreetmap.org/oauth/access_token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://api.openstreetmap.org/api/0.6/user/details" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = ET.fromstring(response.content) + user_info = data.find("user") + if user_info is None: + raise HTTPException( + status_code=400, + detail="Error getting user info from OpenStreetMap", + ) + user = { + "id": user_info.attrib.get("id"), + "username": user_info.attrib.get("display_name"), + } + return user + except: + raise HTTPException( + status_code=400, + detail="Error parsing user info from OpenStreetMap", + ) + + +def openstreetmap_sso(code, redirect_uri=None) -> OpenStreetMapSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + f"https://www.openstreetmap.org/oauth/access_token", + data={ + "client_id": getenv("OSM_CLIENT_ID"), + "client_secret": getenv("OSM_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + }, + ) + if response.status_code != 200: + logging.error(f"Error getting OpenStreetMap access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] + return OpenStreetMapSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/orcid.py b/agixt/sso/orcid.py new file mode 100644 index 000000000000..1a09788bf3d3 --- /dev/null +++ b/agixt/sso/orcid.py @@ -0,0 +1,108 @@ +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- ORCID_CLIENT_ID: ORCID OAuth client ID +- ORCID_CLIENT_SECRET: ORCID OAuth client secret + +Required APIs + +Follow the links to confirm that you have the APIs enabled, +then add the `ORCID_CLIENT_ID` and `ORCID_CLIENT_SECRET` environment variables to your `.env` file. + +Required scopes for ORCID SSO + +- /authenticate (to read public profile information) +- /activities/update (optional, if you need to update activities) +""" + + +class ORCIDSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("ORCID_CLIENT_ID") + self.client_secret = getenv("ORCID_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://orcid.org/oauth/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://pub.orcid.org/v3.0/0000-0002-1825-0097" # Replace with actual ORCID ID endpoint after fetching authenticated ORCID ID + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + first_name = data["person"]["name"]["given-names"]["value"] + last_name = data["person"]["name"]["family-name"]["value"] + email = ( + data["person"]["emails"]["email"][0]["email"] + if data["person"]["emails"]["email"] + else None + ) + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except Exception as e: + logging.error(f"Error getting user info from ORCID: {str(e)}") + raise HTTPException( + status_code=400, + detail="Error getting user info from ORCID", + ) + + +def orcid_sso(code, redirect_uri=None) -> ORCIDSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + "https://orcid.org/oauth/token", + data={ + "client_id": getenv("ORCID_CLIENT_ID"), + "client_secret": getenv("ORCID_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + }, + ) + if response.status_code != 200: + logging.error(f"Error getting ORCID access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] + return ORCIDSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/paypal.py b/agixt/sso/paypal.py new file mode 100644 index 000000000000..53862477f353 --- /dev/null +++ b/agixt/sso/paypal.py @@ -0,0 +1,154 @@ +import requests +import json +import logging +import base64 +import uuid +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- PAYPAL_CLIENT_ID: PayPal OAuth client ID +- PAYPAL_CLIENT_SECRET: PayPal OAuth client secret + +Required APIs + +Ensure you have PayPal REST API enabled and appropriate client credentials obtained. +Add the `PAYPAL_CLIENT_ID` and `PAYPAL_CLIENT_SECRET` environment variables to your `.env` file. + +Required scopes for PayPal OAuth: + +- email +- openid +""" + + +class PayPalSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("PAYPAL_CLIENT_ID") + self.client_secret = getenv("PAYPAL_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + token = base64.b64encode( + f"{self.client_id}:{self.client_secret}".encode() + ).decode() + response = requests.post( + "https://api.paypal.com/v1/oauth2/token", + headers={ + "Authorization": f"Basic {token}", + "Content-Type": "application/x-www-form-urlencoded", + }, + data={ + "grant_type": "refresh_token", + "refresh_token": self.refresh_token, + }, + ) + if response.status_code != 200: + raise HTTPException( + status_code=response.status_code, + detail="Error refreshing PayPal OAuth token", + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://api.paypal.com/v1/identity/oauth2/userinfo?schema=openid" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + first_name = data["given_name"] + last_name = data["family_name"] + email = data["email"] + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except Exception as e: + raise HTTPException( + status_code=400, + detail=f"Error getting user info from PayPal: {str(e)}", + ) + + def send_payment(self, recipient_email, amount, currency="USD"): + """This is an extra method for sending payment in PayPal.""" + payment_data = { + "sender_batch_header": { + "sender_batch_id": "batch_" + str(uuid.uuid4()), + "email_subject": "You have a payment", + }, + "items": [ + { + "recipient_type": "EMAIL", + "amount": { + "value": amount, + "currency": currency, + }, + "receiver": recipient_email, + "note": "Thank you.", + } + ], + } + + response = requests.post( + "https://api.paypal.com/v1/payments/payouts", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(payment_data), + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.post( + "https://api.paypal.com/v1/payments/payouts", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(payment_data), + ) + return response.json() + + +def paypal_sso(code, redirect_uri=None) -> PayPalSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + token = base64.b64encode( + f"{getenv('PAYPAL_CLIENT_ID')}:{getenv('PAYPAL_CLIENT_SECRET')}".encode() + ).decode() + response = requests.post( + "https://api.paypal.com/v1/oauth2/token", + headers={ + "Authorization": f"Basic {token}", + "Content-Type": "application/x-www-form-urlencoded", + }, + data={ + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirect_uri, + }, + ) + if response.status_code != 200: + logging.error(f"Error getting PayPal access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] if "refresh_token" in data else "Not provided" + return PayPalSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/ping_identity.py b/agixt/sso/ping_identity.py new file mode 100644 index 000000000000..c43de32942a5 --- /dev/null +++ b/agixt/sso/ping_identity.py @@ -0,0 +1,155 @@ +import base64 +import json +import requests +import logging +from email.mime.text import MIMEText +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- PING_IDENTITY_CLIENT_ID: Ping Identity OAuth client ID +- PING_IDENTITY_CLIENT_SECRET: Ping Identity OAuth client secret + +Required APIs + +Follow the links to confirm that you have the APIs enabled, +then add the `PING_IDENTITY_CLIENT_ID` and `PING_IDENTITY_CLIENT_SECRET` environment variables to your `.env` file. + +Required scopes for Ping Identity OAuth + +- profile +- email +- openid +""" + + +class PingIdentitySSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("PING_IDENTITY_CLIENT_ID") + self.client_secret = getenv("PING_IDENTITY_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://auth.pingidentity.com/as/token.oauth2", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://auth.pingidentity.com/idp/userinfo.openid" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + first_name = data["given_name"] + last_name = data["family_name"] + email = data["email"] + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from Ping Identity", + ) + + def send_email(self, to, subject, message_text): + if not self.user_info.get("email"): + user_info = self.get_user_info() + self.email_address = user_info["email"] + message = MIMEText(message_text) + message["to"] = to + message["from"] = self.email_address + message["subject"] = subject + raw = base64.urlsafe_b64encode(message.as_bytes()) + raw = raw.decode() + email_data = { + "message": { + "subject": subject, + "body": { + "contentType": "Text", + "content": message_text, + }, + "toRecipients": [ + { + "emailAddress": { + "address": to, + } + } + ], + }, + "saveToSentItems": "true", + } + response = requests.post( + "https://auth.pingidentity.com/v1/api/send-email", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(email_data), + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.post( + "https://auth.pingidentity.com/v1/api/send-email", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(email_data), + ) + return response.json() + + +def ping_identity_sso(code, redirect_uri=None) -> PingIdentitySSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + f"https://auth.pingidentity.com/as/token.oauth2", + data={ + "client_id": getenv("PING_IDENTITY_CLIENT_ID"), + "client_secret": getenv("PING_IDENTITY_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + "scope": "profile email openid", + }, + ) + if response.status_code != 200: + logging.error(f"Error getting Ping Identity access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] if "refresh_token" in data else "Not provided" + return PingIdentitySSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/pixiv.py b/agixt/sso/pixiv.py new file mode 100644 index 000000000000..942e1bd84681 --- /dev/null +++ b/agixt/sso/pixiv.py @@ -0,0 +1,110 @@ +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- PIXIV_CLIENT_ID: Pixiv OAuth client ID +- PIXIV_CLIENT_SECRET: Pixiv OAuth client secret + +Required APIs + +Follow the links to confirm that you have the APIs enabled, +then add the `PIXIV_CLIENT_ID` and `PIXIV_CLIENT_SECRET` environment variables to your `.env` file. + +Required scopes for Pixiv OAuth + +- pixiv.scope.profile.read +""" + + +class PixivSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("PIXIV_CLIENT_ID") + self.client_secret = getenv("PIXIV_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://oauth.secure.pixiv.net/auth/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + ) + if response.status_code != 200: + raise HTTPException( + status_code=response.status_code, + detail="Error refreshing Pixiv token", + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://app-api.pixiv.net/v1/user/me" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + user = data["user"] + first_name = user["name"] + email = user.get("mail_address", "No_Email_Provided") + return { + "email": email, + "first_name": first_name, + "last_name": "", + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from Pixiv", + ) + + def send_email(self, to, subject, message_text): + raise NotImplementedError("Pixiv does not support sending messages") + + +def pixiv_sso(code, redirect_uri=None) -> PixivSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + f"https://oauth.secure.pixiv.net/auth/token", + data={ + "code": code, + "client_id": getenv("PIXIV_CLIENT_ID"), + "client_secret": getenv("PIXIV_CLIENT_SECRET"), + "redirect_uri": redirect_uri, + "grant_type": "authorization_code", + }, + ) + if response.status_code != 200: + logging.error(f"Error getting Pixiv access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] + return PixivSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/reddit.py b/agixt/sso/reddit.py new file mode 100644 index 000000000000..5381e590b588 --- /dev/null +++ b/agixt/sso/reddit.py @@ -0,0 +1,136 @@ +import json +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- REDDIT_CLIENT_ID: Reddit OAuth client ID +- REDDIT_CLIENT_SECRET: Reddit OAuth client secret + +Required APIs + +Follow the links to confirm that you have the APIs enabled, +then add the `REDDIT_CLIENT_ID` and `REDDIT_CLIENT_SECRET` environment variables to your `.env` file. + +Required scopes for Reddit OAuth + +- identity +- submit +- read +""" + + +class RedditSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("REDDIT_CLIENT_ID") + self.client_secret = getenv("REDDIT_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://www.reddit.com/api/v1/access_token", + auth=(self.client_id, self.client_secret), + data={ + "grant_type": "refresh_token", + "refresh_token": self.refresh_token, + }, + headers={"User-Agent": "MyRedditApp"}, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://oauth.reddit.com/api/v1/me" + response = requests.get( + uri, + headers={ + "Authorization": f"Bearer {self.access_token}", + "User-Agent": "MyRedditApp", + }, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={ + "Authorization": f"Bearer {self.access_token}", + "User-Agent": "MyRedditApp", + }, + ) + try: + data = response.json() + username = data["name"] + email = data.get("email", "") + # Reddit API does not inherently provide first_name and last_name + return { + "email": email, + "username": username, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from Reddit", + ) + + def submit_post(self, subreddit, title, content): + post_data = {"sr": subreddit, "title": title, "text": content, "kind": "self"} + response = requests.post( + "https://oauth.reddit.com/api/submit", + headers={ + "Authorization": f"Bearer {self.access_token}", + "User-Agent": "MyRedditApp", + "Content-Type": "application/json", + }, + data=json.dumps(post_data), + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.post( + "https://oauth.reddit.com/api/submit", + headers={ + "Authorization": f"Bearer {self.access_token}", + "User-Agent": "MyRedditApp", + "Content-Type": "application/json", + }, + data=json.dumps(post_data), + ) + if response.status_code != 200: + logging.error(f"Error submitting post to Reddit: {response.text}") + return response.json() + + +def reddit_sso(code, redirect_uri=None) -> RedditSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%26", "&") + ) + response = requests.post( + "https://www.reddit.com/api/v1/access_token", + auth=(getenv("REDDIT_CLIENT_ID"), getenv("REDDIT_CLIENT_SECRET")), + data={ + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirect_uri, + }, + headers={"User-Agent": "MyRedditApp"}, + ) + if response.status_code != 200: + logging.error(f"Error getting Reddit access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] + return RedditSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/salesforce.py b/agixt/sso/salesforce.py new file mode 100644 index 000000000000..c50979ca8f5d --- /dev/null +++ b/agixt/sso/salesforce.py @@ -0,0 +1,118 @@ +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- SALESFORCE_CLIENT_ID: Salesforce OAuth client ID +- SALESFORCE_CLIENT_SECRET: Salesforce OAuth client secret + +Required APIs + +Follow the links to confirm that you have the APIs enabled, +then add the `SALESFORCE_CLIENT_ID` and `SALESFORCE_CLIENT_SECRET` environment variables to your `.env` file. + +Required scopes for Salesforce OAuth + +- refresh_token +- full +- email +""" + + +class SalesforceSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + instance_url=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.instance_url = instance_url + self.client_id = getenv("SALESFORCE_CLIENT_ID") + self.client_secret = getenv("SALESFORCE_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + f"{self.instance_url}/services/oauth2/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = f"{self.instance_url}/services/oauth2/userinfo" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + first_name = data["given_name"] + last_name = data["family_name"] + email = data["email"] + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from Salesforce", + ) + + def send_email(self, to, subject, message_text): + # Salesforce does not have a direct email sending API in the same way Google and Microsoft do. + # This will need to be implemented according to the specific Salesforce instance and setup. + raise NotImplementedError( + "Send email functionality is dependent on the Salesforce instance configuration." + ) + + +def salesforce_sso(code, redirect_uri=None) -> SalesforceSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + f"https://login.salesforce.com/services/oauth2/token", + data={ + "grant_type": "authorization_code", + "client_id": getenv("SALESFORCE_CLIENT_ID"), + "client_secret": getenv("SALESFORCE_CLIENT_SECRET"), + "code": code, + "redirect_uri": redirect_uri, + }, + ) + if response.status_code != 200: + logging.error(f"Error getting Salesforce access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] + instance_url = data["instance_url"] + return SalesforceSSO( + access_token=access_token, + refresh_token=refresh_token, + instance_url=instance_url, + ) diff --git a/agixt/sso/sina_weibo.py b/agixt/sso/sina_weibo.py new file mode 100644 index 000000000000..ffd789731ad6 --- /dev/null +++ b/agixt/sso/sina_weibo.py @@ -0,0 +1,135 @@ +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- WEIBO_CLIENT_ID: Weibo OAuth client ID +- WEIBO_CLIENT_SECRET: Weibo OAuth client secret + +Required APIs + +Follow the links to confirm that you have the APIs enabled, +then add the `WEIBO_CLIENT_ID` and `WEIBO_CLIENT_SECRET` environment variables to your `.env` file. + +Required scopes for Weibo OAuth + +- email +- statuses_update +""" + + +class WeiboSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("WEIBO_CLIENT_ID") + self.client_secret = getenv("WEIBO_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://api.weibo.com/oauth2/access_token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "grant_type": "refresh_token", + "refresh_token": self.refresh_token, + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://api.weibo.com/2/account/get_uid.json" + response = requests.get( + uri, + params={"access_token": self.access_token}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + params={"access_token": self.access_token}, + ) + try: + uid_response = response.json() + uid = uid_response["uid"] + + uri_info = f"https://api.weibo.com/2/users/show.json?uid={uid}" + response = requests.get( + uri_info, + params={"access_token": self.access_token}, + ) + data = response.json() + + email = data.get( + "email", None + ) # Assuming you have permissions to email scope + first_name = data["name"] + last_name = "" # Weibo does not provide a separate field for last name + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from Weibo", + ) + + def send_message(self, status): + uri = "https://api.weibo.com/2/statuses/update.json" + data = { + "access_token": self.access_token, + "status": status, + } + + response = requests.post( + uri, + data=data, + ) + + if response.status_code == 401: + self.access_token = self.get_new_token() + data["access_token"] = self.access_token + response = requests.post( + uri, + data=data, + ) + return response.json() + + +def sina_weibo_sso(code, redirect_uri=None) -> WeiboSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + f"https://api.weibo.com/oauth2/access_token", + data={ + "client_id": getenv("WEIBO_CLIENT_ID"), + "client_secret": getenv("WEIBO_CLIENT_SECRET"), + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + "code": code, + }, + ) + if response.status_code != 200: + logging.error(f"Error getting Weibo access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] if "refresh_token" in data else "Not provided" + return WeiboSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/spotify.py b/agixt/sso/spotify.py new file mode 100644 index 000000000000..563a04293ff4 --- /dev/null +++ b/agixt/sso/spotify.py @@ -0,0 +1,117 @@ +import base64 +import requests +import logging +from email.mime.text import MIMEText +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- SPOTIFY_CLIENT_ID: Spotify OAuth client ID +- SPOTIFY_CLIENT_SECRET: Spotify OAuth client secret + +Required APIs + +Ensure that you have the required APIs enabled in your Spotify developer account and add the `SPOTIFY_CLIENT_ID` and `SPOTIFY_CLIENT_SECRET` environment variables to your `.env` file. + +Required scopes for Spotify SSO + +- user-read-email +- user-read-private +- playlist-read-private +""" + + +class SpotifySSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("SPOTIFY_CLIENT_ID") + self.client_secret = getenv("SPOTIFY_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://accounts.spotify.com/api/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://api.spotify.com/v1/me" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + first_name, last_name = data["display_name"].split(" ", 1) + email = data["email"] + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from Spotify", + ) + + def send_email(self, to, subject, message_text): + if not self.user_info.get("email"): + user_info = self.get_user_info() + self.email_address = user_info["email"] + message = MIMEText(message_text) + message["to"] = to + message["from"] = self.email_address + message["subject"] = subject + # Since Spotify does not have a direct API for sending emails, we'll only prepare the message + raw = base64.urlsafe_b64encode(message.as_bytes()) + return {"raw_message": raw.decode()} + + +def spotify_sso(code, redirect_uri=None) -> SpotifySSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + "https://accounts.spotify.com/api/token", + data={ + "client_id": getenv("SPOTIFY_CLIENT_ID"), + "client_secret": getenv("SPOTIFY_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + }, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + if response.status_code != 200: + logging.error(f"Error getting Spotify access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] if "refresh_token" in data else "Not provided" + return SpotifySSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/stack_exchange.py b/agixt/sso/stack_exchange.py new file mode 100644 index 000000000000..48447d9bc26b --- /dev/null +++ b/agixt/sso/stack_exchange.py @@ -0,0 +1,114 @@ +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- STACKEXCHANGE_CLIENT_ID: Stack Exchange OAuth client ID +- STACKEXCHANGE_CLIENT_SECRET: Stack Exchange OAuth client secret + +Required APIs + +Follow the links to confirm that you have the APIs enabled, +then add the `STACKEXCHANGE_CLIENT_ID` and `STACKEXCHANGE_CLIENT_SECRET` environment variables to your `.env` file. + +Required scopes for Stack Exchange OAuth + +- read_inbox +- no_expiry +- private_info +- write_access +""" + + +class StackExchangeSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("STACKEXCHANGE_CLIENT_ID") + self.client_secret = getenv("STACKEXCHANGE_CLIENT_SECRET") + self.key = getenv("STACKEXCHANGE_KEY") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://stackexchange.com/oauth/access_token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + "scope": "read_inbox no_expiry private_info write_access", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + response = requests.get( + f"https://api.stackexchange.com/2.2/me", + params={ + "access_token": self.access_token, + "key": self.key, + "site": "stackoverflow", + }, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + f"https://api.stackexchange.com/2.2/me", + params={ + "access_token": self.access_token, + "key": self.key, + "site": "stackoverflow", + }, + ) + try: + data = response.json()["items"][0] + display_name = data["display_name"] + email = data[ + "email" + ] # Note: Stack Exchange does not provide email directly + return { + "display_name": display_name, + "email": email, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from Stack Exchange", + ) + + +def stack_exchange_sso(code, redirect_uri=None) -> StackExchangeSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + f"https://stackexchange.com/oauth/access_token", + data={ + "client_id": getenv("STACKEXCHANGE_CLIENT_ID"), + "client_secret": getenv("STACKEXCHANGE_CLIENT_SECRET"), + "code": code, + "redirect_uri": redirect_uri, + "grant_type": "authorization_code", + }, + ) + if response.status_code != 200: + logging.error(f"Error getting Stack Exchange access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] if "refresh_token" in data else "Not provided" + return StackExchangeSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/strava.py b/agixt/sso/strava.py new file mode 100644 index 000000000000..6bc5a5a119be --- /dev/null +++ b/agixt/sso/strava.py @@ -0,0 +1,148 @@ +import json +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- STRAVA_CLIENT_ID: Strava OAuth client ID +- STRAVA_CLIENT_SECRET: Strava OAuth client secret + +Required APIs + +No additional APIs need to be enabled beyond the standard Strava API settings. + +Required scopes for Strava OAuth + +- read +- activity:write +""" + + +class StravaSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("STRAVA_CLIENT_ID") + self.client_secret = getenv("STRAVA_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://www.strava.com/oauth/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://www.strava.com/api/v3/athlete" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + first_name = data["firstname"] + last_name = data["lastname"] + email = data.get("email") # Strava API doesn't return email by default + return { + "first_name": first_name, + "last_name": last_name, + "email": email, # Might be None if not provided + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from Strava", + ) + + def create_activity( + self, name, activity_type, start_date, elapsed_time, description=None + ): + """Create an activity on Strava. + + :param name: The name of the activity. + :param activity_type: Type of activity (e.g., "Run", "Ride"). + :param start_date: ISO 8601 formatted date-time when the activity took place. + :param elapsed_time: Activity duration in seconds. + :param description: Description of the activity. + """ + if not self.user_info.get("email"): + user_info = self.get_user_info() + self.email_address = user_info["email"] + + activity_data = { + "name": name, + "type": activity_type, + "start_date_local": start_date, + "elapsed_time": elapsed_time, + } + + if description: + activity_data["description"] = description + + response = requests.post( + "https://www.strava.com/api/v3/activities", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(activity_data), + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.post( + "https://www.strava.com/api/v3/activities", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(activity_data), + ) + return response.json() + + +def strava_sso(code, redirect_uri=None) -> StravaSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + "https://www.strava.com/oauth/token", + data={ + "client_id": getenv("STRAVA_CLIENT_ID"), + "client_secret": getenv("STRAVA_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + }, + ) + if response.status_code != 200: + logging.error(f"Error getting Strava access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] + return StravaSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/stripe.py b/agixt/sso/stripe.py new file mode 100644 index 000000000000..2178464a0b8f --- /dev/null +++ b/agixt/sso/stripe.py @@ -0,0 +1,98 @@ +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- STRIPE_CLIENT_ID: Stripe OAuth client ID +- STRIPE_CLIENT_SECRET: Stripe OAuth client secret + +Required scopes for Stripe SSO + +- read_write +""" + + +class StripeSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("STRIPE_CLIENT_ID") + self.client_secret = getenv("STRIPE_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://connect.stripe.com/oauth/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://api.stripe.com/v1/account" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + email = data["email"] + business_name = data["business_name"] + return { + "email": email, + "business_name": business_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from Stripe", + ) + + def send_email(self, to, subject, message_text): + raise NotImplementedError("Stripe does not support sending email directly.") + + +def stripe_sso(code, redirect_uri=None) -> StripeSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + "https://connect.stripe.com/oauth/token", + data={ + "client_id": getenv("STRIPE_CLIENT_ID"), + "client_secret": getenv("STRIPE_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + }, + ) + if response.status_code != 200: + logging.error(f"Error getting Stripe access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] + return StripeSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/twitch.py b/agixt/sso/twitch.py new file mode 100644 index 000000000000..1914df939fee --- /dev/null +++ b/agixt/sso/twitch.py @@ -0,0 +1,109 @@ +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- TWITCH_CLIENT_ID: Twitch OAuth client ID +- TWITCH_CLIENT_SECRET: Twitch OAuth client secret + +Required scope for Twitch OAuth + +- user:read:email +Follow the links to confirm that you have the APIs enabled, +then add the `TWITCH_CLIENT_ID` and `TWITCH_CLIENT_SECRET` environment variables to your `.env` file. +""" + + +class TwitchSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("TWITCH_CLIENT_ID") + self.client_secret = getenv("TWITCH_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://id.twitch.tv/oauth2/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://api.twitch.tv/helix/users" + response = requests.get( + uri, + headers={ + "Authorization": f"Bearer {self.access_token}", + "Client-Id": self.client_id, + }, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={ + "Authorization": f"Bearer {self.access_token}", + "Client-Id": self.client_id, + }, + ) + try: + data = response.json()["data"][0] + first_name = data["display_name"] + email = data["email"] + return { + "email": email, + "first_name": first_name, + "last_name": "", # Twitch API does not provide surname in user info + } + except Exception as e: + raise HTTPException( + status_code=400, + detail="Error getting user info from Twitch", + ) + + # Twitch doesn't support email sending directly, implement your own function for sending messages if needed + def send_message(self, message_text): + # You can implement another way to notify the user, like a whisper or chat message + pass + + +def twitch_sso(code, redirect_uri=None) -> TwitchSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + f"https://id.twitch.tv/oauth2/token", + data={ + "client_id": getenv("TWITCH_CLIENT_ID"), + "client_secret": getenv("TWITCH_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + }, + ) + if response.status_code != 200: + logging.error(f"Error getting Twitch access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] + return TwitchSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/viadeo.py b/agixt/sso/viadeo.py new file mode 100644 index 000000000000..fc70b968bb38 --- /dev/null +++ b/agixt/sso/viadeo.py @@ -0,0 +1,146 @@ +import base64 +import json +import requests +import logging +from email.mime.text import MIMEText +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- VIADEO_CLIENT_ID: Viadeo OAuth client ID +- VIADEO_CLIENT_SECRET: Viadeo OAuth client secret + +Required APIs + +Ensure you have the required APIs enabled and add the +`VIADEO_CLIENT_ID` and `VIADEO_CLIENT_SECRET` environment variables to your `.env` file. + +Viadeo API docs reference: +https://developer.viadeo.com/ + +Required scopes for Viadeo OAuth + +- basic (to access user profile) +- email (to access user email) +""" + + +class ViadeoSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("VIADEO_CLIENT_ID") + self.client_secret = getenv("VIADEO_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://secure.viadeo.com/oauth-provider/refreshToken", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://api.viadeo.com/me" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + first_name = data["first_name"] + last_name = data["last_name"] + email = data["email"] + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from Viadeo", + ) + + def send_email(self, to, subject, message_text): + if not self.user_info.get("email"): + user_info = self.get_user_info() + self.email_address = user_info["email"] + message = MIMEText(message_text) + message["to"] = to + message["from"] = self.email_address + message["subject"] = subject + raw = base64.urlsafe_b64encode(message.as_bytes()) + raw = raw.decode() + email_data = { + "message": { + "subject": subject, + "text": message_text, + "to": to, + }, + } + response = requests.post( + "https://api.viadeo.com/send_email", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(email_data), + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.post( + "https://api.viadeo.com/send_email", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(email_data), + ) + return response.json() + + +def viadeo_sso(code, redirect_uri=None) -> ViadeoSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + "https://secure.viadeo.com/oauth-provider/accessToken", + data={ + "client_id": getenv("VIADEO_CLIENT_ID"), + "client_secret": getenv("VIADEO_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + }, + ) + if response.status_code != 200: + logging.error(f"Error getting Viadeo access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] if "refresh_token" in data else "Not provided" + return ViadeoSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/vimeo.py b/agixt/sso/vimeo.py new file mode 100644 index 000000000000..578d8615752a --- /dev/null +++ b/agixt/sso/vimeo.py @@ -0,0 +1,163 @@ +import os +import json +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- VIMEO_CLIENT_ID: Vimeo OAuth client ID +- VIMEO_CLIENT_SECRET: Vimeo OAuth client secret + +Required APIs + +Ensure you have the necessary APIs enabled in Vimeo's developer platform, +then add the `VIMEO_CLIENT_ID` and `VIMEO_CLIENT_SECRET` environment variables to your `.env` file. + +Required scopes for Vimeo OAuth + +- public +- private +- video_files +""" + + +class VimeoSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("VIMEO_CLIENT_ID") + self.client_secret = getenv("VIMEO_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://api.vimeo.com/oauth/access_token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://api.vimeo.com/me" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + first_name = data["name"].split()[0] + last_name = data["name"].split()[-1] + email = data["email"] + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from Vimeo", + ) + + def upload_video(self, video_file_path, video_title, video_description): + uri = "https://api.vimeo.com/me/videos" + response = requests.post( + uri, + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps( + { + "upload": { + "approach": "tus", + "size": str(os.path.getsize(video_file_path)), + }, + "name": video_title, + "description": video_description, + } + ), + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.post( + uri, + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps( + { + "upload": { + "approach": "tus", + "size": str(os.path.getsize(video_file_path)), + }, + "name": video_title, + "description": video_description, + } + ), + ) + if response.status_code != 200: + raise HTTPException( + status_code=response.status_code, + detail=f"Error uploading video to Vimeo: {response.text}", + ) + upload_link = response.json()["upload"]["upload_link"] + with open(video_file_path, "rb") as video_file: + tus_response = requests.patch( + upload_link, + data=video_file, + headers={"Content-Type": "application/offset+octet-stream"}, + ) + if tus_response.status_code != 204: + raise HTTPException( + status_code=tus_response.status_code, + detail=f"Error uploading video to Vimeo using upload link: {tus_response.text}", + ) + return response.json() + + +def vimeo_sso(code, redirect_uri=None) -> VimeoSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + "https://api.vimeo.com/oauth/access_token", + data={ + "client_id": getenv("VIMEO_CLIENT_ID"), + "client_secret": getenv("VIMEO_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + }, + ) + if response.status_code != 200: + logging.error(f"Error getting Vimeo access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] + return VimeoSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/vk.py b/agixt/sso/vk.py new file mode 100644 index 000000000000..c2f4e8621b46 --- /dev/null +++ b/agixt/sso/vk.py @@ -0,0 +1,86 @@ +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- VK_CLIENT_ID: VK OAuth client ID +- VK_CLIENT_SECRET: VK OAuth client secret + +Required APIs + +Follow the links to confirm that you have the APIs enabled, +then add the `VK_CLIENT_ID` and `VK_CLIENT_SECRET` environment variables to your `.env` file. + +Required scopes for VK SSO + +- email +""" + + +class VKSSO: + def __init__( + self, + access_token=None, + user_id=None, + email=None, + ): + self.access_token = access_token + self.user_id = user_id + self.email = email + self.client_id = getenv("VK_CLIENT_ID") + self.client_secret = getenv("VK_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + raise NotImplementedError("VK API does not use refresh tokens.") + + def get_user_info(self): + uri = f"https://api.vk.com/method/users.get?user_ids={self.user_id}&fields=first_name,last_name&access_token={self.access_token}&v=5.131" + response = requests.get(uri) + if response.status_code != 200: + raise HTTPException( + status_code=response.status_code, + detail="Error getting user info from VK", + ) + try: + data = response.json()["response"][0] + first_name = data["first_name"] + last_name = data["last_name"] + return { + "email": self.email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error parsing user info from VK", + ) + + def send_email(self, to, subject, message_text): + raise NotImplementedError("VK API does not support sending emails.") + + +def vk_sso(code, redirect_uri=None) -> VKSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + response = requests.get( + "https://oauth.vk.com/access_token", + params={ + "client_id": getenv("VK_CLIENT_ID"), + "client_secret": getenv("VK_CLIENT_SECRET"), + "redirect_uri": redirect_uri, + "code": code, + }, + ) + if response.status_code != 200: + logging.error(f"Error getting VK access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + user_id = data["user_id"] + email = data.get("email", "Not provided") + return VKSSO(access_token=access_token, user_id=user_id, email=email) diff --git a/agixt/sso/wechat.py b/agixt/sso/wechat.py new file mode 100644 index 000000000000..f956363bd1a4 --- /dev/null +++ b/agixt/sso/wechat.py @@ -0,0 +1,96 @@ +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- WECHAT_CLIENT_ID: WeChat OAuth client ID +- WECHAT_CLIENT_SECRET: WeChat OAuth client secret + +Required scopes for WeChat SSO: + +- snsapi_userinfo +""" + + +class WeChatSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("WECHAT_CLIENT_ID") + self.client_secret = getenv("WECHAT_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.get( + "https://api.weixin.qq.com/sns/oauth2/refresh_token", + params={ + "appid": self.client_id, + "grant_type": "refresh_token", + "refresh_token": self.refresh_token, + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://api.weixin.qq.com/sns/userinfo" + response = requests.get( + uri, + params={ + "access_token": self.access_token, + "openid": self.client_id, + "lang": "en", + }, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + params={ + "access_token": self.access_token, + "openid": self.client_id, + "lang": "en", + }, + ) + try: + data = response.json() + first_name = data["nickname"] + last_name = "" # WeChat does not provide last name + email = data.get("email") # WeChat may not provide email + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from WeChat", + ) + + +def wechat_sso(code, redirect_uri=None) -> WeChatSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + response = requests.get( + "https://api.weixin.qq.com/sns/oauth2/access_token", + params={ + "appid": getenv("WECHAT_CLIENT_ID"), + "secret": getenv("WECHAT_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + }, + ) + if response.status_code != 200: + logging.error(f"Error getting WeChat access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] + return WeChatSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/withings.py b/agixt/sso/withings.py new file mode 100644 index 000000000000..30a8ef7333d0 --- /dev/null +++ b/agixt/sso/withings.py @@ -0,0 +1,114 @@ +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- WITHINGS_CLIENT_ID: Withings OAuth client ID +- WITHINGS_CLIENT_SECRET: Withings OAuth client secret + +Required APIs + +Follow the links to confirm that you have the APIs enabled, +then add the `WITHINGS_CLIENT_ID` and `WITHINGS_CLIENT_SECRET` environment variables to your `.env` file. + +Required scopes for Withings SSO + +- user.info +- user.metrics +- user.activity +""" + + +class WithingsSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("WITHINGS_CLIENT_ID") + self.client_secret = getenv("WITHINGS_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://wbsapi.withings.net/v2/oauth2", + data={ + "action": "requesttoken", + "client_id": self.client_id, + "client_secret": self.client_secret, + "grant_type": "refresh_token", + "refresh_token": self.refresh_token, + }, + ) + data = response.json()["body"] + return data["access_token"] + + def get_user_info(self): + uri = "https://wbsapi.withings.net/v2/user" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + params={"action": "getdevice"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + params={"action": "getdevice"}, + ) + try: + data = response.json()["body"]["devices"][0] + first_name = data.get("firstname", "") + last_name = data.get("lastname", "") + email = data.get("email", "") + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from Withings", + ) + + def send_email(self, to, subject, message_text): + raise HTTPException( + status_code=501, detail="Withings API does not support sending email" + ) + + +def withings_sso(code, redirect_uri=None) -> WithingsSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + f"https://wbsapi.withings.net/v2/oauth2", + data={ + "action": "requesttoken", + "client_id": getenv("WITHINGS_CLIENT_ID"), + "client_secret": getenv("WITHINGS_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + }, + ) + if response.status_code != 200: + logging.error(f"Error getting Withings access token: {response.text}") + return None, None + data = response.json()["body"] + access_token = data["access_token"] + refresh_token = data["refresh_token"] + return WithingsSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/xero.py b/agixt/sso/xero.py new file mode 100644 index 000000000000..d5e0409bba55 --- /dev/null +++ b/agixt/sso/xero.py @@ -0,0 +1,110 @@ +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- XERO_CLIENT_ID: Xero OAuth client ID +- XERO_CLIENT_SECRET: Xero OAuth client secret + +Required APIs + +Follow the links to confirm that you have the APIs enabled, +then add the `XERO_CLIENT_ID` and `XERO_CLIENT_SECRET` environment variables to your `.env` file. + +Required scopes for Xero OAuth + +- openid +- profile +- email +- offline_access +""" + + +class XeroSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("XERO_CLIENT_ID") + self.client_secret = getenv("XERO_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://identity.xero.com/connect/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + ) + if response.status_code != 200: + logging.error(f"Error refreshing Xero token: {response.text}") + raise HTTPException( + status_code=response.status_code, + detail="Unable to refresh token from Xero", + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://api.xero.com/connections" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json()[0] # Assuming you want the first connection info + first_name = data.get("name", "").split()[0] + last_name = " ".join(data.get("name", "").split()[1:]) + email = data.get("email", "") + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except Exception as exc: + logging.error(f"Error parsing user info from Xero: {exc}") + raise HTTPException( + status_code=400, + detail="Error getting user info from Xero", + ) + + def send_email(self, to, subject, message_text): + # Xero does not provide an email sending service. + raise NotImplementedError("Xero does not support sending emails via API.") + + +def xero_sso(code, redirect_uri=None) -> XeroSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + response = requests.post( + "https://identity.xero.com/connect/token", + data={ + "client_id": getenv("XERO_CLIENT_ID"), + "client_secret": getenv("XERO_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + }, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + if response.status_code != 200: + logging.error(f"Error getting Xero access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] + return XeroSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/xing.py b/agixt/sso/xing.py new file mode 100644 index 000000000000..c3397fdfc7a9 --- /dev/null +++ b/agixt/sso/xing.py @@ -0,0 +1,156 @@ +import base64 +import json +import requests +import logging +from email.mime.text import MIMEText +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- XING_CLIENT_ID: XING OAuth client ID +- XING_CLIENT_SECRET: XING OAuth client secret + +Required APIs + +Follow the links to confirm that you have the APIs enabled, +then add the `XING_CLIENT_ID` and `XING_CLIENT_SECRET` environment variables to your `.env` file. + +- Xing API https://dev.xing.com/ + +Required scopes for XING SSO + +- https://api.xing.com/v1/users/me +- https://api.xing.com/v1/authorize +""" + + +class XingSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("XING_CLIENT_ID") + self.client_secret = getenv("XING_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://api.xing.com/v1/oauth/token", + auth=(self.client_id, self.client_secret), + data={ + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://api.xing.com/v1/users/me" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + user_profile = data["users"][0] + first_name = user_profile["first_name"] + last_name = user_profile["last_name"] + email = user_profile["active_email"] + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from XING", + ) + + def send_email(self, to, subject, message_text): + if not self.email_address: + user_info = self.get_user_info() + self.email_address = user_info["email"] + message = MIMEText(message_text) + message["to"] = to + message["from"] = self.email_address + message["subject"] = subject + raw = base64.urlsafe_b64encode(message.as_bytes()) + raw = raw.decode() + email_data = { + "message": { + "subject": subject, + "body": { + "contentType": "Text", + "content": message_text, + }, + "toRecipients": [ + { + "emailAddress": { + "address": to, + } + } + ], + }, + "saveToSentItems": "true", + } + response = requests.post( + "https://api.xing.com/v1/messages", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(email_data), + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.post( + "https://api.xing.com/v1/messages", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(email_data), + ) + return response.json() + + +def xing_sso(code, redirect_uri=None) -> XingSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + f"https://api.xing.com/v1/oauth/token", + data={ + "client_id": getenv("XING_CLIENT_ID"), + "client_secret": getenv("XING_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + "scope": "https://api.xing.com/v1/users/me https://api.xing.com/v1/messages", + }, + ) + if response.status_code != 200: + logging.error(f"Error getting XING access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data.get("refresh_token", "Not provided") + return XingSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/yahoo.py b/agixt/sso/yahoo.py new file mode 100644 index 000000000000..8646b02737e6 --- /dev/null +++ b/agixt/sso/yahoo.py @@ -0,0 +1,140 @@ +import base64 +import json +import requests +import logging +from email.mime.text import MIMEText +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- YAHOO_CLIENT_ID: Yahoo OAuth client ID +- YAHOO_CLIENT_SECRET: Yahoo OAuth client secret + +Required APIs + +Follow the links to confirm that you have the APIs enabled, +then add the `YAHOO_CLIENT_ID` and `YAHOO_CLIENT_SECRET` environment variables to your `.env` file. + +Required scopes for Yahoo OAuth + +- profile +- email +- mail-w +""" + + +class YahooSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("YAHOO_CLIENT_ID") + self.client_secret = getenv("YAHOO_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://api.login.yahoo.com/oauth2/get_token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://api.login.yahoo.com/openid/v1/userinfo" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + first_name = data["given_name"] + last_name = data["family_name"] + email = data["email"] + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from Yahoo", + ) + + def send_email(self, to, subject, message_text): + if not self.user_info.get("email"): + user_info = self.get_user_info() + self.email_address = user_info["email"] + message = MIMEText(message_text) + message["to"] = to + message["from"] = self.email_address + message["subject"] = subject + raw = base64.urlsafe_b64encode(message.as_bytes()) + raw = raw.decode() + email_data = {"raw": raw} + response = requests.post( + "https://api.login.yahoo.com/ws/mail/v3/send", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(email_data), + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.post( + "https://api.login.yahoo.com/ws/mail/v3/send", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(email_data), + ) + return response.json() + + +def yahoo_sso(code, redirect_uri=None) -> YahooSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + f"https://api.login.yahoo.com/oauth2/get_token", + data={ + "client_id": getenv("YAHOO_CLIENT_ID"), + "client_secret": getenv("YAHOO_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + }, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + if response.status_code != 200: + logging.error(f"Error getting Yahoo access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] + return YahooSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/yammer.py b/agixt/sso/yammer.py new file mode 100644 index 000000000000..53f865f2a98e --- /dev/null +++ b/agixt/sso/yammer.py @@ -0,0 +1,132 @@ +import json +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- YAMMER_CLIENT_ID: Yammer OAuth client ID +- YAMMER_CLIENT_SECRET: Yammer OAuth client secret + +Required APIs + +Follow the links to confirm that you have the APIs enabled, +then add the `YAMMER_CLIENT_ID` and `YAMMER_CLIENT_SECRET` environment variables to your `.env` file. + +Required scopes for Yammer OAuth + +- messages:email +- messages:post +""" + + +class YammerSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("YAMMER_CLIENT_ID") + self.client_secret = getenv("YAMMER_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://www.yammer.com/oauth2/access_token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://www.yammer.com/api/v1/users/current.json" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + first_name = data["first_name"] + last_name = data["last_name"] + email = data["email"] + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from Yammer", + ) + + def send_message(self, group_id, message_text): + if not self.user_info.get("email"): + user_info = self.get_user_info() + self.email_address = user_info["email"] + message_data = { + "body": message_text, + "group_id": group_id, + } + response = requests.post( + "https://www.yammer.com/api/v1/messages.json", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(message_data), + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.post( + "https://www.yammer.com/api/v1/messages.json", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(message_data), + ) + return response.json() + + +def yammer_sso(code, redirect_uri=None) -> YammerSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + f"https://www.yammer.com/oauth2/token", + data={ + "client_id": getenv("YAMMER_CLIENT_ID"), + "client_secret": getenv("YAMMER_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + }, + ) + if response.status_code != 200: + logging.error(f"Error getting Yammer access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] + return YammerSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/yandex.py b/agixt/sso/yandex.py new file mode 100644 index 000000000000..980b562d5045 --- /dev/null +++ b/agixt/sso/yandex.py @@ -0,0 +1,142 @@ +import base64 +import json +import requests +import logging +from email.mime.text import MIMEText +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- YANDEX_CLIENT_ID: Yandex OAuth client ID +- YANDEX_CLIENT_SECRET: Yandex OAuth client secret + +Required APIs + +Follow the links to confirm that you have the APIs enabled, +then add the `YANDEX_CLIENT_ID` and `YANDEX_CLIENT_SECRET` environment variables to your `.env` file. + +Required scopes for Yandex OAuth + +- login:info +- login:email +- mail.send +""" + + +class YandexSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("YANDEX_CLIENT_ID") + self.client_secret = getenv("YANDEX_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://oauth.yandex.com/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = "https://login.yandex.ru/info" + response = requests.get( + uri, + headers={"Authorization": f"OAuth {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"OAuth {self.access_token}"}, + ) + try: + data = response.json() + first_name = data.get("first_name") + last_name = data.get("last_name") + email = data.get("default_email", data.get("emails", [None])[0]) + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from Yandex", + ) + + def send_email(self, to, subject, message_text): + if not self.user_info.get("email"): + user_info = self.get_user_info() + self.email_address = user_info["email"] + message = MIMEText(message_text) + message["to"] = to + message["from"] = self.email_address + message["subject"] = subject + raw = base64.urlsafe_b64encode(message.as_bytes()) + raw = raw.decode() + email_data = { + "to": to, + "subject": subject, + "text": message_text, + } + response = requests.post( + "https://smtp.yandex.ru/send", + headers={ + "Authorization": f"OAuth {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(email_data), + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.post( + "https://smtp.yandex.ru/send", + headers={ + "Authorization": f"OAuth {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(email_data), + ) + return response.json() + + +def yandex_sso(code, redirect_uri=None) -> YandexSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + f"https://oauth.yandex.com/token", + data={ + "code": code, + "client_id": getenv("YANDEX_CLIENT_ID"), + "client_secret": getenv("YANDEX_CLIENT_SECRET"), + "redirect_uri": redirect_uri, + "grant_type": "authorization_code", + }, + ) + if response.status_code != 200: + logging.error(f"Error getting Yandex access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data.get("refresh_token", "Not provided") + return YandexSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/yelp.py b/agixt/sso/yelp.py new file mode 100644 index 000000000000..581e234f1f5e --- /dev/null +++ b/agixt/sso/yelp.py @@ -0,0 +1,109 @@ +import requests +import logging +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- YELP_CLIENT_ID: Yelp OAuth client ID +- YELP_CLIENT_SECRET: Yelp OAuth client secret + +Required APIs + +Ensure you have registered your app with Yelp and obtained the CLIENT_ID and CLIENT_SECRET. + +Required scopes for Yelp OAuth + +- business +""" + + +class YelpSSO: + def __init__(self, access_token=None, refresh_token=None): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("YELP_CLIENT_ID") + self.client_secret = getenv("YELP_CLIENT_SECRET") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + "https://api.yelp.com/oauth2/token", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + ) + response_data = response.json() + if response.status_code != 200: + logging.error(f"Error refreshing Yelp access token: {response_data}") + raise HTTPException( + status_code=response.status_code, + detail="Error refreshing Yelp access token.", + ) + return response_data["access_token"] + + def get_user_info(self): + uri = "https://api.yelp.com/v3/users/self" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json() + first_name = data["first_name"] + last_name = data["last_name"] + email = data["email"] + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from Yelp", + ) + + def send_email(self, to, subject, message_text): + raise NotImplementedError("Yelp API does not support sending emails directly.") + + +def yelp_sso(code, redirect_uri=None) -> YelpSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + f"https://api.yelp.com/oauth2/token", + data={ + "client_id": getenv("YELP_CLIENT_ID"), + "client_secret": getenv("YELP_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + }, + ) + if response.status_code != 200: + logging.error(f"Error getting Yelp access token: {response.text}") + raise HTTPException( + status_code=response.status_code, detail="Error getting Yelp access token." + ) + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] + return YelpSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/agixt/sso/zendesk.py b/agixt/sso/zendesk.py new file mode 100644 index 000000000000..00f747d3e23a --- /dev/null +++ b/agixt/sso/zendesk.py @@ -0,0 +1,153 @@ +import base64 +import json +import requests +import logging +from email.mime.text import MIMEText +from fastapi import HTTPException +from Globals import getenv + +""" +Required environment variables: + +- ZENDESK_CLIENT_ID: Zendesk OAuth client ID +- ZENDESK_CLIENT_SECRET: Zendesk OAuth client secret +- ZENDESK_SUBDOMAIN: Your Zendesk subdomain + +Required APIs + +Ensure you have the necessary APIs enabled, then add the `ZENDESK_CLIENT_ID`, `ZENDESK_CLIENT_SECRET`, and `ZENDESK_SUBDOMAIN` environment variables to your `.env` file. + +Required scopes for Zendesk OAuth + +- read +- write +""" + + +class ZendeskSSO: + def __init__( + self, + access_token=None, + refresh_token=None, + ): + self.access_token = access_token + self.refresh_token = refresh_token + self.client_id = getenv("ZENDESK_CLIENT_ID") + self.client_secret = getenv("ZENDESK_CLIENT_SECRET") + self.subdomain = getenv("ZENDESK_SUBDOMAIN") + self.user_info = self.get_user_info() + + def get_new_token(self): + response = requests.post( + f"https://{self.subdomain}.zendesk.com/oauth/tokens", + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + "grant_type": "refresh_token", + }, + ) + return response.json()["access_token"] + + def get_user_info(self): + uri = f"https://{self.subdomain}.zendesk.com/api/v2/users/me.json" + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.get( + uri, + headers={"Authorization": f"Bearer {self.access_token}"}, + ) + try: + data = response.json()["user"] + first_name, last_name = data["name"].split(" ", 1) + email = data["email"] + return { + "email": email, + "first_name": first_name, + "last_name": last_name, + } + except: + raise HTTPException( + status_code=400, + detail="Error getting user info from Zendesk", + ) + + def send_email(self, to, subject, message_text): + if not self.user_info.get("email"): + user_info = self.get_user_info() + self.email_address = user_info["email"] + message = MIMEText(message_text) + message["to"] = to + message["from"] = self.email_address + message["subject"] = subject + raw = base64.urlsafe_b64encode(message.as_bytes()) + raw = raw.decode() + + email_data = { + "request": { + "subject": subject, + "comment": {"body": message_text}, + "requester": { + "name": f"{self.user_info['first_name']} {self.user_info['last_name']}", + "email": self.user_info["email"], + }, + "email_ccs": [ + { + "user_email": to, + } + ], + } + } + + response = requests.post( + f"https://{self.subdomain}.zendesk.com/api/v2/requests.json", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(email_data), + ) + if response.status_code == 401: + self.access_token = self.get_new_token() + response = requests.post( + f"https://{self.subdomain}.zendesk.com/api/v2/requests.json", + headers={ + "Authorization": f"Bearer {self.access_token}", + "Content-Type": "application/json", + }, + data=json.dumps(email_data), + ) + return response.json() + + +def zendesk_sso(code, redirect_uri=None) -> ZendeskSSO: + if not redirect_uri: + redirect_uri = getenv("MAGIC_LINK_URL") + code = ( + str(code) + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%3F", "?") + .replace("%3D", "=") + ) + response = requests.post( + f"https://{getenv('ZENDESK_SUBDOMAIN')}.zendesk.com/oauth/tokens", + data={ + "client_id": getenv("ZENDESK_CLIENT_ID"), + "client_secret": getenv("ZENDESK_CLIENT_SECRET"), + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + }, + ) + if response.status_code != 200: + logging.error(f"Error getting Zendesk access token: {response.text}") + return None, None + data = response.json() + access_token = data["access_token"] + refresh_token = data["refresh_token"] + return ZendeskSSO(access_token=access_token, refresh_token=refresh_token) diff --git a/docker-compose-dev.yml b/docker-compose-dev.yml index 5736bc2d89b1..fe8491418252 100644 --- a/docker-compose-dev.yml +++ b/docker-compose-dev.yml @@ -13,24 +13,147 @@ services: image: joshxt/agixt:main init: true environment: - - DATABASE_HOST=${DATABASE_HOST:-db} - - DATABASE_USER=${DATABASE_USER:-postgres} - - DATABASE_PASSWORD=${DATABASE_PASSWORD:-postgres} - - DATABASE_NAME=${DATABASE_NAME:-postgres} - - DATABASE_PORT=${DATABASE_PORT:-5432} - - UVICORN_WORKERS=${UVICORN_WORKERS:-10} - - USING_JWT=${USING_JWT:-false} - - AGIXT_API_KEY=${AGIXT_API_KEY} - - AGIXT_URI=${AGIXT_URI-http://agixt:7437} - - DISABLED_EXTENSIONS=${DISABLED_EXTENSIONS:-} - - DISABLED_PROVIDERS=${DISABLED_PROVIDERS:-} - - WORKING_DIRECTORY=${WORKING_DIRECTORY:-/agixt/WORKSPACE} - - TOKENIZERS_PARALLELISM=False - - LOG_LEVEL=${LOG_LEVEL:-INFO} - - AUTH_PROVIDER=${AUTH_PROVIDER:-none} - - TZ=${TZ-America/New_York} + DATABASE_HOST: ${DATABASE_HOST:-db} + DATABASE_USER: ${DATABASE_USER:-postgres} + DATABASE_PASSWORD: ${DATABASE_PASSWORD:-postgres} + DATABASE_NAME: ${DATABASE_NAME:-postgres} + DATABASE_PORT: ${DATABASE_PORT:-5432} + UVICORN_WORKERS: ${UVICORN_WORKERS:-10} + USING_JWT: ${USING_JWT:-false} + AGIXT_API_KEY: ${AGIXT_API_KEY} + AGIXT_URI: ${AGIXT_URI-http://agixt:7437} + MAGIC_LINK_URL: ${AUTH_WEB-http://agixtinteractive:3437/user} + DISABLED_EXTENSIONS: ${DISABLED_EXTENSIONS:-} + DISABLED_PROVIDERS: ${DISABLED_PROVIDERS:-} + WORKING_DIRECTORY: ${WORKING_DIRECTORY:-/agixt/WORKSPACE} + TOKENIZERS_PARALLELISM: False + LOG_LEVEL: ${LOG_LEVEL:-INFO} + AUTH_PROVIDER: ${AUTH_PROVIDER:-none} + AOL_CLIENT_ID: ${AOL_CLIENT_ID} + AOL_CLIENT_SECRET: ${AOL_CLIENT_SECRET} + APPLE_CLIENT_ID: ${APPLE_CLIENT_ID} + APPLE_CLIENT_SECRET: ${APPLE_CLIENT_SECRET} + AUTODESK_CLIENT_ID: ${AUTODESK_CLIENT_ID} + AUTODESK_CLIENT_SECRET: ${AUTODESK_CLIENT_SECRET} + AWS_CLIENT_ID: ${AWS_CLIENT_ID} + AWS_CLIENT_SECRET: ${AWS_CLIENT_SECRET} + AWS_REGION: ${AWS_REGION} + AWS_USER_POOL_ID: ${AWS_USER_POOL_ID} + BATTLENET_CLIENT_ID: ${BATTLENET_CLIENT_ID} + BATTLENET_CLIENT_SECRET: ${BATTLENET_CLIENT_SECRET} + BITBUCKET_CLIENT_ID: ${BITBUCKET_CLIENT_ID} + BITBUCKET_CLIENT_SECRET: ${BITBUCKET_CLIENT_SECRET} + BITLY_ACCESS_TOKEN: ${BITLY_ACCESS_TOKEN} + BITLY_CLIENT_ID: ${BITLY_CLIENT_ID} + BITLY_CLIENT_SECRET: ${BITLY_CLIENT_SECRET} + CF_CLIENT_ID: ${CF_CLIENT_ID} + CF_CLIENT_SECRET: ${CF_CLIENT_SECRET} + CLEAR_SCORE_CLIENT_ID: ${CLEAR_SCORE_CLIENT_ID} + CLEAR_SCORE_CLIENT_SECRET: ${CLEAR_SCORE_CLIENT_SECRET} + DEUTSCHE_TELKOM_CLIENT_ID: ${DEUTSCHE_TELKOM_CLIENT_ID} + DEUTSCHE_TELKOM_CLIENT_SECRET: ${DEUTSCHE_TELKOM_CLIENT_SECRET} + DEVIANTART_CLIENT_ID: ${DEVIANTART_CLIENT_ID} + DEVIANTART_CLIENT_SECRET: ${DEVIANTART_CLIENT_SECRET} + DISCORD_CLIENT_ID: ${DISCORD_CLIENT_ID} + DISCORD_CLIENT_SECRET: ${DISCORD_CLIENT_SECRET} + DROPBOX_CLIENT_ID: ${DROPBOX_CLIENT_ID} + DROPBOX_CLIENT_SECRET: ${DROPBOX_CLIENT_SECRET} + FACEBOOK_CLIENT_ID: ${FACEBOOK_CLIENT_ID} + FACEBOOK_CLIENT_SECRET: ${FACEBOOK_CLIENT_SECRET} + FATSECRET_CLIENT_ID: ${FATSECRET_CLIENT_ID} + FATSECRET_CLIENT_SECRET: ${FATSECRET_CLIENT_SECRET} + FITBIT_CLIENT_ID: ${FITBIT_CLIENT_ID} + FITBIT_CLIENT_SECRET: ${FITBIT_CLIENT_SECRET} + FORMSTACK_CLIENT_ID: ${FORMSTACK_CLIENT_ID} + FORMSTACK_CLIENT_SECRET: ${FORMSTACK_CLIENT_SECRET} + FOURSQUARE_CLIENT_ID: ${FOURSQUARE_CLIENT_ID} + FOURSQUARE_CLIENT_SECRET: ${FOURSQUARE_CLIENT_SECRET} + GITHUB_CLIENT_ID: ${GITHUB_CLIENT_ID} + GITHUB_CLIENT_SECRET: ${GITHUB_CLIENT_SECRET} + GITLAB_CLIENT_ID: ${GITLAB_CLIENT_ID} + GITLAB_CLIENT_SECRET: ${GITLAB_CLIENT_SECRET} + GOOGLE_CLIENT_ID: ${GOOGLE_CLIENT_ID} + GOOGLE_CLIENT_SECRET: ${GOOGLE_CLIENT_SECRET} + HUDDLE_CLIENT_ID: ${HUDDLE_CLIENT_ID} + HUDDLE_CLIENT_SECRET: ${HUDDLE_CLIENT_SECRET} + IMGUR_CLIENT_ID: ${IMGUR_CLIENT_ID} + IMGUR_CLIENT_SECRET: ${IMGUR_CLIENT_SECRET} + INSTAGRAM_CLIENT_ID: ${INSTAGRAM_CLIENT_ID} + INSTAGRAM_CLIENT_SECRET: ${INSTAGRAM_CLIENT_SECRET} + INTEL_CLIENT_ID: ${INTEL_CLIENT_ID} + INTEL_CLIENT_SECRET: ${INTEL_CLIENT_SECRET} + JIVE_CLIENT_ID: ${JIVE_CLIENT_ID} + JIVE_CLIENT_SECRET: ${JIVE_CLIENT_SECRET} + KEYCLOAK_CLIENT_ID: ${KEYCLOAK_CLIENT_ID} + KEYCLOAK_CLIENT_SECRET: ${KEYCLOAK_CLIENT_SECRET} + KEYCLOAK_REALM: ${KEYCLOAK_REALM} + KEYCLOAK_SERVER_URL: ${KEYCLOAK_SERVER_URL} + LINKEDIN_CLIENT_ID: ${LINKEDIN_CLIENT_ID} + LINKEDIN_CLIENT_SECRET: ${LINKEDIN_CLIENT_SECRET} + MICROSOFT_CLIENT_ID: ${MICROSOFT_CLIENT_ID} + MICROSOFT_CLIENT_SECRET: ${MICROSOFT_CLIENT_SECRET} + NETIQ_CLIENT_ID: ${NETIQ_CLIENT_ID} + NETIQ_CLIENT_SECRET: ${NETIQ_CLIENT_SECRET} + OKTA_CLIENT_ID: ${OKTA_CLIENT_ID} + OKTA_CLIENT_SECRET: ${OKTA_CLIENT_SECRET} + OKTA_DOMAIN: ${OKTA_DOMAIN} + OPENAM_BASE_URL: ${OPENAM_BASE_URL} + OPENAM_CLIENT_ID: ${OPENAM_CLIENT_ID} + OPENAM_CLIENT_SECRET: ${OPENAM_CLIENT_SECRET} + ORCID_CLIENT_ID: ${ORCID_CLIENT_ID} + ORCID_CLIENT_SECRET: ${ORCID_CLIENT_SECRET} + OSM_CLIENT_ID: ${OSM_CLIENT_ID} + OSM_CLIENT_SECRET: ${OSM_CLIENT_SECRET} + PAYPAL_CLIENT_ID: ${PAYPAL_CLIENT_ID} + PAYPAL_CLIENT_SECRET: ${PAYPAL_CLIENT_SECRET} + PING_IDENTITY_CLIENT_ID: ${PING_IDENTITY_CLIENT_ID} + PING_IDENTITY_CLIENT_SECRET: ${PING_IDENTITY_CLIENT_SECRET} + PIXIV_CLIENT_ID: ${PIXIV_CLIENT_ID} + PIXIV_CLIENT_SECRET: ${PIXIV_CLIENT_SECRET} + REDDIT_CLIENT_ID: ${REDDIT_CLIENT_ID} + REDDIT_CLIENT_SECRET: ${REDDIT_CLIENT_SECRET} + SALESFORCE_CLIENT_ID: ${SALESFORCE_CLIENT_ID} + SALESFORCE_CLIENT_SECRET: ${SALESFORCE_CLIENT_SECRET} + SPOTIFY_CLIENT_ID: ${SPOTIFY_CLIENT_ID} + SPOTIFY_CLIENT_SECRET: ${SPOTIFY_CLIENT_SECRET} + STACKEXCHANGE_CLIENT_ID: ${STACKEXCHANGE_CLIENT_ID} + STACKEXCHANGE_CLIENT_SECRET: ${STACKEXCHANGE_CLIENT_SECRET} + STRAVA_CLIENT_ID: ${STRAVA_CLIENT_ID} + STRAVA_CLIENT_SECRET: ${STRAVA_CLIENT_SECRET} + STRIPE_CLIENT_ID: ${STRIPE_CLIENT_ID} + STRIPE_CLIENT_SECRET: ${STRIPE_CLIENT_SECRET} + TWITCH_CLIENT_ID: ${TWITCH_CLIENT_ID} + TWITCH_CLIENT_SECRET: ${TWITCH_CLIENT_SECRET} + VIADEO_CLIENT_ID: ${VIADEO_CLIENT_ID} + VIADEO_CLIENT_SECRET: ${VIADEO_CLIENT_SECRET} + VIMEO_CLIENT_ID: ${VIMEO_CLIENT_ID} + VIMEO_CLIENT_SECRET: ${VIMEO_CLIENT_SECRET} + VK_CLIENT_ID: ${VK_CLIENT_ID} + VK_CLIENT_SECRET: ${VK_CLIENT_SECRET} + WECHAT_CLIENT_ID: ${WECHAT_CLIENT_ID} + WECHAT_CLIENT_SECRET: ${WECHAT_CLIENT_SECRET} + WEIBO_CLIENT_ID: ${WEIBO_CLIENT_ID} + WEIBO_CLIENT_SECRET: ${WEIBO_CLIENT_SECRET} + WITHINGS_CLIENT_ID: ${WITHINGS_CLIENT_ID} + WITHINGS_CLIENT_SECRET: ${WITHINGS_CLIENT_SECRET} + XERO_CLIENT_ID: ${XERO_CLIENT_ID} + XERO_CLIENT_SECRET: ${XERO_CLIENT_SECRET} + XING_CLIENT_ID: ${XING_CLIENT_ID} + XING_CLIENT_SECRET: ${XING_CLIENT_SECRET} + YAHOO_CLIENT_ID: ${YAHOO_CLIENT_ID} + YAHOO_CLIENT_SECRET: ${YAHOO_CLIENT_SECRET} + YAMMER_CLIENT_ID: ${YAMMER_CLIENT_ID} + YAMMER_CLIENT_SECRET: ${YAMMER_CLIENT_SECRET} + YANDEX_CLIENT_ID: ${YANDEX_CLIENT_ID} + YANDEX_CLIENT_SECRET: ${YANDEX_CLIENT_SECRET} + YELP_CLIENT_ID: ${YELP_CLIENT_ID} + YELP_CLIENT_SECRET: ${YELP_CLIENT_SECRET} + ZENDESK_CLIENT_ID: ${ZENDESK_CLIENT_ID} + ZENDESK_CLIENT_SECRET: ${ZENDESK_CLIENT_SECRET} + ZENDESK_SUBDOMAIN: ${ZENDESK_SUBDOMAIN} + TZ: ${TZ-America/New_York} ports: - - "7437:7437" + - 7437:7437 volumes: - ./models:/agixt/models - ./agixt/WORKSPACE:/agixt/WORKSPACE @@ -45,9 +168,169 @@ services: depends_on: - agixt environment: - - AGIXT_URI=${AGIXT_URI-http://agixt:7437} - - AGIXT_API_KEY=${AGIXT_API_KEY} + AGIXT_URI: ${AGIXT_URI-http://agixt:7437} + AGIXT_API_KEY: ${AGIXT_API_KEY} + AOL_CLIENT_ID: ${AOL_CLIENT_ID} + APPLE_CLIENT_ID: ${APPLE_CLIENT_ID} + AUTODESK_CLIENT_ID: ${AUTODESK_CLIENT_ID} + AWS_CLIENT_ID: ${AWS_CLIENT_ID} + AWS_REGION: ${AWS_REGION} + AWS_USER_POOL_ID: ${AWS_USER_POOL_ID} + BATTLENET_CLIENT_ID: ${BATTLENET_CLIENT_ID} + BITBUCKET_CLIENT_ID: ${BITBUCKET_CLIENT_ID} + BITLY_ACCESS_TOKEN: ${BITLY_ACCESS_TOKEN} + BITLY_CLIENT_ID: ${BITLY_CLIENT_ID} + CF_CLIENT_ID: ${CF_CLIENT_ID} + CLEAR_SCORE_CLIENT_ID: ${CLEAR_SCORE_CLIENT_ID} + DEUTSCHE_TELKOM_CLIENT_ID: ${DEUTSCHE_TELKOM_CLIENT_ID} + DEVIANTART_CLIENT_ID: ${DEVIANTART_CLIENT_ID} + DISCORD_CLIENT_ID: ${DISCORD_CLIENT_ID} + DROPBOX_CLIENT_ID: ${DROPBOX_CLIENT_ID} + FACEBOOK_CLIENT_ID: ${FACEBOOK_CLIENT_ID} + FATSECRET_CLIENT_ID: ${FATSECRET_CLIENT_ID} + FITBIT_CLIENT_ID: ${FITBIT_CLIENT_ID} + FORMSTACK_CLIENT_ID: ${FORMSTACK_CLIENT_ID} + FOURSQUARE_CLIENT_ID: ${FOURSQUARE_CLIENT_ID} + GITHUB_CLIENT_ID: ${GITHUB_CLIENT_ID} + GITLAB_CLIENT_ID: ${GITLAB_CLIENT_ID} + GOOGLE_CLIENT_ID: ${GOOGLE_CLIENT_ID} + HUDDLE_CLIENT_ID: ${HUDDLE_CLIENT_ID} + IMGUR_CLIENT_ID: ${IMGUR_CLIENT_ID} + INSTAGRAM_CLIENT_ID: ${INSTAGRAM_CLIENT_ID} + INTEL_CLIENT_ID: ${INTEL_CLIENT_ID} + JIVE_CLIENT_ID: ${JIVE_CLIENT_ID} + KEYCLOAK_CLIENT_ID: ${KEYCLOAK_CLIENT_ID} + KEYCLOAK_REALM: ${KEYCLOAK_REALM} + KEYCLOAK_SERVER_URL: ${KEYCLOAK_SERVER_URL} + LINKEDIN_CLIENT_ID: ${LINKEDIN_CLIENT_ID} + MICROSOFT_CLIENT_ID: ${MICROSOFT_CLIENT_ID} + NETIQ_CLIENT_ID: ${NETIQ_CLIENT_ID} + OKTA_CLIENT_ID: ${OKTA_CLIENT_ID} + OKTA_DOMAIN: ${OKTA_DOMAIN} + OPENAM_BASE_URL: ${OPENAM_BASE_URL} + OPENAM_CLIENT_ID: ${OPENAM_CLIENT_ID} + ORCID_CLIENT_ID: ${ORCID_CLIENT_ID} + OSM_CLIENT_ID: ${OSM_CLIENT_ID} + PAYPAL_CLIENT_ID: ${PAYPAL_CLIENT_ID} + PING_IDENTITY_CLIENT_ID: ${PING_IDENTITY_CLIENT_ID} + PIXIV_CLIENT_ID: ${PIXIV_CLIENT_ID} + REDDIT_CLIENT_ID: ${REDDIT_CLIENT_ID} + SALESFORCE_CLIENT_ID: ${SALESFORCE_CLIENT_ID} + SPOTIFY_CLIENT_ID: ${SPOTIFY_CLIENT_ID} + STACKEXCHANGE_CLIENT_ID: ${STACKEXCHANGE_CLIENT_ID} + STRAVA_CLIENT_ID: ${STRAVA_CLIENT_ID} + STRIPE_CLIENT_ID: ${STRIPE_CLIENT_ID} + TWITCH_CLIENT_ID: ${TWITCH_CLIENT_ID} + VIADEO_CLIENT_ID: ${VIADEO_CLIENT_ID} + VIMEO_CLIENT_ID: ${VIMEO_CLIENT_ID} + VK_CLIENT_ID: ${VK_CLIENT_ID} + WECHAT_CLIENT_ID: ${WECHAT_CLIENT_ID} + WEIBO_CLIENT_ID: ${WEIBO_CLIENT_ID} + WITHINGS_CLIENT_ID: ${WITHINGS_CLIENT_ID} + XERO_CLIENT_ID: ${XERO_CLIENT_ID} + XING_CLIENT_ID: ${XING_CLIENT_ID} + YAHOO_CLIENT_ID: ${YAHOO_CLIENT_ID} + YAMMER_CLIENT_ID: ${YAMMER_CLIENT_ID} + YANDEX_CLIENT_ID: ${YANDEX_CLIENT_ID} + YELP_CLIENT_ID: ${YELP_CLIENT_ID} + ZENDESK_CLIENT_ID: ${ZENDESK_CLIENT_ID} + ZENDESK_SUBDOMAIN: ${ZENDESK_SUBDOMAIN} volumes: - ./agixt/WORKSPACE:/app/WORKSPACE ports: - "8501:8501" + agixtinteractive: + image: ghcr.io/jamesonrgrieve/agixt-interactive:main + init: true + environment: + NEXT_TELEMETRY_DISABLED: 1 + AGIXT_AGENT: ${AGIXT_AGENT-gpt4free} + AGIXT_FILE_UPLOAD_ENABLED: ${AGIXT_FILE_UPLOAD_ENABLED-true} + AGIXT_VOICE_INPUT_ENABLED: ${AGIXT_VOICE_INPUT_ENABLED-true} + AGIXT_FOOTER_MESSAGE: ${AGIXT_FOOTER_MESSAGE-Powered by AGiXT} + AGIXT_REQUIRE_API_KEY: ${AGIXT_REQUIRE_API_KEY-false} + AGIXT_RLHF: ${AGIXT_RLHF-true} + AGIXT_SERVER: ${AGIXT_URI-http://agixt:7437} + AGIXT_SHOW_AGENT_BAR: ${AGIXT_SHOW_AGENT_BAR-true} + AGIXT_SHOW_APP_BAR: ${AGIXT_SHOW_APP_BAR-true} + AGIXT_SHOW_CONVERSATION_BAR: ${AGIXT_SHOW_CONVERSATION_BAR-true} + AGIXT_CONVERSATION_MODE: ${AGIXT_CONVERSATION_MODE-select} + APP_DESCRIPTION: ${APP_DESCRIPTION-A chat powered by AGiXT.} + INTERACTIVE_MODE: ${INTERACTIVE_MODE-chat} + APP_NAME: ${APP_NAME-AGiXT} + APP_URI: ${APP_URI-http://agixtinteractive:3437} + AUTH_WEB: ${AUTH_WEB-http://agixtinteractive:3437/user} + LOG_VERBOSITY_SERVER: 3 + THEME_NAME: ${THEME_NAME} + ALLOW_EMAIL_SIGN_IN: ${ALLOW_EMAIL_SIGN_IN-true} + AOL_CLIENT_ID: ${AOL_CLIENT_ID} + APPLE_CLIENT_ID: ${APPLE_CLIENT_ID} + AUTODESK_CLIENT_ID: ${AUTODESK_CLIENT_ID} + AWS_CLIENT_ID: ${AWS_CLIENT_ID} + AWS_REGION: ${AWS_REGION} + AWS_USER_POOL_ID: ${AWS_USER_POOL_ID} + BATTLENET_CLIENT_ID: ${BATTLENET_CLIENT_ID} + BITBUCKET_CLIENT_ID: ${BITBUCKET_CLIENT_ID} + BITLY_ACCESS_TOKEN: ${BITLY_ACCESS_TOKEN} + BITLY_CLIENT_ID: ${BITLY_CLIENT_ID} + CF_CLIENT_ID: ${CF_CLIENT_ID} + CLEAR_SCORE_CLIENT_ID: ${CLEAR_SCORE_CLIENT_ID} + DEUTSCHE_TELKOM_CLIENT_ID: ${DEUTSCHE_TELKOM_CLIENT_ID} + DEVIANTART_CLIENT_ID: ${DEVIANTART_CLIENT_ID} + DISCORD_CLIENT_ID: ${DISCORD_CLIENT_ID} + DROPBOX_CLIENT_ID: ${DROPBOX_CLIENT_ID} + FACEBOOK_CLIENT_ID: ${FACEBOOK_CLIENT_ID} + FATSECRET_CLIENT_ID: ${FATSECRET_CLIENT_ID} + FITBIT_CLIENT_ID: ${FITBIT_CLIENT_ID} + FORMSTACK_CLIENT_ID: ${FORMSTACK_CLIENT_ID} + FOURSQUARE_CLIENT_ID: ${FOURSQUARE_CLIENT_ID} + GITHUB_CLIENT_ID: ${GITHUB_CLIENT_ID} + GITLAB_CLIENT_ID: ${GITLAB_CLIENT_ID} + GOOGLE_CLIENT_ID: ${GOOGLE_CLIENT_ID} + HUDDLE_CLIENT_ID: ${HUDDLE_CLIENT_ID} + IMGUR_CLIENT_ID: ${IMGUR_CLIENT_ID} + INSTAGRAM_CLIENT_ID: ${INSTAGRAM_CLIENT_ID} + INTEL_CLIENT_ID: ${INTEL_CLIENT_ID} + JIVE_CLIENT_ID: ${JIVE_CLIENT_ID} + KEYCLOAK_CLIENT_ID: ${KEYCLOAK_CLIENT_ID} + KEYCLOAK_REALM: ${KEYCLOAK_REALM} + KEYCLOAK_SERVER_URL: ${KEYCLOAK_SERVER_URL} + LINKEDIN_CLIENT_ID: ${LINKEDIN_CLIENT_ID} + MICROSOFT_CLIENT_ID: ${MICROSOFT_CLIENT_ID} + NETIQ_CLIENT_ID: ${NETIQ_CLIENT_ID} + OKTA_CLIENT_ID: ${OKTA_CLIENT_ID} + OKTA_DOMAIN: ${OKTA_DOMAIN} + OPENAM_BASE_URL: ${OPENAM_BASE_URL} + OPENAM_CLIENT_ID: ${OPENAM_CLIENT_ID} + ORCID_CLIENT_ID: ${ORCID_CLIENT_ID} + OSM_CLIENT_ID: ${OSM_CLIENT_ID} + PAYPAL_CLIENT_ID: ${PAYPAL_CLIENT_ID} + PING_IDENTITY_CLIENT_ID: ${PING_IDENTITY_CLIENT_ID} + PIXIV_CLIENT_ID: ${PIXIV_CLIENT_ID} + REDDIT_CLIENT_ID: ${REDDIT_CLIENT_ID} + SALESFORCE_CLIENT_ID: ${SALESFORCE_CLIENT_ID} + SPOTIFY_CLIENT_ID: ${SPOTIFY_CLIENT_ID} + STACKEXCHANGE_CLIENT_ID: ${STACKEXCHANGE_CLIENT_ID} + STRAVA_CLIENT_ID: ${STRAVA_CLIENT_ID} + STRIPE_CLIENT_ID: ${STRIPE_CLIENT_ID} + TWITCH_CLIENT_ID: ${TWITCH_CLIENT_ID} + VIADEO_CLIENT_ID: ${VIADEO_CLIENT_ID} + VIMEO_CLIENT_ID: ${VIMEO_CLIENT_ID} + VK_CLIENT_ID: ${VK_CLIENT_ID} + WECHAT_CLIENT_ID: ${WECHAT_CLIENT_ID} + WEIBO_CLIENT_ID: ${WEIBO_CLIENT_ID} + WITHINGS_CLIENT_ID: ${WITHINGS_CLIENT_ID} + XERO_CLIENT_ID: ${XERO_CLIENT_ID} + XING_CLIENT_ID: ${XING_CLIENT_ID} + YAHOO_CLIENT_ID: ${YAHOO_CLIENT_ID} + YAMMER_CLIENT_ID: ${YAMMER_CLIENT_ID} + YANDEX_CLIENT_ID: ${YANDEX_CLIENT_ID} + YELP_CLIENT_ID: ${YELP_CLIENT_ID} + ZENDESK_CLIENT_ID: ${ZENDESK_CLIENT_ID} + ZENDESK_SUBDOMAIN: ${ZENDESK_SUBDOMAIN} + TZ: ${TZ-America/New_York} + ports: + - 3437:3437 + restart: unless-stopped + volumes: + - ./node_modules:/app/node_modules diff --git a/docs/2-Concepts/09-Agent Training.md b/docs/2-Concepts/09-Agent Training.md index 61576e82265c..f3fcd1a04418 100644 --- a/docs/2-Concepts/09-Agent Training.md +++ b/docs/2-Concepts/09-Agent Training.md @@ -47,7 +47,7 @@ agent_name="gpt4free" agixt.learn_github_repo( agent_name=agent_name, github_repo="Josh-XT/AGiXT", - collection_number=0, + collection_number="0", ) # Create a synthetic dataset in DPO/CPO/ORPO format. @@ -77,7 +77,7 @@ agent_name="gpt4free" agixt.learn_github_repo( agent_name=agent_name, github_repo="Josh-XT/AGiXT", - collection_number=0, + collection_number="0", ) # Train the desired model on a synthetic DPO dataset created based on the agents memories. diff --git a/docs/4-Authentication/amazon.md b/docs/4-Authentication/amazon.md new file mode 100644 index 000000000000..b3e6a53a4cc8 --- /dev/null +++ b/docs/4-Authentication/amazon.md @@ -0,0 +1,39 @@ +# Amazon SSO + +## Overview + +This module provides Single Sign-On (SSO) functionality using AWS Cognito, allowing users to authenticate and fetch user information. Additionally, it includes the functionality to send emails using Amazon SES (Simple Email Service). + +## Required Environment Variables + +To use the Amazon SSO module, you need to set up the following environment variables. These credentials can be acquired from the AWS Management Console. + +```plaintext +AWS_CLIENT_ID: AWS Cognito OAuth client ID +AWS_CLIENT_SECRET: AWS Cognito OAuth client secret +AWS_USER_POOL_ID: AWS Cognito User Pool ID +AWS_REGION: AWS Cognito Region +``` + +### Step-by-Step Guide to Acquire the Required Keys + +1. **AWS Client ID and Client Secret**: + - Navigate to the [Amazon Cognito Console](https://console.aws.amazon.com/cognito/home). + - Click on **Manage User Pools** and select the user pool you have set up for your application. + - Navigate to the **App integration** section. + - Under **App clients and analytics**, find your app client or create one by clicking **Add an app client**. + - Save the **App client id** and **App client secret** as they will be your `AWS_CLIENT_ID` and `AWS_CLIENT_SECRET`. + +2. **AWS User Pool ID**: + - In the Cognito User Pool you've set up, the **User Pool ID** is displayed at the top of the **General settings** section in the details page of your user pool. Assign this value to `AWS_USER_POOL_ID`. + +3. **AWS Region**: + - The region in which your Cognito User Pool is located, such as `us-west-2`. Assign this value to `AWS_REGION`. + +## Required Scopes for AWS OAuth + +Include the following scopes in your OAuth configuration to get relevant user information and send emails: + +- openid +- email +- profile diff --git a/docs/4-Authentication/aol.md b/docs/4-Authentication/aol.md new file mode 100644 index 000000000000..2213c2a1120f --- /dev/null +++ b/docs/4-Authentication/aol.md @@ -0,0 +1,33 @@ +# AOL SSO Integration Documentation + +This documentation guides you through setting up Single Sign-On (SSO) integration with AOL using OAuth. The OAuth tokens allow you to fetch user information and send emails on behalf of the user. Please note that the endpoints and scopes used in this example are hypothetical and may not reflect the actual endpoints provided by AOL. + +## Required Environment Variables + +To use AOL SSO, you need to provide specific environment variables. These are necessary for the OAuth client to function correctly. + +- `AOL_CLIENT_ID`: Your AOL OAuth client ID. +- `AOL_CLIENT_SECRET`: Your AOL OAuth client secret. + +Ensure these variables are set in your environment. + +## Steps to Set Up AOL SSO Integration + +1. **Acquire OAuth Client ID and Secret** + + - Visit the AOL Developer website and log in to your account. + - Navigate to the OAuth section and create a new OAuth application. + - Note down the `AOL_CLIENT_ID` and `AOL_CLIENT_SECRET` provided by AOL. + - Make sure to enable the following scopes for your application: + - `https://api.aol.com/userinfo.profile` + - `https://api.aol.com/userinfo.email` + - `https://api.aol.com/mail.send` + +2. **Set the Environment Variables** + + Add the acquired client ID and secret to your environment. This could be done by adding them to a `.env` file in the root of your project: + + ```plaintext + AOL_CLIENT_ID=your_aol_client_id_here + AOL_CLIENT_SECRET=your_aol_client_secret_here + ``` diff --git a/docs/4-Authentication/apple.md b/docs/4-Authentication/apple.md new file mode 100644 index 000000000000..5e7e91ec5fbc --- /dev/null +++ b/docs/4-Authentication/apple.md @@ -0,0 +1,86 @@ +# Apple + +Apple Sign-In (SSO) allows you to use Apple's secure authentication service for user login and data retrieval. Setting up Apple SSO involves creating the required OAuth credentials and configuring your environment variables. + +## Required Environment Variables + +Before you begin, ensure you have the following environment variables set up in your `.env` file: + +- `APPLE_CLIENT_ID`: Your Apple OAuth client ID +- `APPLE_CLIENT_SECRET`: Your Apple OAuth client secret + +## Acquiring Apple OAuth Client ID and Client Secret + +1. **Create an Apple Developer Account**: If you don't have an Apple Developer account, you will need to [register](https://developer.apple.com/programs/). + +2. **Create an App ID**: + - Go to [Apple Developer Portal](https://developer.apple.com/account/ios/identifier/bundle). + - Under **Certificates, Identifiers & Profiles**, select **Identifiers**. + - Click the "+" button to create a new App ID. + - Register your App ID and ensure it has the appropriate capabilities for Sign-In with Apple. + +3. **Configure Sign-In with Apple**: + - After creating the App ID, configure it to use Sign-In with Apple. + - Navigate to the **Keys** section in the [Apple Developer Portal](https://developer.apple.com/account/resources/authkeys/list). + - Click the "+" button to create a new key. + - Select the Sign In with Apple capability. + - Download the key after generating it and keep it secure. + +4. **Create a Service ID**: + - Go to **Certificates, Identifiers & Profiles** > **Identifiers** > **Service IDs**. + - Click the "+" button to create a new Service ID. + - Enable **Sign-In with Apple** for this Service ID. + +5. **Configure Redirect URI**: + - Under the Sign-In with Apple configuration for your Service ID, add a redirect URI that matches the one used in your OAuth flow. + +6. **Generate Apple Client Secret**: + - The `APPLE_CLIENT_SECRET` is a JWT generated using your Apple private key and other details. + - Use libraries such as `PyJWT` to generate this JWT. + + ```python + import jwt + import time + from uuid import uuid4 + + PRIVATE_KEY = "-----BEGIN PRIVATE KEY-----\n...your private key...\n-----END PRIVATE KEY-----\n" + + def generate_client_secret(): + headers = { + "kid": "YOUR_KEY_ID", + "alg": "ES256", + } + claims = { + "iss": "YOUR_TEAM_ID", + "iat": int(time.time()), + "exp": int(time.time()) + 86400*180, + "aud": "https://appleid.apple.com", + "sub": "YOUR_SERVICE_ID", + } + client_secret = jwt.encode(claims, PRIVATE_KEY, headers=headers, algorithm="ES256") + return client_secret + + APPLE_CLIENT_SECRET = generate_client_secret() + ``` + +### Setting Up the Environment Variables + +Once you have the `APPLE_CLIENT_ID` and `APPLE_CLIENT_SECRET`, add them to your `.env` file: + +```env +APPLE_CLIENT_ID=your_apple_client_id +APPLE_CLIENT_SECRET=your_apple_client_secret +``` + +### Required Scopes for Apple SSO + +The required scopes for Apple SSO include: + +- `name`: To get the user's name. +- `email`: To get the user's email address. + +### Additional Notes + +- This implementation uses placeholder values for `first_name`, `last_name`, and `email` which should be replaced with actual logic to capture user information during the initial token exchange. +- The `send_email` function is not implemented because Apple OAuth does not support sending emails directly via API. +- Ensure all sensitive information such as private keys and client secrets are securely stored and managed. diff --git a/docs/4-Authentication/autodesk.md b/docs/4-Authentication/autodesk.md new file mode 100644 index 000000000000..5d4185ea82d6 --- /dev/null +++ b/docs/4-Authentication/autodesk.md @@ -0,0 +1,45 @@ +# Autodesk Single Sign-On (SSO) using OAuth + +The `AutodeskSSO` class provides a way to authenticate Autodesk users and retrieve their profile information from Autodesk's API. This guide describes how to configure and use the Autodesk SSO mechanism in your application. + +## Required Environment Variables + +To use Autodesk SSO in your project, you need to set up the Autodesk OAuth client credentials: + +- `AUTODESK_CLIENT_ID`: This is the client ID obtained from Autodesk when you create an OAuth application. +- `AUTODESK_CLIENT_SECRET`: This is the client secret obtained from Autodesk. + +## Steps to Obtain the Autodesk Client ID and Secret + +1. **Register your application**: + - Go to the [Autodesk Developer Portal](https://forge.autodesk.com). + - Sign in with your Autodesk account. + - Navigate to **My Apps**. + - Click **Create App**. + - Fill out the form with the necessary details and submit. + - Once the application is created, you will receive the `Client ID` and `Client Secret`. + +2. **Add Environment Variables**: + - Create or update your `.env` file to include: + + ```env + AUTODESK_CLIENT_ID=your_client_id_here + AUTODESK_CLIENT_SECRET=your_client_secret_here + ``` + +### Required APIs + +Ensure you have the following APIs enabled in your Autodesk Developer account: + +1. Click the links below to confirm you have enabled the APIs required for Autodesk's OAuth process: + - [Data Management API](https://forge.autodesk.com/en/docs/data/v2/developers_guide/overview/) + - [User Profile API](https://forge.autodesk.com/en/docs/oauth/v2/developers_guide/scopes/) + +### Required Scopes for Autodesk OAuth + +When setting up OAuth, the following scopes should be included: + +- `data:read` +- `data:write` +- `bucket:read` +- `bucket:create` diff --git a/docs/4-Authentication/battlenet.md b/docs/4-Authentication/battlenet.md new file mode 100644 index 000000000000..d9cfcb6c04b8 --- /dev/null +++ b/docs/4-Authentication/battlenet.md @@ -0,0 +1,46 @@ +# Battle.net SSO Integration + +This guide will help you set up Battle.net single sign-on (SSO) integration using OAuth2. Follow the instructions to acquire the necessary keys and configure your environment variables. + +## Required Environment Variables + +Before you begin, make sure to add the following environment variables to your `.env` file: + +- `BATTLENET_CLIENT_ID`: Battle.net OAuth client ID +- `BATTLENET_CLIENT_SECRET`: Battle.net OAuth client secret + +### Obtaining Battle.net Client ID and Client Secret + +To get your Battle.net Client ID and Client Secret, follow these steps: + +1. **Create a Battle.net Developer Account:** + - Go to the [Battle.net Developer Portal](https://develop.battle.net/access/). + - Sign in using your Battle.net account credentials. + +2. **Create an Application:** + - Navigate to the "Create Client" section and fill out the required details about your application. + - After creating the application, you will be provided with a `Client ID` and `Client Secret`. + +3. **Enable APIs:** + - Ensure that the necessary APIs are enabled for your Battle.net application. This generally includes the OAuth2 authentication API. + +4. **Add Redirect URI:** + - Configure the redirect URI to match the URL where you want to receive the authorization code. + +### Required Scopes for Battle.net OAuth + +Ensure you request the following scopes when setting up your SSO integration: + +- `openid` +- `email` + +These scopes will grant your application access to basic Battle.net profile information and the user's email address. + +### Environment Variables Configuration + +After you have your `Client ID` and `Client Secret`, add them to your `.env` file like so: + +```env +BATTLENET_CLIENT_ID=your_battlenet_client_id +BATTLENET_CLIENT_SECRET=your_battlenet_client_secret +``` diff --git a/docs/4-Authentication/bitbucket.md b/docs/4-Authentication/bitbucket.md new file mode 100644 index 000000000000..e5f1177b6400 --- /dev/null +++ b/docs/4-Authentication/bitbucket.md @@ -0,0 +1,43 @@ +# Bitbucket SSO Integration + +The `BitbucketSSO` class facilitates Single Sign-On (SSO) via Bitbucket. This integration enables your application to authenticate users with their Bitbucket accounts, providing an easy way for users to log in without creating a new account on your platform. + +## Required Environment Variables + +To set up Bitbucket SSO, you need to obtain and set the following environment variables: + +- `BITBUCKET_CLIENT_ID`: Bitbucket OAuth client ID +- `BITBUCKET_CLIENT_SECRET`: Bitbucket OAuth client secret + +## Step-by-Step Guide + +### 1. Register Your Application on Bitbucket + +1. Visit the Bitbucket developer portal: [Bitbucket OAuth Settings](https://bitbucket.org/account/settings/app-passwords/). +2. Log in with your Bitbucket account. +3. Navigate to "OAuth" under "Access Management." +4. Click on "Add consumer." +5. Fill in the required details: + - **Name**: A name for your application. + - **Description**: A brief description of what the application does. + - **Callback URL**: The URL to which Bitbucket will send users after they authorize. +6. Select the necessary scopes for your application. For Bitbucket SSO, you need at least: + - `account` + - `email` +7. Save the consumer to get the Client ID and Client Secret. + +### 2. Set Environment Variables + +Add the obtained credentials to your `.env` file: + +```env +BITBUCKET_CLIENT_ID=your_bitbucket_client_id +BITBUCKET_CLIENT_SECRET=your_bitbucket_client_secret +``` + +### 3. Required Scopes for Bitbucket SSO + +Ensure that you request the following scopes when redirecting users for authentication: + +- `account` +- `email` diff --git a/docs/4-Authentication/bitly.md b/docs/4-Authentication/bitly.md new file mode 100644 index 000000000000..b647db2394f6 --- /dev/null +++ b/docs/4-Authentication/bitly.md @@ -0,0 +1,38 @@ +# Bitly Integration + +The Bitly integration allows you to shorten URLs and manage Bitly tokens using the Bitly API. + +## Required Environment Variables + +- `BITLY_CLIENT_ID`: Bitly OAuth client ID +- `BITLY_CLIENT_SECRET`: Bitly OAuth client secret +- `BITLY_ACCESS_TOKEN`: Bitly access token (you can obtain it via OAuth or from the Bitly account settings) + +## Required Scopes for Bitly OAuth + +- `bitly:read` +- `bitly:write` + +## How to Acquire Required Keys and Tokens + +1. **Create a Bitly Account**: + - Sign up for a Bitly account at . + +2. **Create an OAuth App**: + - Navigate to . + - Click on "Registered OAuth Apps" and then "Add a New App". + - Fill in the required information to create a new app. The "App Name" and "App Description" can be anything, but for "Redirect URIs," you'll need to specify the URIs that Bitly will redirect to after authentication. + - After creating the app, you will receive a `Client ID` and `Client Secret`. + +3. **Generate an Access Token**: + - You can generate an access token from your Bitly account settings. + - Go to , and click on "Generic Access Token" to create a new token. + +4. **Set up Environment Variables**: + - Add the following environment variables to your `.env` file: + + ```plaintext + BITLY_CLIENT_ID=your_bitly_client_id + BITLY_CLIENT_SECRET=your_bitly_client_secret + BITLY_ACCESS_TOKEN=your_bitly_access_token + ``` diff --git a/docs/4-Authentication/clearscore.md b/docs/4-Authentication/clearscore.md new file mode 100644 index 000000000000..acb6ff524031 --- /dev/null +++ b/docs/4-Authentication/clearscore.md @@ -0,0 +1,33 @@ +# ClearScore Single Sign-On (SSO) Integration Documentation + +This document details how to integrate ClearScore SSO into your application. It includes setup steps, environment variable configurations, and API requirements. + +## Required Environment Variables + +To use ClearScore SSO, you need to configure the following environment variables in your `.env` file: + +- `CLEAR_SCORE_CLIENT_ID`: ClearScore OAuth client ID +- `CLEAR_SCORE_CLIENT_SECRET`: ClearScore OAuth client secret + +## Acquiring ClearScore OAuth Credentials + +1. **Register Your Application**: Visit the ClearScore API developer portal and register your application. +2. **Obtain Client Credentials**: After registration, you will receive a `Client ID` and `Client Secret`. +3. **Set Environment Variables**: Add the `CLEAR_SCORE_CLIENT_ID` and `CLEAR_SCORE_CLIENT_SECRET` values to your `.env` file in the following format: + +```plaintext +CLEAR_SCORE_CLIENT_ID=your_clear_score_client_id +CLEAR_SCORE_CLIENT_SECRET=your_clear_score_client_secret +``` + +## Required APIs + +To interact with ClearScore's OAuth and email sending capabilities, ensure your application requests the following scopes: + +- `user.info.read` +- `email.send` + +## Scope Descriptions + +- **`user.info.read`**: Allows reading of user profile information. +- **`email.send`**: Allows sending emails on behalf of the user. diff --git a/docs/4-Authentication/cloud_foundry.md b/docs/4-Authentication/cloud_foundry.md new file mode 100644 index 000000000000..548c08b645ae --- /dev/null +++ b/docs/4-Authentication/cloud_foundry.md @@ -0,0 +1,51 @@ +# Cloud Foundry Single Sign-On (SSO) Integration + +This document describes how to set up and use Cloud Foundry SSO in your application. Follow the steps below to configure the integration, acquire the necessary keys, and set up your environment. + +## Required Environment Variables + +- `CF_CLIENT_ID`: Cloud Foundry OAuth client ID +- `CF_CLIENT_SECRET`: Cloud Foundry OAuth client secret + +## Required APIs and Scopes + +You need to enable the following APIs and ensure the appropriate scopes are configured. + +- **Cloud Foundry API (CF API)** +- **User Info API** + +### Steps to Acquire Required Keys + +1. **Log in to your Cloud Foundry Account:** + Go to your Cloud Foundry provider�s management console and log in using your credentials. + +2. **Register an Application:** + Navigate to the OAuth Applications section. Create a new OAuth application. + +3. **Obtain Client ID and Client Secret:** + Once the application is created, you will receive a `CLIENT_ID` and `CLIENT_SECRET`. Make a note of these values as you will need to set them as environment variables. + +4. **Set up Redirect URIs:** + Specify the redirect URIs required for your application. These should point to the appropriate endpoints in your application handling OAuth redirects. + +5. **Enable CF OAuth and User Info API:** + Ensure that the Cloud Foundry OAuth and User Info APIs are enabled for your account or organization. This often involves checking specific settings in the Cloud Foundry management console. + +## Required Scopes for Cloud Foundry SSO + +Ensure that the following OAuth scopes are included in your application's authorization request: + +- `openid` +- `profile` +- `email` + +## Setting Up Environment Variables + +After acquiring the necessary credentials, set up your environment variables. Add the following lines to your application's `.env` file: + +```env +CF_CLIENT_ID=YOUR_CLOUD_FOUNDRY_CLIENT_ID +CF_CLIENT_SECRET=YOUR_CLOUD_FOUNDRY_CLIENT_SECRET +``` + +Replace `YOUR_CLOUD_FOUNDRY_CLIENT_ID`, `YOUR_CLOUD_FOUNDRY_CLIENT_SECRET`, and `YOUR_APPLICATION_REDIRECT_URI` with the actual values you obtained during the setup process. diff --git a/docs/4-Authentication/deutsche_telekom.md b/docs/4-Authentication/deutsche_telekom.md new file mode 100644 index 000000000000..16b670084016 --- /dev/null +++ b/docs/4-Authentication/deutsche_telekom.md @@ -0,0 +1,63 @@ +# Documentation for Deutsche Telekom SSO + +## Overview + +The provided script integrates the Deutsche Telekom Single Sign-On (SSO) service. It allows users to authenticate with their Deutsche Telekom credentials and access their profile and email services. + +## Requirements + +To successfully use the Deutsche Telekom SSO integration, you need to set up the following environment, APIs, and scopes. + +### Required Environment Variables + +- `DEUTSCHE_TELKOM_CLIENT_ID`: Deutsche Telekom OAuth client ID +- `DEUTSCHE_TELKOM_CLIENT_SECRET`: Deutsche Telekom OAuth client secret + +### Required APIs + +Ensure you have access to the following API endpoint for Deutsche Telekom: + +- `https://www.deutschetelekom.com/ldap-sso` + +### Required Scopes + +The following OAuth scopes are required for Deutsche Telekom SSO: + +- `t-online-profile`: Access to profile data +- `t-online-email`: Access to email services + +## Setup Instructions + +### 1. Registering Your Application + +Before using the Deutsche Telekom SSO service, you need to register your application to obtain the `CLIENT_ID` and `CLIENT_SECRET`. + +#### Steps to Register + +1. Navigate to the Deutsche Telekom Developer Portal. +2. Log in with your Deutsche Telekom account. +3. Register a new application to get the OAuth credentials. +4. Note down the `Client ID` and `Client Secret` provided by Deutsche Telekom after registering your application. + +### 2. Setting Up Environment Variables + +Once you have the `Client ID` and `Client Secret`, set up the following environment variables in your system: + +```sh +export DEUTSCHE_TELKOM_CLIENT_ID=your_client_id_here +export DEUTSCHE_TELKOM_CLIENT_SECRET=your_client_secret_here +``` + +Where: + +- `DEUTSCHE_TELKOM_CLIENT_ID` is the Client ID you obtained from the Deutsche Telekom Developer Portal. +- `DEUTSCHE_TELKOM_CLIENT_SECRET` is the Client Secret provided by Deutsche Telekom. + +> Ensure that you replace `your_client_id_here`, `your_client_secret_here`, and `your_redirect_uri_here` with the actual values. + +### 3. Required Scopes + +Make sure the Deutsche Telekom application is configured to request the following OAuth scopes: + +- `t-online-profile` +- `t-online-email` diff --git a/docs/4-Authentication/deviantart.md b/docs/4-Authentication/deviantart.md new file mode 100644 index 000000000000..9b4a585600f6 --- /dev/null +++ b/docs/4-Authentication/deviantart.md @@ -0,0 +1,39 @@ +# DeviantArt SSO Integration Guide + +This guide will help you set up and integrate deviantART Single Sign-On (SSO) using OAuth2 in your application. Please follow the steps carefully to ensure seamless integration. + +## Required Environment Variables + +To start using deviantART SSO, you need to set up the following environment variables in your environment. You can add these variables to your `.env` file: + +- `DEVIANTART_CLIENT_ID`: deviantART OAuth client ID +- `DEVIANTART_CLIENT_SECRET`: deviantART OAuth client secret + +## How to Acquire deviantART Client ID and Client Secret + +1. **Register your Application**: + - Log in to your deviantART account. + - Navigate to the [deviantART OAuth2 Application Registration page](https://www.deviantart.com/settings/applications). + - Click on "Register a new application" to create a new application. + - Fill out the necessary details such as application name, description, and set the redirect URI (e.g., `http://localhost:8000/callback` for local testing). + +2. **Retrieve Client ID and Client Secret**: + - After successfully registering your application, you will be provided with a `Client ID` and `Client Secret`. + - Store these credentials in a safe place. + - Add them to your `.env` file as follows: + + ```env + DEVIANTART_CLIENT_ID=your_client_id + DEVIANTART_CLIENT_SECRET=your_client_secret + ``` + +### Required OAuth Scopes for deviantART + +The following OAuth scopes are required for deviantART SSO to work properly: + +- `user` +- `browse` +- `stash` +- `send_message` + +These scopes allow the application to access user information, browse deviantART content, access stash, and send messages on behalf of the user. diff --git a/docs/4-Authentication/discord.md b/docs/4-Authentication/discord.md new file mode 100644 index 000000000000..f96f5c131610 --- /dev/null +++ b/docs/4-Authentication/discord.md @@ -0,0 +1,44 @@ +# Discord Single Sign-On (SSO) + +## Overview + +This module allows you to integrate Discord's OAuth2 functionality into your application, enabling Single Sign-On using Discord credentials. This can be useful for authentication and fetching user information such as their Discord username, discriminator, and email. + +## Required Environment Variables + +To use the Discord SSO functionality, you need to create a Discord application and obtain the necessary credentials. The required environment variables are: + +- `DISCORD_CLIENT_ID`: Discord OAuth client ID +- `DISCORD_CLIENT_SECRET`: Discord OAuth client secret + +These variables should be added to your `.env` file for ease of use and security. + +### How to Obtain DISCORD_CLIENT_ID and DISCORD_CLIENT_SECRET + +1. **Log in to the Discord Developer Portal:** + Go to the [Discord Developer Portal](https://discord.com/developers/applications) and log in with your Discord account. + +2. **Create a New Application:** + - Click on the "New Application" button. + - Provide a name for your application and click "Create". + +3. **Set Up OAuth2:** + - Navigate to the "OAuth2" tab in your application's settings. + - Under the "OAuth2" section, you will see your `CLIENT_ID` and `CLIENT_SECRET`. Copy these values. + +4. **Enable OAuth2 Scopes:** + - Scroll down to the "OAuth2 URL Generator" section. + - Select the `email` scope to ensure you have permission to access the user's email. + +5. **Add Credentials to Your `.env` File:** + + ```env + DISCORD_CLIENT_ID=your_client_id_here + DISCORD_CLIENT_SECRET=your_client_secret_here + ``` + +### Required APIs and Scopes + +Make sure you have the following scopes set up for your Discord application: + +- `email` (Refer to the [Discord OAuth2 Scopes Documentation](https://discord.com/developers/docs/topics/oauth2#shared-resources-oauth2-scopes)) diff --git a/docs/4-Authentication/dropbox.md b/docs/4-Authentication/dropbox.md new file mode 100644 index 000000000000..31b4aae5c6b1 --- /dev/null +++ b/docs/4-Authentication/dropbox.md @@ -0,0 +1,44 @@ +# Dropbox SSO Integration + +This document describes how to integrate Dropbox Single Sign-On (SSO) with your application. By following these instructions, you will be able to allow users to authenticate with Dropbox and access their Dropbox account information and files. + +## Required Environment Variables + +Before you start, you need to obtain the necessary credentials and set up environment variables: + +1. **DROPBOX_CLIENT_ID**: Your Dropbox OAuth client ID. +2. **DROPBOX_CLIENT_SECRET**: Your Dropbox OAuth client secret. + +### Acquiring Dropbox OAuth Credentials + +To obtain the necessary credentials from Dropbox: + +1. **Create a Dropbox App**: + - Visit the [Dropbox App Console](https://www.dropbox.com/developers/apps). + - Click on "Create App". + - Choose an API (Scoped access). + - Select the type of access you need: "Full Dropbox" or "App Folder". + - Name your app and click "Create App". + +2. **Get Your App Credentials**: + - Navigate to the "Settings" tab of your app in the Dropbox App Console. + - You will find your `App key` (use this as `DROPBOX_CLIENT_ID`) and `App secret` (use this as `DROPBOX_CLIENT_SECRET`). + +3. **Set the Redirect URI**: + - In the "OAuth 2" section in the settings tab, add your redirect URI (e.g., `https://yourapp.com/auth/dropbox/callback`). + +### Required Scopes for Dropbox OAuth + +When setting up OAuth access, ensure that you enable the following scopes: + +- `account_info.read`: Required to access user account information. +- `files.metadata.read`: Required to read the metadata for files in the user's Dropbox. + +### Setting Environment Variables + +Add the following environment variables to your `.env` file: + +```env +DROPBOX_CLIENT_ID=your_dropbox_client_id +DROPBOX_CLIENT_SECRET=your_dropbox_client_secret +``` diff --git a/docs/4-Authentication/facebook.md b/docs/4-Authentication/facebook.md new file mode 100644 index 000000000000..8cf1637e4c2d --- /dev/null +++ b/docs/4-Authentication/facebook.md @@ -0,0 +1,37 @@ +# Facebook + +## Required Environment Variables + +To integrate Facebook's OAuth into your application, you need to provide the following environment variables: + +- `FACEBOOK_CLIENT_ID`: Facebook OAuth client ID +- `FACEBOOK_CLIENT_SECRET`: Facebook OAuth client secret + +These values should be added to your `.env` file for secure and convenient access by your application. + +## Steps to Obtain Facebook OAuth Credentials + +1. **Create a Facebook App:** + - Navigate to the [Facebook for Developers](https://developers.facebook.com/apps) page. + - Click on "Create App" and select an appropriate app type. + - Once created, go to your app's dashboard. + +2. **Get the Client ID and Client Secret:** + - In your app's dashboard, click on "Settings" and then "Basic." + - Here, you will find your App ID (Client ID) and App Secret (Client Secret). + +3. **Add these values to your .env File:** + Create a `.env` file in your project root (if it doesn't already exist) and add the following lines: + + ```env + FACEBOOK_CLIENT_ID=your_facebook_client_id + FACEBOOK_CLIENT_SECRET=your_facebook_client_secret + ``` + +## Required Scopes for Facebook OAuth + +To ensure proper integration and to access specific user data, the following scopes must be requested during the OAuth authorization process: + +- `public_profile` +- `email` +- `pages_messaging` (for sending messages, if applicable) diff --git a/docs/4-Authentication/fatsecret.md b/docs/4-Authentication/fatsecret.md new file mode 100644 index 000000000000..3c53b9ad9fd4 --- /dev/null +++ b/docs/4-Authentication/fatsecret.md @@ -0,0 +1,25 @@ +# FatSecret + +## Required environment variables + +- `FATSECRET_CLIENT_ID`: FatSecret OAuth client ID +- `FATSECRET_CLIENT_SECRET`: FatSecret OAuth client secret + +Ensure that these environment variables are added to your `.env` file. + +## Required APIs + +To use FatSecret's services, you need to register your application and obtain client credentials by following these steps: + +1. Go to the [FatSecret Platform](https://platform.fatsecret.com/api/). +2. Click on "Sign Up" to create an account or log in if you already have one. +3. Once logged in, create a new application to get your `client_id` and `client_secret`. + +## Setting up your environment variables + +After acquiring your `FATSECRET_CLIENT_ID` and `FATSECRET_CLIENT_SECRET`, add them to your `.env` file like this: + +```plaintext +FATSECRET_CLIENT_ID=your_fatsecret_client_id +FATSECRET_CLIENT_SECRET=your_fatsecret_client_secret +``` diff --git a/docs/4-Authentication/fitbit.md b/docs/4-Authentication/fitbit.md new file mode 100644 index 000000000000..deecb0300d56 --- /dev/null +++ b/docs/4-Authentication/fitbit.md @@ -0,0 +1,45 @@ +# Fitbit + +## Required environment variables + +- `FITBIT_CLIENT_ID`: Fitbit OAuth client ID +- `FITBIT_CLIENT_SECRET`: Fitbit OAuth client secret + +## Required APIs + +Before using the Fitbit SSO, you need to confirm that you have the necessary APIs enabled and have acquired the required environment variables (FITBIT_CLIENT_ID, FITBIT_CLIENT_SECRET). Follow the steps below to do this: + +1. **Create a Fitbit Developer Account**: + - Go to the Fitbit dev portal: [Fitbit Developer](https://dev.fitbit.com/) + - Create an account or log in if you already have one. + +2. **Register Your Application**: + - Navigate to the "Manage my Apps" section. + - Click on "Register a New Application". + - Fill out the application details. + - Set the OAuth 2.0 Application Type to "Personal" or "Server". + +3. **Obtain Client ID and Client Secret**: + - After registering your application, Fitbit will provide you with a **Client ID** and **Client Secret**. + +4. **Set Up Environment Variables**: + - Create or update your `.env` file with the following: + + ```env + FITBIT_CLIENT_ID=your_fitbit_client_id + FITBIT_CLIENT_SECRET=your_fitbit_client_secret + ``` + +## Required Scopes for Fitbit OAuth + +When configuring OAuth for Fitbit, you need to request the appropriate permissions (scopes). Below are the required scopes: + +- `activity` +- `heartrate` +- `location` +- `nutrition` +- `profile` +- `settings` +- `sleep` +- `social` +- `weight` diff --git a/docs/4-Authentication/formstack.md b/docs/4-Authentication/formstack.md new file mode 100644 index 000000000000..eb2d6c21f968 --- /dev/null +++ b/docs/4-Authentication/formstack.md @@ -0,0 +1,40 @@ +# Formstack + +## Required Environment Variables + +To use the FormstackSSO class and its methods, you need to set up the following environment variables in your `.env` file: + +- `FORMSTACK_CLIENT_ID`: Your Formstack OAuth client ID +- `FORMSTACK_CLIENT_SECRET`: Your Formstack OAuth client secret + +You can get these credentials by following these steps: + +1. **Create a Formstack Application**: + - Log in to your Formstack account. + - Navigate to the "Account" section and select "API" from the sidebar menu. + - Click on "Add Application" to create a new app. + - Fill in the necessary details such as the app name and description. + - Make sure to note down the generated `Client ID` and `Client Secret` as you will need them to set up the environment variables. + +2. **Add Environment Variables**: + - Open your `.env` file. + - Add the following lines, replacing the placeholder values with your actual Formstack credentials: + + ```plaintext + FORMSTACK_CLIENT_ID=your_formstack_client_id + FORMSTACK_CLIENT_SECRET=your_formstack_client_secret + ``` + +## Required APIs + +Ensure that the necessary APIs are enabled in your Formstack account: + +- **User API**: This API allows you to access user information such as first name, last name, and email address. +- **Form API**: This API allows you to manage forms, including sending form submissions. + +## Required Scopes for Formstack OAuth + +When setting up OAuth for Formstack, make sure you request the following scopes: + +- `formstack:read`: Allows you to read data from Formstack, such as user information. +- `formstack:write`: Allows you to write data to Formstack, such as submitting form data. diff --git a/docs/4-Authentication/foursquare.md b/docs/4-Authentication/foursquare.md new file mode 100644 index 000000000000..9244d8ede22a --- /dev/null +++ b/docs/4-Authentication/foursquare.md @@ -0,0 +1,41 @@ +# Foursquare SSO Integration + +The following documentation will guide you through the steps necessary to set up and use Foursquare Single Sign-On (SSO) in your application. + +## Required Environment Variables + +Before you start, you need to have the following environment variables set in your `.env` file: + +- `FOURSQUARE_CLIENT_ID`: Your Foursquare OAuth client ID. +- `FOURSQUARE_CLIENT_SECRET`: Your Foursquare OAuth client secret. + +## Steps to Acquire Foursquare OAuth Credentials + +To obtain the `FOURSQUARE_CLIENT_ID` and `FOURSQUARE_CLIENT_SECRET`, follow these steps: + +1. **Create a Foursquare Developer Account:** + - Go to the [Foursquare Developer Portal](https://developer.foursquare.com/). + - Sign up or log in to your Foursquare account. + +2. **Create a New App:** + - Once logged in, go to the "My Apps" section. + - Click on "Create a New App". + - Fill in the required details about your application. + - After filling in the details, submit the form to create the app. + +3. **Retrieve Your Credentials:** + - After creating the app, you will be taken to your app's details page. + - Your `Client ID` and `Client Secret` will be displayed on this page. These are the values you need to add to your `.env` file. + +## Required APIs + +The basic Foursquare API does not require any specific scopes for accessing basic user information. Foursquare uses a userless access approach for its APIs. + +## How to Set Up Foursquare SSO + +Add the `FOURSQUARE_CLIENT_ID` and `FOURSQUARE_CLIENT_SECRET` environment variables to your `.env` file: + +```env +FOURSQUARE_CLIENT_ID=your-client-id +FOURSQUARE_CLIENT_SECRET=your-client-secret +``` diff --git a/docs/4-Authentication/github.md b/docs/4-Authentication/github.md new file mode 100644 index 000000000000..237e65973d2b --- /dev/null +++ b/docs/4-Authentication/github.md @@ -0,0 +1,43 @@ +# GitHub Single Sign-On Implementation + +This documentation details how to implement GitHub Single Sign-On (SSO) in your application using the provided `GitHubSSO` class and related functions. + +## Required Environment Variables + +To use the `GitHubSSO` class, you need to have the following environment variables set: + +- `GITHUB_CLIENT_ID`: GitHub OAuth client ID +- `GITHUB_CLIENT_SECRET`: GitHub OAuth client secret + +## Required Scopes for GitHub OAuth + +Ensure your GitHub OAuth application requests the following scopes to access the necessary user information: + +- `user:email` +- `read:user` + +## How to Acquire GitHub OAuth Client ID and Client Secret + +1. **Register a new OAuth application on GitHub:** + - Go to GitHub's developer settings: [GitHub Developer Settings](https://github.com/settings/developers) + - Click on `New OAuth App`. + - Fill in the required fields: + - **Application name**: Your application�s name. + - **Homepage URL**: The URL to your application's homepage. + - **Authorization callback URL**: The redirect URI where users will be sent after authorization. This should match the `redirect_uri` parameter in your authorization request. + - Click `Register application`. + +2. **Get the client credentials:** + - After registering, you will see your new application listed on the OAuth Apps page. + - Click on the application to see its details. + - Copy the `Client ID` and `Client Secret` to use as environment variables in your application. + +3. **Set Environment Variables:** + - Add the `Client ID` and `Client Secret` to your environment variables. This can be done in your `.env` file like so: + + ```env + GITHUB_CLIENT_ID=your_client_id + GITHUB_CLIENT_SECRET=your_client_secret + ``` + + - Replace `your_client_id` and `your_client_secret` with the actual values you copied from GitHub. diff --git a/docs/4-Authentication/gitlab.md b/docs/4-Authentication/gitlab.md new file mode 100644 index 000000000000..59ae4eb07140 --- /dev/null +++ b/docs/4-Authentication/gitlab.md @@ -0,0 +1,42 @@ +# GitLab SSO Integration + +This guide walks you through integrating GitLab single sign-on (SSO) with your application. Using GitLab SSO, you can enable users to authenticate using their GitLab accounts. + +## Required Environment Variables + +To set up GitLab SSO, two key environment variables need to be configured: + +- `GITLAB_CLIENT_ID`: This is the OAuth client ID from your GitLab application. +- `GITLAB_CLIENT_SECRET`: This is the OAuth client secret from your GitLab application. + +## Steps to Acquire GitLab Client ID and Client Secret + +1. **Create a GitLab OAuth Application:** + - Go to [GitLab Sign-In](https://gitlab.com/users/sign_in) and log in using your credentials. + - Go to your GitLab [Profile Settings](https://gitlab.com/profile/applications). + - Click on `New application`. + +2. **Configure the Application:** + - Enter the `Name` for your application (e.g., "MyAppSSO"). + - Fill in the `Redirect URI` field with the URL to which your application will redirect after successful authentication (e.g., `http://localhost:8000/callback`). + - Under `Scopes`, select `read_user`, `api`, and `email`. + - Click on `Save application`. + +3. **Retrieve Your Credentials:** + - After saving, GitLab will provide a `Application ID` (which corresponds to `GITLAB_CLIENT_ID`) and `Secret` (which corresponds to `GITLAB_CLIENT_SECRET`). + - Set these values in your environment variables or `.env` file: + + ```env + GITLAB_CLIENT_ID=your_client_id + GITLAB_CLIENT_SECRET=your_client_secret + ``` + +## Required Scopes for GitLab SSO + +When creating your OAuth application on GitLab, ensure that you select the following scopes: + +- `read_user`: Allows reading the authenticated user�s profile data. +- `api`: Full access to the authenticated user's API. +- `email`: Access to the authenticated user's email address. + +These scopes are necessary for retrieving user information such as name and email. diff --git a/docs/4-Authentication/google.md b/docs/4-Authentication/google.md new file mode 100644 index 000000000000..59bf3e06129a --- /dev/null +++ b/docs/4-Authentication/google.md @@ -0,0 +1,63 @@ +# Google SSO Module Documentation + +This module allows you to implement Google Single Sign-On (SSO) and send emails using the Gmail API. + +## Setup Instructions + +### Prerequisites + +Ensure you have the following prerequisites before proceeding: + +1. Python environment with necessary dependencies. +2. Google Cloud project with the required APIs enabled. + +### Step-by-Step Guide + +#### 1. Enable Required APIs + +To use this module, you need to enable two APIs in your Google Cloud project: + +- **People API:** This API is required to fetch user information such as names and email addresses. Enable it [here](https://console.cloud.google.com/marketplace/product/google/people.googleapis.com). +- **Gmail API:** This API is needed to send emails using Gmail. Enable it [here](https://console.cloud.google.com/marketplace/product/google/gmail.googleapis.com). + +#### 2. Obtain OAuth 2.0 Credentials + +Follow these steps to get your OAuth 2.0 credentials: + +1. **Create a Google Cloud Project:** + - Go to the [Google Cloud Console](https://console.cloud.google.com/). + - Click on the project dropdown and select **New Project**. + - Enter the project name and other required information and click **Create**. + +2. **Configure OAuth Consent Screen:** + - In the [Google Cloud Console](https://console.cloud.google.com/), navigate to **APIs & Services > OAuth consent screen**. + - Select **External** for user type if you are making it publicly accessible. + - Fill in the required fields like App name, User support email, Authorized domains, etc. + - Save the details. + +3. **Create OAuth 2.0 Client ID:** + - Go to **APIs & Services > Credentials**. + - Click on **Create Credentials** and choose **OAuth 2.0 Client ID**. + - Configure the application type. For web applications, you need to specify the **Authorized redirect URIs**. + - Save the credentials and note down the **Client ID** and **Client Secret**. + +#### 3. Set Environment Variables + +Add the obtained credentials to your environment variables. Create a `.env` file in your project root directory with the following content: + +```dotenv +GOOGLE_CLIENT_ID=your_google_client_id +GOOGLE_CLIENT_SECRET=your_google_client_secret +``` + +Replace `your_google_client_id` and `your_google_client_secret` with the values you obtained in the previous step. + +### Required Scopes + +The following OAuth 2.0 scopes are required for the module to function correctly: + +- `https://www.googleapis.com/auth/userinfo.profile` +- `https://www.googleapis.com/auth/userinfo.email` +- `https://www.googleapis.com/auth/gmail.send` + +Ensure these scopes are specified when requesting user consent. diff --git a/docs/4-Authentication/huddle.md b/docs/4-Authentication/huddle.md new file mode 100644 index 000000000000..1817e85d3df9 --- /dev/null +++ b/docs/4-Authentication/huddle.md @@ -0,0 +1,41 @@ +# Huddle SSO + +This module facilitates Single Sign-On (SSO) with Huddle and provides functionalities to retrieve user information and send emails via Huddle's API. + +## Required Environment Variables + +To utilize the Huddle SSO functionalities in this module, you need to set up the following environment variables in your `.env` file: + +- `HUDDLE_CLIENT_ID`: Huddle OAuth client ID +- `HUDDLE_CLIENT_SECRET`: Huddle OAuth client secret + +### How to Acquire the Required Keys + +1. **Create a Huddle App:** + - Visit the Huddle Developer Portal [Huddle Dev Portal](https://www.huddle.com/developers/). + - Log in with your Huddle account. + - Navigate to the 'Apps' section and create a new application. + - Fill in the necessary details, including redirect URI and scopes. + - Upon creation, you will be provided with a `Client ID` and a `Client Secret`. + +2. **Add Keys to `.env` File:** + - Open your `.env` file in the root of your project. + - Add the following lines: + + ```env + HUDDLE_CLIENT_ID=your_client_id_here + HUDDLE_CLIENT_SECRET=your_client_secret_here + ``` + +## Required APIs + +Ensure you have the necessary Huddle APIs enabled: + +- Make sure your created application in the Huddle Developer Portal has permissions for the required scopes listed below. + +## Required Scopes for Huddle OAuth + +Generate access tokens with the following scopes: + +- `user_info` +- `send_email` diff --git a/docs/4-Authentication/imgur.md b/docs/4-Authentication/imgur.md new file mode 100644 index 000000000000..d91c6aea82c9 --- /dev/null +++ b/docs/4-Authentication/imgur.md @@ -0,0 +1,36 @@ +# Imgur SSO Integration Documentation + +This documentation provides details on setting up and using the Imgur Single Sign-On (SSO) integration. The Imgur SSO integration allows users to authenticate and interact with Imgur's API for actions like uploading images and retrieving user information. + +## Required Environment Variables + +To use the Imgur SSO functionality, you'll need to set the following environment variables: + +- `IMGUR_CLIENT_ID`: Imgur OAuth client ID +- `IMGUR_CLIENT_SECRET`: Imgur OAuth client secret + +## Steps to Acquire Required Keys + +1. **Create an Imgur Application:** + - Go to the [Imgur API Applications page](https://api.imgur.com/oauth2/addclient). + - Log in with your Imgur account. + - Fill out the required fields to register a new application. You'll need to provide: + - **Application Name**: Choose a name for your application. + - **Authorization Type**: Select `OAuth 2 authorization with a callback URL`. + - **Authorization callback URL**: Enter the URL where you want users to be redirected after authorization (e.g., `http://localhost:3000/callback`). + - After submitting the form, you will receive the `Client ID` and `Client Secret`. These values are required for environment variables. + +2. **Set Environment Variables:** + - Add the obtained `Client ID` and `Client Secret` to your environment configuration file (e.g., `.env` file). + + ```plaintext + IMGUR_CLIENT_ID=your_imgur_client_id + IMGUR_CLIENT_SECRET=your_imgur_client_secret + ``` + +## Required Scopes for Imgur SSO + +To enable the required functionalities, ensure that your application requests the following scopes when users authenticate: + +- `read`: Allows reading user data and images. +- `write`: Allows uploading images and other write operations. diff --git a/docs/4-Authentication/instagram.md b/docs/4-Authentication/instagram.md new file mode 100644 index 000000000000..e07cd2c69ddd --- /dev/null +++ b/docs/4-Authentication/instagram.md @@ -0,0 +1,70 @@ +# Instagram SSO Documentation + +## Overview + +This module provides Single Sign-On (SSO) capabilities for Instagram using OAuth. The integration allows you to authenticate users, fetch user profile information, and publish media posts on behalf of users. + +## Required Environment Variables + +To start using Instagram SSO, you need to provide the following environment variables: + +- `INSTAGRAM_CLIENT_ID`: Instagram OAuth client ID +- `INSTAGRAM_CLIENT_SECRET`: Instagram OAuth client secret + +Add these environment variables to your `.env` file. + +## Required APIs + +Ensure that the Instagram Basic Display API is enabled for your application. This is necessary to authenticate users and fetch user profile information. + +## Required Scopes for Instagram OAuth + +When setting up your Instagram OAuth client, make sure to include the following scopes to request necessary permissions: + +- `user_profile` +- `user_media` + +## Setup Steps + +Follow these steps to acquire your keys and set up the required environment variables: + +### 1. Create or Sign into an Instagram Developer Account + +- Go to the [Instagram Developer Documentation](https://developers.facebook.com/docs/instagram-basic-display-api/getting-started). +- Sign in using your Facebook account associated with your Instagram Business Account. + +### 2. Create a New Instagram App + +- Navigate to the "My Apps" section and click "Create App." +- Choose the "For Everything Else" option and click "Next." +- Provide an app name and your contact email, then click "Create App ID." + +### 3. Add Instagram Basic Display + +- In your newly created app dashboard, locate and click "Add Product" in the sidebar. +- Find "Instagram" in the list and click "Set Up." +- In the Instagram Basic Display section, click "Create New App". + +### 4. Configure Instagram OAuth + +- Once the app is created, visit the "Basic Display" settings under the "Instagram" product. +- Add the necessary OAuth redirect URI based on your application's configuration. + +### 5. Retrieve the Client ID and Client Secret + +- In the "Basic Display" section, you should see your `Client ID` and `Client Secret`. +- Copy these values and add them to your `.env` file as `INSTAGRAM_CLIENT_ID` and `INSTAGRAM_CLIENT_SECRET`. + +### 6. Finalize Setup + +- Ensure that you've configured the required scopes (`user_profile`, `user_media`) under the Instagram Basic Display settings. +- Save all changes and ensure that the app status is live. + +### 7. Environment Variable Configuration + +Ensure your `.env` file looks something like: + +```plaintext +INSTAGRAM_CLIENT_ID=your_client_id_here +INSTAGRAM_CLIENT_SECRET=your_client_secret_here +``` diff --git a/docs/4-Authentication/intel_cloud_services.md b/docs/4-Authentication/intel_cloud_services.md new file mode 100644 index 000000000000..c7247d3a8a19 --- /dev/null +++ b/docs/4-Authentication/intel_cloud_services.md @@ -0,0 +1,48 @@ +# Intel Cloud Services + +The `intel_cloud_services.py` script allows for the integration with Intel's OAuth services for Single Sign-On (SSO) and email sending capabilities through their API. + +## Required Environment Variables + +To use the Intel Cloud Services, you need to set up the following environment variables: + +- `INTEL_CLIENT_ID`: Intel OAuth client ID +- `INTEL_CLIENT_SECRET`: Intel OAuth client secret + +## How to Obtain the Required Environment Variables + +### Step 1: Create an Intel Developer Account + +1. Go to the [Intel Developer Zone](https://developer.intel.com). +2. Register for an account if you don't have one. +3. Log in to your Intel Developer account. + +### Step 2: Create an Application + +1. Navigate to the [Intel APIs](https://developer.intel.com/apis). +2. Create a new application and obtain its Client ID and Client Secret. + - **Client ID**: This will be your `INTEL_CLIENT_ID`. + - **Client Secret**: This will be your `INTEL_CLIENT_SECRET`. + +### Step 3: Enable Required APIs + +Make sure the following APIs are enabled for your application: + +- User Info API +- Mail Send API + +## Required Scopes for Intel SSO + +The following scopes must be added to your OAuth application settings: + +- `https://api.intel.com/userinfo.read` +- `https://api.intel.com/mail.send` + +## Setting Up Your .env File + +Add the acquired `INTEL_CLIENT_ID` and `INTEL_CLIENT_SECRET` to your `.env` file: + +```plaintext +INTEL_CLIENT_ID=your_client_id_here +INTEL_CLIENT_SECRET=your_client_secret_here +``` diff --git a/docs/4-Authentication/jive.md b/docs/4-Authentication/jive.md new file mode 100644 index 000000000000..bbbb01611df7 --- /dev/null +++ b/docs/4-Authentication/jive.md @@ -0,0 +1,64 @@ +# Jive SSO Integration + +This document provides a detailed guide on integrating Jive Single Sign-On (SSO) with your application. By following these instructions, you will be able to set up environment variables and acquire the necessary OAuth keys for Jive. + +## Prerequisites + +Before you begin, ensure that you have the following prerequisites: + +- A Jive account with the necessary permissions to create an OAuth application. +- Access to your Jive instance to perform API operations. +- Python environment set up with necessary libraries, including `requests` and `fastapi`. + +## Required Environment Variables + +Set up the following environment variables in your `.env` file or environment management system. These variables will be used for OAuth authentication with Jive. + +- `JIVE_CLIENT_ID`: Jive OAuth client ID. +- `JIVE_CLIENT_SECRET`: Jive OAuth client secret. + +## How to Acquire Jive OAuth Keys + +1. **Log in to your Jive instance.** + +2. **Navigate to the OAuth section:** + - Go to `Admin Console > System > Settings > OAuth`. + +3. **Create a new OAuth application:** + - Click on "Register a New Application". + - Fill out the application form with relevant information: + - **Application Name**: Choose a name for your application. + - **Client ID**: This will be provided by the system once you register the application. + - **Client Secret**: This will also be provided by the system upon registering the application. + - **Redirect URI**: Enter the URL where users will be redirected after authentication. This is typically your application's URL. + - **Scopes**: Select the scopes your application will require as per Jive's API documentation. + - Example Scopes: `read`, `write`, `admin`, etc. + +4. **Save the application:** + - Once you save the application, the `Client ID` and `Client Secret` will be generated. Note these down as you will need them for your environment variables. + +5. **Add environment variables:** + - Add the `JIVE_CLIENT_ID` and `JIVE_CLIENT_SECRET` to your `.env` file or manage them in your service's environment configuration. + +## Required APIs + +Ensure you have the necessary Jive API enabled to perform operations such as fetching user information and sending emails. + +## Required Scopes for Jive OAuth + +The required scopes for Jive OAuth will depend on what operations you wish to perform. Commonly used scopes include: + +- **Read**: To read user information. +- **Write**: To send emails or perform other writing operations. +- **Admin**: For administrative tasks. + +Refer to Jive�s API documentation for a detailed list of available scopes. + +## Setting Environment Variables + +Add the following lines to your `.env` file: + +```env +JIVE_CLIENT_ID=your_jive_client_id +JIVE_CLIENT_SECRET=your_jive_client_secret +``` diff --git a/docs/4-Authentication/keycloak.md b/docs/4-Authentication/keycloak.md new file mode 100644 index 000000000000..95790b2b6bb9 --- /dev/null +++ b/docs/4-Authentication/keycloak.md @@ -0,0 +1,20 @@ +# Keycloak Single Sign-On (SSO) Integration + +This script facilitates Single Sign-On (SSO) integration with Keycloak by enabling seamless user authentication and retrieval of user information using OAuth 2.0. + +## Required Environment Variables + +Before running the `keycloak.py` script, ensure you have the following environment variables set up: + +- `KEYCLOAK_CLIENT_ID`: Keycloak OAuth client ID +- `KEYCLOAK_CLIENT_SECRET`: Keycloak OAuth client secret +- `KEYCLOAK_REALM`: Name of the Keycloak realm +- `KEYCLOAK_SERVER_URL`: Base URL of the Keycloak server + +These variables can typically be added to your `.env` file. + +## Required Scopes for Keycloak SSO + +- `openid` +- `email` +- `profile` diff --git a/docs/4-Authentication/linkedin.md b/docs/4-Authentication/linkedin.md new file mode 100644 index 000000000000..c2f4dd688ed1 --- /dev/null +++ b/docs/4-Authentication/linkedin.md @@ -0,0 +1,54 @@ +# LinkedIn SSO Integration + +## Overview + +This document provides a guide for setting up LinkedIn Single Sign-On (SSO) integration in your application. Follow the steps below to configure the LinkedIn OAuth client and obtain the necessary keys and permissions. + +## Required Environment Variables + +To integrate LinkedIn SSO, you need to set the following environment variables in your `.env` file: + +- `LINKEDIN_CLIENT_ID`: LinkedIn OAuth client ID +- `LINKEDIN_CLIENT_SECRET`: LinkedIn OAuth client secret + +## Steps to Acquire LinkedIn OAuth Credentials + +Follow these steps to obtain the required LinkedIn OAuth credentials: + +1. **Create a LinkedIn Application** + - Navigate to [LinkedIn Developer Portal](https://www.linkedin.com/developers/) + - Log in with your LinkedIn account. + - Click on "Create App" and fill in the required details. + - After the app is created, you will be redirected to the app's dashboard. + +2. **Obtain Client ID and Client Secret** + - In the app's dashboard, locate `Client ID` and `Client Secret` under the "Auth" tab. + - Copy these values and add them to your `.env` file as `LINKEDIN_CLIENT_ID` and `LINKEDIN_CLIENT_SECRET`. + +3. **Set Up Redirect URI** + - Ensure you set up the correct Redirect URI. This URI should match the one used in your application. You can set this URI in the app's "Auth" tab. + +## Required APIs and Scopes + +To enable LinkedIn SSO, you must ensure that your application has access to the following APIs and scopes: + +### Required APIs + +- No additional APIs are needed other than LinkedIn's default OAuth APIs. + +### Required Scopes for LinkedIn OAuth + +- `r_liteprofile`: Grants access to retrieve the user's profile. +- `r_emailaddress`: Grants access to retrieve the user's email address. +- `w_member_social`: Grants access to post and share content on LinkedIn. + +Ensure that these scopes are requested during the OAuth authorization process. + +## Setting Environment Variables + +Add the obtained credentials and required environment variables to your `.env` file: + +```env +LINKEDIN_CLIENT_ID=your_linkedin_client_id +LINKEDIN_CLIENT_SECRET=your_linkedin_client_secret +``` diff --git a/docs/4-Authentication/microsoft.md b/docs/4-Authentication/microsoft.md new file mode 100644 index 000000000000..3f15266008b0 --- /dev/null +++ b/docs/4-Authentication/microsoft.md @@ -0,0 +1,61 @@ +# Microsoft Single Sign-On (SSO) Integration + +## Overview + +This module provides an integration with Microsoft's Single Sign-On (SSO) to allow your application to authenticate users through their Microsoft accounts and send emails using Microsoft's Graph API. + +## Required Environment Variables + +To use the Microsoft SSO integration, you'll need to set up the following environment variables: + +- `MICROSOFT_CLIENT_ID`: Microsoft OAuth client ID +- `MICROSOFT_CLIENT_SECRET`: Microsoft OAuth client secret + +These values can be obtained by registering your application in the Microsoft Azure portal. + +## Setting Up Microsoft SSO + +### Step 1: Register Your Application + +1. Go to the [Azure portal](https://portal.azure.com/). +2. Select **Azure Active Directory**. +3. In the left-hand navigation pane, select **App registrations**. +4. Select **New registration**. +5. Enter a name for your application. +6. Under **Redirect URI**, enter a redirect URI where the authentication response can be sent. This should match the `MAGIC_LINK_URL` environment variable in your `.env` file. +7. Click **Register**. + +### Step 2: Configure API Permissions + +1. Go to the **API permissions** section of your app's registration page. +2. Click on **Add a permission**. +3. Select **Microsoft Graph**. +4. Choose **Delegated permissions** and add the following permissions: + - `User.Read` + - `Mail.Send` + - `Calendars.ReadWrite.Shared` + +### Step 3: Obtain Client ID and Client Secret + +1. In the **Overview** section of your app registration, you will find the **Application (client) ID**. This is your `MICROSOFT_CLIENT_ID`. +2. Go to the **Certificates & secrets** section. +3. Under **Client secrets**, click on **New client secret**. +4. Add a description and choose an expiry period. Click on **Add**. +5. Copy the value of the client secret. This is your `MICROSOFT_CLIENT_SECRET`. Be sure to store it securely. + +### Step 4: Add Environment Variables + +Add the following environment variables to your `.env` file: + +```sh +MICROSOFT_CLIENT_ID=your_client_id +MICROSOFT_CLIENT_SECRET=your_client_secret +``` + +## Required Scopes for Microsoft OAuth + +- `https://graph.microsoft.com/User.Read` +- `https://graph.microsoft.com/Mail.Send` +- `https://graph.microsoft.com/Calendars.ReadWrite.Shared` + +These scopes are requested when obtaining access tokens, allowing your application to read user profile information, send emails on behalf of the user, and access shared calendars. diff --git a/docs/4-Authentication/netiq.md b/docs/4-Authentication/netiq.md new file mode 100644 index 000000000000..e9478bd33fa5 --- /dev/null +++ b/docs/4-Authentication/netiq.md @@ -0,0 +1,41 @@ +# NetIQ + +## Required Environment Variables + +To integrate NetIQ single sign-on (SSO) in your application, you need to set the following environment variables: + +- `NETIQ_CLIENT_ID`: NetIQ OAuth client ID +- `NETIQ_CLIENT_SECRET`: NetIQ OAuth client secret + +## Steps to Acquire NetIQ Client ID and Client Secret + +1. **Log in to your NetIQ Admin Console**: + - Access the NetIQ admin console through your provided administrative URL. + +2. **Register a New OAUTH Application**: + - Navigate to the `OAuth` section within the admin console. + - Add a new application by providing all the necessary details such as application name, redirect URIs, etc. + +3. **Obtain `Client ID` and `Client Secret`**: + - After successfully registering your application, NetIQ will provide you with a `Client ID` and `Client Secret`. + +4. **Set Environment Variables**: + - Add `NETIQ_CLIENT_ID` and `NETIQ_CLIENT_SECRET` to your .env file: + + ```dotenv + NETIQ_CLIENT_ID=your-netiq-client-id + NETIQ_CLIENT_SECRET=your-netiq-client-secret + ``` + +## Required APIs + +Ensure that the required APIs are enabled in your NetIQ settings. Usually, this can be found in the API Management section of the admin console. + +## Required Scopes for NetIQ OAuth + +Make sure to include the following scopes for your NetIQ OAuth authorization: + +- `profile` +- `email` +- `openid` +- `user.info` diff --git a/docs/4-Authentication/okta.md b/docs/4-Authentication/okta.md new file mode 100644 index 000000000000..bd60bcf78cd3 --- /dev/null +++ b/docs/4-Authentication/okta.md @@ -0,0 +1,52 @@ +# Okta SSO Setup and Usage + +This section provides detailed documentation on setting up and using Okta Single Sign-On (SSO) in your project. By following these steps, you will be able to authenticate users using Okta and retrieve their user information. + +## Required Environment Variables + +Before you begin, ensure you have the following environment variables set up in your `.env` file: + +- `OKTA_CLIENT_ID`: Okta OAuth client ID +- `OKTA_CLIENT_SECRET`: Okta OAuth client secret +- `OKTA_DOMAIN`: Okta domain (e.g., dev-123456.okta.com) + +## Required OAuth Scopes + +Ensure that your Okta OAuth application has the following scopes enabled: + +- `openid` +- `profile` +- `email` + +## Setting Up + +### Step 1: Creating an Okta Application + +1. Log in to your Okta Developer account at [developer.okta.com](https://developer.okta.com/). +2. From the dashboard, navigate to **Applications** -> **Applications**. +3. Click on **Create App Integration**. +4. Select **OAuth 2.0 / OIDC**, then click **Next**. +5. Choose **Web Application** and configure the following settings: + - **Sign-in redirect URIs**: Add the callback URI of your application (e.g., `http://localhost:8000/callback`) + - **Sign-out redirect URIs**: Optionally, add a sign-out URI. +6. Click **Save**. + +### Step 2: Retrieving Your Okta Client ID and Client Secret + +1. After saving the application, you will be redirected to the application settings page. +2. Scroll down to the **Client Credentials** section. +3. Copy the **Client ID** and **Client Secret** and add them to your `.env` file: + + ```plaintext + OKTA_CLIENT_ID=your_client_id + OKTA_CLIENT_SECRET=your_client_secret + ``` + +### Step 3: Configuring Your Okta Domain + +1. In the Okta dashboard, navigate to **Settings** -> **Customizations** -> **Domain**. +2. Copy your Okta domain (e.g., `dev-123456.okta.com`) and add it to your `.env` file: + + ```plaintext + OKTA_DOMAIN=your_okta_domain + ``` diff --git a/docs/4-Authentication/openam.md b/docs/4-Authentication/openam.md new file mode 100644 index 000000000000..ee931acd401c --- /dev/null +++ b/docs/4-Authentication/openam.md @@ -0,0 +1,47 @@ +# OpenAM SSO + +## Overview + +This module provides Single Sign-On (SSO) functionality using OpenAM's OAuth 2.0 service. It enables users to obtain tokens and user information from OpenAM, and also provides a mechanism for token refresh. + +## Required Environment Variables + +To use the OpenAM SSO module, you must set the following environment variables: + +- `OPENAM_CLIENT_ID`: OpenAM OAuth client ID +- `OPENAM_CLIENT_SECRET`: OpenAM OAuth client secret +- `OPENAM_BASE_URL`: Base URL for the OpenAM server (e.g., `https://openam.example.com`) + +## Required Scopes for OpenAM OAuth + +The following scopes are required for OpenAM OAuth: + +- `profile` +- `email` + +## Instructions to Acquire Keys and Set Up Environment Variables + +1. **Register the Client with OpenAM:** + + - **Navigate to Admin Console:** Log in to the OpenAM administrative console. + - **Register the Application:** + - Go to `Applications` > `Agents` > `OAuth 2.0 / OIDC` > `Clients`. + - Click `New Client`. + - Fill in details such as `Client ID`, `Client Secret`, and `Redirect URIs`. + - **Set Scopes:** + - Ensure that your client has the required scopes (`profile` and `email`). + +2. **Obtain Client ID and Secret:** + + - **Client ID**: Found in the client registration under OpenAM's administrative console. + - **Client Secret**: Found in the client registration under OpenAM's administrative console. + +3. **Set Environment Variables:** + + Add the obtained values to your environment file (`.env`). + + ```env + OPENAM_CLIENT_ID= + OPENAM_CLIENT_SECRET= + OPENAM_BASE_URL= + ``` diff --git a/docs/4-Authentication/openstreetmap.md b/docs/4-Authentication/openstreetmap.md new file mode 100644 index 000000000000..7d861af6e93a --- /dev/null +++ b/docs/4-Authentication/openstreetmap.md @@ -0,0 +1,42 @@ +# OpenStreetMap SSO Integration + +This guide explains how to integrate OpenStreetMap Single Sign-On (SSO) using OAuth. Follow the steps below to acquire the required keys and set up the necessary environment variables for OpenStreetMap SSO. + +## Required Environment Variables + +- `OSM_CLIENT_ID`: OpenStreetMap OAuth client ID +- `OSM_CLIENT_SECRET`: OpenStreetMap OAuth client secret + +## Steps to Acquire OpenStreetMap OAuth Credentials + +1. **Create an OpenStreetMap OAuth Application:** + + - Navigate to the [OpenStreetMap OAuth settings page](https://www.openstreetmap.org/user/{your_username}/oauth_clients). + - Log in with your OpenStreetMap account if you are not already logged in. + - Click on "Register your application". + - Fill out the form with the required information: + - **Name:** Give your application a name. + - **Main Application URL:** Provide the URL where your application is hosted. + - **Callback URL:** Provide the URL where the user will be redirected after authentication. + - **Support URL:** Provide the URL for support. + - Click the "Save" button. + +2. **Save Your OAuth Credentials:** + + - After registering your application, you will be given a **client ID** and **client secret**. Keep these credentials safe. + - Add them to your `.env` file as follows: + + ```env + OSM_CLIENT_ID=your_openstreetmap_client_id + OSM_CLIENT_SECRET=your_openstreetmap_client_secret + ``` + +### Required APIs + +Ensure you have the appropriate OAuth configuration in OpenStreetMap and that the `OSM_CLIENT_ID` and `OSM_CLIENT_SECRET` environment variables are properly set in your `.env` file. + +### Required Scopes for OpenStreetMap OAuth + +The required scope for OpenStreetMap OAuth integration is: + +- `read_prefs` diff --git a/docs/4-Authentication/orcid.md b/docs/4-Authentication/orcid.md new file mode 100644 index 000000000000..660b70469231 --- /dev/null +++ b/docs/4-Authentication/orcid.md @@ -0,0 +1,28 @@ +# ORCID + +## Required Environment Variables + +To set up ORCID Single Sign-On (SSO), you will need the following environment variables: + +- `ORCID_CLIENT_ID`: ORCID OAuth client ID +- `ORCID_CLIENT_SECRET`: ORCID OAuth client secret + +## How to Acquire ORCID Client ID and Secret + +1. **Register your application with ORCID**: If you haven't already, you need to register your application with ORCID. Visit the [ORCID Developer Tools](https://orcid.org/content/register-client-application). + +2. **Fill out the registration form**: Provide necessary details about your application. After the registration is complete, you will receive your `ORCID_CLIENT_ID` and `ORCID_CLIENT_SECRET`. + +3. **Add environment variables**: Once you have your `ORCID_CLIENT_ID` and `ORCID_CLIENT_SECRET`, add them to your `.env` file as follows: + + ```plaintext + ORCID_CLIENT_ID=your_orcid_client_id + ORCID_CLIENT_SECRET=your_orcid_client_secret + ``` + +## Required Scopes for ORCID SSO + +The following scopes are required for ORCID SSO: + +- `/authenticate`: This scope allows the application to read public profile information. +- `/activities/update` (optional): This scope allows the application to update ORCID activities. diff --git a/docs/4-Authentication/paypal.md b/docs/4-Authentication/paypal.md new file mode 100644 index 000000000000..c1fd3a96fef7 --- /dev/null +++ b/docs/4-Authentication/paypal.md @@ -0,0 +1,55 @@ +# PayPal SSO Documentation + +## Overview + +This document provides information on how to configure and use the PayPal SSO (Single Sign-On) implemented in `./sso/paypal.py`. The PayPal SSO allows you to authenticate users with their PayPal accounts and perform actions such as retrieving user info and sending payments. + +## Required Environment Variables + +To use the PayPal SSO, you must set up the following environment variables: + +- `PAYPAL_CLIENT_ID`: PayPal OAuth client ID. +- `PAYPAL_CLIENT_SECRET`: PayPal OAuth client secret. + +Ensure you add these environment variables to your `.env` file. + +### Steps to Acquire PayPal Client ID and Secret + +1. **Log in to PayPal Developer Dashboard:** + + Go to the [PayPal Developer Dashboard](https://developer.paypal.com/). + +2. **Create a New App:** + + - Navigate to **My Apps & Credentials**. + - Click on **Create App** under the **REST API apps** section. + - Provide an **App Name** and select a sandbox business account. + - Click **Create App**. + +3. **Get Client ID and Secret:** + + - Once the app is created, you�ll find your **Client ID** and **Secret** on the app�s page. + - Copy the **Client ID** and **Secret** and add them to your `.env` file as follows: + + ```plaintext + PAYPAL_CLIENT_ID=YOUR_CLIENT_ID + PAYPAL_CLIENT_SECRET=YOUR_CLIENT_SECRET + ``` + +## Configuration of Redirect URI + +Make sure your `redirect_uri` is correctly set up in the PayPal Developer Dashboard: + +- Go to your app settings. +- Add the `redirect_uri` to the **Return URL** section under the **App settings**. + +## Required APIs + +Ensure that you have the PayPal REST API enabled and the appropriate client credentials. + +## Required Scopes for PayPal OAuth + +To authenticate users and retrieve their information, you will need the following OAuth scopes: + +- `email` +- `openid` diff --git a/docs/4-Authentication/ping_identity.md b/docs/4-Authentication/ping_identity.md new file mode 100644 index 000000000000..cfec39ce22f3 --- /dev/null +++ b/docs/4-Authentication/ping_identity.md @@ -0,0 +1,47 @@ +# Ping Identity SSO + +This section explains how to set up Single Sign-On (SSO) with Ping Identity using the provided `ping_identity.py` script. The script handles OAuth authentication and allows users to fetch user information and send emails via Ping Identity. Follow the steps below to configure and use the Ping Identity SSO integration. + +## Required Environment Variables + +- `PING_IDENTITY_CLIENT_ID`: Ping Identity OAuth client ID +- `PING_IDENTITY_CLIENT_SECRET`: Ping Identity OAuth client secret + +## Required APIs + +Ensure that you have the following APIs enabled for your Ping Identity application. You can enable these APIs through the Ping Identity admin dashboard. + +1. **UserInfo API**: Used to fetch user information. +2. **Email API**: Used to send emails on behalf of the user. + +## Acquiring Required Keys + +### Steps to Acquire `PING_IDENTITY_CLIENT_ID` and `PING_IDENTITY_CLIENT_SECRET` + +1. **Log in to Ping Identity Admin Portal**: Access the Ping Identity admin portal. + +2. **Create an OAuth Client**: + - Go to the 'Connections' tab. + - Select 'Applications' and click 'Add Application' to create a new OAuth client. + - Fill in the required details (application name, description, etc.) and select the OAuth grant type you intend to use (Authorization Code Grant in most cases). + +3. **Configure Callback URL**: + - Set the Redirect URI in your application settings to the endpoint where Ping Identity will send authorization responses (e.g., `https://your-app.com/callback`). + +4. **Obtain Client ID and Client Secret**: + - After creating the OAuth client, you will be provided with a `Client ID` and `Client Secret`. Save these credentials securely, as they will be required in your application. + +5. **Enable Necessary Scopes**: + - Ensure your OAuth client is configured to request the necessary scopes: + - `profile` + - `email` + - `openid` + +## Setup Environment Variables + +Add the following lines to your `.env` file, replacing the placeholders with your actual client ID and client secret. + + ```env + PING_IDENTITY_CLIENT_ID=your_client_id + PING_IDENTITY_CLIENT_SECRET=your_client_secret + ``` diff --git a/docs/4-Authentication/pixiv.md b/docs/4-Authentication/pixiv.md new file mode 100644 index 000000000000..7ee591eb6acc --- /dev/null +++ b/docs/4-Authentication/pixiv.md @@ -0,0 +1,37 @@ +# Pixiv SSO + +This documentation will guide you through setting up Pixiv Single Sign-On (SSO) in your application using the provided `pixiv.py` script. The integration leverages Pixiv's OAuth for user authentication. + +## Required Environment Variables + +To use Pixiv SSO, you need to set up the following environment variables in your `.env` file: + +- `PIXIV_CLIENT_ID`: Your Pixiv OAuth client ID. +- `PIXIV_CLIENT_SECRET`: Your Pixiv OAuth client secret. + +## Required APIs + +Before setting up your environment variables, ensure you have the necessary Pixiv APIs enabled. Your application will need the following scopes to perform authentication through Pixiv OAuth: + +- `pixiv.scope.profile.read` + +## How To Acquire Keys + +Follow these steps to get your `PIXIV_CLIENT_ID` and `PIXIV_CLIENT_SECRET`: + +1. **Create a Pixiv OAuth Application**: + - Go to the Pixiv developer site and log in with your Pixiv account. + - Navigate to the section where you can manage your OAuth applications. + - Create a new application. You will be asked to provide details such as the name and description of your application. Ensure you also specify the required scopes (`pixiv.scope.profile.read`). + +2. **Get Client ID and Client Secret**: + - Once the application is created, Pixiv will provide you with a `client_id` and `client_secret`. These will be your `PIXIV_CLIENT_ID` and `PIXIV_CLIENT_SECRET`. + +3. **Add Environment Variables**: + - Open or create a `.env` file in the root of your project directory. + - Add the following lines to your `.env` file: + + ```env + PIXIV_CLIENT_ID=your_pixiv_client_id_here + PIXIV_CLIENT_SECRET=your_pixiv_client_secret_here + ``` diff --git a/docs/4-Authentication/reddit.md b/docs/4-Authentication/reddit.md new file mode 100644 index 000000000000..4423f360d499 --- /dev/null +++ b/docs/4-Authentication/reddit.md @@ -0,0 +1,29 @@ +# Reddit SSO Integration + +This module allows you to integrate Reddit Single Sign-On (SSO) into your application. By using this module, you can authenticate via Reddit, retrieve user information, and even submit posts to a subreddit on behalf of a user. The code leverages the Reddit OAuth2.0 for authentication and authorization. + +## Required Environment Variables + +Before you begin, you'll need to set up the following environment variables in your `.env` file: + +- `REDDIT_CLIENT_ID`: Your Reddit application's client ID. +- `REDDIT_CLIENT_SECRET`: Your Reddit application's client secret. + +## Required APIs + +Make sure you have a Reddit OAuth application set up. You can create one by following these steps: + +1. Go to [Reddit Apps](https://www.reddit.com/prefs/apps). +2. Scroll down to "Developed Applications" and click on `Create App`. +3. Fill in the application name, and choose `script` as the type. +4. Set the `redirect_uri` to a valid URL where you will receive the authorization code (e.g., `http://localhost:8000/reddit_callback`). +5. Note down the `client ID` and `client secret` from the created application. +6. Add the `REDDIT_CLIENT_ID` and `REDDIT_CLIENT_SECRET` to your `.env` file. + +## Required Scopes for Reddit OAuth + +Ensure that your Reddit OAuth application requests the following scopes: + +- `identity`: Access to the user�s Reddit identity. +- `submit`: Ability to submit and edit content. +- `read`: Ability to read private messages and save content. diff --git a/docs/4-Authentication/salesforce.md b/docs/4-Authentication/salesforce.md new file mode 100644 index 000000000000..3db1e49b982c --- /dev/null +++ b/docs/4-Authentication/salesforce.md @@ -0,0 +1,54 @@ +# Salesforce Single Sign-On (SSO) + +## Overview + +This module allows you to integrate Salesforce SSO into your application. By setting up the required environment variables and following the steps below, you can leverage Salesforce's OAuth capabilities for secure authentication and user information retrieval. + +## Required Environment Variables + +To integrate Salesforce SSO, you need to obtain the following environment variables: + +- `SALESFORCE_CLIENT_ID`: Salesforce OAuth client ID +- `SALESFORCE_CLIENT_SECRET`: Salesforce OAuth client secret + +### Steps to Acquire the Required Environment Variables + +1. **Create a Connected App in Salesforce**: + - Log in to your Salesforce account. + - Navigate to `Setup` by clicking on the gear icon in the top-right corner. + - In the Quick Find box, type `App Manager` and select it from the dropdown. + - Click the `New Connected App` button. + - Fill in the required fields: + - **Connected App Name**: A unique name for your app. + - **API Name**: This will auto-populate based on your app name. + - **Contact Email**: Your email address. + - In the **OAuth Settings** section: + - Check the `Enable OAuth Settings` checkbox. + - Enter the callback URL (e.g., `http://localhost:8000/callback`). + - Select the following OAuth Scopes: + - `Full access (full)` + - `Perform requests on your behalf at any time (refresh_token, offline_access)` + - `Access your basic information (id, profile, email, address, phone)` + - Click the `Save` button. + +2. **Retrieve Client ID and Client Secret**: + - After saving, navigate back to `App Manager`. + - Find your newly created app in the list and click on its name. + - Under `API (Enable OAuth Settings)`, you'll find the `Consumer Key` (Client ID) and `Consumer Secret` (Client Secret). + +3. **Add Environment Variables**: + - Create a `.env` file in the root directory of your project if it doesn't already exist. + - Add the following lines to the `.env` file, replacing the placeholders with your actual Consumer Key and Consumer Secret: + + ```env + SALESFORCE_CLIENT_ID=your_salesforce_client_id + SALESFORCE_CLIENT_SECRET=your_salesforce_client_secret + ``` + +## Required Salesforce OAuth Scopes + +Ensure your Salesforce Connected App has the following OAuth scopes enabled: + +- `refresh_token` +- `full` +- `email` diff --git a/docs/4-Authentication/sina_weibo.md b/docs/4-Authentication/sina_weibo.md new file mode 100644 index 000000000000..34c0dd357e3e --- /dev/null +++ b/docs/4-Authentication/sina_weibo.md @@ -0,0 +1,44 @@ +# Sina Weibo Single Sign-On (SSO) Integration + +## Overview + +This module enables Single Sign-On (SSO) integration with Sina Weibo through OAuth 2.0. This guide provides the necessary steps and details to set up the integration, including acquiring the required keys, setting up environment variables, and implementing the code provided. + +## Required Environment Variables + +In order to use the Sina Weibo SSO integration, you need to set several environment variables. These variables are used in the OAuth authentication process to communicate with the Weibo API. + +Below is a list of the required environment variables: + +- `WEIBO_CLIENT_ID`: Weibo OAuth client ID +- `WEIBO_CLIENT_SECRET`: Weibo OAuth client secret + +## Required APIs and Scopes + +Before proceeding, ensure you have the necessary APIs enabled and permissions configured. The required scopes for Weibo OAuth are: + +- `email` +- `statuses_update` + +## Steps to Acquire Keys and Set Up Environment + +### Step 1: Register Your Application with Weibo + +1. Log in to the [Weibo Open Platform](https://open.weibo.com/). +2. Navigate to "My Apps" and click on "Create App". +3. Fill in the required details about your application, such as name, description, and redirect URL. +4. Once your application is created, you will receive a `CLIENT_ID` and `CLIENT_SECRET`. + +### Step 2: Set Up Your Environment Variables + +Once you have your `CLIENT_ID` and `CLIENT_SECRET`, you'll need to add them to your environment as follows: + +1. Create a `.env` file in the root directory of your project. +2. Add the following lines to your `.env` file: + +```env +WEIBO_CLIENT_ID=your_client_id +WEIBO_CLIENT_SECRET=your_client_secret +``` + +Replace `your_client_id`, `your_client_secret`, and `your_redirect_uri` with your actual Weibo OAuth credentials and your application's redirect URL. diff --git a/docs/4-Authentication/spotify.md b/docs/4-Authentication/spotify.md new file mode 100644 index 000000000000..92fbc166a217 --- /dev/null +++ b/docs/4-Authentication/spotify.md @@ -0,0 +1,46 @@ +# Spotify SSO Integration Documentation + +This document outlines the steps required to integrate Spotify Single Sign-On (SSO) into your application, including how to set up the necessary environment variables and acquire the necessary API keys. + +## Required Environment Variables + +To use Spotify SSO, you need to set the following environment variables: + +1. `SPOTIFY_CLIENT_ID`: Your Spotify OAuth client ID. +2. `SPOTIFY_CLIENT_SECRET`: Your Spotify OAuth client secret. + +Ensure you have set these variables in your `.env` file. + +## Steps to Acquire Spotify Client ID and Client Secret + +1. **Create a Spotify Developer Account** + + If you don't have a Spotify Developer account, create one by registering at the [Spotify Developer Dashboard](https://developer.spotify.com/dashboard). + +2. **Create an App** + + Once you are logged in to the Spotify Developer Dashboard, create a new application: + - Go to **Dashboard**. + - Click on the **Create an App** button. + - Fill out the **App Name** and **App Description** fields. + - Check the **I understand and accept the Spotify Developer Terms of Service**. + - Click **Create**. + +3. **Retrieve Your Client ID and Client Secret** + + After creating the app, you will be redirected to your app's dashboard: + - Find the **Client ID** and **Client Secret** on this page. + - Add these values to your `.env` file as shown below: + + ```env + SPOTIFY_CLIENT_ID=your-client-id + SPOTIFY_CLIENT_SECRET=your-client-secret + ``` + +## Required APIs and Scopes for Spotify SSO + +You need to enable the necessary scopes to allow your application to access user data and functionalities: + +- `user-read-email`: Allows reading user's email. +- `user-read-private`: Allows reading user's subscription details. +- `playlist-read-private`: Allows reading user's private playlists. diff --git a/docs/4-Authentication/stack_exchange.md b/docs/4-Authentication/stack_exchange.md new file mode 100644 index 000000000000..9000fbf1bff8 --- /dev/null +++ b/docs/4-Authentication/stack_exchange.md @@ -0,0 +1,37 @@ +# Stack Exchange SSO Guide + +This guide provides detailed instructions on how to set up and use Stack Exchange Single Sign-On (SSO) in your application. Please follow each step carefully to ensure successful integration. + +## Required Environment Variables + +To utilize Stack Exchange SSO, you need to set the following environment variables in your `.env` file: + +- `STACKEXCHANGE_CLIENT_ID`: Your Stack Exchange OAuth client ID. +- `STACKEXCHANGE_CLIENT_SECRET`: Your Stack Exchange OAuth client secret. +- `STACKEXCHANGE_KEY`: (Optional) A key for additional API requests (can enhance rate limits). + +## Setting Up Stack Exchange OAuth Credentials + +1. **Create a Stack Exchange Application:** + - Go to the [Stack Exchange API Applications](https://stackapps.com/apps/oauth/register) page. + - Click the "Register Your Application" button. + - Fill in the required details such as Application Name, Description, Organization Information, etc. + - Set the OAuth Redirect URL (you will need this URL for redirect_uri). + - After their review, you will obtain the `Client ID` and `Client Secret` which are needed for the environment variables. + +2. **Enable Required Scopes:** + - The application will need the following scopes to function properly: + - read_inbox + - no_expiry + - private_info + - write_access + +3. **Add the Environment Variables:** + - Create a `.env` file at the root of your project if it does not exist already. + - Add the following lines to the file with your corresponding credentials: + + ```env + STACKEXCHANGE_CLIENT_ID=your_stack_exchange_client_id + STACKEXCHANGE_CLIENT_SECRET=your_stack_exchange_client_secret + STACKEXCHANGE_KEY=your_stack_exchange_key + ``` diff --git a/docs/4-Authentication/strava.md b/docs/4-Authentication/strava.md new file mode 100644 index 000000000000..9ef791511e83 --- /dev/null +++ b/docs/4-Authentication/strava.md @@ -0,0 +1,36 @@ +# Strava + +## Required environment variables + +To use the Strava SSO and activity creation functionality, you need to set the following environment variables in your `.env` file: + +- `STRAVA_CLIENT_ID`: Strava OAuth client ID +- `STRAVA_CLIENT_SECRET`: Strava OAuth client secret + +## How to Acquire Strava Client ID and Client Secret + +1. **Create a Strava Developer Account**: + - If you don�t already have a Strava account, you need to sign up for one at [Strava](https://www.strava.com/). + +2. **Register Your Application**: + - Go to [Strava Developers](https://developers.strava.com/). + - Sign in with your Strava account if needed. + - Navigate to the �Create & Manage Your App� section. + - Click on "Create New App." + - Fill in the required details such as Application Name, Category, Club, Website, Authorization Callback Domain, and Scope. + - After creating the app, you will be provided with a `Client ID` and `Client Secret`. + +3. **Set Environment Variables**: + - Add the `STRAVA_CLIENT_ID` and `STRAVA_CLIENT_SECRET` to your `.env` file: + + ```dotenv + STRAVA_CLIENT_ID=your_strava_client_id + STRAVA_CLIENT_SECRET=your_strava_client_secret + ``` + +## Required scopes for Strava OAuth + +When setting up OAuth for your Strava application, ensure that the following scopes are enabled: + +- `read` +- `activity:write` diff --git a/docs/4-Authentication/stripe.md b/docs/4-Authentication/stripe.md new file mode 100644 index 000000000000..976385c198c9 --- /dev/null +++ b/docs/4-Authentication/stripe.md @@ -0,0 +1,31 @@ +# Stripe + +## Required Environment Variables + +To use Stripe SSO, ensure that the following environment variables are set: + +- `STRIPE_CLIENT_ID`: Your Stripe OAuth client ID. +- `STRIPE_CLIENT_SECRET`: Your Stripe OAuth client secret. + +## Required Scopes for Stripe SSO + +Make sure you have the required scope for Stripe SSO: + +- `read_write` + +## How to Acquire Required Keys + +1. **Create a Stripe Account**: If you don't have a Stripe account, sign up at [Stripe](https://stripe.com/). +2. **Create a New Project**: Once logged in, navigate to your dashboard and create a new project. +3. **Get Your Client ID and Secret**: Go to "Developers" > "API keys". Here you will find your client ID and secret: + - **Client ID**: This is your OAuth client ID used for authentication. + - **Client Secret**: This is your OAuth client secret that should be kept secure. + +## Setting Up Environment Variables + +Once you have your client ID and secret, add them to your environment variables. If you�re using a `.env` file, it should look like this: + +```env +STRIPE_CLIENT_ID=your_stripe_client_id +STRIPE_CLIENT_SECRET=your_stripe_client_secret +``` diff --git a/docs/4-Authentication/twitch.md b/docs/4-Authentication/twitch.md new file mode 100644 index 000000000000..0f247ebf92bd --- /dev/null +++ b/docs/4-Authentication/twitch.md @@ -0,0 +1,46 @@ +# Twitch + +The integration for Twitch Single Sign-On (SSO) requires setting up environment variables and acquiring special identifiers and keys from the Twitch Developer Console. Below you'll find detailed instructions to guide you through the process. + +## Required Environment Variables + +Ensure the following environment variables are added to your `.env` file: + +- `TWITCH_CLIENT_ID`: Your Twitch OAuth client ID +- `TWITCH_CLIENT_SECRET`: Your Twitch OAuth client secret + +## Required Scope for Twitch OAuth + +To successfully use Twitch SSO, the following OAuth scope must be enabled: + +- `user:read:email` + +## Instructions to Acquire Required Keys + +1. **Create a Twitch Developer Account** + - Navigate to [Twitch Developer Console](https://dev.twitch.tv/). + - If you don't already have a developer account, you'll need to create one. + +2. **Register Your Application** + - Log in to the Twitch Developer Console. + - Click on the "Your Console" tab. + - Click on "Register Your Application". + - Fill out the required details, including: + - **Name**: Name your application. + - **OAuth Redirect URLs**: Add the URLs that Twitch should redirect to after OAuth authentication. + - **Category**: Select the category that best describes your application. + +3. **Retrieve Your Client ID and Client Secret** + - After registering, your application will be assigned a **Client ID** and a **Client Secret**. + - Copy the Client ID to `TWITCH_CLIENT_ID` and the Client Secret to `TWITCH_CLIENT_SECRET` in your `.env` file. + +## Summary + +Your `.env` file should look something like this: + +```dotenv +TWITCH_CLIENT_ID=your_twitch_client_id +TWITCH_CLIENT_SECRET=your_twitch_client_secret +``` + +Replace `your_twitch_client_id` and `your_twitch_client_secret` with the actual values obtained from the Twitch Developer Console. diff --git a/docs/4-Authentication/viadeo.md b/docs/4-Authentication/viadeo.md new file mode 100644 index 000000000000..61b56d7243ef --- /dev/null +++ b/docs/4-Authentication/viadeo.md @@ -0,0 +1,31 @@ +# Viadeo + +## Required environment variables + +- `VIADEO_CLIENT_ID`: Viadeo OAuth client ID +- `VIADEO_CLIENT_SECRET`: Viadeo OAuth client secret + +## Required APIs + +Ensure you have the required APIs enabled, then add the `VIADEO_CLIENT_ID` and `VIADEO_CLIENT_SECRET` environment variables to your `.env` file. + +To acquire the `VIADEO_CLIENT_ID` and `VIADEO_CLIENT_SECRET`, follow these steps: + +1. **Creating an App on Viadeo:** + - Navigate to the [Viadeo Developer Portal](https://developer.viadeo.com/). + - Sign in with your Viadeo account. + - Create a new application and provide the necessary details. + - Upon creation, you will be issued a `Client ID` and `Client Secret`. + +2. **Setting Up Environment Variables:** + - After obtaining the `Client ID` and `Client Secret`, add these values to your `.env` file: + + ```plaintext + VIADEO_CLIENT_ID=your_client_id_here + VIADEO_CLIENT_SECRET=your_client_secret_here + ``` + +### Required Scopes for Viadeo OAuth + +- `basic` (to access user profile) +- `email` (to access user email) diff --git a/docs/4-Authentication/vimeo.md b/docs/4-Authentication/vimeo.md new file mode 100644 index 000000000000..29e81747e59d --- /dev/null +++ b/docs/4-Authentication/vimeo.md @@ -0,0 +1,29 @@ +# Vimeo + +## Required Environment Variables + +To use Vimeo's OAuth system, you need to set up the following environment variables in your `.env` file: + +- `VIMEO_CLIENT_ID`: Vimeo OAuth client ID +- `VIMEO_CLIENT_SECRET`: Vimeo OAuth client secret + +## Required APIs + +Ensure you have the necessary APIs enabled in Vimeo's developer platform. Follow these steps to obtain your `VIMEO_CLIENT_ID` and `VIMEO_CLIENT_SECRET`: + +1. **Create a Vimeo Developer Account**: If you don't have one, you'll need to create a Vimeo developer account at [Vimeo Developer](https://developer.vimeo.com/). +2. **Create an App**: Go to your [My Apps](https://developer.vimeo.com/apps) page and create a new app. You will be given a `Client ID` and `Client Secret` which you need to copy and save. +3. **Set Up Scopes**: Ensure that your app has the following scopes enabled: + - `public`: Access public videos and account details. + - `private`: Access private videos. + - `video_files`: Access video files. + +4. **Add Environment Variables**: Copy your `VIMEO_CLIENT_ID` and `VIMEO_CLIENT_SECRET` into your `.env` file. + +## Required Scopes for Vimeo OAuth + +To ensure that your application can access the necessary Vimeo resources, the following scopes must be enabled: + +- `public` +- `private` +- `video_files` diff --git a/docs/4-Authentication/vk.md b/docs/4-Authentication/vk.md new file mode 100644 index 000000000000..2d09f59adcf7 --- /dev/null +++ b/docs/4-Authentication/vk.md @@ -0,0 +1,35 @@ +# VK SSO + +## Required Environment Variables + +- `VK_CLIENT_ID`: VK OAuth client ID +- `VK_CLIENT_SECRET`: VK OAuth client secret + +## Required APIs + +Ensure that you have the necessary VK APIs enabled by following these instructions. Once confirmed, add the `VK_CLIENT_ID` and `VK_CLIENT_SECRET` environment variables to your `.env` file. + +1. **VK API Access Setup:** + - Visit VK's [Developers Page](https://vk.com/dev) and create a new application if you haven't done so already. + - Note down the Application ID (this will be your VK Client ID) and secure your Application Secret (this will be your VK Client Secret). + - Configure your application to use VK API. + +2. **Get the VK Client ID and Client Secret:** + - After setting up your VK application, go to the application settings. + - From the application settings, retrieve the **Application ID** which will serve as `VK_CLIENT_ID`. + - Retrieve the **Secure Key** which will serve as `VK_CLIENT_SECRET`. + +Add these values to your `.env` file in the following format: + +```env +VK_CLIENT_ID=your_vk_client_id +VK_CLIENT_SECRET=your_vk_client_secret +``` + +## Required Scopes for VK SSO + +To authenticate users via VK SSO, you need the following scope: + +- `email` + +Make sure your VK application requests this scope during the OAuth authorization process. diff --git a/docs/4-Authentication/wechat.md b/docs/4-Authentication/wechat.md new file mode 100644 index 000000000000..f056263539ae --- /dev/null +++ b/docs/4-Authentication/wechat.md @@ -0,0 +1,36 @@ +# WeChat + +WeChat SSO allows users to log in to your application using their WeChat account. This is implemented in the `wechat.py` file, which handles the OAuth flow, token management, and fetching user information. + +## Required Environment Variables + +To use WeChat SSO, you need to add the following environment variables to your environment or `.env` file: + +- `WECHAT_CLIENT_ID`: WeChat OAuth client ID. +- `WECHAT_CLIENT_SECRET`: WeChat OAuth client secret. + +## Acquiring WeChat Client ID and Client Secret + +1. **Register Your Application:** + - Visit the [WeChat Open Platform](https://open.weixin.qq.com/) and sign in with your WeChat account. + - Navigate to the "Manage Center" and click on "Create Application". + - Fill out the required information about your application. + +2. **Get Your Credentials:** + - Once your application is created, navigate to the "Basic Configuration" section. + - Copy the `AppID` and `AppSecret`. These correspond to `WECHAT_CLIENT_ID` and `WECHAT_CLIENT_SECRET` respectively. + +3. **Add Redirect URI:** + - In the same section, add the authorization callback URL, which is the `redirect_uri` you will use for WeChat OAuth. + +4. **Set Environment Variables:** + - Add the `WECHAT_CLIENT_ID` and `WECHAT_CLIENT_SECRET` to your environment or `.env` file: + + ```text + WECHAT_CLIENT_ID=your_client_id + WECHAT_CLIENT_SECRET=your_client_secret + ``` + +## Required Scopes for WeChat SSO + +- `snsapi_userinfo`: This scope allows your application to fetch the user's profile information. diff --git a/docs/4-Authentication/withings.md b/docs/4-Authentication/withings.md new file mode 100644 index 000000000000..7107be391289 --- /dev/null +++ b/docs/4-Authentication/withings.md @@ -0,0 +1,35 @@ +# Withings SSO Integration + +This document provides detailed instructions on how to configure and use Withings Single Sign-On (SSO) for authenticating users and accessing user data using OAuth. + +## Required Environment Variables + +To set up Withings SSO, you need the following environment variables. Ensure you add these to your `.env` file. + +- `WITHINGS_CLIENT_ID`: Withings OAuth client ID +- `WITHINGS_CLIENT_SECRET`: Withings OAuth client secret + +### Steps to Acquire Withings Client ID and Secret + +1. **Register your application with Withings**: + - Visit the [Withings Developer Portal](https://developer.withings.com/). + - Log in or sign up if you don't already have an account. + - Create a new application under your account. + - Fill in the required details such as application name, description, and redirect URIs. + +2. **Generate the Client ID and Secret**: + - Once your application is created, navigate to the application details page. + - You will find the `Client ID` and `Client Secret` here. Copy these values to your `.env` file. + +```plaintext +WITHINGS_CLIENT_ID=your_withings_client_id +WITHINGS_CLIENT_SECRET=your_withings_client_secret +``` + +## Required Scopes for Withings SSO + +When configuring the Withings SSO, make sure to request the following scopes. These scopes ensure that your application has the necessary permissions to access user information and metrics. + +- `user.info`: Access basic user information. +- `user.metrics`: Access user's health metrics. +- `user.activity`: Access user's activity data. diff --git a/docs/4-Authentication/xero.md b/docs/4-Authentication/xero.md new file mode 100644 index 000000000000..75b56a1b08f1 --- /dev/null +++ b/docs/4-Authentication/xero.md @@ -0,0 +1,41 @@ +# Xero SSO Integration Documentation + +## Overview + +This document describes how to integrate Xero Single Sign-On (SSO) using the provided `xero.py` script. The script leverages Xero's OAuth 2.0 for authentication and retrieving user information. + +## Prerequisites + +1. **Create a Xero App**: + + To start using Xero's SSO, you need to create an app in the Xero Developer portal: + - Go to the [Xero Developer Portal](https://developer.xero.com/myapps). + - Sign in with your Xero account. + - Click on "New App" and fill in the necessary details. + - Application name: Provide a name for your application. + - Integration: Select the type of integration (e.g., Web application). + - OAuth 2.0 redirect URI: Provide your application's redirect URL. + - Once the app is created, you will get the `CLIENT_ID` and `CLIENT_SECRET`. These are necessary for the OAuth flow. + +2. **Environment Variables**: + + The keys fetched from the Xero Developer Portal need to be saved as environment variables in your project. + + - `XERO_CLIENT_ID`: Xero OAuth client ID + - `XERO_CLIENT_SECRET`: Xero OAuth client secret + + Add these to your `.env` file: + + ```plaintext + XERO_CLIENT_ID=your_client_id_here + XERO_CLIENT_SECRET=your_client_secret_here + ``` + +## Required Scopes + +When setting up OAuth access for Xero, ensure that you enable the following scopes: + +- `offline_access`: Allows your application to access Xero data when the user is not present. +- `openid`: Provides basic user information. +- `profile`: Access to the user's profile information. +- `email`: Access to the user's email address. diff --git a/docs/4-Authentication/xing.md b/docs/4-Authentication/xing.md new file mode 100644 index 000000000000..adb836d5fe90 --- /dev/null +++ b/docs/4-Authentication/xing.md @@ -0,0 +1,39 @@ +# Xing + +The provided module allows for Single Sign-On (SSO) with Xing and includes the functionality to retrieve user information as well as send emails. Follow the steps below to set up and use the Xing SSO module. + +## Required Environment Variables + +To use the Xing SSO module, you need to set up two environment variables: + +- `XING_CLIENT_ID`: Your Xing OAuth client ID +- `XING_CLIENT_SECRET`: Your Xing OAuth client secret + +## Acquiring Environment Variables + +1. **Xing OAuth Client ID and Client Secret**: + - You will first need to create an application on [Xing Developer Portal](https://dev.xing.com/). Here�s how: + 1. Sign up or log into the [Xing Developer Portal](https://dev.xing.com/). + 2. Create a new application (you might have to complete some verification steps). + 3. Once your application is created, you will get access to the client ID and client secret. + +2. **Setting Up Environment Variables**: + - Add the `XING_CLIENT_ID` and `XING_CLIENT_SECRET` to your `.env` file: + + ```env + XING_CLIENT_ID=your_xing_oauth_client_id + XING_CLIENT_SECRET=your_xing_oauth_client_secret + ``` + +## Required APIs + +Ensure you have the following APIs enabled: + +- [Xing API](https://dev.xing.com/) + +## Required Scopes for Xing SSO + +These are the API scopes required for Xing SSO: + +- `https://api.xing.com/v1/users/me` +- `https://api.xing.com/v1/authorize` diff --git a/docs/4-Authentication/yahoo.md b/docs/4-Authentication/yahoo.md new file mode 100644 index 000000000000..98a430ab7f98 --- /dev/null +++ b/docs/4-Authentication/yahoo.md @@ -0,0 +1,48 @@ +# Yahoo SSO + +The `YahooSSO` class and the `yahoo_sso` function facilitate Single Sign-On (SSO) with Yahoo, allowing you to retrieve user information (email, first name, last name) and send emails through Yahoo's mail services. Yahoo SSO requires specific OAuth credentials and certain API permissions to function properly. + +## Required Environment Variables + +To configure Yahoo SSO, you need to set the following environment variables. Add them to your `.env` file: + +- `YAHOO_CLIENT_ID`: Yahoo OAuth client ID +- `YAHOO_CLIENT_SECRET`: Yahoo OAuth client secret + +## Acquiring Yahoo OAuth Credentials + +1. **Create a Yahoo Developer Account:** + - Go to the [Yahoo Developer Network](https://developer.yahoo.com). + - Sign in with your Yahoo account or create a new one. + +2. **Create an App and Obtain Client ID and Secret:** + - Navigate to the [Yahoo Developer Dashboard](https://developer.yahoo.com/apps/). + - Click on "Create an App". + - Fill in the required details such as application name, description, and Redirect URI. + - Select the required permissions: `profile`, `email`, and `mail-w`. + - After creating the app, you will be provided with the `Client ID` and `Client Secret`. Note these down as you'll need to add them to your environment variables. + +3. **Add Redirect URI:** + - Ensure that you have specified a valid Redirect URI in your app settings. This URI is where Yahoo will redirect users after authentication with the authorization code. + - Example Redirect URI: `https://yourdomain.com/oauth/callback/yahoo`. + +## Required APIs and Scopes + +Ensure that your Yahoo app has the following scopes enabled: + +- `profile` +- `email` +- `mail-w` + +These scopes allow your application to view user profiles, retrieve email addresses, and send emails. + +## Setting Up Your Environment Variables + +Add the following lines to your `.env` file: + +```plaintext +YAHOO_CLIENT_ID=your_yahoo_client_id +YAHOO_CLIENT_SECRET=your_yahoo_client_secret +``` + +Replace `your_yahoo_client_id` and `your_yahoo_client_secret` with the credentials you obtained from the Yahoo Developer Dashboard. Set `MAGIC_LINK_URL` to your application's redirect URI. diff --git a/docs/4-Authentication/yammer.md b/docs/4-Authentication/yammer.md new file mode 100644 index 000000000000..cdbcea0d9053 --- /dev/null +++ b/docs/4-Authentication/yammer.md @@ -0,0 +1,40 @@ +# Yammer Integration + +This module provides Single Sign-On (SSO) and messaging capabilities using Yammer's OAuth 2.0. It allows a user to authenticate via Yammer, acquire an access token, retrieve user information, and send messages to Yammer groups. + +## Prerequisites + +Before you can use this module, you need to set up a few things on Yammer and obtain the necessary credentials. Here is a step-by-step guide to help you: + +## Step-by-Step Guide + +1. **Creating a Yammer App:** + - Go to the [Yammer Developer Site](https://www.yammer.com/client_applications) + - Click on "Register New App". + - Fill out the form with the required details such as: + - **App Name** + - **Organization** + - **Support Email** + - In the "Redirect URL" field, enter the URL where users will be redirected after authentication (usually your application's URL). + +2. **Obtaining the Client ID and Client Secret:** + - After creating your app, you will be taken to the app details page. + - Here, you will find your **Client ID** and **Client Secret**. + +3. **Environment Configuration:** + - Create a `.env` file in your project root directory if you don't already have one. + - Add the following environment variables to your `.env` file: + + ```plaintext + YAMMER_CLIENT_ID=your_yammer_client_id + YAMMER_CLIENT_SECRET=your_yammer_client_secret + ``` + +## Required APIs + +Make sure to enable the following Yammer API scopes: + +- `messages:email` +- `messages:post` + +These scopes can be configured when registering your app on the Yammer developer site. diff --git a/docs/4-Authentication/yandex.md b/docs/4-Authentication/yandex.md new file mode 100644 index 000000000000..e9455fbfa992 --- /dev/null +++ b/docs/4-Authentication/yandex.md @@ -0,0 +1,20 @@ +# Yandex + +## Required Environment Variables + +To integrate Yandex SSO into your application, you need to set the following environment variables: + +- `YANDEX_CLIENT_ID`: Your Yandex OAuth client ID. +- `YANDEX_CLIENT_SECRET`: Your Yandex OAuth client secret. + +## Required APIs + +Make sure the necessary APIs are enabled in your Yandex application. + +## Required Scopes for Yandex OAuth + +These OAuth scopes are required for Yandex SSO: + +- `login:info` +- `login:email` +- `mail.send` diff --git a/docs/4-Authentication/yelp.md b/docs/4-Authentication/yelp.md new file mode 100644 index 000000000000..9566b1b23a7c --- /dev/null +++ b/docs/4-Authentication/yelp.md @@ -0,0 +1,38 @@ +# Yelp SSO Integration + +The Yelp SSO integration allows users to authenticate and retrieve user information through Yelp's OAuth system. This section provides a detailed guide on how to configure and use the Yelp SSO in your application. + +## Required Environment Variables + +This integration requires the following environment variables: + +- `YELP_CLIENT_ID`: Yelp OAuth client ID +- `YELP_CLIENT_SECRET`: Yelp OAuth client secret + +To set these variables, ensure they are included in your application's environment configuration file (e.g., `.env`): + +```plaintext +YELP_CLIENT_ID=your_client_id_here +YELP_CLIENT_SECRET=your_client_secret_here +``` + +## Acquiring Client ID and Client Secret + +To obtain the `YELP_CLIENT_ID` and `YELP_CLIENT_SECRET`, follow these steps: + +1. **Register Your App:** + - Go to the [Yelp Developer Portal](https://www.yelp.com/developers/v3/get_started). + - Log in or create a Yelp account. + - Navigate to the "Create App" section. + - Fill out the required details to register your application. + +2. **Retrieve Credentials:** + - Once your application is registered, you will be provided with a `CLIENT_ID` and `CLIENT_SECRET`. + +## Required Scopes + +Ensure your application requests the necessary scopes for Yelp OAuth: + +- `business` + +These scopes allow the application to access specific user information and perform operations permitted by Yelp's API. diff --git a/docs/4-Authentication/zendesk.md b/docs/4-Authentication/zendesk.md new file mode 100644 index 000000000000..c3f9ecaf371b --- /dev/null +++ b/docs/4-Authentication/zendesk.md @@ -0,0 +1,51 @@ +# Zendesk SSO Integration Guide + +This guide provides detailed instructions to integrate Zendesk Single Sign-On (SSO) into your application using the provided `zendesk.py` script. By following these steps, you will set up the necessary environment variables, OAuth client, and required OAuth scopes to enable seamless authentication and email sending through Zendesk. + +## Required Environment Variables + +Before you proceed with the Zendesk integration, ensure that you have the following environment variables set up in your `.env` file: + +- `ZENDESK_CLIENT_ID`: Your Zendesk OAuth client ID +- `ZENDESK_CLIENT_SECRET`: Your Zendesk OAuth client secret +- `ZENDESK_SUBDOMAIN`: Your Zendesk subdomain (e.g., if your Zendesk URL is `https://yourcompany.zendesk.com`, then your subdomain is `yourcompany`) + +## Acquiring the Required Keys + +To acquire the required keys for setting up the environment variables: + +1. **Create a Zendesk OAuth Client:** + - Log in to your Zendesk Admin Center. + - Navigate to `Channels` > `API`. + - Under the `OAuth Clients` tab, click on the `Add OAuth Client` button. + - Fill in the required details: + - **Client Name:** A recognizable name for your OAuth client. + - **Description:** A brief description for your reference. + - **Client ID:** This will be automatically generated. Copy this value. + - **Client Secret:** Click the `Reveal` button to see the client secret. Copy this value too. + - **Redirect URLs:** Add the URL(s) where Zendesk will redirect after an authentication attempt. + - Save your new OAuth client. + +2. **Set the environment variables**: + Add the following lines to your `.env` file: + + ```env + ZENDESK_CLIENT_ID=your_zendesk_client_id + ZENDESK_CLIENT_SECRET=your_zendesk_client_secret + ZENDESK_SUBDOMAIN=your_zendesk_subdomain + ``` + +## Required APIs + +Make sure the following APIs are enabled in your Zendesk account: + +- OAuth API for authentication and token exchange. +- Users API to retrieve user information. +- Requests API to handle support requests and send emails. + +## Required Scopes for Zendesk OAuth + +When you set up your OAuth client, ensure that the following scopes are enabled: + +- `read`: To grant read access. +- `write`: To grant write access. diff --git a/examples/AGiXT-Expert-OAI.ipynb b/examples/AGiXT-Expert-OAI.ipynb index e04854d3c1b7..f335a4d1ac72 100644 --- a/examples/AGiXT-Expert-OAI.ipynb +++ b/examples/AGiXT-Expert-OAI.ipynb @@ -107,7 +107,7 @@ " agent_name=agent_name,\n", " file_name=zip_file_name,\n", " file_content=training_data,\n", - " collection_number=0,\n", + " collection_number=\"0\",\n", ")" ] }, diff --git a/examples/AGiXT-Expert-ezLocalai.ipynb b/examples/AGiXT-Expert-ezLocalai.ipynb index 46566943f430..459e611f9ddd 100644 --- a/examples/AGiXT-Expert-ezLocalai.ipynb +++ b/examples/AGiXT-Expert-ezLocalai.ipynb @@ -111,7 +111,7 @@ " agent_name=agent_name,\n", " file_name=zip_file_name,\n", " file_content=training_data,\n", - " collection_number=0,\n", + " collection_number=\"0\",\n", ")" ] }, diff --git a/examples/Chatbot.ipynb b/examples/Chatbot.ipynb index bd39f6f2ec7c..2bca2c1250c0 100644 --- a/examples/Chatbot.ipynb +++ b/examples/Chatbot.ipynb @@ -62,7 +62,7 @@ "outputs": [], "source": [ "ApiClient.learn_url(\n", - " agent_name=agent_name, url=\"https://josh-xt.github.io/AGiXT/\", collection_number=0\n", + " agent_name=agent_name, url=\"https://josh-xt.github.io/AGiXT/\", collection_number="0"\n", ")\n" ] }, diff --git a/requirements.txt b/requirements.txt index 39d259e1bf87..53c48915dd3b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -agixtsdk==0.0.45 +agixtsdk==0.0.47 safeexecute==0.0.9 google-generativeai==0.4.1 discord==2.3.2 @@ -17,5 +17,5 @@ google-api-python-client==2.125.0 google-auth-oauthlib python-multipart==0.0.9 nest_asyncio -g4f==0.3.1.9 +g4f==0.3.2.0 pyotp \ No newline at end of file diff --git a/tests/tests.ipynb b/tests/tests.ipynb index 448b626fc153..088f2f114b96 100644 --- a/tests/tests.ipynb +++ b/tests/tests.ipynb @@ -759,7 +759,7 @@ " agent_name=agent_name,\n", " user_input=\"What is AGiXT?\",\n", " text=\"AGiXT is an open-source artificial intelligence automation platform.\",\n", - " collection_number=0,\n", + " collection_number=\"0\",\n", ")\n", "print(text_learning)" ] @@ -801,7 +801,7 @@ " agent_name=agent_name,\n", " file_name=learn_file_path,\n", " file_content=learn_file_content,\n", - " collection_number=0,\n", + " collection_number=\"0\",\n", ")\n", "print(file_learning)" ] @@ -836,7 +836,7 @@ "url_learning = ApiClient.learn_url(\n", " agent_name=agent_name,\n", " url=\"https://josh-xt.github.io/AGiXT\",\n", - " collection_number=0,\n", + " collection_number=\"0\",\n", ")\n", "print(url_learning)" ] @@ -874,7 +874,7 @@ " user_input=\"What can you tell me about AGiXT?\",\n", " limit=10,\n", " min_relevance_score=0.2,\n", - " collection_number=0,\n", + " collection_number=\"0\",\n", ")\n", "print(memories)" ] @@ -949,7 +949,7 @@ " user_input=\"What can you tell me about AGiXT?\",\n", " limit=1,\n", " min_relevance_score=0.2,\n", - " collection_number=0,\n", + " collection_number=\"0\",\n", ")\n", "# Remove the first memory\n", "memory = memories[0]\n", @@ -957,7 +957,7 @@ "print(\"Memory:\", memory)\n", "# print(\"Memory ID:\", memory_id)\n", "# delete_memory_resp = ApiClient.delete_agent_memory(\n", - "# agent_name=agent_name, memory_id=memory_id, collection_number=0\n", + "# agent_name=agent_name, memory_id=memory_id, collection_number=\"0\"\n", "# )\n", "# print(\"Delete memory response:\", delete_memory_resp)" ] @@ -993,7 +993,7 @@ "# Note: Use this function with caution as it will erase the agent's memory.\n", "agent_name = \"new_agent\"\n", "wipe_mem_resp = ApiClient.wipe_agent_memories(\n", - " agent_name=agent_name, collection_number=0\n", + " agent_name=agent_name, collection_number=\"0\"\n", ")\n", "print(\"Wipe agent memories response:\", wipe_mem_resp)" ] From b04888d5519282d7c6ef457a4fe81d1a9980c183 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 14 Jun 2024 16:38:12 -0400 Subject: [PATCH 0166/1256] Add support for ppt and pptx uploads --- agixt/readers/file.py | 5 +++++ requirements.txt | 3 ++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/agixt/readers/file.py b/agixt/readers/file.py index 8a60468b6ed8..8d73df0c59eb 100644 --- a/agixt/readers/file.py +++ b/agixt/readers/file.py @@ -6,6 +6,7 @@ import zipfile import shutil import logging +from pptxtopdf import convert class FileReader(Memories): @@ -44,6 +45,10 @@ async def write_file_to_memory(self, file_path: str): else: file_path = os.path.normpath(file_path) filename = os.path.basename(file_path) + if file_path.endswith((".ppt", ".pptx")): + pdf_file_path = file_path.replace(".pptx", ".pdf").replace(".ppt", ".pdf") + convert(file_path, pdf_file_path) + file_path = pdf_file_path content = "" try: # If file extension is pdf, convert to text diff --git a/requirements.txt b/requirements.txt index 53c48915dd3b..7b2c0043eb80 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,4 +18,5 @@ google-auth-oauthlib python-multipart==0.0.9 nest_asyncio g4f==0.3.2.0 -pyotp \ No newline at end of file +pyotp +pptxtopdf==0.0.2 \ No newline at end of file From 5c63f5b2fe7b5da1241b39c36a3986a813816105 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 14 Jun 2024 16:59:36 -0400 Subject: [PATCH 0167/1256] revert change --- agixt/readers/file.py | 3 ++- requirements.txt | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/agixt/readers/file.py b/agixt/readers/file.py index 8d73df0c59eb..207eab3d1ef6 100644 --- a/agixt/readers/file.py +++ b/agixt/readers/file.py @@ -6,7 +6,6 @@ import zipfile import shutil import logging -from pptxtopdf import convert class FileReader(Memories): @@ -45,10 +44,12 @@ async def write_file_to_memory(self, file_path: str): else: file_path = os.path.normpath(file_path) filename = os.path.basename(file_path) + """ if file_path.endswith((".ppt", ".pptx")): pdf_file_path = file_path.replace(".pptx", ".pdf").replace(".ppt", ".pdf") convert(file_path, pdf_file_path) file_path = pdf_file_path + """ content = "" try: # If file extension is pdf, convert to text diff --git a/requirements.txt b/requirements.txt index 7b2c0043eb80..1e082060aecb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,4 +19,3 @@ python-multipart==0.0.9 nest_asyncio g4f==0.3.2.0 pyotp -pptxtopdf==0.0.2 \ No newline at end of file From 02fe048f18e343e9cdf049e5387a4acbb0d13a94 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 14 Jun 2024 18:06:06 -0400 Subject: [PATCH 0168/1256] add error --- agixt/XT.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agixt/XT.py b/agixt/XT.py index 51fe1172f484..e415b2241d1f 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -655,7 +655,7 @@ async def learn_from_file( if res == True: response = f"I have read the entire content of the file called {file_name} into my memory." else: - response = f"I was unable to read the file called {file_name}." + response = f"[ERROR] I was unable to read the file called {file_name}." if conversation_name != "" and conversation_name != None: c.log_interaction( role=self.agent_name, From ffc0e6c1d0878d5fc3dfc75cfd20877f4ccbfba8 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 14 Jun 2024 21:15:21 -0400 Subject: [PATCH 0169/1256] use localhost --- docker-compose-dev.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docker-compose-dev.yml b/docker-compose-dev.yml index fe8491418252..bbc7b48803de 100644 --- a/docker-compose-dev.yml +++ b/docker-compose-dev.yml @@ -22,7 +22,7 @@ services: USING_JWT: ${USING_JWT:-false} AGIXT_API_KEY: ${AGIXT_API_KEY} AGIXT_URI: ${AGIXT_URI-http://agixt:7437} - MAGIC_LINK_URL: ${AUTH_WEB-http://agixtinteractive:3437/user} + MAGIC_LINK_URL: ${AUTH_WEB-http://localhost:3437/user} DISABLED_EXTENSIONS: ${DISABLED_EXTENSIONS:-} DISABLED_PROVIDERS: ${DISABLED_PROVIDERS:-} WORKING_DIRECTORY: ${WORKING_DIRECTORY:-/agixt/WORKSPACE} @@ -258,8 +258,8 @@ services: APP_DESCRIPTION: ${APP_DESCRIPTION-A chat powered by AGiXT.} INTERACTIVE_MODE: ${INTERACTIVE_MODE-chat} APP_NAME: ${APP_NAME-AGiXT} - APP_URI: ${APP_URI-http://agixtinteractive:3437} - AUTH_WEB: ${AUTH_WEB-http://agixtinteractive:3437/user} + APP_URI: ${APP_URI-http://localhost:3437} + AUTH_WEB: ${AUTH_WEB-http://localhost:3437/user} LOG_VERBOSITY_SERVER: 3 THEME_NAME: ${THEME_NAME} ALLOW_EMAIL_SIGN_IN: ${ALLOW_EMAIL_SIGN_IN-true} From 40086bff5fc56beb427cc3711db5afbb580d3749 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Fri, 14 Jun 2024 21:21:41 -0400 Subject: [PATCH 0170/1256] add agent selector --- docker-compose-dev.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker-compose-dev.yml b/docker-compose-dev.yml index bbc7b48803de..aeee73a29a6c 100644 --- a/docker-compose-dev.yml +++ b/docker-compose-dev.yml @@ -253,7 +253,7 @@ services: AGIXT_SERVER: ${AGIXT_URI-http://agixt:7437} AGIXT_SHOW_AGENT_BAR: ${AGIXT_SHOW_AGENT_BAR-true} AGIXT_SHOW_APP_BAR: ${AGIXT_SHOW_APP_BAR-true} - AGIXT_SHOW_CONVERSATION_BAR: ${AGIXT_SHOW_CONVERSATION_BAR-true} + AGIXT_SHOW_SELECTION: ${AGIXT_SHOW_SELECTION-conversation,agent} AGIXT_CONVERSATION_MODE: ${AGIXT_CONVERSATION_MODE-select} APP_DESCRIPTION: ${APP_DESCRIPTION-A chat powered by AGiXT.} INTERACTIVE_MODE: ${INTERACTIVE_MODE-chat} From 936460cb8b124917ba93f98ce1d326afc6f2ec96 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sat, 15 Jun 2024 06:27:14 -0400 Subject: [PATCH 0171/1256] add get_agent_id for global --- agixt/Agent.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/agixt/Agent.py b/agixt/Agent.py index 7d7368632e5a..e7f8bf667ddf 100644 --- a/agixt/Agent.py +++ b/agixt/Agent.py @@ -526,5 +526,15 @@ def get_agent_id(self): .first() ) if not agent: - return None + agent = ( + self.session.query(AgentModel) + .filter( + AgentModel.name == self.agent_name, + AgentModel.user_id + == self.session.query(User).filter(User.email == DEFAULT_USER), + ) + .first() + ) + if not agent: + return None return agent.id From d139bae7c11c5f1aa19201a532fd420696ba74ca Mon Sep 17 00:00:00 2001 From: Josh XT <102809327+Josh-XT@users.noreply.github.com> Date: Sat, 15 Jun 2024 17:37:15 -0400 Subject: [PATCH 0172/1256] Improve File Uploads, Vision Always On (#1210) * add convert ppt to pdf * add pdf2img * Recall images and persist vision through conversations * fix broken ref * remove images ref * fix error * add data analysis function * add multifile data analysis support * improve prompt * add analyze_csv to pipeline * make sure path starts with current dir * handle path properly * improve handling of xls and xlsx * improve xls handling * always convert xls to csv * improve csv item format --- Dockerfile | 2 +- agixt/Agent.py | 3 +- agixt/Conversations.py | 48 ++- agixt/Interactions.py | 45 ++- agixt/Memories.py | 13 +- agixt/Prompts.py | 32 +- agixt/XT.py | 273 +++++++++++++++++- .../Default/Code Interpreter Multifile.txt | 37 +++ agixt/prompts/Default/Determine File.txt | 15 + .../Verify Code Interpreter Multifile.txt | 25 ++ agixt/readers/file.py | 7 +- static-requirements.txt | 4 +- 12 files changed, 476 insertions(+), 28 deletions(-) create mode 100644 agixt/prompts/Default/Code Interpreter Multifile.txt create mode 100644 agixt/prompts/Default/Determine File.txt create mode 100644 agixt/prompts/Default/Verify Code Interpreter Multifile.txt diff --git a/Dockerfile b/Dockerfile index 0d8ec1c509a2..fba5deca0824 100644 --- a/Dockerfile +++ b/Dockerfile @@ -15,7 +15,7 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ apt-get update --fix-missing ; \ apt-get upgrade -y ; \ curl -sL https://deb.nodesource.com/setup_14.x | bash - ; \ - apt-get install -y --fix-missing --no-install-recommends git build-essential gcc g++ sqlite3 libsqlite3-dev wget libgomp1 ffmpeg python3 python3-pip python3-dev curl postgresql-client libnss3 libnspr4 libatk1.0-0 libatk-bridge2.0-0 libcups2 libatspi2.0-0 libxcomposite1 nodejs libportaudio2 libasound-dev && \ + apt-get install -y --fix-missing --no-install-recommends git build-essential gcc g++ sqlite3 libsqlite3-dev wget libgomp1 ffmpeg python3 python3-pip python3-dev curl postgresql-client libnss3 libnspr4 libatk1.0-0 libatk-bridge2.0-0 libcups2 libatspi2.0-0 libxcomposite1 nodejs libportaudio2 libasound-dev libreoffice unoconv poppler-utils && \ apt-get install -y gcc-10 g++-10 && \ update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-10 10 && \ update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-10 10 && \ diff --git a/agixt/Agent.py b/agixt/Agent.py index e7f8bf667ddf..b9af3594464f 100644 --- a/agixt/Agent.py +++ b/agixt/Agent.py @@ -530,8 +530,7 @@ def get_agent_id(self): self.session.query(AgentModel) .filter( AgentModel.name == self.agent_name, - AgentModel.user_id - == self.session.query(User).filter(User.email == DEFAULT_USER), + User.email == DEFAULT_USER, ) .first() ) diff --git a/agixt/Conversations.py b/agixt/Conversations.py index 520b3876fc58..a56bbe285586 100644 --- a/agixt/Conversations.py +++ b/agixt/Conversations.py @@ -122,6 +122,47 @@ def get_conversation(self, limit=100, page=1): return_messages.append(msg) return {"interactions": return_messages} + def get_activities(self, limit=100, page=1): + session = get_session() + user_data = session.query(User).filter(User.email == self.user).first() + user_id = user_data.id + if not self.conversation_name: + self.conversation_name = f"{str(datetime.now())} Conversation" + conversation = ( + session.query(Conversation) + .filter( + Conversation.name == self.conversation_name, + Conversation.user_id == user_id, + ) + .first() + ) + if not conversation: + return {"activities": []} + offset = (page - 1) * limit + messages = ( + session.query(Message) + .filter(Message.conversation_id == conversation.id) + .order_by(Message.timestamp.asc()) + .limit(limit) + .offset(offset) + .all() + ) + if not messages: + return {"activities": []} + return_activities = [] + for message in messages: + if message.content.startswith("[ACTIVITY]"): + msg = { + "id": message.id, + "role": message.role, + "message": message.content, + "timestamp": message.timestamp, + } + return_activities.append(msg) + # Order messages by timestamp oldest to newest + return_activities = sorted(return_activities, key=lambda x: x["timestamp"]) + return {"activities": return_activities} + def new_conversation(self, conversation_content=[]): session = get_session() user_data = session.query(User).filter(User.email == self.user).first() @@ -189,7 +230,12 @@ def log_interaction(self, role, message): if role.lower() == "user": logging.info(f"{self.user}: {message}") else: - logging.info(f"{role}: {message}") + if "[WARN]" in message: + logging.warning(f"{role}: {message}") + elif "[ERROR]" in message: + logging.error(f"{role}: {message}") + else: + logging.info(f"{role}: {message}") def delete_conversation(self): session = get_session() diff --git a/agixt/Interactions.py b/agixt/Interactions.py index b117cb043028..040cef7ec008 100644 --- a/agixt/Interactions.py +++ b/agixt/Interactions.py @@ -194,17 +194,58 @@ async def format_prompt( limit=top_results, min_relevance_score=min_relevance_score, ) - context += await FileReader( + conversation_memories = FileReader( agent_name=self.agent_name, agent_config=self.agent.AGENT_CONFIG, collection_number=c.get_conversation_id(), ApiClient=self.ApiClient, user=self.user, - ).get_memories( + ) + conversation_context = await conversation_memories.get_memories( user_input=user_input, limit=top_results, min_relevance_score=min_relevance_score, ) + context += conversation_context + if "vision_provider" in self.agent.AGENT_CONFIG["settings"]: + vision_provider = self.agent.AGENT_CONFIG["settings"][ + "vision_provider" + ] + if ( + vision_provider != "None" + and vision_provider != "" + and vision_provider != None + ): + for memory in conversation_context: + # If the memory starts with "Sourced from image", get a new vision response to add and inject + if memory.startswith("Sourced from image "): + file_name = memory.split("Sourced from image ")[ + 1 + ].split(":")[0] + # File will be in the agent's workspace /{file_name} + file_path = os.path.join( + self.agent.working_directory, file_name + ) + if os.path.exists(file_path): + images = [file_path] + timestamp = datetime.now().strftime( + "%B %d, %Y %I:%M %p" + ) + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY] Looking at image `{file_name}`.", + ) + vision_response = await self.agent.inference( + prompt=user_input, images=images + ) + await conversation_memories.write_text_to_memory( + user_input=user_input, + text=f"{self.agent_name}'s visual description from viewing uploaded image called `{file_name}`:\n{vision_response}\n", + external_source=f"image {file_name}", + ) + context.append( + f"{self.agent_name}'s visual description from viewing uploaded image called `{file_name}`:\n{vision_response}\n" + ) else: context = [] if "context" in kwargs: diff --git a/agixt/Memories.py b/agixt/Memories.py index 3a725f6586e8..c38d64817e21 100644 --- a/agixt/Memories.py +++ b/agixt/Memories.py @@ -282,8 +282,10 @@ async def get_collection(self): ) except: try: - return self.chroma_client.get_or_create_collection( - name=self.collection_name, embedding_function=self.embedder + return self.chroma_client.create_collection( + name=self.collection_name, + embedding_function=self.embedder, + get_or_create=True, ) except: logging.warning(f"Error getting collection: {self.collection_name}") @@ -452,8 +454,13 @@ async def get_memories( if "external_source_name" in result else None ) + timestamp = ( + result["timestamp"] + if "timestamp" in result + else datetime.now().strftime("%Y-%m-%d %H:%M:%S") + ) if external_source: - metadata = f"Sourced from {external_source}:\n{metadata}" + metadata = f"Sourced from {external_source}:\nSourced on: {timestamp}\n{metadata}" if metadata not in response and metadata != "": response.append(metadata) return response diff --git a/agixt/Prompts.py b/agixt/Prompts.py index 544e3fa422c5..85782efdc26e 100644 --- a/agixt/Prompts.py +++ b/agixt/Prompts.py @@ -1,5 +1,6 @@ from DB import Prompt, PromptCategory, Argument, User, get_session from Globals import DEFAULT_USER +import os class Prompts: @@ -50,7 +51,7 @@ def add_prompt(self, prompt_name, prompt, prompt_category="Default"): self.session.add(argument) self.session.commit() - def get_prompt(self, prompt_name, prompt_category="Default"): + def get_prompt(self, prompt_name: str, prompt_category: str = "Default"): user_data = self.session.query(User).filter(User.email == DEFAULT_USER).first() prompt = ( self.session.query(Prompt) @@ -93,6 +94,35 @@ def get_prompt(self, prompt_name, prompt_category="Default"): ) .first() ) + if not prompt: + prompt_file = os.path.normpath( + os.path.join(os.getcwd(), "prompts", "Default", f"{prompt_name}.txt") + ) + base_path = os.path.join(os.getcwd(), "prompts") + if not prompt_file.startswith(base_path): + return None + if os.path.exists(prompt_file): + with open(prompt_file, "r") as f: + prompt_content = f.read() + self.add_prompt( + prompt_name=prompt_name, + prompt=prompt_content, + prompt_category="Default", + ) + prompt = ( + self.session.query(Prompt) + .filter( + Prompt.name == prompt_name, + Prompt.user_id == self.user_id, + Prompt.prompt_category.has(name="Default"), + ) + .join(PromptCategory) + .filter( + PromptCategory.name == "Default", + Prompt.user_id == self.user_id, + ) + .first() + ) if prompt: return prompt.content return None diff --git a/agixt/XT.py b/agixt/XT.py index e415b2241d1f..1839788ca351 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -9,12 +9,15 @@ from typing import Type, get_args, get_origin, Union, List from enum import Enum from pydantic import BaseModel +from pdf2image import convert_from_path +import pandas as pd +import subprocess import logging import asyncio import os +import re import base64 import uuid -import requests import json import time @@ -39,6 +42,7 @@ def __init__(self, user: str, agent_name: str, api_key: str): self.agent_workspace = self.agent.working_directory os.makedirs(self.agent_workspace, exist_ok=True) self.outputs = f"{self.uri}/outputs/{self.agent.agent_id}" + self.failures = 0 async def prompts(self, prompt_category: str = "Default"): """ @@ -564,15 +568,22 @@ async def learn_from_file( ) file_name = file_data["file_name"] file_path = os.path.join(self.agent_workspace, file_name) + file_type = file_name.split(".")[-1] + if file_type in ["ppt", "pptx"]: + # Convert it to a PDF + pdf_file_path = file_path.replace(".pptx", ".pdf").replace(".ppt", ".pdf") + subprocess.run( + ["unoconv", "-f", "pdf", "-o", pdf_file_path, file_path], check=True + ) + file_path = pdf_file_path if conversation_name != "" and conversation_name != None: c = Conversations(conversation_name=conversation_name, user=self.user_email) c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY] Reading file {file_name} into memory.", + message=f"[ACTIVITY] Reading file `{file_name}` into memory.", ) if user_input == "": user_input = "Describe each stage of this image." - file_type = file_name.split(".")[-1] file_reader = FileReader( agent_name=self.agent_name, agent_config=self.agent.AGENT_CONFIG, @@ -580,7 +591,68 @@ async def learn_from_file( ApiClient=self.ApiClient, user=self.user_email, ) - if ( + if file_type == "pdf": + # Turn the pdf to images, then run inference on each image + pdf_path = file_path + images = convert_from_path(pdf_path) + for i, image in enumerate(images): + image_path = os.path.join(self.agent_workspace, f"{file_name}_{i}.png") + image.save(image_path, "PNG") + await self.learn_from_file( + file_url=image_path, + file_name=f"{file_name}_{i}.png", + user_input=user_input, + collection_id=collection_id, + conversation_name=conversation_name, + ) + if file_type == "xlsx" or file_type == "xls": + df = pd.read_excel(file_path) + # Check if the spreadsheet has multiple sheets + if isinstance(df, dict): + sheet_names = list(df.keys()) + x = 0 + csv_files = [] + for sheet_name in sheet_names: + x += 1 + df = pd.read_excel(file_path, sheet_name=sheet_name) + file_path = file_path.replace(f".{file_type}", f"_{x}.csv") + csv_file_name = os.path.basename(file_path) + df.to_csv(file_path, index=False) + csv_files.append(f"`{csv_file_name}`") + await self.learn_from_file( + file_url=f"{self.outputs}/{csv_file_name}", + file_name=csv_file_name, + user_input=f"Original file: {file_name}\nSheet: {sheet_name}\nNew file: {csv_file_name}\n{user_input}", + collection_id=collection_id, + conversation_name=conversation_name, + ) + str_csv_files = ", ".join(csv_files) + response = f"Separated the content of the spreadsheet called {file_name} into {x} files called {str_csv_files} and read them into memory." + else: + # Save it as a CSV file and run this function again + file_path = file_path.replace(f".{file_type}", ".csv") + csv_file_name = os.path.basename(file_path) + df.to_csv(file_path, index=False) + return await self.learn_from_file( + file_url=f"{self.outputs}/{csv_file_name}", + file_name=csv_file_name, + user_input=f"Original file: {file_name}\nNew file: {csv_file_name}\n{user_input}", + collection_id=collection_id, + conversation_name=conversation_name, + ) + elif file_type == "csv": + df = pd.read_csv(file_path) + df_dict = df.to_dict() + for line in df_dict: + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + message = f"Content from file uploaded at {timestamp} named `{file_name}`:\n```json\n{json.dumps(df_dict[line], indent=2)}```\n" + await file_reader.write_text_to_memory( + user_input=f"{user_input}\n{message}", + text=message, + external_source=f"file {file_path}", + ) + response = f"Read the content of the file called {file_name} into memory." + elif ( file_type == "wav" or file_type == "mp3" or file_type == "ogg" @@ -600,7 +672,7 @@ async def learn_from_file( await file_reader.write_text_to_memory( user_input=user_input, text=f"Transcription from the audio file called `{file_name}`:\n{audio_response}\n", - external_source=f"Audio file called `{file_name}`", + external_source=f"audio {file_name}", ) response = ( f"I have transcribed the audio from `{file_name}` into my memory." @@ -632,12 +704,13 @@ async def learn_from_file( vision_response = await self.agent.inference( prompt=user_input, images=[file_url] ) + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") await file_reader.write_text_to_memory( user_input=user_input, - text=f"{self.agent_name}'s visual description from viewing uploaded image called `{file_name}`:\n{vision_response}\n", - external_source=f"Image called `{file_name}`", + text=f"{self.agent_name}'s visual description from viewing uploaded image called `{file_name}` from {timestamp}:\n{vision_response}\n", + external_source=f"image {file_name}", ) - response = f"I have generated a description of the image called `{file_name}` into my memory." + response = f"Generated a description of the image called `{file_name}` into my memory." except Exception as e: logging.error(f"Error getting vision response: {e}") response = f"[ERROR] I was unable to view the image called `{file_name}`." @@ -646,14 +719,11 @@ async def learn_from_file( f"[ERROR] I was unable to view the image called `{file_name}`." ) else: - if conversation_name != "" and conversation_name != None: - c.log_interaction( - role=self.agent_name, - message=f"[ACTIVITY] Reading file `{file_name}` into memory.", - ) res = await file_reader.write_file_to_memory(file_path=file_path) if res == True: - response = f"I have read the entire content of the file called {file_name} into my memory." + response = ( + f"Read the content of the file called {file_name} into memory." + ) else: response = f"[ERROR] I was unable to read the file called {file_name}." if conversation_name != "" and conversation_name != None: @@ -1073,6 +1143,8 @@ async def chat_completions(self, prompt: ChatCompletions): new_prompt += transcribed_audio # Add user input to conversation c = Conversations(conversation_name=conversation_name, user=self.user_email) + for file in files: + new_prompt += f"\nUploaded file: `{file['file_name']}`." c.log_interaction(role="USER", message=new_prompt) conversation_id = c.get_conversation_id() for file in files: @@ -1089,6 +1161,11 @@ async def chat_completions(self, prompt: ChatCompletions): summarize_content=False, conversation_name=conversation_name, ) + await self.analyze_csv( + user_input=new_prompt, + conversation_name=conversation_name, + file_content=None, + ) if mode == "command" and command_name and command_variable: try: command_args = ( @@ -1376,3 +1453,171 @@ async def convert_to_pydantic_model( response_type=response_type, failures=failures, ) + + def get_agent_workspace_markdown(self): + def generate_markdown_structure(folder_path, indent=0): + if not os.path.isdir(folder_path): + return "" + markdown_output = "" + items = sorted(os.listdir(folder_path)) + for item in items: + item_path = os.path.join(folder_path, item) + if os.path.isdir(item_path): + markdown_output += f"{' ' * indent}* **{item}/**\n" + markdown_output += generate_markdown_structure( + item_path, indent + 1 + ) + else: + markdown_output += f"{' ' * indent}* {item}\n" + return markdown_output + + return generate_markdown_structure(folder_path=self.agent_workspace) + + async def analyze_csv( + self, + user_input: str, + conversation_name: str, + file_content=None, + ): + c = Conversations(conversation_name=conversation_name, user=self.user_email) + if not file_content: + files = os.listdir(self.agent_workspace) + file_names = [] + file_name = "" + # Check if any files are csv files, if not, return empty string + csv_files = [file for file in files if file.endswith(".csv")] + if len(csv_files) == 0: + return "" + activities = c.get_activities(limit=20)["activities"] + if len(activities) == 0: + return "" + likely_files = [] + for activity in activities: + if ".csv" in activity["message"]: + likely_files.append(activity["message"].split("`")[1]) + if len(likely_files) == 0: + return "" + elif len(likely_files) == 1: + file_name = likely_files[0] + file_path = os.path.join(self.agent_workspace, file_name) + file_content = open(file_path, "r").read() + else: + file_determination = await self.inference( + user_input=user_input, + prompt_category="Default", + prompt_name="Determine File", + directory_listing="\n".join(csv_files), + conversation_results=10, + conversation_name=conversation_name, + log_user_input=False, + log_output=False, + voice_response=False, + ) + # Iterate over files and use regex to see if the file name is in the response + for file in files: + if re.search(file, file_determination): + file_names.append(file) + if len(file_names) == 1: + file_name = file_names[0] + file_path = os.path.join(self.agent_workspace, file_name) + file_content = open(file_path, "r").read() + if file_name == "": + return "" + if len(file_names) > 1: + # Found multiple files, do things a little differently. + previews = [] + import_files = "" + for file in file_names: + if import_files == "": + import_files = f"`{self.agent_workspace}/{file}`" + else: + import_files += f", `{self.agent_workspace}/{file}`" + file_path = os.path.join(self.agent_workspace, file) + file_content = open(file_path, "r").read() + lines = file_content.split("\n") + lines = lines[:2] + file_preview = "\n".join(lines) + previews.append(f"`{file_path}`\n```csv\n{file_preview}\n```") + file_preview = "\n".join(previews) + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY] Analyzing data from multiple files: {import_files}.", + ) + else: + lines = file_content.split("\n") + lines = lines[:2] + file_preview = "\n".join(lines) + c.log_interaction( + "[ACTIVITY] Analyzing data from file `{file_name}`.", + ) + code_interpreter = await self.inference( + user_input=user_input, + prompt_category="Default", + prompt_name=( + "Code Interpreter Multifile" + if len(file_names) > 1 + else "Code Interpreter" + ), + import_file=import_files if len(file_names) > 1 else file_path, + file_preview=file_preview, + conversation_name=conversation_name, + log_user_input=False, + log_output=False, + browse_links=False, + websearch=False, + websearch_depth=0, + voice_response=False, + ) + # Step 5 - Verify the code is good before executing it. + code_verification = await self.inference( + user_input=user_input, + prompt_category="Default", + prompt_name=( + "Verify Code Interpreter Multifile" + if len(file_names) > 1 + else "Verify Code Interpreter" + ), + import_file=import_files if len(file_names) > 1 else file_path, + file_preview=file_preview, + code=code_interpreter, + conversation_name=conversation_name, + log_user_input=False, + log_output=False, + browse_links=False, + websearch=False, + websearch_depth=0, + voice_response=False, + ) + # Step 6 - Execute the code, will need to revert to step 4 if the code is not correct to try again. + code_execution = await self.execute_command( + command_name="Execute Python Code", + command_args={"code": code_verification, "text": file_content}, + conversation_name=conversation_name, + ) + if not code_execution.startswith("Error"): + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY] Data analysis complete.", + ) + c.log_interaction( + role=self.agent_name, + message=f"## Results from analyzing data in `{file_name}`:\n{code_execution}", + ) + else: + self.failures += 1 + if self.failures < 3: + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY][WARN] Data analysis failed, trying again ({self.failures}/3).", + ) + return await self.analyze_csv( + user_input=user_input, + conversation_name=conversation_name, + file_content=file_content, + ) + else: + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY][ERROR] Data analysis failed after 3 attempts.", + ) + return code_execution diff --git a/agixt/prompts/Default/Code Interpreter Multifile.txt b/agixt/prompts/Default/Code Interpreter Multifile.txt new file mode 100644 index 000000000000..cba4273a4705 --- /dev/null +++ b/agixt/prompts/Default/Code Interpreter Multifile.txt @@ -0,0 +1,37 @@ +You are very powerful Python Code Interpreter, designed to assist with a wide range of tasks, particularly those related to data science, data analysis, data visualization, and file manipulation. + +Unlike many text-based AIs, You have the capability to directly manipulate files, convert images, and perform a variety of other tasks. Here are some examples: + +- Image Description and Manipulation: You can directly manipulate images, including zooming, cropping, color grading, and resolution enhancement. It can also convert images from one format to another. +- QR Code Generation: You can create QR codes for various purposes. +- Project Management: You can assist in creating Gantt charts and mapping out project steps. +- Study Scheduling: You can design optimized study schedules for exam preparation. +- File Conversion: You can directly convert files from one format to another, such as PDF to text or video to audio. +- Mathematical Computation: You can solve complex math equations and produce graphs. +- Document Analysis: You can analyze, summarize, or extract information from large documents. +- Data Visualization: You can analyze datasets, identify trends, and create various types of graphs. +- Geolocation Visualization: You can provide geolocation maps to showcase specific trends or occurrences. +- Code Analysis and Creation: You can analyze and critique code, and even create code from scratch. +- Many other things that can be accomplished running python code in a jupyter environment. +- Multiple visualizations are allowed as long as the return is a markdown string of the base64 image. +- The date today is {date} . + +You can execute Python code within a sandboxed Jupyter kernel environment. You come equipped with a variety of pre-installed Python packages including numpy, pandas, matplotlib, seaborn, scikit-learn, yfinance, scipy, statsmodels, sympy, bokeh, plotly, dash, and networkx. Additionally, you have the ability to use other packages which automatically get installed when found in the code, simply comment `# pip install packageName` anywhere in the code to have it automatically installed. + +Remember, You are constantly learning and improving. You are capable of generating human-like text based on the input it receives, engaging in natural-sounding conversations, and providing responses that are coherent and relevant to the topic at hand. Enjoy your coding session! + +If the user's input doesn't request any specific analysis or asks to surprise them, write code that will to plot something interesting to provide them with insights into the data through visualizations. + +**Make sure the final output of the code is a visualization. The functions final return should be a print of base64 image markdown string that can be displayed on a website parsing markdown code. Example `print('![Generated Image](data:image/png;base64,IMAGE_CONTENT)')`** + +You are working the with files: {import_file} + +Use these exact file paths in any code that will analyze them. + +CSV file previews: + +{file_preview} + +User's input: {user_input} + +```python \ No newline at end of file diff --git a/agixt/prompts/Default/Determine File.txt b/agixt/prompts/Default/Determine File.txt new file mode 100644 index 000000000000..864e0a071414 --- /dev/null +++ b/agixt/prompts/Default/Determine File.txt @@ -0,0 +1,15 @@ +Recent conversation history for context: + {conversation_history} + +Today's date is {date} . + +## Directory listing of the assistant's working directory + + {directory_listing} + +User's last message: + {user_input} + +The assistant needs to determine which file the user is referring to using context from their last message, activities, and conversation history. Provide only the name of the file the user is referring to based on the conversation history and recent activities. + +**Respond only with the name of the file. Multiple files can be selected by using new lines between file names. Simply say "None" if there is nothing relevant to do with any files from the user's last message and context.** diff --git a/agixt/prompts/Default/Verify Code Interpreter Multifile.txt b/agixt/prompts/Default/Verify Code Interpreter Multifile.txt new file mode 100644 index 000000000000..27e953e440cb --- /dev/null +++ b/agixt/prompts/Default/Verify Code Interpreter Multifile.txt @@ -0,0 +1,25 @@ +The date today is {date} . + +You are working the with files: {import_file} . + +Use this exact file path in any code that will analyze it. + +CSV file previews: + +{file_preview} + +We built python code to build a visualization for the user's input. + +User's input: {user_input} + +**Ensure the code does not modify the file system, if it does, remove the portion of the code that does.** + +**Make sure the final output of the code is a visualization. The functions final return should be a print of base64 image markdown string that can be displayed on a website parsing markdown code. Example `print('![Generated Image](data:image/png;base64,IMAGE_CONTENT)')`** + +**Confirm that the code follows the rules and CSV format. Return full updated code without placeholders that is confirmed.** + +```python +{code} +``` + +```python \ No newline at end of file diff --git a/agixt/readers/file.py b/agixt/readers/file.py index 207eab3d1ef6..a22f6941f8d7 100644 --- a/agixt/readers/file.py +++ b/agixt/readers/file.py @@ -6,6 +6,7 @@ import zipfile import shutil import logging +from datetime import datetime class FileReader(Memories): @@ -92,11 +93,11 @@ async def write_file_to_memory(self, file_path: str): with open(file_path, "r") as f: content = f.read() if content != "": - stored_content = f"From file: {filename}\n{content}" + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") await self.write_text_to_memory( user_input=file_path, - text=stored_content, - external_source=f"file called {filename}", + text=f"Content from file uploaded at {timestamp} named `{filename}`:\n{content}", + external_source=f"file {filename}", ) return True except: diff --git a/static-requirements.txt b/static-requirements.txt index cda4bd35e26b..4b43b8522534 100644 --- a/static-requirements.txt +++ b/static-requirements.txt @@ -23,4 +23,6 @@ sendgrid==6.11.0 httpx==0.27.0 numpy==1.26.4 mysql-connector-python==8.3.0 -pydub==0.25.1 \ No newline at end of file +pydub==0.25.1 +python-pptx==0.6.23 +pdf2image==1.17.0 \ No newline at end of file From f61f6c87c37b5f7798bfe3f6f9e9c74f205b8f13 Mon Sep 17 00:00:00 2001 From: Josh XT <102809327+Josh-XT@users.noreply.github.com> Date: Sat, 15 Jun 2024 19:32:07 -0400 Subject: [PATCH 0173/1256] Add GitHub Repo Download to Memory to Chat Completions (#1211) * move github repo download * fix security issue * secure path * handle path control * fix name ref * improve logic and security * fix typo --- agixt/Interactions.py | 3 +- agixt/Websearch.py | 56 ------------ agixt/XT.py | 203 +++++++++++++++++++++++++++++++++++++++--- 3 files changed, 194 insertions(+), 68 deletions(-) diff --git a/agixt/Interactions.py b/agixt/Interactions.py index 040cef7ec008..243d73282238 100644 --- a/agixt/Interactions.py +++ b/agixt/Interactions.py @@ -8,7 +8,6 @@ import uuid from datetime import datetime from readers.file import FileReader -from readers.github import GithubReader from Websearch import Websearch from Extensions import Extensions from ApiClient import ( @@ -66,7 +65,7 @@ def __init__( ApiClient=self.ApiClient, user=self.user, ) - self.github_memories = GithubReader( + self.github_memories = FileReader( agent_name=self.agent_name, agent_config=self.agent.AGENT_CONFIG, collection_number="7", diff --git a/agixt/Websearch.py b/agixt/Websearch.py index a0909dcc13d5..a8c5632db471 100644 --- a/agixt/Websearch.py +++ b/agixt/Websearch.py @@ -13,12 +13,9 @@ from ApiClient import Agent, Conversations from Globals import getenv, get_tokens from readers.youtube import YoutubeReader -from readers.github import GithubReader from datetime import datetime from Memories import extract_keywords from googleapiclient.discovery import build -from googleapiclient.errors import HttpError - logging.basicConfig( level=getenv("LOG_LEVEL"), @@ -162,59 +159,6 @@ async def get_web_content( external_source=url, ) return content, None - if url.startswith("https://github.com/"): - do_not_pull_repo = [ - "/pull/", - "/issues", - "/discussions", - "/actions/", - "/projects", - "/security", - "/releases", - "/commits", - "/branches", - "/tags", - "/stargazers", - "/watchers", - "/network", - "/settings", - "/compare", - "/archive", - ] - if any(x in url for x in do_not_pull_repo): - res = False - else: - if "/tree/" in url: - branch = url.split("/tree/")[1].split("/")[0] - else: - branch = "main" - res = await GithubReader( - agent_name=self.agent_name, - agent_config=self.agent.AGENT_CONFIG, - collection_number="7", - user=self.user, - ApiClient=self.ApiClient, - ).write_github_repository_to_memory( - github_repo=url, - github_user=( - self.agent_settings["GITHUB_USER"] - if "GITHUB_USER" in self.agent_settings - else None - ), - github_token=( - self.agent_settings["GITHUB_TOKEN"] - if "GITHUB_TOKEN" in self.agent_settings - else None - ), - github_branch=branch, - ) - if res: - self.browsed_links.append(url) - self.agent.add_browsed_link(url=url, conversation_id=conversation_id) - return ( - f"Content from GitHub repository at {url} has been added to memory.", - None, - ) try: async with async_playwright() as p: browser = await p.chromium.launch() diff --git a/agixt/XT.py b/agixt/XT.py index 1839788ca351..6cffeedf70eb 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -10,10 +10,14 @@ from enum import Enum from pydantic import BaseModel from pdf2image import convert_from_path +import pdfplumber +import docx2txt +import zipfile import pandas as pd import subprocess import logging import asyncio +import requests import os import re import base64 @@ -569,15 +573,20 @@ async def learn_from_file( file_name = file_data["file_name"] file_path = os.path.join(self.agent_workspace, file_name) file_type = file_name.split(".")[-1] + c = Conversations(conversation_name=conversation_name, user=self.user_email) if file_type in ["ppt", "pptx"]: # Convert it to a PDF pdf_file_path = file_path.replace(".pptx", ".pdf").replace(".ppt", ".pdf") + if conversation_name != "" and conversation_name != None: + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY] Converting PowerPoint file `{file_name}` to PDF.", + ) subprocess.run( ["unoconv", "-f", "pdf", "-o", pdf_file_path, file_path], check=True ) file_path = pdf_file_path if conversation_name != "" and conversation_name != None: - c = Conversations(conversation_name=conversation_name, user=self.user_email) c.log_interaction( role=self.agent_name, message=f"[ACTIVITY] Reading file `{file_name}` into memory.", @@ -591,7 +600,11 @@ async def learn_from_file( ApiClient=self.ApiClient, user=self.user_email, ) - if file_type == "pdf": + # The only thing we disallow is binary that we can't convert to text + disallowed_types = ["exe", "bin", "rar"] + if file_type in disallowed_types: + response = f"[ERROR] I was unable to read the file called `{file_name}`." + elif file_type == "pdf": # Turn the pdf to images, then run inference on each image pdf_path = file_path images = convert_from_path(pdf_path) @@ -605,7 +618,43 @@ async def learn_from_file( collection_id=collection_id, conversation_name=conversation_name, ) - if file_type == "xlsx" or file_type == "xls": + with pdfplumber.open(file_path) as pdf: + content = "\n".join([page.extract_text() for page in pdf.pages]) + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + await file_reader.write_text_to_memory( + user_input=user_input, + text=f"Content from PDF uploaded at {timestamp} named `{file_name}`:\n{content}", + external_source=f"file {file_path}", + ) + response = ( + f"Read the content of the PDF file called `{file_name}` into memory." + ) + elif file_path.endswith(".zip"): + new_folder = os.path.normpath( + os.path.join(self.agent_workspace, f"extracted_{file_name}") + ) + if os.path.normpath(file_path).startswith( + self.agent_workspace + ) and new_folder.startswith(self.agent_workspace): + with zipfile.ZipFile(file_path, "r") as zipObj: + zipObj.extractall(path=new_folder) + # Iterate over every file that was extracted including subdirectories + for root, dirs, files in os.walk(new_folder): + for name in files: + file_path = os.path.join(root, name) + await self.learn_from_file( + file_url=file_path, + file_name=name, + user_input=user_input, + collection_id=collection_id, + conversation_name=conversation_name, + ) + response = f"Extracted the content of the zip file called `{file_name}` and read them into memory." + else: + response = ( + f"[ERROR] I was unable to read the file called `{file_name}`." + ) + elif file_type == "xlsx" or file_type == "xls": df = pd.read_excel(file_path) # Check if the spreadsheet has multiple sheets if isinstance(df, dict): @@ -627,7 +676,7 @@ async def learn_from_file( conversation_name=conversation_name, ) str_csv_files = ", ".join(csv_files) - response = f"Separated the content of the spreadsheet called {file_name} into {x} files called {str_csv_files} and read them into memory." + response = f"Separated the content of the spreadsheet called `{file_name}` into {x} files called {str_csv_files} and read them into memory." else: # Save it as a CSV file and run this function again file_path = file_path.replace(f".{file_type}", ".csv") @@ -640,6 +689,14 @@ async def learn_from_file( collection_id=collection_id, conversation_name=conversation_name, ) + elif file_path.endswith(".doc") or file_path.endswith(".docx"): + file_content = docx2txt.process(file_path) + await file_reader.write_text_to_memory( + user_input=user_input, + text=file_content, + external_source=f"file {file_path}", + ) + response = f"Read the content of the file called `{file_name}` into memory." elif file_type == "csv": df = pd.read_csv(file_path) df_dict = df.to_dict() @@ -651,7 +708,7 @@ async def learn_from_file( text=message, external_source=f"file {file_path}", ) - response = f"Read the content of the file called {file_name} into memory." + response = f"Read the content of the file called `{file_name}` into memory." elif ( file_type == "wav" or file_type == "mp3" @@ -719,13 +776,22 @@ async def learn_from_file( f"[ERROR] I was unable to view the image called `{file_name}`." ) else: - res = await file_reader.write_file_to_memory(file_path=file_path) - if res == True: + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + if os.path.normpath(file_path).startswith(self.agent_workspace): + with open(file_path, "r") as f: + file_content = f.read() + await file_reader.write_text_to_memory( + user_input=user_input, + text=f"Content from file uploaded named `{file_name}` at {timestamp}:\n{file_content}", + external_source=f"file {file_path}", + ) response = ( - f"Read the content of the file called {file_name} into memory." + f"Read the content of the file called `{file_name}` into memory." ) else: - response = f"[ERROR] I was unable to read the file called {file_name}." + response = ( + f"[ERROR] I was unable to read the file called `{file_name}`." + ) if conversation_name != "" and conversation_name != None: c.log_interaction( role=self.agent_name, @@ -979,6 +1045,8 @@ async def chat_completions(self, prompt: ChatCompletions): dict: Chat completion response """ conversation_name = prompt.user + c = Conversations(conversation_name=conversation_name, user=self.user_email) + conversation_id = c.get_conversation_id() urls = [] files = [] new_prompt = "" @@ -1099,6 +1167,122 @@ async def chat_completions(self, prompt: ChatCompletions): for key, value in msg.items(): if "_url" in key: url = str(value["url"] if "url" in value else value) + if url.startswith("https://github.com/"): + do_not_pull_repo = [ + "/pull/", + "/issues", + "/discussions", + "/actions/", + "/projects", + "/security", + "/releases", + "/commits", + "/branches", + "/tags", + "/stargazers", + "/watchers", + "/network", + "/settings", + "/compare", + "/archive", + ] + if any(x in url for x in do_not_pull_repo): + # If the URL is not a repository, don't pull it + urls.append(url) + else: + # Download the zip for the repo + github_user = ( + self.agent_settings["GITHUB_USERNAME"] + if "GITHUB_USERNAME" in self.agent_settings + else None + ) + github_token = ( + self.agent_settings["GITHUB_TOKEN"] + if "GITHUB_TOKEN" in self.agent_settings + else None + ) + github_repo = url.replace( + "https://github.com/", "" + ) + github_repo = github_repo.replace( + "https://www.github.com/", "" + ) + if not github_branch: + github_branch = "main" + user = github_repo.split("/")[0] + repo = github_repo.split("/")[1] + if " " in repo: + repo = repo.split(" ")[0] + if "\n" in repo: + repo = repo.split("\n")[0] + # Remove any symbols that would not be in the user, repo, or branch + for symbol in [ + " ", + "\n", + "\t", + "\r", + "\\", + "/", + ":", + "*", + "?", + '"', + "<", + ">", + ]: + repo = repo.replace(symbol, "") + user = user.replace(symbol, "") + github_branch = github_branch.replace( + symbol, "" + ) + repo_url = f"https://github.com/{user}/{repo}/archive/refs/heads/{github_branch}.zip" + try: + if github_user and github_token: + response = requests.get( + repo_url, + auth=(github_user, github_token), + ) + else: + response = requests.get(repo_url) + except: + github_branch = "master" + repo_url = f"https://github.com/{user}/{repo}/archive/refs/heads/{github_branch}.zip" + try: + if github_user and github_token: + response = requests.get( + repo_url, + auth=( + github_user, + github_token, + ), + ) + else: + response = requests.get(repo_url) + except: + pass + if response.status_code == 200: + file_name = ( + f"{user}_{repo}_{github_branch}.zip" + ) + file_data = response.content + file_path = os.path.normpath( + os.path.join( + self.agent_workspace, file_name + ) + ) + if file_path.startswith( + self.agent_workspace + ): + with open(file_path, "wb") as f: + f.write(file_data) + files.append( + { + "file_name": file_name, + "file_url": f"{self.outputs}/{file_name}", + } + ) + else: + urls.append(url) if "file_name" in msg: file_name = str(msg["file_name"]) else: @@ -1142,7 +1326,6 @@ async def chat_completions(self, prompt: ChatCompletions): ) new_prompt += transcribed_audio # Add user input to conversation - c = Conversations(conversation_name=conversation_name, user=self.user_email) for file in files: new_prompt += f"\nUploaded file: `{file['file_name']}`." c.log_interaction(role="USER", message=new_prompt) From 09aa60648c5a4a9e197f7416c82f7843857a5001 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sat, 15 Jun 2024 20:09:16 -0400 Subject: [PATCH 0174/1256] Cite sources when context is injected. --- agixt/Interactions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agixt/Interactions.py b/agixt/Interactions.py index 243d73282238..5b71e4174052 100644 --- a/agixt/Interactions.py +++ b/agixt/Interactions.py @@ -255,7 +255,7 @@ async def format_prompt( ) if context != [] and context != "": context = "\n".join(context) - context = f"The user's input causes you remember these things:\n{context}\n" + context = f"The user's input causes you remember these things:\n{context}\n\nIf referencing a file, paper, or website, cite sources.\n" else: context = "" working_directory = self.agent.working_directory From 5efec08de4fa1357c11b7d14656a9874de2bcff9 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 16 Jun 2024 01:19:54 -0400 Subject: [PATCH 0175/1256] fix refs to api key --- agixt/MagicalAuth.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/agixt/MagicalAuth.py b/agixt/MagicalAuth.py index d789c3cb5bda..0188b524f18b 100644 --- a/agixt/MagicalAuth.py +++ b/agixt/MagicalAuth.py @@ -31,7 +31,7 @@ - SENDGRID_API_KEY: SendGrid API key - SENDGRID_FROM_EMAIL: Default email address to send emails from -- ENCRYPTION_SECRET: Encryption key to encrypt and decrypt data +- AGIXT_API_KEY: Encryption key to encrypt and decrypt data - MAGIC_LINK_URL: URL to send in the email for the user to click on - REGISTRATION_WEBHOOK: URL to send a POST request to when a user registers """ @@ -95,23 +95,23 @@ def webhook_create_user( def verify_api_key(authorization: str = Header(None)): - ENCRYPTION_SECRET = getenv("ENCRYPTION_SECRET") + AGIXT_API_KEY = getenv("AGIXT_API_KEY") if getenv("AUTH_PROVIDER") == "magicalauth": - ENCRYPTION_SECRET = f'{ENCRYPTION_SECRET}{datetime.now().strftime("%Y%m%d")}' + AGIXT_API_KEY = f'{AGIXT_API_KEY}{datetime.now().strftime("%Y%m%d")}' authorization = str(authorization).replace("Bearer ", "").replace("bearer ", "") - if ENCRYPTION_SECRET: + if AGIXT_API_KEY: if authorization is None: raise HTTPException( status_code=401, detail="Authorization header is missing" ) - if authorization == ENCRYPTION_SECRET: + if authorization == AGIXT_API_KEY: return "ADMIN" try: - if authorization == ENCRYPTION_SECRET: + if authorization == AGIXT_API_KEY: return "ADMIN" token = jwt.decode( jwt=authorization, - key=ENCRYPTION_SECRET, + key=AGIXT_API_KEY, algorithms=["HS256"], ) db = get_session() @@ -163,7 +163,7 @@ def send_email( class MagicalAuth: def __init__(self, token: str = None): - encryption_key = getenv("ENCRYPTION_SECRET") + encryption_key = getenv("AGIXT_API_KEY") self.link = getenv("MAGIC_LINK_URL") self.encryption_key = f'{encryption_key}{datetime.now().strftime("%Y%m%d")}' self.token = ( @@ -400,7 +400,7 @@ def register( requests.post( registration_webhook, json={"email": self.email}, - headers={"Authorization": getenv("ENCRYPTION_SECRET")}, + headers={"Authorization": getenv("AGIXT_API_KEY")}, ) except Exception as e: pass From 592810798c78e48c7998a8764a730d531883cd45 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Sun, 16 Jun 2024 02:13:42 -0400 Subject: [PATCH 0176/1256] improve refs --- docker-compose-dev.yml | 54 ++++++++++++++++++++---------------------- 1 file changed, 26 insertions(+), 28 deletions(-) diff --git a/docker-compose-dev.yml b/docker-compose-dev.yml index aeee73a29a6c..48a997b0fb98 100644 --- a/docker-compose-dev.yml +++ b/docker-compose-dev.yml @@ -19,16 +19,15 @@ services: DATABASE_NAME: ${DATABASE_NAME:-postgres} DATABASE_PORT: ${DATABASE_PORT:-5432} UVICORN_WORKERS: ${UVICORN_WORKERS:-10} - USING_JWT: ${USING_JWT:-false} - AGIXT_API_KEY: ${AGIXT_API_KEY} - AGIXT_URI: ${AGIXT_URI-http://agixt:7437} - MAGIC_LINK_URL: ${AUTH_WEB-http://localhost:3437/user} - DISABLED_EXTENSIONS: ${DISABLED_EXTENSIONS:-} - DISABLED_PROVIDERS: ${DISABLED_PROVIDERS:-} + AGIXT_API_KEY: ${AGIXT_API_KEY:-None} + AGIXT_URI: ${AGIXT_URI:-http://agixt:7437} + AUTH_PROVIDER: ${AUTH_PROVIDER:-magicalauth} + MAGIC_LINK_URL: ${AUTH_WEB:-http://localhost:3437/user} + DISABLED_EXTENSIONS: ${DISABLED_EXTENSIONS} + DISABLED_PROVIDERS: ${DISABLED_PROVIDERS} WORKING_DIRECTORY: ${WORKING_DIRECTORY:-/agixt/WORKSPACE} TOKENIZERS_PARALLELISM: False LOG_LEVEL: ${LOG_LEVEL:-INFO} - AUTH_PROVIDER: ${AUTH_PROVIDER:-none} AOL_CLIENT_ID: ${AOL_CLIENT_ID} AOL_CLIENT_SECRET: ${AOL_CLIENT_SECRET} APPLE_CLIENT_ID: ${APPLE_CLIENT_ID} @@ -151,7 +150,7 @@ services: ZENDESK_CLIENT_ID: ${ZENDESK_CLIENT_ID} ZENDESK_CLIENT_SECRET: ${ZENDESK_CLIENT_SECRET} ZENDESK_SUBDOMAIN: ${ZENDESK_SUBDOMAIN} - TZ: ${TZ-America/New_York} + TZ: ${TZ:-America/New_York} ports: - 7437:7437 volumes: @@ -168,8 +167,7 @@ services: depends_on: - agixt environment: - AGIXT_URI: ${AGIXT_URI-http://agixt:7437} - AGIXT_API_KEY: ${AGIXT_API_KEY} + AGIXT_URI: ${AGIXT_URI:-http://agixt:7437} AOL_CLIENT_ID: ${AOL_CLIENT_ID} APPLE_CLIENT_ID: ${APPLE_CLIENT_ID} AUTODESK_CLIENT_ID: ${AUTODESK_CLIENT_ID} @@ -244,25 +242,25 @@ services: init: true environment: NEXT_TELEMETRY_DISABLED: 1 - AGIXT_AGENT: ${AGIXT_AGENT-gpt4free} - AGIXT_FILE_UPLOAD_ENABLED: ${AGIXT_FILE_UPLOAD_ENABLED-true} - AGIXT_VOICE_INPUT_ENABLED: ${AGIXT_VOICE_INPUT_ENABLED-true} - AGIXT_FOOTER_MESSAGE: ${AGIXT_FOOTER_MESSAGE-Powered by AGiXT} - AGIXT_REQUIRE_API_KEY: ${AGIXT_REQUIRE_API_KEY-false} - AGIXT_RLHF: ${AGIXT_RLHF-true} - AGIXT_SERVER: ${AGIXT_URI-http://agixt:7437} - AGIXT_SHOW_AGENT_BAR: ${AGIXT_SHOW_AGENT_BAR-true} - AGIXT_SHOW_APP_BAR: ${AGIXT_SHOW_APP_BAR-true} - AGIXT_SHOW_SELECTION: ${AGIXT_SHOW_SELECTION-conversation,agent} - AGIXT_CONVERSATION_MODE: ${AGIXT_CONVERSATION_MODE-select} - APP_DESCRIPTION: ${APP_DESCRIPTION-A chat powered by AGiXT.} - INTERACTIVE_MODE: ${INTERACTIVE_MODE-chat} - APP_NAME: ${APP_NAME-AGiXT} - APP_URI: ${APP_URI-http://localhost:3437} - AUTH_WEB: ${AUTH_WEB-http://localhost:3437/user} + AGIXT_AGENT: ${AGIXT_AGENT:-gpt4free} + AGIXT_FILE_UPLOAD_ENABLED: ${AGIXT_FILE_UPLOAD_ENABLED:-true} + AGIXT_VOICE_INPUT_ENABLED: ${AGIXT_VOICE_INPUT_ENABLED:-true} + AGIXT_FOOTER_MESSAGE: ${AGIXT_FOOTER_MESSAGE:-Powered by AGiXT} + AGIXT_REQUIRE_API_KEY: ${AGIXT_REQUIRE_API_KEY:-false} + AGIXT_RLHF: ${AGIXT_RLH:-true} + AGIXT_SERVER: ${AGIXT_URI:-http://agixt:7437} + AGIXT_SHOW_AGENT_BAR: ${AGIXT_SHOW_AGENT_BAR:-true} + AGIXT_SHOW_APP_BAR: ${AGIXT_SHOW_APP_BAR:-true} + AGIXT_SHOW_SELECTION: ${AGIXT_SHOW_SELECTION:-conversation,agent} + AGIXT_CONVERSATION_MODE: ${AGIXT_CONVERSATION_MODE:-select} + APP_DESCRIPTION: ${APP_DESCRIPTION:-A chat powered by AGiXT.} + INTERACTIVE_MODE: ${INTERACTIVE_MODE:-chat} + APP_NAME: ${APP_NAME:-AGiXT} + APP_URI: ${APP_URI:-http://localhost:3437} + AUTH_WEB: ${AUTH_WEB:-http://localhost:3437/user} LOG_VERBOSITY_SERVER: 3 THEME_NAME: ${THEME_NAME} - ALLOW_EMAIL_SIGN_IN: ${ALLOW_EMAIL_SIGN_IN-true} + ALLOW_EMAIL_SIGN_IN: ${ALLOW_EMAIL_SIGN_IN:-true} AOL_CLIENT_ID: ${AOL_CLIENT_ID} APPLE_CLIENT_ID: ${APPLE_CLIENT_ID} AUTODESK_CLIENT_ID: ${AUTODESK_CLIENT_ID} @@ -328,7 +326,7 @@ services: YELP_CLIENT_ID: ${YELP_CLIENT_ID} ZENDESK_CLIENT_ID: ${ZENDESK_CLIENT_ID} ZENDESK_SUBDOMAIN: ${ZENDESK_SUBDOMAIN} - TZ: ${TZ-America/New_York} + TZ: ${TZ:-America/New_York} ports: - 3437:3437 restart: unless-stopped From 7a0c144b507e2af2aba7221866b600c67954a9a4 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 17 Jun 2024 09:22:42 -0400 Subject: [PATCH 0177/1256] retrieve email properly for agent id --- agixt/Agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agixt/Agent.py b/agixt/Agent.py index b9af3594464f..db155e51fc28 100644 --- a/agixt/Agent.py +++ b/agixt/Agent.py @@ -530,7 +530,7 @@ def get_agent_id(self): self.session.query(AgentModel) .filter( AgentModel.name == self.agent_name, - User.email == DEFAULT_USER, + AgentModel.user.has(email=DEFAULT_USER), ) .first() ) From 7a7fda00b21b0eb0ef0602df2c916e6995d7ba09 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 17 Jun 2024 10:52:28 -0400 Subject: [PATCH 0178/1256] fix learn file endpoint --- agixt/XT.py | 3 ++- agixt/endpoints/Memory.py | 39 +++++++++++++-------------------------- 2 files changed, 15 insertions(+), 27 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index 6cffeedf70eb..9bf5defd7026 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -1731,7 +1731,8 @@ async def analyze_csv( lines = lines[:2] file_preview = "\n".join(lines) c.log_interaction( - "[ACTIVITY] Analyzing data from file `{file_name}`.", + role=self.agent_name, + message=f"[ACTIVITY] Analyzing data from file `{file_name}`.", ) code_interpreter = await self.inference( user_input=user_input, diff --git a/agixt/endpoints/Memory.py b/agixt/endpoints/Memory.py index 4c8383b2eb78..3ec739189d73 100644 --- a/agixt/endpoints/Memory.py +++ b/agixt/endpoints/Memory.py @@ -12,6 +12,7 @@ from readers.file import FileReader from readers.arxiv import ArxivReader from readers.youtube import YoutubeReader +from datetime import datetime from Models import ( AgentMemoryQuery, TextMemoryInput, @@ -142,12 +143,11 @@ async def learn_file( user=Depends(verify_api_key), authorization: str = Header(None), ) -> ResponseMessage: - ApiClient = get_api_client(authorization=authorization) # Strip any path information from the file name + agixt = AGiXT(user=user, agent_name=agent_name, api_key=authorization) file.file_name = os.path.basename(file.file_name) - base_path = os.path.join(os.getcwd(), "WORKSPACE") - file_path = os.path.normpath(os.path.join(base_path, file.file_name)) - if not file_path.startswith(base_path): + file_path = os.path.normpath(os.path.join(agixt.agent_workspace, file.file_name)) + if not file_path.startswith(agixt.agent_workspace): raise Exception("Path given not allowed") try: file_content = base64.b64decode(file.file_content) @@ -155,28 +155,15 @@ async def learn_file( file_content = file.file_content.encode("utf-8") with open(file_path, "wb") as f: f.write(file_content) - try: - agent_config = Agent( - agent_name=agent_name, user=user, ApiClient=ApiClient - ).get_agent_config() - await FileReader( - agent_name=agent_name, - agent_config=agent_config, - collection_number=str(file.collection_number), - ApiClient=ApiClient, - user=user, - ).write_file_to_memory(file_path=file_path) - try: - os.remove(file_path) - except Exception: - pass - return ResponseMessage(message="Agent learned the content from the file.") - except Exception as e: - try: - os.remove(file_path) - except Exception: - pass - raise HTTPException(status_code=500, detail=str(e)) + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + response = await agixt.learn_from_file( + file_url=f"{agixt.outputs}/{file.file_name}", + file_name=file.file_name, + user_input=f"File {file.file_name} uploaded on {timestamp} to {agixt.outputs}/{file.file_name} .", + collection_id=str(file.collection_number), + conversation_name=f"File uploaded on {timestamp}", + ) + return ResponseMessage(message=response) @app.post( From dfe789d06ef0f0d8ea18835165429e3d8c0215c2 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 17 Jun 2024 11:02:57 -0400 Subject: [PATCH 0179/1256] allow any type for collection --- agixt/Models.py | 2 +- agixt/XT.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/agixt/Models.py b/agixt/Models.py index e30d0f4fd20d..22444b953062 100644 --- a/agixt/Models.py +++ b/agixt/Models.py @@ -162,7 +162,7 @@ class UrlInput(BaseModel): class FileInput(BaseModel): file_name: str file_content: str - collection_number: Optional[str] = "0" + collection_number: Optional[Any] = "0" class TextMemoryInput(BaseModel): diff --git a/agixt/XT.py b/agixt/XT.py index 9bf5defd7026..e09b9f375c8c 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -831,8 +831,12 @@ async def download_file_to_workspace(self, url: str, file_name: str = ""): if url.startswith("http"): return {"file_name": file_name, "file_url": url} else: - file_type = url.split(",")[0].split("/")[1].split(";")[0] - file_data = base64.b64decode(url.split(",")[1]) + if "," in url: + file_type = url.split(",")[0].split("/")[1].split(";")[0] + file_data = base64.b64decode(url.split(",")[1]) + else: + file_type = file_name.split(".")[-1] + file_data = base64.b64decode(url) full_path = os.path.normpath(os.path.join(self.agent_workspace, file_name)) if not full_path.startswith(self.agent_workspace): raise Exception("Path given not allowed") From bcde1026a405844b72a7ccde0e63a0fb66e56228 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 17 Jun 2024 11:15:12 -0400 Subject: [PATCH 0180/1256] add logging --- agixt/XT.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/agixt/XT.py b/agixt/XT.py index e09b9f375c8c..f0e6a1fbdfbc 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -562,11 +562,17 @@ async def learn_from_file( Returns: str: Response from the agent """ + logging.info(f"Learning from file: {file_url}") + logging.info(f"File name: {file_name}") + logging.info(f"User input: {user_input}") + logging.info(f"Collection ID: {collection_id}") + logging.info(f"Conversation name: {conversation_name}") if file_name == "": file_name = file_url.split("/")[-1] if file_url.startswith(self.outputs): - file_path = os.path.join(self.agent_workspace, file_name) + file_path = os.path.join(self.agent_workspace, file_url.split("/")[-1]) else: + logging.info(f"{file_url} does not start with {self.outputs}") file_data = await self.download_file_to_workspace( url=file_url, file_name=file_name ) From 669cecf460316c2e545d2746c301b40e5eb84a9c Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 17 Jun 2024 11:26:28 -0400 Subject: [PATCH 0181/1256] fix output ref --- agixt/XT.py | 15 +++++++++------ agixt/endpoints/Memory.py | 2 +- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index f0e6a1fbdfbc..9c524081a3a7 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -555,7 +555,8 @@ async def learn_from_file( Args: file_url (str): URL of the file - file_path (str): Path to the file + file_name (str): Name of the file + user_input (str): User input to the agent collection_id (str): Collection ID to save the file to conversation_name (str): Name of the conversation @@ -615,11 +616,12 @@ async def learn_from_file( pdf_path = file_path images = convert_from_path(pdf_path) for i, image in enumerate(images): - image_path = os.path.join(self.agent_workspace, f"{file_name}_{i}.png") + image_file_name = f"{file_name}_{i}.png" + image_path = os.path.join(self.agent_workspace, image_file_name) image.save(image_path, "PNG") await self.learn_from_file( - file_url=image_path, - file_name=f"{file_name}_{i}.png", + file_url=f"{self.outputs}/{image_file_name}", + file_name=image_file_name, user_input=user_input, collection_id=collection_id, conversation_name=conversation_name, @@ -648,9 +650,10 @@ async def learn_from_file( for root, dirs, files in os.walk(new_folder): for name in files: file_path = os.path.join(root, name) + file_name = os.path.basename(file_path) await self.learn_from_file( - file_url=file_path, - file_name=name, + file_url=f"{self.outputs}/{file_name}", + file_name=file_name, user_input=user_input, collection_id=collection_id, conversation_name=conversation_name, diff --git a/agixt/endpoints/Memory.py b/agixt/endpoints/Memory.py index 3ec739189d73..21698b140f3f 100644 --- a/agixt/endpoints/Memory.py +++ b/agixt/endpoints/Memory.py @@ -161,7 +161,7 @@ async def learn_file( file_name=file.file_name, user_input=f"File {file.file_name} uploaded on {timestamp} to {agixt.outputs}/{file.file_name} .", collection_id=str(file.collection_number), - conversation_name=f"File uploaded on {timestamp}", + conversation_name=f"{datetime.now().strftime('%Y-%m-%d')} Conversation", ) return ResponseMessage(message=response) From 4852c264ef21ae22de2f5edb503e84652f2b406e Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 17 Jun 2024 11:31:26 -0400 Subject: [PATCH 0182/1256] fix zip upload --- agixt/XT.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index 9c524081a3a7..fa072e33bb9f 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -650,10 +650,9 @@ async def learn_from_file( for root, dirs, files in os.walk(new_folder): for name in files: file_path = os.path.join(root, name) - file_name = os.path.basename(file_path) await self.learn_from_file( - file_url=f"{self.outputs}/{file_name}", - file_name=file_name, + file_url=f"{self.outputs}/extracted_{file_name}/{name}", + file_name=name, user_input=user_input, collection_id=collection_id, conversation_name=conversation_name, From 712d7a0be92b113d5b5dd26afb659e3ff0a81f53 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 17 Jun 2024 11:33:53 -0400 Subject: [PATCH 0183/1256] use extracted_zip_folder_name --- agixt/XT.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index fa072e33bb9f..caf03ccdb343 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -638,8 +638,9 @@ async def learn_from_file( f"Read the content of the PDF file called `{file_name}` into memory." ) elif file_path.endswith(".zip"): + extracted_zip_folder_name = f"extracted_{file_name.replace('.zip', '_zip')}" new_folder = os.path.normpath( - os.path.join(self.agent_workspace, f"extracted_{file_name}") + os.path.join(self.agent_workspace, extracted_zip_folder_name) ) if os.path.normpath(file_path).startswith( self.agent_workspace @@ -651,7 +652,7 @@ async def learn_from_file( for name in files: file_path = os.path.join(root, name) await self.learn_from_file( - file_url=f"{self.outputs}/extracted_{file_name}/{name}", + file_url=f"{self.outputs}/{extracted_zip_folder_name}/{name}", file_name=name, user_input=user_input, collection_id=collection_id, From 78b437d2cd83f0ff8d264180b277f3a16b9e6988 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 17 Jun 2024 11:40:19 -0400 Subject: [PATCH 0184/1256] fix output url on zip upload --- agixt/XT.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/agixt/XT.py b/agixt/XT.py index caf03ccdb343..93b855337acc 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -651,8 +651,10 @@ async def learn_from_file( for root, dirs, files in os.walk(new_folder): for name in files: file_path = os.path.join(root, name) + output_url = f"{self.outputs}/{extracted_zip_folder_name}/{dirs[0]}/{name}" + logging.info(f"Output URL: {output_url}") await self.learn_from_file( - file_url=f"{self.outputs}/{extracted_zip_folder_name}/{name}", + file_url=output_url, file_name=name, user_input=user_input, collection_id=collection_id, From da50956674e9aa493a38c6250f89aae37cbff0b6 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 17 Jun 2024 11:43:55 -0400 Subject: [PATCH 0185/1256] fix zip paths --- agixt/XT.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/agixt/XT.py b/agixt/XT.py index 93b855337acc..82467b7d44c3 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -651,7 +651,8 @@ async def learn_from_file( for root, dirs, files in os.walk(new_folder): for name in files: file_path = os.path.join(root, name) - output_url = f"{self.outputs}/{extracted_zip_folder_name}/{dirs[0]}/{name}" + current_folder = root.replace(new_folder, "") + output_url = f"{self.outputs}/{extracted_zip_folder_name}/{current_folder}/{name}" logging.info(f"Output URL: {output_url}") await self.learn_from_file( file_url=output_url, From fbd07f78ab3033347ade5be4566915e64a4e6155 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 17 Jun 2024 11:48:18 -0400 Subject: [PATCH 0186/1256] split on url for depth into workspace --- agixt/XT.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/agixt/XT.py b/agixt/XT.py index 82467b7d44c3..1ed51a3cec2f 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -571,7 +571,9 @@ async def learn_from_file( if file_name == "": file_name = file_url.split("/")[-1] if file_url.startswith(self.outputs): - file_path = os.path.join(self.agent_workspace, file_url.split("/")[-1]) + file_path = os.path.join( + self.agent_workspace, file_url.split(self.outputs)[1] + ) else: logging.info(f"{file_url} does not start with {self.outputs}") file_data = await self.download_file_to_workspace( From 19c20f732eb62be855b7c69eb91703c19d9bfe79 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 17 Jun 2024 12:03:01 -0400 Subject: [PATCH 0187/1256] remove url from input --- agixt/endpoints/Memory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agixt/endpoints/Memory.py b/agixt/endpoints/Memory.py index 21698b140f3f..af9171ca577c 100644 --- a/agixt/endpoints/Memory.py +++ b/agixt/endpoints/Memory.py @@ -159,7 +159,7 @@ async def learn_file( response = await agixt.learn_from_file( file_url=f"{agixt.outputs}/{file.file_name}", file_name=file.file_name, - user_input=f"File {file.file_name} uploaded on {timestamp} to {agixt.outputs}/{file.file_name} .", + user_input=f"File {file.file_name} uploaded on {timestamp}.", collection_id=str(file.collection_number), conversation_name=f"{datetime.now().strftime('%Y-%m-%d')} Conversation", ) From bf734e8e66e8c67335ad10a5fc05ff9224741e8f Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 17 Jun 2024 12:09:46 -0400 Subject: [PATCH 0188/1256] add logging --- agixt/XT.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/agixt/XT.py b/agixt/XT.py index 1ed51a3cec2f..bee60151ec92 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -644,7 +644,11 @@ async def learn_from_file( new_folder = os.path.normpath( os.path.join(self.agent_workspace, extracted_zip_folder_name) ) - if os.path.normpath(file_path).startswith( + logging.info(f"Extracting zip file to {new_folder}") + norm_file_path = os.path.normpath(file_path) + logging.info(f"Normalized file path: {norm_file_path}") + logging.info(f"Agent workspace: {self.agent_workspace}") + if norm_file_path.startswith( self.agent_workspace ) and new_folder.startswith(self.agent_workspace): with zipfile.ZipFile(file_path, "r") as zipObj: From 658b4834b42f9036d9a62a528ff32e69bb8233dc Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 17 Jun 2024 12:22:51 -0400 Subject: [PATCH 0189/1256] Updates --- agixt/XT.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agixt/XT.py b/agixt/XT.py index bee60151ec92..56b2fd659fa6 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -616,7 +616,7 @@ async def learn_from_file( elif file_type == "pdf": # Turn the pdf to images, then run inference on each image pdf_path = file_path - images = convert_from_path(pdf_path) + images = convert_from_path(pdf_path, output_folder=self.agent_workspace) for i, image in enumerate(images): image_file_name = f"{file_name}_{i}.png" image_path = os.path.join(self.agent_workspace, image_file_name) From 24b896b0695ba70f4ce535847ecb732f10e45876 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 17 Jun 2024 12:41:01 -0400 Subject: [PATCH 0190/1256] fix path --- agixt/XT.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/agixt/XT.py b/agixt/XT.py index 56b2fd659fa6..c5ff3721174d 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -620,6 +620,7 @@ async def learn_from_file( for i, image in enumerate(images): image_file_name = f"{file_name}_{i}.png" image_path = os.path.join(self.agent_workspace, image_file_name) + path = os.path.normpath(image_path).split(self.agent_workspace)[1] image.save(image_path, "PNG") await self.learn_from_file( file_url=f"{self.outputs}/{image_file_name}", @@ -645,7 +646,9 @@ async def learn_from_file( os.path.join(self.agent_workspace, extracted_zip_folder_name) ) logging.info(f"Extracting zip file to {new_folder}") - norm_file_path = os.path.normpath(file_path) + norm_file_path = os.path.normpath( + os.path.join(self.agent_workspace, file_name) + ) logging.info(f"Normalized file path: {norm_file_path}") logging.info(f"Agent workspace: {self.agent_workspace}") if norm_file_path.startswith( From b171deafe4c2000ee69a8e194f6046b1f16b4023 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 17 Jun 2024 12:49:09 -0400 Subject: [PATCH 0191/1256] fix zip path --- agixt/XT.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index c5ff3721174d..c2ad98ecb6a7 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -645,15 +645,7 @@ async def learn_from_file( new_folder = os.path.normpath( os.path.join(self.agent_workspace, extracted_zip_folder_name) ) - logging.info(f"Extracting zip file to {new_folder}") - norm_file_path = os.path.normpath( - os.path.join(self.agent_workspace, file_name) - ) - logging.info(f"Normalized file path: {norm_file_path}") - logging.info(f"Agent workspace: {self.agent_workspace}") - if norm_file_path.startswith( - self.agent_workspace - ) and new_folder.startswith(self.agent_workspace): + if new_folder.startswith(self.agent_workspace): with zipfile.ZipFile(file_path, "r") as zipObj: zipObj.extractall(path=new_folder) # Iterate over every file that was extracted including subdirectories From 54cf2b81d991699d236c30bd547ae0d9b7ed07e4 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 17 Jun 2024 13:02:03 -0400 Subject: [PATCH 0192/1256] replace path --- agixt/XT.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index c2ad98ecb6a7..30e89afdef23 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -571,8 +571,8 @@ async def learn_from_file( if file_name == "": file_name = file_url.split("/")[-1] if file_url.startswith(self.outputs): - file_path = os.path.join( - self.agent_workspace, file_url.split(self.outputs)[1] + file_path = os.path.normpath( + os.path.join(self.agent_workspace, file_url.split(self.outputs)[1]) ) else: logging.info(f"{file_url} does not start with {self.outputs}") @@ -580,12 +580,13 @@ async def learn_from_file( url=file_url, file_name=file_name ) file_name = file_data["file_name"] - file_path = os.path.join(self.agent_workspace, file_name) + file_path = os.path.normpath(os.path.join(self.agent_workspace, file_name)) file_type = file_name.split(".")[-1] c = Conversations(conversation_name=conversation_name, user=self.user_email) if file_type in ["ppt", "pptx"]: # Convert it to a PDF pdf_file_path = file_path.replace(".pptx", ".pdf").replace(".ppt", ".pdf") + file_name = str(file_name).replace(".pptx", ".pdf").replace(".ppt", ".pdf") if conversation_name != "" and conversation_name != None: c.log_interaction( role=self.agent_name, @@ -615,12 +616,10 @@ async def learn_from_file( response = f"[ERROR] I was unable to read the file called `{file_name}`." elif file_type == "pdf": # Turn the pdf to images, then run inference on each image - pdf_path = file_path - images = convert_from_path(pdf_path, output_folder=self.agent_workspace) + images = convert_from_path(file_path, output_folder=self.agent_workspace) for i, image in enumerate(images): image_file_name = f"{file_name}_{i}.png" image_path = os.path.join(self.agent_workspace, image_file_name) - path = os.path.normpath(image_path).split(self.agent_workspace)[1] image.save(image_path, "PNG") await self.learn_from_file( file_url=f"{self.outputs}/{image_file_name}", @@ -651,7 +650,6 @@ async def learn_from_file( # Iterate over every file that was extracted including subdirectories for root, dirs, files in os.walk(new_folder): for name in files: - file_path = os.path.join(root, name) current_folder = root.replace(new_folder, "") output_url = f"{self.outputs}/{extracted_zip_folder_name}/{current_folder}/{name}" logging.info(f"Output URL: {output_url}") From 3ffd0681c807b8d312f6167046553fd866a2ae57 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 17 Jun 2024 13:09:35 -0400 Subject: [PATCH 0193/1256] improve file path manipulation --- agixt/XT.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/agixt/XT.py b/agixt/XT.py index 30e89afdef23..c0dd88d96111 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -572,7 +572,10 @@ async def learn_from_file( file_name = file_url.split("/")[-1] if file_url.startswith(self.outputs): file_path = os.path.normpath( - os.path.join(self.agent_workspace, file_url.split(self.outputs)[1]) + os.path.join( + self.agent_workspace, + file_url.split(self.outputs.split(str(self.agent.agent_id))[0])[1], + ) ) else: logging.info(f"{file_url} does not start with {self.outputs}") @@ -581,6 +584,7 @@ async def learn_from_file( ) file_name = file_data["file_name"] file_path = os.path.normpath(os.path.join(self.agent_workspace, file_name)) + logging.info(f"File path: {file_path}") file_type = file_name.split(".")[-1] c = Conversations(conversation_name=conversation_name, user=self.user_email) if file_type in ["ppt", "pptx"]: From 4a29fc5f256352cc7f6a349f4ee75f99dc9c13ef Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 17 Jun 2024 13:24:46 -0400 Subject: [PATCH 0194/1256] add logging --- agixt/endpoints/Memory.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/agixt/endpoints/Memory.py b/agixt/endpoints/Memory.py index af9171ca577c..0258d970992a 100644 --- a/agixt/endpoints/Memory.py +++ b/agixt/endpoints/Memory.py @@ -28,7 +28,13 @@ UserInput, FeedbackInput, ) +import logging +from Globals import getenv +logging.basicConfig( + level=getenv("LOG_LEVEL"), + format=getenv("LOG_FORMAT"), +) app = APIRouter() @@ -144,10 +150,11 @@ async def learn_file( authorization: str = Header(None), ) -> ResponseMessage: # Strip any path information from the file name - agixt = AGiXT(user=user, agent_name=agent_name, api_key=authorization) + agent = AGiXT(user=user, agent_name=agent_name, api_key=authorization) file.file_name = os.path.basename(file.file_name) - file_path = os.path.normpath(os.path.join(agixt.agent_workspace, file.file_name)) - if not file_path.startswith(agixt.agent_workspace): + file_path = os.path.normpath(os.path.join(agent.agent_workspace, file.file_name)) + logging.info(f"File path: {file_path}") + if not file_path.startswith(agent.agent_workspace): raise Exception("Path given not allowed") try: file_content = base64.b64decode(file.file_content) @@ -156,8 +163,10 @@ async def learn_file( with open(file_path, "wb") as f: f.write(file_content) timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - response = await agixt.learn_from_file( - file_url=f"{agixt.outputs}/{file.file_name}", + logging.info(f"File {file.file_name} uploaded on {timestamp}.") + logging.info(f"URL of file: {agent.outputs}/{file.file_name}") + response = await agent.learn_from_file( + file_url=f"{agent.outputs}/{file.file_name}", file_name=file.file_name, user_input=f"File {file.file_name} uploaded on {timestamp}.", collection_id=str(file.collection_number), From 5437d344e4f9319aa60985424feaee451a52ae0f Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 17 Jun 2024 13:32:44 -0400 Subject: [PATCH 0195/1256] fix path --- agixt/XT.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index c0dd88d96111..eeff698aa582 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -572,10 +572,7 @@ async def learn_from_file( file_name = file_url.split("/")[-1] if file_url.startswith(self.outputs): file_path = os.path.normpath( - os.path.join( - self.agent_workspace, - file_url.split(self.outputs.split(str(self.agent.agent_id))[0])[1], - ) + os.path.join(self.agent_workspace, file_url.split(self.outputs)[1]) ) else: logging.info(f"{file_url} does not start with {self.outputs}") From 16343bf69b2e0cda0770b2d1c840ba7240c71881 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 17 Jun 2024 13:44:46 -0400 Subject: [PATCH 0196/1256] use full path --- agixt/XT.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/agixt/XT.py b/agixt/XT.py index eeff698aa582..5bd7c34f338c 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -571,8 +571,9 @@ async def learn_from_file( if file_name == "": file_name = file_url.split("/")[-1] if file_url.startswith(self.outputs): + folder_path = file_url.split(self.outputs)[1] file_path = os.path.normpath( - os.path.join(self.agent_workspace, file_url.split(self.outputs)[1]) + os.path.join(self.agent_workspace, folder_path, file_name) ) else: logging.info(f"{file_url} does not start with {self.outputs}") From 5b620719fded54bf6f6fccb95c7730274f51f9c5 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 17 Jun 2024 13:52:29 -0400 Subject: [PATCH 0197/1256] add file path correction --- agixt/XT.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index 5bd7c34f338c..cf998377e757 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -568,12 +568,14 @@ async def learn_from_file( logging.info(f"User input: {user_input}") logging.info(f"Collection ID: {collection_id}") logging.info(f"Conversation name: {conversation_name}") + logging.info(f"Agent workspace: {self.agent_workspace}") + logging.info(f"Outputs: {self.outputs}") if file_name == "": file_name = file_url.split("/")[-1] if file_url.startswith(self.outputs): - folder_path = file_url.split(self.outputs)[1] + folder_path = file_url.split(f"{self.outputs}/")[1] file_path = os.path.normpath( - os.path.join(self.agent_workspace, folder_path, file_name) + os.path.join(self.agent_workspace, folder_path) ) else: logging.info(f"{file_url} does not start with {self.outputs}") @@ -583,6 +585,11 @@ async def learn_from_file( file_name = file_data["file_name"] file_path = os.path.normpath(os.path.join(self.agent_workspace, file_name)) logging.info(f"File path: {file_path}") + if not file_path.startswith(self.agent_workspace): + file_path = os.path.normpath( + os.path.join(self.agent_workspace, file_name) + ) + logging.info(f"Corrected file path: {file_path}") file_type = file_name.split(".")[-1] c = Conversations(conversation_name=conversation_name, user=self.user_email) if file_type in ["ppt", "pptx"]: From 3ca7414b96b84d08ea9e7969119cde38dc6c8e4a Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 17 Jun 2024 14:17:52 -0400 Subject: [PATCH 0198/1256] fix conversation name for log --- agixt/endpoints/Memory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agixt/endpoints/Memory.py b/agixt/endpoints/Memory.py index 0258d970992a..1aade8ae372c 100644 --- a/agixt/endpoints/Memory.py +++ b/agixt/endpoints/Memory.py @@ -170,7 +170,7 @@ async def learn_file( file_name=file.file_name, user_input=f"File {file.file_name} uploaded on {timestamp}.", collection_id=str(file.collection_number), - conversation_name=f"{datetime.now().strftime('%Y-%m-%d')} Conversation", + conversation_name=f"Agent Training on {datetime.now().strftime('%Y-%m-%d')} by {user}", ) return ResponseMessage(message=response) From 805d85da7ba308165742bc90d5fef4f3bc3be85f Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 17 Jun 2024 15:28:57 -0400 Subject: [PATCH 0199/1256] update vision prompt --- agixt/XT.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index cf998377e757..751bc5216e36 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -586,9 +586,7 @@ async def learn_from_file( file_path = os.path.normpath(os.path.join(self.agent_workspace, file_name)) logging.info(f"File path: {file_path}") if not file_path.startswith(self.agent_workspace): - file_path = os.path.normpath( - os.path.join(self.agent_workspace, file_name) - ) + file_path = os.path.normpath(os.path.join(self.agent_workspace, file_name)) logging.info(f"Corrected file path: {file_path}") file_type = file_name.split(".")[-1] c = Conversations(conversation_name=conversation_name, user=self.user_email) @@ -778,8 +776,9 @@ async def learn_from_file( message=f"[ACTIVITY] Viewing image at {file_url}.", ) try: + vision_prompt = f"The assistant has an image in context\nThe user's last message was: {user_input}\n\nAnswer anything relevant to the image that the user is questioning if anything, additionally, describe the image in detail." vision_response = await self.agent.inference( - prompt=user_input, images=[file_url] + prompt=vision_prompt, images=[file_url] ) timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") await file_reader.write_text_to_memory( From 7e208198de7c55cb7db026be8621af56da377965 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 17 Jun 2024 15:42:13 -0400 Subject: [PATCH 0200/1256] add vision_inference function --- agixt/Agent.py | 29 ++++++++++++++++++++++++++ agixt/XT.py | 56 +++++++++++++++++++++++++------------------------- 2 files changed, 57 insertions(+), 28 deletions(-) diff --git a/agixt/Agent.py b/agixt/Agent.py index db155e51fc28..59e275e93f1c 100644 --- a/agixt/Agent.py +++ b/agixt/Agent.py @@ -191,6 +191,25 @@ def __init__(self, agent_name=None, user=DEFAULT_USER, ApiClient=None): self.PROVIDER = Providers( name=self.AI_PROVIDER, ApiClient=ApiClient, **self.PROVIDER_SETTINGS ) + vision_provider = ( + self.AGENT_CONFIG["settings"]["vision_provider"] + if "vision_provider" in self.AGENT_CONFIG["settings"] + else "None" + ) + if ( + vision_provider != "None" + and vision_provider != None + and vision_provider != "" + ): + try: + self.VISION_PROVIDER = Providers( + name=vision_provider, ApiClient=ApiClient, **self.PROVIDER_SETTINGS + ) + except Exception as e: + logging.error(f"Error loading vision provider: {str(e)}") + self.VISION_PROVIDER = None + else: + self.VISION_PROVIDER = None tts_provider = ( self.AGENT_CONFIG["settings"]["tts_provider"] if "tts_provider" in self.AGENT_CONFIG["settings"] @@ -323,6 +342,16 @@ async def inference(self, prompt: str, tokens: int = 0, images: list = []): ) return answer.replace("\_", "_") + async def vision_inference(self, prompt: str, tokens: int = 0, images: list = []): + if not prompt: + return "" + if not self.VISION_PROVIDER: + return "" + answer = await self.VISION_PROVIDER.inference( + prompt=prompt, tokens=tokens, images=images + ) + return answer.replace("\_", "_") + def embeddings(self, input) -> np.ndarray: return self.embedder(input=input) diff --git a/agixt/XT.py b/agixt/XT.py index 751bc5216e36..04d94c254475 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -763,37 +763,37 @@ async def learn_from_file( "bmp", "svg", ]: - if "vision_provider" in self.agent.AGENT_CONFIG["settings"]: - vision_provider = self.agent.AGENT_CONFIG["settings"]["vision_provider"] - if ( - vision_provider != "None" - and vision_provider != "" - and vision_provider != None - ): - if conversation_name != "" and conversation_name != None: - c.log_interaction( - role=self.agent_name, - message=f"[ACTIVITY] Viewing image at {file_url}.", - ) - try: - vision_prompt = f"The assistant has an image in context\nThe user's last message was: {user_input}\n\nAnswer anything relevant to the image that the user is questioning if anything, additionally, describe the image in detail." - vision_response = await self.agent.inference( - prompt=vision_prompt, images=[file_url] - ) - timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - await file_reader.write_text_to_memory( - user_input=user_input, - text=f"{self.agent_name}'s visual description from viewing uploaded image called `{file_name}` from {timestamp}:\n{vision_response}\n", - external_source=f"image {file_name}", - ) - response = f"Generated a description of the image called `{file_name}` into my memory." - except Exception as e: - logging.error(f"Error getting vision response: {e}") - response = f"[ERROR] I was unable to view the image called `{file_name}`." - else: + if ( + self.agent.VISION_PROVIDER != "None" + and self.agent.VISION_PROVIDER != "" + and self.agent.VISION_PROVIDER != None + ): + if conversation_name != "" and conversation_name != None: + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY] Viewing image at {file_url}.", + ) + try: + vision_prompt = f"The assistant has an image in context\nThe user's last message was: {user_input}\n\nAnswer anything relevant to the image that the user is questioning if anything, additionally, describe the image in detail." + vision_response = await self.agent.vision_inference( + prompt=vision_prompt, images=[file_url] + ) + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + await file_reader.write_text_to_memory( + user_input=user_input, + text=f"{self.agent_name}'s visual description from viewing uploaded image called `{file_name}` from {timestamp}:\n{vision_response}\n", + external_source=f"image {file_name}", + ) + response = f"Generated a description of the image called `{file_name}` into my memory." + except Exception as e: + logging.error(f"Error getting vision response: {e}") response = ( f"[ERROR] I was unable to view the image called `{file_name}`." ) + else: + response = ( + f"[ERROR] I was unable to view the image called `{file_name}`." + ) else: timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") if os.path.normpath(file_path).startswith(self.agent_workspace): From 03447f96ffce3402a1f127fa23f24f289ec7845a Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 17 Jun 2024 15:45:51 -0400 Subject: [PATCH 0201/1256] use vision inference with images --- agixt/Interactions.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/agixt/Interactions.py b/agixt/Interactions.py index 5b71e4174052..eb716a716dcb 100644 --- a/agixt/Interactions.py +++ b/agixt/Interactions.py @@ -581,12 +581,11 @@ async def run( ) vision_response = "" if "vision_provider" in self.agent.AGENT_CONFIG["settings"]: - vision_provider = self.agent.AGENT_CONFIG["settings"]["vision_provider"] if ( images != [] - and vision_provider != "None" - and vision_provider != "" - and vision_provider != None + and self.agent.VISION_PROVIDER != "None" + and self.agent.VISION_PROVIDER != "" + and self.agent.VISION_PROVIDER != None ): logging.info(f"Getting vision response for images: {images}") message = ( @@ -597,7 +596,7 @@ async def run( message=f"[ACTIVITY] {message}", ) try: - vision_response = await self.agent.inference( + vision_response = await self.agent.vision_inference( prompt=user_input, images=images ) logging.info(f"Vision Response: {vision_response}") From 68c2be510d7ac3564542e2ccb27f07e01883b4eb Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 17 Jun 2024 15:59:58 -0400 Subject: [PATCH 0202/1256] remove trailing . --- agixt/XT.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agixt/XT.py b/agixt/XT.py index 04d94c254475..c8f677098749 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -771,7 +771,7 @@ async def learn_from_file( if conversation_name != "" and conversation_name != None: c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY] Viewing image at {file_url}.", + message=f"[ACTIVITY] Viewing image at {file_url} ", ) try: vision_prompt = f"The assistant has an image in context\nThe user's last message was: {user_input}\n\nAnswer anything relevant to the image that the user is questioning if anything, additionally, describe the image in detail." From 76350eba9858fc4e9155617a0d6cf9099b0b5c0c Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 17 Jun 2024 18:03:04 -0400 Subject: [PATCH 0203/1256] improve prompt --- agixt/XT.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agixt/XT.py b/agixt/XT.py index c8f677098749..d13ce5e5b6c7 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -774,7 +774,7 @@ async def learn_from_file( message=f"[ACTIVITY] Viewing image at {file_url} ", ) try: - vision_prompt = f"The assistant has an image in context\nThe user's last message was: {user_input}\n\nAnswer anything relevant to the image that the user is questioning if anything, additionally, describe the image in detail." + vision_prompt = f"The assistant has an image in context\nThe user's last message was: {user_input}\nThe uploaded image is `{file_name}`.\n\nAnswer anything relevant to the image that the user is questioning if anything, additionally, describe the image in detail." vision_response = await self.agent.vision_inference( prompt=vision_prompt, images=[file_url] ) From a5e35c12a13cdf2667b29a7bb3fd93997311678a Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 17 Jun 2024 18:12:47 -0400 Subject: [PATCH 0204/1256] move images for pdf --- agixt/XT.py | 44 ++++++++++++++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index d13ce5e5b6c7..10c40b9daf71 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -9,7 +9,6 @@ from typing import Type, get_args, get_origin, Union, List from enum import Enum from pydantic import BaseModel -from pdf2image import convert_from_path import pdfplumber import docx2txt import zipfile @@ -622,27 +621,44 @@ async def learn_from_file( if file_type in disallowed_types: response = f"[ERROR] I was unable to read the file called `{file_name}`." elif file_type == "pdf": - # Turn the pdf to images, then run inference on each image - images = convert_from_path(file_path, output_folder=self.agent_workspace) - for i, image in enumerate(images): - image_file_name = f"{file_name}_{i}.png" - image_path = os.path.join(self.agent_workspace, image_file_name) - image.save(image_path, "PNG") - await self.learn_from_file( - file_url=f"{self.outputs}/{image_file_name}", - file_name=image_file_name, - user_input=user_input, - collection_id=collection_id, - conversation_name=conversation_name, - ) with pdfplumber.open(file_path) as pdf: content = "\n".join([page.extract_text() for page in pdf.pages]) + # Save images to workspace + for i, image in enumerate(pdf.images): + image_file_name = f"{file_name}_{i}.png" + image_path = os.path.join(self.agent_workspace, image_file_name) + image.save(image_path, "PNG") + await self.learn_from_file( + file_url=f"{self.outputs}/{image_file_name}", + file_name=image_file_name, + user_input=user_input, + collection_id=collection_id, + conversation_name=conversation_name, + ) timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") await file_reader.write_text_to_memory( user_input=user_input, text=f"Content from PDF uploaded at {timestamp} named `{file_name}`:\n{content}", external_source=f"file {file_path}", ) + if ( + self.agent.VISION_PROVIDER != "None" + and self.agent.VISION_PROVIDER != "" + and self.agent.VISION_PROVIDER != None + ): + with pdfplumber.open(file_path) as pdf: + # Save images to workspace + for i, image in enumerate(pdf.images): + image_file_name = f"{file_name}_{i}.png" + image_path = os.path.join(self.agent_workspace, image_file_name) + image.save(image_path, "PNG") + await self.learn_from_file( + file_url=f"{self.outputs}/{image_file_name}", + file_name=image_file_name, + user_input=f"{user_input}\nUploaded file: {image_file_name} from PDF file: {file_name}", + collection_id=collection_id, + conversation_name=conversation_name, + ) response = ( f"Read the content of the PDF file called `{file_name}` into memory." ) From 6c4afea7750edbffc245d6db6fd997fd873ba7d1 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 17 Jun 2024 18:24:50 -0400 Subject: [PATCH 0205/1256] use images properly --- agixt/XT.py | 45 +++++++++++++++++++++------------------------ 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index 10c40b9daf71..1181de99ad3e 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -623,18 +623,6 @@ async def learn_from_file( elif file_type == "pdf": with pdfplumber.open(file_path) as pdf: content = "\n".join([page.extract_text() for page in pdf.pages]) - # Save images to workspace - for i, image in enumerate(pdf.images): - image_file_name = f"{file_name}_{i}.png" - image_path = os.path.join(self.agent_workspace, image_file_name) - image.save(image_path, "PNG") - await self.learn_from_file( - file_url=f"{self.outputs}/{image_file_name}", - file_name=image_file_name, - user_input=user_input, - collection_id=collection_id, - conversation_name=conversation_name, - ) timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") await file_reader.write_text_to_memory( user_input=user_input, @@ -647,18 +635,27 @@ async def learn_from_file( and self.agent.VISION_PROVIDER != None ): with pdfplumber.open(file_path) as pdf: - # Save images to workspace - for i, image in enumerate(pdf.images): - image_file_name = f"{file_name}_{i}.png" - image_path = os.path.join(self.agent_workspace, image_file_name) - image.save(image_path, "PNG") - await self.learn_from_file( - file_url=f"{self.outputs}/{image_file_name}", - file_name=image_file_name, - user_input=f"{user_input}\nUploaded file: {image_file_name} from PDF file: {file_name}", - collection_id=collection_id, - conversation_name=conversation_name, - ) + # Iterate over each page + for i, page in enumerate(pdf.pages): + # Extract images + images = page.images + # Save each image + for j, img in enumerate(images): + # Extract image bytes and convert to an image object + image_bytes = page.extract_image(img["object_id"])["image"] + im = Image.open(io.BytesIO(image_bytes)) + image_name = file_name.replace( + ".pdf", f"_page_{i}_image_{j}.png" + ) + # Save the image + im.save(image_name) + await self.learn_from_file( + file_url=f"{self.outputs}/{image_name}", + file_name=image_name, + user_input=f"Original file: {file_name}\nPage: {i} Image: {j}\n{user_input}", + collection_id=collection_id, + conversation_name=conversation_name, + ) response = ( f"Read the content of the PDF file called `{file_name}` into memory." ) From 9f59ea1d5f0e5e1813b63f6cc21e308b4cd5403f Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 17 Jun 2024 18:31:22 -0400 Subject: [PATCH 0206/1256] fix imports --- agixt/XT.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/agixt/XT.py b/agixt/XT.py index 1181de99ad3e..5a7c462d6709 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -9,6 +9,8 @@ from typing import Type, get_args, get_origin, Union, List from enum import Enum from pydantic import BaseModel +from PIL import Image +import io import pdfplumber import docx2txt import zipfile From ee56585e74989ff8bbbdf7881a9b2e1ce9c0b50e Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 17 Jun 2024 19:01:37 -0400 Subject: [PATCH 0207/1256] fix image ref --- agixt/XT.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/agixt/XT.py b/agixt/XT.py index 5a7c462d6709..65e0a0b424ee 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -644,7 +644,9 @@ async def learn_from_file( # Save each image for j, img in enumerate(images): # Extract image bytes and convert to an image object - image_bytes = page.extract_image(img["object_id"])["image"] + image_bytes = ( + page.to_image(resolution=300).to_image().original_bytes + ) im = Image.open(io.BytesIO(image_bytes)) image_name = file_name.replace( ".pdf", f"_page_{i}_image_{j}.png" From 9ad286df82509502108e75f27fa4b8de4963452f Mon Sep 17 00:00:00 2001 From: Josh XT Date: Mon, 17 Jun 2024 19:02:16 -0400 Subject: [PATCH 0208/1256] clear pdf img for now --- agixt/XT.py | 29 ----------------------------- 1 file changed, 29 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index 65e0a0b424ee..3553c5bb3845 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -631,35 +631,6 @@ async def learn_from_file( text=f"Content from PDF uploaded at {timestamp} named `{file_name}`:\n{content}", external_source=f"file {file_path}", ) - if ( - self.agent.VISION_PROVIDER != "None" - and self.agent.VISION_PROVIDER != "" - and self.agent.VISION_PROVIDER != None - ): - with pdfplumber.open(file_path) as pdf: - # Iterate over each page - for i, page in enumerate(pdf.pages): - # Extract images - images = page.images - # Save each image - for j, img in enumerate(images): - # Extract image bytes and convert to an image object - image_bytes = ( - page.to_image(resolution=300).to_image().original_bytes - ) - im = Image.open(io.BytesIO(image_bytes)) - image_name = file_name.replace( - ".pdf", f"_page_{i}_image_{j}.png" - ) - # Save the image - im.save(image_name) - await self.learn_from_file( - file_url=f"{self.outputs}/{image_name}", - file_name=image_name, - user_input=f"Original file: {file_name}\nPage: {i} Image: {j}\n{user_input}", - collection_id=collection_id, - conversation_name=conversation_name, - ) response = ( f"Read the content of the PDF file called `{file_name}` into memory." ) From b90ebaac153a6d1da382cbff3c04b85255afa466 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Tue, 18 Jun 2024 11:44:28 -0400 Subject: [PATCH 0209/1256] Use agent dir for working dir --- agixt/Agent.py | 6 +----- agixt/Globals.py | 2 -- agixt/XT.py | 2 +- start.py | 49 +++++++++++++++++++++++++++++++++++++++++++----- 4 files changed, 46 insertions(+), 13 deletions(-) diff --git a/agixt/Agent.py b/agixt/Agent.py index 59e275e93f1c..d55685d8f525 100644 --- a/agixt/Agent.py +++ b/agixt/Agent.py @@ -371,11 +371,7 @@ async def text_to_speech(self, text: str): def get_commands_string(self): if len(self.available_commands) == 0: return "" - working_dir = ( - self.AGENT_CONFIG["WORKING_DIRECTORY"] - if "WORKING_DIRECTORY" in self.AGENT_CONFIG - else os.path.join(os.getcwd(), "WORKSPACE") - ) + working_dir = self.working_directory verbose_commands = f"### Available Commands\n**The assistant has commands available to use if they would be useful to provide a better user experience.**\nIf a file needs saved, the assistant's working directory is {working_dir}, use that as the file path.\n\n" verbose_commands += "**See command execution examples of commands that the assistant has access to below:**\n" for command in self.available_commands: diff --git a/agixt/Globals.py b/agixt/Globals.py index 28efef1582bf..56c00c948b26 100644 --- a/agixt/Globals.py +++ b/agixt/Globals.py @@ -26,8 +26,6 @@ "WEBSEARCH_TIMEOUT": 0, "WAIT_BETWEEN_REQUESTS": 1, "WAIT_AFTER_FAILURE": 3, - "WORKING_DIRECTORY": "./WORKSPACE", - "WORKING_DIRECTORY_RESTRICTED": True, "persona": "", } diff --git a/agixt/XT.py b/agixt/XT.py index 3553c5bb3845..fd3b12a62cc1 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -1674,9 +1674,9 @@ async def analyze_csv( file_content=None, ): c = Conversations(conversation_name=conversation_name, user=self.user_email) + file_names = [] if not file_content: files = os.listdir(self.agent_workspace) - file_names = [] file_name = "" # Check if any files are csv files, if not, return empty string csv_files = [file for file in files if file.endswith(".csv")] diff --git a/start.py b/start.py index 0c954f57896c..cee5441cff9f 100644 --- a/start.py +++ b/start.py @@ -85,6 +85,35 @@ def start_ezlocalai(): ) +def get_cuda_vram(): + try: + result = subprocess.run( + [ + "nvidia-smi", + "--query-gpu=memory.total,memory.free", + "--format=csv,noheader,nounits", + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=True, + ) + lines = result.stdout.strip().split("\n") + if len(lines) == 0: + return 0, 0 + total_vram, free_vram = map(int, lines[0].split(",")) + return total_vram, free_vram + except FileNotFoundError: + print("nvidia-smi not found. No CUDA support.") + return 0, 0 + except subprocess.CalledProcessError as e: + print(f"nvidia-smi failed with error: {e.stderr}") + return 0, 0 + except Exception as e: + print(f"Error getting CUDA information: {e}") + return 0, 0 + + if not is_docker_installed(): print("Docker is not installed. Please install Docker and try again.") exit(1) @@ -130,7 +159,9 @@ def start_ezlocalai(): else: run_shell_command("cd ezlocalai && git pull && cd ..") ezlocalai_uri = prompt_user("Set your ezLocalai URI", f"http://{local_ip}:8091") - default_llm = prompt_user("Default LLM to use", "Mistral-7B-Instruct-v0.2") + default_llm = prompt_user( + "Default LLM to use", "QuantFactory/dolphin-2.9.2-qwen2-7b-GGUF" + ) default_vlm = prompt_user( "Use vision model? Enter model from Hugging Face or 'None' for no vision model", "deepseek-ai/deepseek-vl-1.3b-chat", @@ -142,13 +173,23 @@ def start_ezlocalai(): img_enabled = True else: img_enabled = False + total_vram, free_vram = get_cuda_vram() with open(".env", "a") as env_file: env_file.write("USE_EZLOCALAI=true\n") + gpu_layers = 0 + if total_vram > 0: + gpu_layers = min(33, total_vram // 500) + if gpu_layers < 0: + gpu_layers = 0 with open("ezlocalai/.env", "w") as env_file: - env_file.write(f"EZLOCALAI_URI={ezlocalai_uri}\n") + env_file.write(f"EZLOCALAI_URL={ezlocalai_uri}\n") env_file.write(f"DEFAULT_LLM={default_llm}\n") env_file.write(f"DEFAULT_VLM={default_vlm}\n") + env_file.write(f"GPU_LAYERS={gpu_layers}\n") env_file.write(f"IMG_ENABLED={img_enabled}\n") + if img_enabled: + env_file.write("SD_MODEL=stabilityai/sdxl-turbo") + env_file.write("IMG_DEVICE=cpu") # Create a default ezlocalai agent that will work with AGiXT out of the box ezlocalai_agent_settings = { "commands": {}, @@ -161,7 +202,7 @@ def start_ezlocalai(): "image_provider": "ezlocalai" if img_enabled else "default", "EZLOCALAI_API_KEY": api_key, "AI_MODEL": "Mistral-7B-Instruct-v0.2", - "API_URI": f"{ezlocalai_uri}/v1/", + "EZLOCALAI_API_URI": f"{ezlocalai_uri}/v1/", "MAX_TOKENS": "4096", "AI_TEMPERATURE": 0.5, "AI_TOP_P": 0.9, @@ -174,8 +215,6 @@ def start_ezlocalai(): "WEBSEARCH_TIMEOUT": 0, "WAIT_BETWEEN_REQUESTS": 1, "WAIT_AFTER_FAILURE": 3, - "WORKING_DIRECTORY": "./WORKSPACE", - "WORKING_DIRECTORY_RESTRICTED": True, "persona": "", }, } From 4e5cb86ba417bb4c8b9b7834b6d7c6cdf0a772b3 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Tue, 18 Jun 2024 11:48:01 -0400 Subject: [PATCH 0210/1256] move refs for file name --- agixt/XT.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index fd3b12a62cc1..554446cea7af 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -1675,9 +1675,9 @@ async def analyze_csv( ): c = Conversations(conversation_name=conversation_name, user=self.user_email) file_names = [] + file_name = "" if not file_content: files = os.listdir(self.agent_workspace) - file_name = "" # Check if any files are csv files, if not, return empty string csv_files = [file for file in files if file.endswith(".csv")] if len(csv_files) == 0: @@ -1743,7 +1743,7 @@ async def analyze_csv( file_preview = "\n".join(lines) c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY] Analyzing data from file `{file_name}`.", + message=f"[ACTIVITY] Analyzing data from file.", ) code_interpreter = await self.inference( user_input=user_input, @@ -1796,7 +1796,7 @@ async def analyze_csv( ) c.log_interaction( role=self.agent_name, - message=f"## Results from analyzing data in `{file_name}`:\n{code_execution}", + message=f"## Results from analyzing data:\n{code_execution}", ) else: self.failures += 1 From 89b8a3f46282e33f96ff2b217a37ddaef7a05a33 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Tue, 18 Jun 2024 15:40:40 -0400 Subject: [PATCH 0211/1256] fix prompt agent endpoint --- agixt/endpoints/Agent.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/agixt/endpoints/Agent.py b/agixt/endpoints/Agent.py index e9cdc271f8d8..2b50658134cf 100644 --- a/agixt/endpoints/Agent.py +++ b/agixt/endpoints/Agent.py @@ -193,20 +193,14 @@ async def prompt_agent( user=Depends(verify_api_key), authorization: str = Header(None), ): - ApiClient = get_api_client(authorization=authorization) - agent = Interactions(agent_name=agent_name, user=user, ApiClient=ApiClient) - if ( - "prompt" in agent_prompt.prompt_args - and "prompt_name" not in agent_prompt.prompt_args - ): - agent_prompt.prompt_args["prompt_name"] = agent_prompt.prompt_args["prompt"] - if "prompt_name" not in agent_prompt.prompt_args: - agent_prompt.prompt_args["prompt_name"] = "Chat" - if "prompt_category" not in agent_prompt.prompt_args: - agent_prompt.prompt_args["prompt_category"] = "Default" - agent_prompt.prompt_args = {k: v for k, v in agent_prompt.prompt_args.items()} - response = await agent.run( - log_user_input=True, + agent = AGiXT(user=user, agent_name=agent_name, authorization=authorization) + if "tts" in agent_prompt.prompt_args: + agent_prompt.prompt_args["voice_response"] = ( + str(agent_prompt.prompt_args["tts"]).lower() == "true" + ) + del agent_prompt.prompt_args["tts"] + response = await agent.inference( + prompt=agent_prompt.prompt_name, **agent_prompt.prompt_args, ) return {"response": str(response)} From 457dc33e01db760ee87be16d70dcd4ec7de526d6 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Tue, 18 Jun 2024 15:50:31 -0400 Subject: [PATCH 0212/1256] fix ref --- agixt/endpoints/Agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agixt/endpoints/Agent.py b/agixt/endpoints/Agent.py index 2b50658134cf..c7622208c8c3 100644 --- a/agixt/endpoints/Agent.py +++ b/agixt/endpoints/Agent.py @@ -193,7 +193,7 @@ async def prompt_agent( user=Depends(verify_api_key), authorization: str = Header(None), ): - agent = AGiXT(user=user, agent_name=agent_name, authorization=authorization) + agent = AGiXT(user=user, agent_name=agent_name, api_key=authorization) if "tts" in agent_prompt.prompt_args: agent_prompt.prompt_args["voice_response"] = ( str(agent_prompt.prompt_args["tts"]).lower() == "true" From be5dfc8f81692bff75089a1eae2d91fe88a03891 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Tue, 18 Jun 2024 16:52:05 -0400 Subject: [PATCH 0213/1256] migrate injected memories --- agixt/endpoints/Agent.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/agixt/endpoints/Agent.py b/agixt/endpoints/Agent.py index c7622208c8c3..0f6b0c658ef2 100644 --- a/agixt/endpoints/Agent.py +++ b/agixt/endpoints/Agent.py @@ -199,6 +199,11 @@ async def prompt_agent( str(agent_prompt.prompt_args["tts"]).lower() == "true" ) del agent_prompt.prompt_args["tts"] + if "context_results" in agent_prompt.prompt_args: + agent_prompt.prompt_args["injected_memories"] = int( + agent_prompt.prompt_args["context_results"] + ) + del agent_prompt.prompt_args["context_results"] response = await agent.inference( prompt=agent_prompt.prompt_name, **agent_prompt.prompt_args, From 58e518ec289afa7aa94041b31cd0a7b61d42cbbb Mon Sep 17 00:00:00 2001 From: Josh XT Date: Tue, 18 Jun 2024 17:22:01 -0400 Subject: [PATCH 0214/1256] add new message endpoint --- agixt/Models.py | 6 ++++++ agixt/endpoints/Conversation.py | 18 ++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/agixt/Models.py b/agixt/Models.py index 22444b953062..a11b6c026de5 100644 --- a/agixt/Models.py +++ b/agixt/Models.py @@ -27,6 +27,12 @@ class UserInput(BaseModel): injected_memories: Optional[int] = 10 +class LogInteraction(BaseModel): + role: str + message: str + conversation_name: Optional[str] = "" + + class Dataset(BaseModel): batch_size: int = 5 diff --git a/agixt/endpoints/Conversation.py b/agixt/endpoints/Conversation.py index f8a951e4e992..18a53c555a80 100644 --- a/agixt/endpoints/Conversation.py +++ b/agixt/endpoints/Conversation.py @@ -7,6 +7,7 @@ ConversationHistoryMessageModel, UpdateConversationHistoryMessageModel, ResponseMessage, + LogInteraction, ) app = APIRouter() @@ -129,3 +130,20 @@ async def update_history_message( new_message=history.new_message, ) return ResponseMessage(message=f"Message updated.") + + +@app.post( + "/api/conversation/message", + tags=["Conversation"], + dependencies=[Depends(verify_api_key)], +) +async def log_interaction( + log_interaction: LogInteraction, user=Depends(verify_api_key) +) -> ResponseMessage: + Conversations( + conversation_name=log_interaction.conversation_name, user=user + ).log_interaction( + message=log_interaction.message, + role=log_interaction.role, + ) + return ResponseMessage(message=f"Interaction logged.") From 26bd4f93c612ce25778ee2ae73fce8f5191f9e66 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Tue, 18 Jun 2024 20:26:21 -0400 Subject: [PATCH 0215/1256] improve logic on prompt_agent --- agixt/endpoints/Agent.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/agixt/endpoints/Agent.py b/agixt/endpoints/Agent.py index 0f6b0c658ef2..ce528c63588d 100644 --- a/agixt/endpoints/Agent.py +++ b/agixt/endpoints/Agent.py @@ -204,8 +204,10 @@ async def prompt_agent( agent_prompt.prompt_args["context_results"] ) del agent_prompt.prompt_args["context_results"] + if "prompt" in agent_prompt.prompt_args: + agent_prompt.prompt_args["prompt_name"] = agent_prompt.prompt_args["prompt"] + del agent_prompt.prompt_args["prompt"] response = await agent.inference( - prompt=agent_prompt.prompt_name, **agent_prompt.prompt_args, ) return {"response": str(response)} From c1e8a3a39c74a3e54191768445242c347125bdf1 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Tue, 18 Jun 2024 20:43:12 -0400 Subject: [PATCH 0216/1256] switch back --- agixt/endpoints/Agent.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/agixt/endpoints/Agent.py b/agixt/endpoints/Agent.py index ce528c63588d..0f6b0c658ef2 100644 --- a/agixt/endpoints/Agent.py +++ b/agixt/endpoints/Agent.py @@ -204,10 +204,8 @@ async def prompt_agent( agent_prompt.prompt_args["context_results"] ) del agent_prompt.prompt_args["context_results"] - if "prompt" in agent_prompt.prompt_args: - agent_prompt.prompt_args["prompt_name"] = agent_prompt.prompt_args["prompt"] - del agent_prompt.prompt_args["prompt"] response = await agent.inference( + prompt=agent_prompt.prompt_name, **agent_prompt.prompt_args, ) return {"response": str(response)} From 088388ffd8cc8395b62d6a638f60c99412d4da36 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Tue, 18 Jun 2024 21:20:51 -0400 Subject: [PATCH 0217/1256] add more activities --- agixt/Interactions.py | 5 ++++- agixt/XT.py | 42 +++++++++++++++++++++++++++++------------- 2 files changed, 33 insertions(+), 14 deletions(-) diff --git a/agixt/Interactions.py b/agixt/Interactions.py index eb716a716dcb..c568bfd5ea1e 100644 --- a/agixt/Interactions.py +++ b/agixt/Interactions.py @@ -682,7 +682,10 @@ async def run( with open(audio_path, "wb") as f: f.write(audio_data) tts_response = f'' - self.response = f"{self.response}\n\n{tts_response}" + c.log_interaction( + role=self.agent_name, message=tts_response + ) + except Exception as e: logging.warning(f"Failed to get TTS response: {e}") if disable_memory != True: diff --git a/agixt/XT.py b/agixt/XT.py index 554446cea7af..5911797bc2be 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -194,12 +194,19 @@ async def generate_image(self, prompt: str, conversation_name: str = ""): ) return await self.agent.generate_image(prompt=prompt) - async def text_to_speech(self, text: str, conversation_name: str = ""): + async def text_to_speech( + self, + text: str, + conversation_name: str = "", + log_output: bool = False, + ): """ Generate Text to Speech audio from text Args: text (str): Text to convert to speech + conversation_name (str): Name of the conversation + log_output (bool): Whether to log the output Returns: str: URL of the generated audio @@ -222,6 +229,11 @@ async def text_to_speech(self, text: str, conversation_name: str = ""): with open(audio_path, "wb") as f: f.write(audio_data) tts_url = f"{self.outputs}/{file_name}" + if log_output: + c.log_interaction( + role=self.agent_name, + message=f'', + ) return tts_url async def audio_to_text(self, audio_path: str, conversation_name: str = ""): @@ -230,6 +242,7 @@ async def audio_to_text(self, audio_path: str, conversation_name: str = ""): Args: audio_path (str): Path to the audio file + conversation_name (str): Name of the conversation Returns str: Transcription of the audio @@ -249,6 +262,7 @@ async def translate_audio(self, audio_path: str, conversation_name: str = ""): Args: audio_path (str): Path to the audio file + conversation_name (str): Name of the conversation Returns str: Translation of the audio @@ -268,6 +282,7 @@ async def execute_command( command_args: dict, conversation_name: str = "", voice_response: bool = False, + log_output: bool = False, ): """ Execute a command with arguments @@ -277,6 +292,7 @@ async def execute_command( command_args (dict): Arguments for the command conversation_name (str): Name of the conversation voice_response (bool): Whether to generate a voice response + log_output (bool): Whether to log the output Returns: str: Response from the command @@ -304,8 +320,13 @@ async def execute_command( and self.agent_settings["tts_provider"] != "" and self.agent_settings["tts_provider"] != None ): - tts_response = await self.text_to_speech(text=response) - response = f"{response}\n\n{tts_response}" + await self.text_to_speech( + text=response, + conversation_name=conversation_name, + log_output=log_output, + ) + if log_output: + c.log_interaction(role=self.agent_name, message=response) return response async def run_chain_step( @@ -455,10 +476,9 @@ async def execute_chain( role="USER", message=user_input, ) - agent_name = agent_override if agent_override != "" else "AGiXT" if conversation_name != "": c.log_interaction( - role=agent_name, + role=self.agent_name, message=f"[ACTIVITY] Running chain `{chain_name}`.", ) response = "" @@ -495,20 +515,16 @@ async def execute_chain( response = step_responses[-1] if response == None: return f"Chain failed to complete, it failed on step {step_data['step']}. You can resume by starting the chain from the step that failed with chain ID {chain_run_id}." - if conversation_name != "": - c.log_interaction( - role=agent_name, - message=response, - ) + c.log_interaction(role=self.agent_name, message=response) if "tts_provider" in self.agent_settings and voice_response: if ( self.agent_settings["tts_provider"] != "None" and self.agent_settings["tts_provider"] != "" and self.agent_settings["tts_provider"] != None ): - tts_response = await self.text_to_speech(text=response) - response = f'{response}\n\n' - c.log_interaction(role=self.agent_name, message=response) + await self.text_to_speech( + text=response, conversation_name=conversation_name, log_output=True + ) return response async def learn_from_websites( From 24e5ce4825c6635b3554bd957ee9f04dfedf170e Mon Sep 17 00:00:00 2001 From: Josh XT Date: Tue, 18 Jun 2024 21:22:29 -0400 Subject: [PATCH 0218/1256] remove .text ref --- agixt/XT.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agixt/XT.py b/agixt/XT.py index 5911797bc2be..1bc021c9fdb1 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -217,7 +217,7 @@ async def text_to_speech( role=self.agent_name, message=f"[ACTIVITY] Generating audio.", ) - tts_url = await self.agent.text_to_speech(text=text.text) + tts_url = await self.agent.text_to_speech(text=text) if not str(tts_url).startswith("http"): file_type = "wav" file_name = f"{uuid.uuid4().hex}.{file_type}" From 3b7858779fd85f9e106ccaff599101ec15b19be8 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Tue, 18 Jun 2024 21:24:16 -0400 Subject: [PATCH 0219/1256] improve response --- agixt/Interactions.py | 4 ++++ agixt/XT.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/agixt/Interactions.py b/agixt/Interactions.py index c568bfd5ea1e..731482c0cc3f 100644 --- a/agixt/Interactions.py +++ b/agixt/Interactions.py @@ -669,6 +669,10 @@ async def run( and agent_settings["tts_provider"] != None ): try: + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY] Generating audio response.", + ) tts_response = await self.agent.text_to_speech( text=self.response ) diff --git a/agixt/XT.py b/agixt/XT.py index 1bc021c9fdb1..27902a9a72fe 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -215,7 +215,7 @@ async def text_to_speech( c = Conversations(conversation_name=conversation_name, user=self.user_email) c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY] Generating audio.", + message=f"[ACTIVITY] Generating audio response.", ) tts_url = await self.agent.text_to_speech(text=text) if not str(tts_url).startswith("http"): From 89273b3fa6404d83b4822ecb95b2122920f4e5fa Mon Sep 17 00:00:00 2001 From: Josh XT Date: Tue, 18 Jun 2024 21:54:14 -0400 Subject: [PATCH 0220/1256] improve activity logging --- agixt/Interactions.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/agixt/Interactions.py b/agixt/Interactions.py index 731482c0cc3f..3fe32ffb49fe 100644 --- a/agixt/Interactions.py +++ b/agixt/Interactions.py @@ -659,6 +659,11 @@ async def run( self.response = re.sub( r"!\[.*?\]\(.*?\)", "", self.response, flags=re.DOTALL ) + if log_output: + c.log_interaction( + role=self.agent_name, + message=self.response, + ) tts = False if "tts" in kwargs: tts = str(kwargs["tts"]).lower() == "true" @@ -714,6 +719,10 @@ async def run( logging.info(f"Image Generation Decision Response: {create_img}") to_create_image = re.search(r"\byes\b", str(create_img).lower()) if to_create_image: + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY] Generating image.", + ) img_prompt = f"**The assistant is acting as a Stable Diffusion Prompt Generator.**\n\nUsers message: {user_input} \nAssistant response: {self.response} \n\nImportant rules to follow:\n- Describe subjects in detail, specify image type (e.g., digital illustration), art style (e.g., steampunk), and background. Include art inspirations (e.g., Art Station, specific artists). Detail lighting, camera (type, lens, view), and render (resolution, style). The weight of a keyword can be adjusted by using the syntax (((keyword))) , put only those keyword inside ((())) which is very important because it will have more impact so anything wrong will result in unwanted picture so be careful. Realistic prompts: exclude artist, specify lens. Separate with double lines. Max 60 words, avoiding 'real' for fantastical.\n- Based on the message from the user and response of the assistant, you will need to generate one detailed stable diffusion image generation prompt based on the context of the conversation to accompany the assistant response.\n- The prompt can only be up to 60 words long, so try to be concise while using enough descriptive words to make a proper prompt.\n- Following all rules will result in a $2000 tip that you can spend on anything!\n- Must be in markdown code block to be parsed out and only provide prompt in the code block, nothing else.\nStable Diffusion Prompt Generator: " image_generation_prompt = await self.agent.inference( prompt=img_prompt @@ -730,16 +739,14 @@ async def run( generated_image = await self.agent.generate_image( prompt=image_generation_prompt ) - self.response = f"{self.response}\n\n![Image generated by {self.agent_name}]({generated_image})" + c.log_interaction( + role=self.agent_name, + message=f"![Image generated by {self.agent_name}]({generated_image})", + ) except: logging.warning( f"Failed to generate image for prompt: {image_generation_prompt}" ) - if log_output: - c.log_interaction( - role=self.agent_name, - message=self.response, - ) if shots > 1: responses = [self.response] for shot in range(shots - 1): @@ -861,6 +868,10 @@ async def execution_agent(self, conversation_name): logging.error( f"Error: {self.agent_name} failed to execute command `{command_name}`. {e}" ) + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY][ERROR] Failed to execute command `{command_name}`.", + ) command_output = f"**Failed to execute command `{command_name}` with args `{command_args}`. Please try again.**" if command_output: c.log_interaction( From e8c400704c550c9d4fed589c683c2522e196f4b8 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Tue, 18 Jun 2024 22:22:40 -0400 Subject: [PATCH 0221/1256] improve activity messages --- agixt/Websearch.py | 8 ++++---- agixt/XT.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/agixt/Websearch.py b/agixt/Websearch.py index a8c5632db471..b774331df1eb 100644 --- a/agixt/Websearch.py +++ b/agixt/Websearch.py @@ -267,7 +267,7 @@ async def recursive_browsing( if conversation_name != "" and conversation_name is not None: c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY] Browsing link: {url}", + message=f"[ACTIVITY] Browsing {url} .", ) ( collected_data, @@ -322,7 +322,7 @@ async def recursive_browsing( ): c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY] Issues reading {url}. Moving on...", + message=f"[ACTIVITY][ERROR] Issues reading {url}. Moving on.", ) async def scrape_websites( @@ -358,7 +358,7 @@ async def scrape_websites( if conversation_name != "" and conversation_name is not None: c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY] Browsing {link} ", + message=f"[ACTIVITY] Browsing {link} .", ) text_content, link_list = await self.get_web_content( url=link, summarize_content=summarize_content @@ -380,7 +380,7 @@ async def scrape_websites( ): c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY] Browsing {sublink[1]} ", + message=f"[ACTIVITY] Browsing {sublink[1]} .", ) ( text_content, diff --git a/agixt/XT.py b/agixt/XT.py index 27902a9a72fe..73cc6b2b1e10 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -393,7 +393,7 @@ async def run_chain_step( if conversation_name != "" and conversation_name != None: c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY] Running prompt: {prompt_name} with args: {args}", + message=f"[ACTIVITY] Running prompt: {prompt_name} with args:\n```json\n{json.dumps(args, indent=2)}```", ) if "prompt_name" not in args: args["prompt_name"] = prompt_name @@ -411,7 +411,7 @@ async def run_chain_step( if conversation_name != "" and conversation_name != None: c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY] Running chain: {args['chain']} with args: {args}", + message=f"[ACTIVITY] Running chain: {args['chain']} with args:\n```json\n{json.dumps(args, indent=2)}```", ) if "chain_name" in args: args["chain"] = args["chain_name"] @@ -775,7 +775,7 @@ async def learn_from_file( if conversation_name != "" and conversation_name != None: c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY] Viewing image at {file_url} ", + message=f"[ACTIVITY] Viewing image at {file_url} .", ) try: vision_prompt = f"The assistant has an image in context\nThe user's last message was: {user_input}\nThe uploaded image is `{file_name}`.\n\nAnswer anything relevant to the image that the user is questioning if anything, additionally, describe the image in detail." From f5fe9024464fd68ff7a2d80d1bdbe9b8615ad5cb Mon Sep 17 00:00:00 2001 From: Josh XT Date: Tue, 18 Jun 2024 22:52:20 -0400 Subject: [PATCH 0222/1256] improve logging --- agixt/Interactions.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/agixt/Interactions.py b/agixt/Interactions.py index 3fe32ffb49fe..19a14ee5a6e5 100644 --- a/agixt/Interactions.py +++ b/agixt/Interactions.py @@ -630,6 +630,9 @@ async def run( message=log_message, ) try: + c.log_interaction( + role=self.agent_name, message="[ACTIVITY] Generating response." + ) self.response = await self.agent.inference( prompt=formatted_prompt, tokens=tokens ) @@ -640,6 +643,10 @@ async def run( error += f"{err.args}\n{err.name}\n{err.msg}\n" logging.error(f"{self.agent.PROVIDER} Error: {error}") logging.info(f"TOKENS: {tokens} PROMPT CONTENT: {formatted_prompt}") + c.log_interaction( + role=self.agent_name, + message=f"[ACTIVITY][ERROR] Unable to generate response.", + ) return f"Unable to retrieve response." # Handle commands if the prompt contains the {COMMANDS} placeholder # We handle command injection that DOESN'T allow command execution by using {command_list} in the prompt From 2b3f0cc6ff5b98606702ed9582ae478b1eb89bfd Mon Sep 17 00:00:00 2001 From: Josh XT Date: Tue, 18 Jun 2024 23:05:18 -0400 Subject: [PATCH 0223/1256] move activity for generating response --- agixt/Interactions.py | 3 --- agixt/XT.py | 4 ++++ 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/agixt/Interactions.py b/agixt/Interactions.py index 19a14ee5a6e5..e9267d216d78 100644 --- a/agixt/Interactions.py +++ b/agixt/Interactions.py @@ -630,9 +630,6 @@ async def run( message=log_message, ) try: - c.log_interaction( - role=self.agent_name, message="[ACTIVITY] Generating response." - ) self.response = await self.agent.inference( prompt=formatted_prompt, tokens=tokens ) diff --git a/agixt/XT.py b/agixt/XT.py index 73cc6b2b1e10..d5288e7178c0 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -158,6 +158,10 @@ async def inference( Returns: str: Response from the agent """ + c = Conversations(conversation_name=conversation_name, user=self.user_email) + c.log_interaction( + role=self.agent_name, message="[ACTIVITY] Generating response." + ) return await self.agent_interactions.run( user_input=user_input, prompt_category=prompt_category, From 415ac34705cd58efa0806e2e456a4e216049d767 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Tue, 18 Jun 2024 23:07:47 -0400 Subject: [PATCH 0224/1256] too spammy --- agixt/XT.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/agixt/XT.py b/agixt/XT.py index d5288e7178c0..73cc6b2b1e10 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -158,10 +158,6 @@ async def inference( Returns: str: Response from the agent """ - c = Conversations(conversation_name=conversation_name, user=self.user_email) - c.log_interaction( - role=self.agent_name, message="[ACTIVITY] Generating response." - ) return await self.agent_interactions.run( user_input=user_input, prompt_category=prompt_category, From c0e39f3bfa93a58f14376cb0426cabd6fa811451 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Tue, 18 Jun 2024 23:22:33 -0400 Subject: [PATCH 0225/1256] force upper user --- agixt/Conversations.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/agixt/Conversations.py b/agixt/Conversations.py index a56bbe285586..dc5a515f3ab8 100644 --- a/agixt/Conversations.py +++ b/agixt/Conversations.py @@ -205,7 +205,8 @@ def log_interaction(self, role, message): ) .first() ) - + if role.lower() == "user": + role = "USER" if not conversation: conversation = self.new_conversation() session.close() From 60a4a32b11252fbf0738290f3933f982b56c2d41 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Wed, 19 Jun 2024 00:00:47 -0400 Subject: [PATCH 0226/1256] format timestamp --- agixt/Conversations.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/agixt/Conversations.py b/agixt/Conversations.py index dc5a515f3ab8..3c1571e90c72 100644 --- a/agixt/Conversations.py +++ b/agixt/Conversations.py @@ -110,11 +110,16 @@ def get_conversation(self, limit=100, page=1): return {"interactions": []} return_messages = [] for message in messages: + try: + formatted_timestamp = message.timestamp.strftime("%Y-%m-%d %I:%M:%S %p") + except Exception as e: + logging.info(f"Error formatting timestamp: {e}") + formatted_timestamp = message.timestamp msg = { "id": message.id, "role": message.role, "message": message.content, - "timestamp": message.timestamp, + "timestamp": formatted_timestamp, "updated_at": message.updated_at, "updated_by": message.updated_by, "feedback_received": message.feedback_received, From af05e83162070f529b8f8a17d4a3a59e1ffd17ae Mon Sep 17 00:00:00 2001 From: Josh XT Date: Wed, 19 Jun 2024 00:04:46 -0400 Subject: [PATCH 0227/1256] fix timezone for messages --- agixt/Conversations.py | 7 +------ agixt/DB.py | 2 +- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/agixt/Conversations.py b/agixt/Conversations.py index 3c1571e90c72..dc5a515f3ab8 100644 --- a/agixt/Conversations.py +++ b/agixt/Conversations.py @@ -110,16 +110,11 @@ def get_conversation(self, limit=100, page=1): return {"interactions": []} return_messages = [] for message in messages: - try: - formatted_timestamp = message.timestamp.strftime("%Y-%m-%d %I:%M:%S %p") - except Exception as e: - logging.info(f"Error formatting timestamp: {e}") - formatted_timestamp = message.timestamp msg = { "id": message.id, "role": message.role, "message": message.content, - "timestamp": formatted_timestamp, + "timestamp": message.timestamp, "updated_at": message.updated_at, "updated_by": message.updated_by, "feedback_received": message.feedback_received, diff --git a/agixt/DB.py b/agixt/DB.py index c9fcb984051e..893767cf1e79 100644 --- a/agixt/DB.py +++ b/agixt/DB.py @@ -274,7 +274,7 @@ class Message(Base): ) role = Column(Text, nullable=False) content = Column(Text, nullable=False) - timestamp = Column(DateTime, server_default=func.now()) + timestamp = Column(DateTime(timezone=True), server_default=func.now()) conversation_id = Column( UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, ForeignKey("conversation.id"), From 5e4453adb5e8f4cefb307185bf6c6b2f0e1c5ffb Mon Sep 17 00:00:00 2001 From: Josh XT Date: Wed, 19 Jun 2024 00:39:23 -0400 Subject: [PATCH 0228/1256] add rename conversation function --- agixt/Conversations.py | 17 +++++++++ agixt/Models.py | 6 +++ agixt/endpoints/Conversation.py | 42 ++++++++++++++++++++- agixt/prompts/Default/Name Conversation.txt | 13 +++++++ 4 files changed, 76 insertions(+), 2 deletions(-) create mode 100644 agixt/prompts/Default/Name Conversation.txt diff --git a/agixt/Conversations.py b/agixt/Conversations.py index dc5a515f3ab8..c08a21d380f0 100644 --- a/agixt/Conversations.py +++ b/agixt/Conversations.py @@ -454,3 +454,20 @@ def get_conversation_id(self): if not conversation: return None return str(conversation.id) + + def rename_conversation(self, new_name): + session = get_session() + user_data = session.query(User).filter(User.email == self.user).first() + user_id = user_data.id + conversation = ( + session.query(Conversation) + .filter( + Conversation.name == self.conversation_name, + Conversation.user_id == user_id, + ) + .first() + ) + if not conversation: + return + conversation.name = new_name + session.commit() diff --git a/agixt/Models.py b/agixt/Models.py index a11b6c026de5..224295d7c5bd 100644 --- a/agixt/Models.py +++ b/agixt/Models.py @@ -241,6 +241,12 @@ class ConversationHistoryModel(BaseModel): conversation_content: List[dict] = [] +class RenameConversationModel(BaseModel): + agent_name: str + conversation_name: str + new_conversation_name: Optional[str] = "/" + + class TTSInput(BaseModel): text: str diff --git a/agixt/endpoints/Conversation.py b/agixt/endpoints/Conversation.py index 18a53c555a80..2bd58d4d568e 100644 --- a/agixt/endpoints/Conversation.py +++ b/agixt/endpoints/Conversation.py @@ -1,6 +1,6 @@ -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Header from ApiClient import verify_api_key, Conversations -from typing import Optional +from XT import AGiXT from Models import ( HistoryModel, ConversationHistoryModel, @@ -8,7 +8,10 @@ UpdateConversationHistoryMessageModel, ResponseMessage, LogInteraction, + RenameConversationModel, ) +import json +from datetime import datetime app = APIRouter() @@ -147,3 +150,38 @@ async def log_interaction( role=log_interaction.role, ) return ResponseMessage(message=f"Interaction logged.") + + +# Ask AI to rename the conversation +@app.put( + "/api/conversation", + tags=["Conversation"], + dependencies=[Depends(verify_api_key)], +) +async def rename_conversation( + rename: RenameConversationModel, + user=Depends(verify_api_key), + authorization: str = Header(None), +): + if rename.new_conversation_name == "/": + agixt = AGiXT(user=user, agent_name=rename.agent_name, api_key=authorization) + response = await agixt.inference( + prompt_name="Name Conversation", + conversation_name=rename.conversation_name, + log_user_input=False, + log_output=False, + ) + if "```json" in response: + response = response.split("```json")[1].split("```")[0].strip() + elif "```" in response: + response = response.split("```")[1].strip() + try: + response = json.loads(response) + new_name = response["suggested_conversation_name"] + except: + new_name = datetime.now().strftime("Conversation Created %Y-%m-%d %I:%M %p") + rename.new_conversation_name = new_name + Conversations( + conversation_name=rename.conversation_name, user=user + ).rename_conversation(new_name=rename.new_conversation_name) + return {"conversation_name": rename.new_conversation_name} diff --git a/agixt/prompts/Default/Name Conversation.txt b/agixt/prompts/Default/Name Conversation.txt new file mode 100644 index 000000000000..ce0a49078f13 --- /dev/null +++ b/agixt/prompts/Default/Name Conversation.txt @@ -0,0 +1,13 @@ +## Conversation History +{conversation_history} + +Act as a JSON converter that converts any text into the desired JSON format based on the schema provided. Respond only with JSON in a properly formatted markdown code block, no explanations. Make your best assumptions based on data to try to fill in information to match the schema provided. +**DO NOT ADD FIELDS TO THE MODEL OR CHANGE TYPES OF FIELDS, FOLLOW THE PYDANTIC SCHEMA!** +**Reformat the following information into a structured format according to the schema provided:** + +Based on the conversation history, suggest a name for the conversation in the `suggested_conversation_name` as a string. + +## Pydantic Schema: +suggested_conversation_name: + +JSON Structured Output: From f5b254e87f7e617f253774d5a585ccd42aff320e Mon Sep 17 00:00:00 2001 From: Josh XT Date: Wed, 19 Jun 2024 00:53:48 -0400 Subject: [PATCH 0229/1256] format times --- agixt/Conversations.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/agixt/Conversations.py b/agixt/Conversations.py index c08a21d380f0..4bfeeebab54c 100644 --- a/agixt/Conversations.py +++ b/agixt/Conversations.py @@ -110,12 +110,17 @@ def get_conversation(self, limit=100, page=1): return {"interactions": []} return_messages = [] for message in messages: + tz = getenv("TZ") + timestamp = message.timestamp.astimezone(tz) + timestamp = timestamp.strftime("%Y-%m-%d %I:%M:%S %p") + updated_at = message.updated_at.astimezone(tz) + updated_at = updated_at.strftime("%Y-%m-%d %I:%M:%S %p") msg = { "id": message.id, "role": message.role, "message": message.content, - "timestamp": message.timestamp, - "updated_at": message.updated_at, + "timestamp": timestamp, + "updated_at": updated_at, "updated_by": message.updated_by, "feedback_received": message.feedback_received, } From 11a9b8681951ed51fd5f122e97716b1ba9f218de Mon Sep 17 00:00:00 2001 From: Josh XT Date: Wed, 19 Jun 2024 01:01:59 -0400 Subject: [PATCH 0230/1256] undo ts change --- agixt/Conversations.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/agixt/Conversations.py b/agixt/Conversations.py index 4bfeeebab54c..c08a21d380f0 100644 --- a/agixt/Conversations.py +++ b/agixt/Conversations.py @@ -110,17 +110,12 @@ def get_conversation(self, limit=100, page=1): return {"interactions": []} return_messages = [] for message in messages: - tz = getenv("TZ") - timestamp = message.timestamp.astimezone(tz) - timestamp = timestamp.strftime("%Y-%m-%d %I:%M:%S %p") - updated_at = message.updated_at.astimezone(tz) - updated_at = updated_at.strftime("%Y-%m-%d %I:%M:%S %p") msg = { "id": message.id, "role": message.role, "message": message.content, - "timestamp": timestamp, - "updated_at": updated_at, + "timestamp": message.timestamp, + "updated_at": message.updated_at, "updated_by": message.updated_by, "feedback_received": message.feedback_received, } From 32497ce7403a440be8bde5e77612144d2a303f27 Mon Sep 17 00:00:00 2001 From: Josh XT <102809327+Josh-XT@users.noreply.github.com> Date: Wed, 19 Jun 2024 01:46:24 -0400 Subject: [PATCH 0231/1256] add user prefs and timezones (#1213) --- agixt/Conversations.py | 23 +++++++++++++++++++++-- agixt/DB.py | 10 +++++++++- requirements.txt | 1 + 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/agixt/Conversations.py b/agixt/Conversations.py index c08a21d380f0..18eb2bc6163e 100644 --- a/agixt/Conversations.py +++ b/agixt/Conversations.py @@ -4,9 +4,11 @@ Conversation, Message, User, + UserPreferences, get_session, ) from Globals import getenv, DEFAULT_USER +import pytz logging.basicConfig( level=getenv("LOG_LEVEL"), @@ -109,13 +111,30 @@ def get_conversation(self, limit=100, page=1): if not messages: return {"interactions": []} return_messages = [] + # Check if there is a user preference for timezone + user_preferences = ( + session.query(UserPreferences) + .filter( + UserPreferences.user_id == user_id, + UserPreferences.pref_key == "timezone", + ) + .first() + ) + if not user_preferences: + user_preferences = UserPreferences( + user_id=user_id, pref_key="timezone", pref_value=getenv("TZ") + ) + session.add(user_preferences) + session.commit() + gmt = pytz.timezone("GMT") + local_tz = pytz.timezone(user_preferences.pref_value) for message in messages: msg = { "id": message.id, "role": message.role, "message": message.content, - "timestamp": message.timestamp, - "updated_at": message.updated_at, + "timestamp": gmt.localize(message.timestamp).astimezone(local_tz), + "updated_at": gmt.localize(message.updated_at).astimezone(local_tz), "updated_by": message.updated_by, "feedback_received": message.feedback_received, } diff --git a/agixt/DB.py b/agixt/DB.py index 893767cf1e79..d2fcffe2e8fc 100644 --- a/agixt/DB.py +++ b/agixt/DB.py @@ -66,6 +66,14 @@ class User(Base): is_active = Column(Boolean, default=True) +class UserPreferences(Base): + __tablename__ = "user_preferences" + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + user_id = Column(UUID(as_uuid=True), ForeignKey("user.id")) + pref_key = Column(String, nullable=False) + pref_value = Column(String, nullable=True) + + class UserOAuth(Base): __tablename__ = "user_oauth" id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) @@ -274,7 +282,7 @@ class Message(Base): ) role = Column(Text, nullable=False) content = Column(Text, nullable=False) - timestamp = Column(DateTime(timezone=True), server_default=func.now()) + timestamp = Column(DateTime, server_default=func.now()) conversation_id = Column( UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, ForeignKey("conversation.id"), diff --git a/requirements.txt b/requirements.txt index 1e082060aecb..8b9a87a582b8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,3 +19,4 @@ python-multipart==0.0.9 nest_asyncio g4f==0.3.2.0 pyotp +pytz \ No newline at end of file From 00bb5425f52653b8eb52a7f0758c70545cd656ea Mon Sep 17 00:00:00 2001 From: Josh XT <102809327+Josh-XT@users.noreply.github.com> Date: Wed, 19 Jun 2024 12:45:24 -0400 Subject: [PATCH 0232/1256] Add user preferences (#1214) * add user prefs and timezones * Add user preferences * Handle provider failure * fix user id refs * improve logging for error on sso provider * add prefs properly --- agixt/MagicalAuth.py | 54 ++++++++++++++++++++++++++++++++++++++-- agixt/OAuth2Providers.py | 11 +++++++- agixt/endpoints/Auth.py | 10 ++++++-- 3 files changed, 70 insertions(+), 5 deletions(-) diff --git a/agixt/MagicalAuth.py b/agixt/MagicalAuth.py index 0188b524f18b..e138181c18dc 100644 --- a/agixt/MagicalAuth.py +++ b/agixt/MagicalAuth.py @@ -1,4 +1,11 @@ -from DB import User, FailedLogins, UserOAuth, OAuthProvider, get_session +from DB import ( + User, + FailedLogins, + UserOAuth, + OAuthProvider, + UserPreferences, + get_session, +) from OAuth2Providers import get_sso_provider from Models import UserInfo, Register, Login from fastapi import Header, HTTPException @@ -347,11 +354,12 @@ def login(self, ip_address): ) user_id = user_info["sub"] user = session.query(User).filter(User.id == user_id).first() - session.close() if user is None: + session.close() raise HTTPException(status_code=404, detail="User not found") if str(user.id) == str(user_id): return user + session.close() self.add_failed_login(ip_address=ip_address) raise HTTPException( status_code=401, @@ -414,9 +422,29 @@ def update_user(self, **kwargs): session = get_session() user = session.query(User).filter(User.id == user.id).first() allowed_keys = list(UserInfo.__annotations__.keys()) + user_preferences = ( + session.query(UserPreferences) + .filter(UserPreferences.user_id == user.id) + .all() + ) for key, value in kwargs.items(): if key in allowed_keys: setattr(user, key, value) + else: + # Check if there is a user preference record, create one if not, update if so. + user_preference = next( + (x for x in user_preferences if x.pref_key == key), + None, + ) + if user_preference is None: + user_preference = UserPreferences( + user_id=user.id, + pref_key=key, + pref_value=value, + ) + session.add(user_preference) + else: + user_preference.pref_value = value session.commit() session.close() return "User updated successfully" @@ -508,3 +536,25 @@ def sso( referrer=referrer, send_link=False, ) + + def get_user_preferences(self): + user = verify_api_key(self.token) + if user is None: + raise HTTPException(status_code=404, detail="User not found") + session = get_session() + user_preferences = ( + session.query(UserPreferences) + .filter(UserPreferences.user_id == user.id) + .all() + ) + user_preferences = {x.pref_key: x.pref_value for x in user_preferences} + session.close() + if "email" in user_preferences: + del user_preferences["email"] + if "first_name" in user_preferences: + del user_preferences["first_name"] + if "last_name" in user_preferences: + del user_preferences["last_name"] + if not user_preferences: + return {} + return user_preferences diff --git a/agixt/OAuth2Providers.py b/agixt/OAuth2Providers.py index 50253a17fbe0..7e35c11aa34f 100644 --- a/agixt/OAuth2Providers.py +++ b/agixt/OAuth2Providers.py @@ -56,6 +56,8 @@ from sso.yelp import yelp_sso from sso.zendesk import zendesk_sso from Globals import getenv +from fastapi import HTTPException +import logging def get_provider_info(provider): @@ -437,6 +439,13 @@ def get_provider_info(provider): def get_sso_provider(provider: str, code, redirect_uri=None): provider_info = get_provider_info(provider) if provider_info: - return provider_info["function"](code=code, redirect_uri=redirect_uri) + try: + return provider_info["function"](code=code, redirect_uri=redirect_uri) + except Exception as e: + logging.error(f"Error getting user information from {provider}: {e}") + raise HTTPException( + status_code=403, + detail=f"Error getting user information from {provider}.", + ) else: return None diff --git a/agixt/endpoints/Auth.py b/agixt/endpoints/Auth.py index 10a8005c01f8..224e71b510aa 100644 --- a/agixt/endpoints/Auth.py +++ b/agixt/endpoints/Auth.py @@ -34,11 +34,14 @@ def log_in( request: Request, authorization: str = Header(None), ): - user_data = MagicalAuth(token=authorization).login(ip_address=request.client.host) + auth = MagicalAuth(token=authorization) + user_data = auth.login(ip_address=request.client.host) + user_preferences = auth.get_user_preferences() return { "email": user_data.email, "first_name": user_data.first_name, "last_name": user_data.last_name, + **user_preferences, } @@ -66,8 +69,11 @@ async def send_magic_link(request: Request, login: Login): summary="Update user details", ) def update_user(update: UserInfo, request: Request, authorization: str = Header(None)): + response = request.json() user = MagicalAuth(token=authorization).update_user( - ip_address=request.client.host, **update.model_dump() + ip_address=request.client.host, + **update.model_dump(), + **response, ) return Detail(detail=user) From 40454862f74009f6f2083c1c510fa9336a997715 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Wed, 19 Jun 2024 13:46:04 -0400 Subject: [PATCH 0233/1256] use libreoffice to convert ppt --- agixt/XT.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/agixt/XT.py b/agixt/XT.py index 73cc6b2b1e10..3f2ef71ac2e9 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -617,7 +617,17 @@ async def learn_from_file( message=f"[ACTIVITY] Converting PowerPoint file `{file_name}` to PDF.", ) subprocess.run( - ["unoconv", "-f", "pdf", "-o", pdf_file_path, file_path], check=True + [ + "libreoffice", + "--headless", + "--convert-to", + "pdf", + "--outdir", + self.agent_workspace, + file_path, + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, ) file_path = pdf_file_path if conversation_name != "" and conversation_name != None: From 45cbad2821d033c59e3140880d67a2073be525ef Mon Sep 17 00:00:00 2001 From: Josh XT <102809327+Josh-XT@users.noreply.github.com> Date: Wed, 19 Jun 2024 16:56:42 -0400 Subject: [PATCH 0234/1256] Add requirements (#1215) * add reqs * add json persist * simplifying --- agixt/MagicalAuth.py | 40 ++++++++++++++++++++++++++-- agixt/registration_requirements.json | 3 +++ docker-compose-dev.yml | 1 + docker-compose.yml | 1 + 4 files changed, 43 insertions(+), 2 deletions(-) create mode 100644 agixt/registration_requirements.json diff --git a/agixt/MagicalAuth.py b/agixt/MagicalAuth.py index e138181c18dc..5e32749afb64 100644 --- a/agixt/MagicalAuth.py +++ b/agixt/MagicalAuth.py @@ -27,6 +27,8 @@ import requests import logging import jwt +import json +import os logging.basicConfig( @@ -400,6 +402,14 @@ def register( ) session.add(user) session.commit() + # Add default user preferences + user_preferences = UserPreferences( + user_id=user.id, + pref_key="timezone", + pref_value=getenv("TZ"), + ) + session.add(user_preferences) + session.commit() session.close() # Send registration webhook out to third party application such as AGiXT to create a user there. registration_webhook = getenv("REGISTRATION_WEBHOOK") @@ -537,6 +547,18 @@ def sso( send_link=False, ) + def registration_requirements(self): + if not os.path.exists("registration_requirements.json"): + requirements = {} + else: + with open("registration_requirements.json", "r") as file: + requirements = json.load(file) + if not requirements: + requirements = {} + if "subscription" not in requirements: + requirements["subscription"] = "None" + return requirements + def get_user_preferences(self): user = verify_api_key(self.token) if user is None: @@ -548,6 +570,8 @@ def get_user_preferences(self): .all() ) user_preferences = {x.pref_key: x.pref_value for x in user_preferences} + if not user_preferences: + return {} session.close() if "email" in user_preferences: del user_preferences["email"] @@ -555,6 +579,18 @@ def get_user_preferences(self): del user_preferences["first_name"] if "last_name" in user_preferences: del user_preferences["last_name"] - if not user_preferences: - return {} + if "missing_requirements" in user_preferences: + del user_preferences["missing_requirements"] + user_requirements = self.registration_requirements() + missing_requirements = [] + for key, value in user_requirements.items(): + if key not in user_preferences: + if key == "subscription": + if str(value).lower() != "none": + if str(value).lower() == "false": + raise HTTPException(status_code=402, detail=str(value)) + else: + missing_requirements.append(key) + if missing_requirements: + user_preferences["missing_requirements"] = missing_requirements return user_preferences diff --git a/agixt/registration_requirements.json b/agixt/registration_requirements.json new file mode 100644 index 000000000000..16293fc46f05 --- /dev/null +++ b/agixt/registration_requirements.json @@ -0,0 +1,3 @@ +{ + "subscription": "None" +} diff --git a/docker-compose-dev.yml b/docker-compose-dev.yml index 48a997b0fb98..3c974d21cca2 100644 --- a/docker-compose-dev.yml +++ b/docker-compose-dev.yml @@ -161,6 +161,7 @@ services: - ./agixt/chains:/agixt/chains - ./agixt/memories:/agixt/memories - ./agixt/conversations:/agixt/conversations + - ./agixt/registration_requirements.json:/agixt/registration_requirements.json - /var/run/docker.sock:/var/run/docker.sock streamlit: image: joshxt/streamlit:main diff --git a/docker-compose.yml b/docker-compose.yml index d03b6fbddfb8..6b23a56c1b39 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -39,6 +39,7 @@ services: - ./agixt/chains:/agixt/chains - ./agixt/memories:/agixt/memories - ./agixt/conversations:/agixt/conversations + - ./agixt/registration_requirements.json:/agixt/registration_requirements.json - /var/run/docker.sock:/var/run/docker.sock streamlit: image: joshxt/streamlit:main From de9d6a1dcddb725201a842fff97fe6bc4c7c9c60 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Wed, 19 Jun 2024 17:07:17 -0400 Subject: [PATCH 0235/1256] fix req --- agixt/MagicalAuth.py | 2 +- agixt/registration_requirements.json | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/agixt/MagicalAuth.py b/agixt/MagicalAuth.py index 5e32749afb64..cb8d150db3c2 100644 --- a/agixt/MagicalAuth.py +++ b/agixt/MagicalAuth.py @@ -590,7 +590,7 @@ def get_user_preferences(self): if str(value).lower() == "false": raise HTTPException(status_code=402, detail=str(value)) else: - missing_requirements.append(key) + missing_requirements.append({key: value}) if missing_requirements: user_preferences["missing_requirements"] = missing_requirements return user_preferences diff --git a/agixt/registration_requirements.json b/agixt/registration_requirements.json index 16293fc46f05..0967ef424bce 100644 --- a/agixt/registration_requirements.json +++ b/agixt/registration_requirements.json @@ -1,3 +1 @@ -{ - "subscription": "None" -} +{} From 75960bd06a07ad02aee14dd741475ab244f8877d Mon Sep 17 00:00:00 2001 From: Josh XT Date: Wed, 19 Jun 2024 22:50:35 -0400 Subject: [PATCH 0236/1256] add registration disable function --- agixt/Globals.py | 2 +- agixt/MagicalAuth.py | 5 +++++ docker-compose-dev.yml | 1 + 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/agixt/Globals.py b/agixt/Globals.py index 56c00c948b26..48637f1a6317 100644 --- a/agixt/Globals.py +++ b/agixt/Globals.py @@ -35,7 +35,6 @@ def getenv(var_name: str): "AGIXT_URI": "http://localhost:7437", "AGIXT_API_KEY": None, "ALLOWED_DOMAINS": "*", - "ALLOWLIST": "*", "WORKSPACE": os.path.join(os.getcwd(), "WORKSPACE"), "APP_NAME": "AGiXT", "EMAIL_SERVER": "", @@ -53,6 +52,7 @@ def getenv(var_name: str): "CHROMA_SSL": "false", "DISABLED_EXTENSIONS": "", "DISABLED_PROVIDERS": "", + "REGISTRATION_DISABLED": "false", "AUTH_PROVIDER": "", } default_value = default_values[var_name] if var_name in default_values else "" diff --git a/agixt/MagicalAuth.py b/agixt/MagicalAuth.py index cb8d150db3c2..b32ca85176eb 100644 --- a/agixt/MagicalAuth.py +++ b/agixt/MagicalAuth.py @@ -375,6 +375,11 @@ def register( new_user.email = new_user.email.lower() self.email = new_user.email allowed_domains = getenv("ALLOWED_DOMAINS") + registration_disabled = getenv("REGISTRATION_DISABLED").lower() == "true" + if registration_disabled: + raise HTTPException( + status_code=403, detail="Registration is disabled for this server." + ) if allowed_domains is None or allowed_domains == "": allowed_domains = "*" if allowed_domains != "*": diff --git a/docker-compose-dev.yml b/docker-compose-dev.yml index 3c974d21cca2..9dd2d210b186 100644 --- a/docker-compose-dev.yml +++ b/docker-compose-dev.yml @@ -26,6 +26,7 @@ services: DISABLED_EXTENSIONS: ${DISABLED_EXTENSIONS} DISABLED_PROVIDERS: ${DISABLED_PROVIDERS} WORKING_DIRECTORY: ${WORKING_DIRECTORY:-/agixt/WORKSPACE} + REGISTRATION_DISABLED: ${REGISTRATION_DISABLED:-false} TOKENIZERS_PARALLELISM: False LOG_LEVEL: ${LOG_LEVEL:-INFO} AOL_CLIENT_ID: ${AOL_CLIENT_ID} From a3919e55b48354f31de735536157470de5678d93 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Wed, 19 Jun 2024 23:16:27 -0400 Subject: [PATCH 0237/1256] fix name --- agixt/Globals.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agixt/Globals.py b/agixt/Globals.py index 48637f1a6317..6041bfbb687f 100644 --- a/agixt/Globals.py +++ b/agixt/Globals.py @@ -35,7 +35,7 @@ def getenv(var_name: str): "AGIXT_URI": "http://localhost:7437", "AGIXT_API_KEY": None, "ALLOWED_DOMAINS": "*", - "WORKSPACE": os.path.join(os.getcwd(), "WORKSPACE"), + "WORKING_DIRECTORY": os.path.join(os.getcwd(), "WORKSPACE"), "APP_NAME": "AGiXT", "EMAIL_SERVER": "", "LOG_LEVEL": "INFO", From 873823015e36e90b728ce70090463d0ee72ea5be Mon Sep 17 00:00:00 2001 From: Josh XT <102809327+Josh-XT@users.noreply.github.com> Date: Thu, 20 Jun 2024 11:55:45 -0400 Subject: [PATCH 0238/1256] Closing database sessions (#1216) * Closing database sessions * fix refs in all * move refs * fix imports * move webhook * fix update agent config * improve logging * use link for image * fix audio path --- agixt/Agent.py | 153 ++++++++++++++++-------- agixt/Chain.py | 254 +++++++++++++++++++++------------------- agixt/MagicalAuth.py | 62 +++------- agixt/Models.py | 1 + agixt/Prompts.py | 93 ++++++++------- agixt/XT.py | 14 ++- agixt/endpoints/Auth.py | 56 +++++++-- 7 files changed, 362 insertions(+), 271 deletions(-) diff --git a/agixt/Agent.py b/agixt/Agent.py index d55685d8f525..0c7b7dd553c2 100644 --- a/agixt/Agent.py +++ b/agixt/Agent.py @@ -16,6 +16,9 @@ from Providers import Providers from Extensions import Extensions from Globals import getenv, DEFAULT_SETTINGS, DEFAULT_USER +from MagicalAuth import get_user_id, is_agixt_admin +from agixtsdk import AGiXTSDK +from fastapi import HTTPException from datetime import datetime, timezone, timedelta import logging import json @@ -29,9 +32,9 @@ def add_agent(agent_name, provider_settings=None, commands=None, user=DEFAULT_USER): - session = get_session() if not agent_name: return {"message": "Agent name cannot be empty."} + session = get_session() # Check if agent already exists agent = ( session.query(AgentModel) @@ -39,6 +42,7 @@ def add_agent(agent_name, provider_settings=None, commands=None, user=DEFAULT_US .first() ) if agent: + session.close() return {"message": f"Agent {agent_name} already exists."} agent = ( session.query(AgentModel) @@ -46,6 +50,7 @@ def add_agent(agent_name, provider_settings=None, commands=None, user=DEFAULT_US .first() ) if agent: + session.close() return {"message": f"Agent {agent_name} already exists."} user_data = session.query(User).filter(User.email == user).first() user_id = user_data.id @@ -80,7 +85,7 @@ def add_agent(agent_name, provider_settings=None, commands=None, user=DEFAULT_US ) session.add(agent_command) session.commit() - + session.close() return {"message": f"Agent {agent_name} created."} @@ -94,6 +99,7 @@ def delete_agent(agent_name, user=DEFAULT_USER): .first() ) if not agent: + session.close() return {"message": f"Agent {agent_name} not found."}, 404 # Delete associated chain steps @@ -125,7 +131,7 @@ def delete_agent(agent_name, user=DEFAULT_USER): # Delete the agent session.delete(agent) session.commit() - + session.close() return {"message": f"Agent {agent_name} deleted."}, 200 @@ -139,11 +145,11 @@ def rename_agent(agent_name, new_name, user=DEFAULT_USER): .first() ) if not agent: + session.close() return {"message": f"Agent {agent_name} not found."}, 404 - agent.name = new_name session.commit() - + session.close() return {"message": f"Agent {agent_name} renamed to {new_name}."}, 200 @@ -162,21 +168,16 @@ def get_agents(user=DEFAULT_USER): if agent.name in [a["name"] for a in output]: continue output.append({"name": agent.name, "id": agent.id, "status": False}) + session.close() return output class Agent: - def __init__(self, agent_name=None, user=DEFAULT_USER, ApiClient=None): + def __init__(self, agent_name=None, user=DEFAULT_USER, ApiClient: AGiXTSDK = None): self.agent_name = agent_name if agent_name is not None else "AGiXT" - self.session = get_session() user = user if user is not None else DEFAULT_USER self.user = user.lower() - try: - user_data = self.session.query(User).filter(User.email == self.user).first() - self.user_id = user_data.id - except Exception as e: - logging.error(f"User {self.user} not found.") - raise + self.user_id = get_user_id(user=self.user) self.AGENT_CONFIG = self.get_agent_config() self.load_config_keys() if "settings" not in self.AGENT_CONFIG: @@ -286,8 +287,9 @@ def load_config_keys(self): setattr(self, key, self.AGENT_CONFIG[key]) def get_agent_config(self): + session = get_session() agent = ( - self.session.query(AgentModel) + session.query(AgentModel) .filter( AgentModel.name == self.agent_name, AgentModel.user_id == self.user_id ) @@ -295,26 +297,23 @@ def get_agent_config(self): ) if not agent: # Check if it is a global agent - global_user = ( - self.session.query(User).filter(User.email == DEFAULT_USER).first() - ) + global_user = session.query(User).filter(User.email == DEFAULT_USER).first() agent = ( - self.session.query(AgentModel) + session.query(AgentModel) .filter( AgentModel.name == self.agent_name, AgentModel.user_id == global_user.id, ) .first() ) - config = {"settings": {}, "commands": {}} if agent: - all_commands = self.session.query(Command).all() + all_commands = session.query(Command).all() agent_settings = ( - self.session.query(AgentSettingModel).filter_by(agent_id=agent.id).all() + session.query(AgentSettingModel).filter_by(agent_id=agent.id).all() ) agent_commands = ( - self.session.query(AgentCommand) + session.query(AgentCommand) .join(Command) .filter( AgentCommand.agent_id == agent.id, @@ -331,7 +330,10 @@ def get_agent_config(self): ) for setting in agent_settings: config["settings"][setting.name] = setting.value + session.commit() + session.close() return config + session.close() return {"settings": DEFAULT_SETTINGS, "commands": {}} async def inference(self, prompt: str, tokens: int = 0, images: list = []): @@ -387,24 +389,68 @@ def get_commands_string(self): return verbose_commands def update_agent_config(self, new_config, config_key): + session = get_session() agent = ( - self.session.query(AgentModel) + session.query(AgentModel) .filter( AgentModel.name == self.agent_name, AgentModel.user_id == self.user_id ) .first() ) if not agent: - logging.error(f"Agent '{self.agent_name}' not found in the database.") - return + if self.user == DEFAULT_USER: + return f"Agent {self.agent_name} not found." + # Check if it is a global agent. + global_user = session.query(User).filter(User.email == DEFAULT_USER).first() + global_agent = ( + session.query(AgentModel) + .filter( + AgentModel.name == self.agent_name, + AgentModel.user_id == global_user.id, + ) + .first() + ) + # if it is a global agent, copy it to the user's agents. + if global_agent: + agent = AgentModel( + name=self.agent_name, + user_id=self.user_id, + provider_id=global_agent.provider_id, + ) + session.add(agent) + agent_settings = ( + session.query(AgentSettingModel) + .filter_by(agent_id=global_agent.id) + .all() + ) + for setting in agent_settings: + agent_setting = AgentSettingModel( + agent_id=agent.id, + name=setting.name, + value=setting.value, + ) + session.add(agent_setting) + agent_commands = ( + session.query(AgentCommand) + .filter_by(agent_id=global_agent.id) + .all() + ) + for agent_command in agent_commands: + agent_command = AgentCommand( + agent_id=agent.id, + command_id=agent_command.command_id, + state=agent_command.state, + ) + session.add(agent_command) + session.commit() + session.close() + return f"Agent {self.agent_name} configuration updated successfully." if config_key == "commands": for command_name, enabled in new_config.items(): - command = ( - self.session.query(Command).filter_by(name=command_name).first() - ) + command = session.query(Command).filter_by(name=command_name).first() if command: agent_command = ( - self.session.query(AgentCommand) + session.query(AgentCommand) .filter_by(agent_id=agent.id, command_id=command.id) .first() ) @@ -414,12 +460,12 @@ def update_agent_config(self, new_config, config_key): agent_command = AgentCommand( agent_id=agent.id, command_id=command.id, state=enabled ) - self.session.add(agent_command) + session.add(agent_command) else: for setting_name, setting_value in new_config.items(): logging.info(f"Setting {setting_name} to {setting_value}.") agent_setting = ( - self.session.query(AgentSettingModel) + session.query(AgentSettingModel) .filter_by(agent_id=agent.id, name=setting_name) .first() ) @@ -429,15 +475,18 @@ def update_agent_config(self, new_config, config_key): agent_setting = AgentSettingModel( agent_id=agent.id, name=setting_name, value=str(setting_value) ) - self.session.add(agent_setting) + session.add(agent_setting) try: - self.session.commit() + session.commit() + session.close() logging.info(f"Agent {self.agent_name} configuration updated successfully.") except Exception as e: - self.session.rollback() + session.rollback() + session.close() logging.error(f"Error updating agent configuration: {str(e)}") - raise - + raise HTTPException( + status_code=500, detail=f"Error updating agent configuration: {str(e)}" + ) return f"Agent {self.agent_name} configuration updated." def get_browsed_links(self, conversation_id=None): @@ -447,21 +496,24 @@ def get_browsed_links(self, conversation_id=None): Returns: list: The list of URLs that have been browsed by the agent. """ + session = get_session() agent = ( - self.session.query(AgentModel) + session.query(AgentModel) .filter( AgentModel.name == self.agent_name, AgentModel.user_id == self.user_id ) .first() ) if not agent: + session.close() return [] browsed_links = ( - self.session.query(AgentBrowsedLink) + session.query(AgentBrowsedLink) .filter_by(agent_id=agent.id, conversation_id=conversation_id) .order_by(AgentBrowsedLink.id.desc()) .all() ) + session.close() if not browsed_links: return [] return browsed_links @@ -495,8 +547,9 @@ def add_browsed_link(self, url, conversation_id=None): Returns: str: The response message. """ + session = get_session() agent = ( - self.session.query(AgentModel) + session.query(AgentModel) .filter( AgentModel.name == self.agent_name, AgentModel.user_id == self.user_id ) @@ -507,8 +560,9 @@ def add_browsed_link(self, url, conversation_id=None): browsed_link = AgentBrowsedLink( agent_id=agent.id, url=url, conversation_id=conversation_id ) - self.session.add(browsed_link) - self.session.commit() + session.add(browsed_link) + session.commit() + session.close() return f"Link {url} added to browsed links." def delete_browsed_link(self, url, conversation_id=None): @@ -521,8 +575,9 @@ def delete_browsed_link(self, url, conversation_id=None): Returns: str: The response message. """ + session = get_session() agent = ( - self.session.query(AgentModel) + session.query(AgentModel) .filter( AgentModel.name == self.agent_name, AgentModel.user_id == self.user_id, @@ -532,19 +587,21 @@ def delete_browsed_link(self, url, conversation_id=None): if not agent: return f"Agent {self.agent_name} not found." browsed_link = ( - self.session.query(AgentBrowsedLink) + session.query(AgentBrowsedLink) .filter_by(agent_id=agent.id, url=url, conversation_id=conversation_id) .first() ) if not browsed_link: return f"Link {url} not found." - self.session.delete(browsed_link) - self.session.commit() + session.delete(browsed_link) + session.commit() + session.close() return f"Link {url} deleted from browsed links." def get_agent_id(self): + session = get_session() agent = ( - self.session.query(AgentModel) + session.query(AgentModel) .filter( AgentModel.name == self.agent_name, AgentModel.user_id == self.user_id ) @@ -552,13 +609,15 @@ def get_agent_id(self): ) if not agent: agent = ( - self.session.query(AgentModel) + session.query(AgentModel) .filter( AgentModel.name == self.agent_name, AgentModel.user.has(email=DEFAULT_USER), ) .first() ) + session.close() if not agent: return None + session.close() return agent.id diff --git a/agixt/Chain.py b/agixt/Chain.py index 174571c8bffe..d62f81bbaa81 100644 --- a/agixt/Chain.py +++ b/agixt/Chain.py @@ -16,6 +16,7 @@ from Globals import getenv, DEFAULT_USER from Prompts import Prompts from Extensions import Extensions +from MagicalAuth import get_user_id import logging import asyncio @@ -27,29 +28,22 @@ class Chain: def __init__(self, user=DEFAULT_USER, ApiClient=None): - self.session = get_session() self.user = user self.ApiClient = ApiClient - try: - user_data = self.session.query(User).filter(User.email == self.user).first() - self.user_id = user_data.id - except: - user_data = ( - self.session.query(User).filter(User.email == DEFAULT_USER).first() - ) - self.user_id = user_data.id + self.user_id = get_user_id(self.user) def get_chain(self, chain_name): + session = get_session() chain_name = chain_name.replace("%20", " ") - user_data = self.session.query(User).filter(User.email == DEFAULT_USER).first() + user_data = session.query(User).filter(User.email == DEFAULT_USER).first() chain_db = ( - self.session.query(ChainDB) + session.query(ChainDB) .filter(ChainDB.user_id == user_data.id, ChainDB.name == chain_name) .first() ) if chain_db is None: chain_db = ( - self.session.query(ChainDB) + session.query(ChainDB) .filter( ChainDB.name == chain_name, ChainDB.user_id == self.user_id, @@ -57,9 +51,10 @@ def get_chain(self, chain_name): .first() ) if chain_db is None: + session.close() return [] chain_steps = ( - self.session.query(ChainStep) + session.query(ChainStep) .filter(ChainStep.chain_id == chain_db.id) .order_by(ChainStep.step_number) .all() @@ -67,24 +62,24 @@ def get_chain(self, chain_name): steps = [] for step in chain_steps: - agent_name = self.session.query(Agent).get(step.agent_id).name + agent_name = session.query(Agent).get(step.agent_id).name prompt = {} if step.target_chain_id: prompt["chain_name"] = ( - self.session.query(ChainDB).get(step.target_chain_id).name + session.query(ChainDB).get(step.target_chain_id).name ) elif step.target_command_id: prompt["command_name"] = ( - self.session.query(Command).get(step.target_command_id).name + session.query(Command).get(step.target_command_id).name ) elif step.target_prompt_id: prompt["prompt_name"] = ( - self.session.query(Prompt).get(step.target_prompt_id).name + session.query(Prompt).get(step.target_prompt_id).name ) # Retrieve argument data for the step arguments = ( - self.session.query(Argument, ChainStepArgument) + session.query(Argument, ChainStepArgument) .join(ChainStepArgument, ChainStepArgument.argument_id == Argument.id) .filter(ChainStepArgument.chain_step_id == step.id) .all() @@ -109,38 +104,42 @@ def get_chain(self, chain_name): "chain_name": chain_db.name, "steps": steps, } - + session.close() return chain_data def get_chains(self): - user_data = self.session.query(User).filter(User.email == DEFAULT_USER).first() + session = get_session() + user_data = session.query(User).filter(User.email == DEFAULT_USER).first() global_chains = ( - self.session.query(ChainDB).filter(ChainDB.user_id == user_data.id).all() - ) - chains = ( - self.session.query(ChainDB).filter(ChainDB.user_id == self.user_id).all() + session.query(ChainDB).filter(ChainDB.user_id == user_data.id).all() ) + chains = session.query(ChainDB).filter(ChainDB.user_id == self.user_id).all() chain_list = [] for chain in chains: chain_list.append(chain.name) for chain in global_chains: chain_list.append(chain.name) + session.close() return chain_list def add_chain(self, chain_name): + session = get_session() chain = ChainDB(name=chain_name, user_id=self.user_id) - self.session.add(chain) - self.session.commit() + session.add(chain) + session.commit() + session.close() def rename_chain(self, chain_name, new_name): + session = get_session() chain = ( - self.session.query(ChainDB) + session.query(ChainDB) .filter(ChainDB.name == chain_name, ChainDB.user_id == self.user_id) .first() ) if chain: chain.name = new_name - self.session.commit() + session.commit() + session.close() def add_chain_step( self, @@ -150,13 +149,14 @@ def add_chain_step( prompt_type: str, prompt: dict, ): + session = get_session() chain = ( - self.session.query(ChainDB) + session.query(ChainDB) .filter(ChainDB.name == chain_name, ChainDB.user_id == self.user_id) .first() ) agent = ( - self.session.query(Agent) + session.query(Agent) .filter(Agent.name == agent_name, Agent.user_id == self.user_id) .first() ) @@ -168,7 +168,7 @@ def add_chain_step( if prompt_type.lower() == "prompt": argument_key = "prompt_name" target_id = ( - self.session.query(Prompt) + session.query(Prompt) .filter( Prompt.name == prompt["prompt_name"], Prompt.user_id == self.user_id, @@ -183,7 +183,7 @@ def add_chain_step( if argument_key not in prompt: argument_key = "chain" target_id = ( - self.session.query(Chain) + session.query(Chain) .filter( Chain.name == prompt["chain_name"], Chain.user_id == self.user_id ) @@ -194,7 +194,7 @@ def add_chain_step( elif prompt_type.lower() == "command": argument_key = "command_name" target_id = ( - self.session.query(Command) + session.query(Command) .filter(Command.name == prompt["command_name"]) .first() .id @@ -216,7 +216,7 @@ def add_chain_step( del prompt["input"] argument_key = "prompt_name" target_id = ( - self.session.query(Prompt) + session.query(Prompt) .filter( Prompt.name == prompt["prompt_name"], Prompt.user_id == self.user_id, @@ -229,7 +229,6 @@ def add_chain_step( argument_value = prompt[argument_key] prompt_arguments = prompt.copy() del prompt_arguments[argument_key] - chain_step = ChainStep( chain_id=chain.id, step_number=step_number, @@ -240,14 +239,12 @@ def add_chain_step( target_command_id=target_id if target_type == "command" else None, target_prompt_id=target_id if target_type == "prompt" else None, ) - self.session.add(chain_step) - self.session.commit() + session.add(chain_step) + session.commit() for argument_name, argument_value in prompt_arguments.items(): argument = ( - self.session.query(Argument) - .filter(Argument.name == argument_name) - .first() + session.query(Argument).filter(Argument.name == argument_name).first() ) if not argument: # Handle the case where argument not found based on argument_name @@ -259,40 +256,39 @@ def add_chain_step( argument_id=argument.id, value=argument_value, ) - self.session.add(chain_step_argument) - self.session.commit() + session.add(chain_step_argument) + session.commit() + session.close() def update_step(self, chain_name, step_number, agent_name, prompt_type, prompt): + session = get_session() chain = ( - self.session.query(ChainDB) + session.query(ChainDB) .filter(ChainDB.name == chain_name, ChainDB.user_id == self.user_id) .first() ) chain_step = ( - self.session.query(ChainStep) + session.query(ChainStep) .filter( ChainStep.chain_id == chain.id, ChainStep.step_number == step_number ) .first() ) - agent = ( - self.session.query(Agent) + session.query(Agent) .filter(Agent.name == agent_name, Agent.user_id == self.user_id) .first() ) agent_id = agent.id if agent else None - target_chain_id = None target_command_id = None target_prompt_id = None - if prompt_type == "Command": command_name = prompt.get("command_name") command_args = prompt.copy() del command_args["command_name"] command = ( - self.session.query(Command).filter(Command.name == command_name).first() + session.query(Command).filter(Command.name == command_name).first() ) if command: target_command_id = command.id @@ -302,7 +298,7 @@ def update_step(self, chain_name, step_number, agent_name, prompt_type, prompt): prompt_args = prompt.copy() del prompt_args["prompt_name"] prompt_obj = ( - self.session.query(Prompt) + session.query(Prompt) .filter( Prompt.name == prompt_name, Prompt.prompt_category.has(name=prompt_category), @@ -317,32 +313,26 @@ def update_step(self, chain_name, step_number, agent_name, prompt_type, prompt): chain_args = prompt.copy() del chain_args["chain_name"] chain_obj = ( - self.session.query(ChainDB) + session.query(ChainDB) .filter(ChainDB.name == chain_name, ChainDB.user_id == self.user_id) .first() ) if chain_obj: target_chain_id = chain_obj.id - chain_step.agent_id = agent_id chain_step.prompt_type = prompt_type chain_step.prompt = prompt.get("prompt_name", None) chain_step.target_chain_id = target_chain_id chain_step.target_command_id = target_command_id chain_step.target_prompt_id = target_prompt_id - - self.session.commit() - + session.commit() # Update the arguments for the step - self.session.query(ChainStepArgument).filter( + session.query(ChainStepArgument).filter( ChainStepArgument.chain_step_id == chain_step.id ).delete() - for argument_name, argument_value in prompt_args.items(): argument = ( - self.session.query(Argument) - .filter(Argument.name == argument_name) - .first() + session.query(Argument).filter(Argument.name == argument_name).first() ) if argument: chain_step_argument = ChainStepArgument( @@ -350,56 +340,59 @@ def update_step(self, chain_name, step_number, agent_name, prompt_type, prompt): argument_id=argument.id, value=argument_value, ) - self.session.add(chain_step_argument) - self.session.commit() + session.add(chain_step_argument) + session.commit() + session.close() def delete_step(self, chain_name, step_number): + session = get_session() chain = ( - self.session.query(ChainDB) + session.query(ChainDB) .filter(ChainDB.name == chain_name, ChainDB.user_id == self.user_id) .first() ) - if chain: chain_step = ( - self.session.query(ChainStep) + session.query(ChainStep) .filter( ChainStep.chain_id == chain.id, ChainStep.step_number == step_number ) .first() ) if chain_step: - self.session.delete( - chain_step - ) # Remove the chain step from the session - self.session.commit() + session.delete(chain_step) # Remove the chain step from the session + session.commit() else: logging.info( f"No step found with number {step_number} in chain '{chain_name}'" ) else: logging.info(f"No chain found with name '{chain_name}'") + session.close() def delete_chain(self, chain_name): + session = get_session() chain = ( - self.session.query(ChainDB) + session.query(ChainDB) .filter(ChainDB.name == chain_name, ChainDB.user_id == self.user_id) .first() ) - self.session.delete(chain) - self.session.commit() + session.delete(chain) + session.commit() + session.close() def get_steps(self, chain_name): + session = get_session() chain_name = chain_name.replace("%20", " ") - user_data = self.session.query(User).filter(User.email == DEFAULT_USER).first() + user_data = session.query(User).filter(User.email == DEFAULT_USER).first() chain_db = ( - self.session.query(ChainDB) + session.query(ChainDB) .filter(ChainDB.user_id == user_data.id, ChainDB.name == chain_name) .first() ) if chain_db is None: chain_db = ( - self.session.query(ChainDB) + session.query(ChainDB) .filter( ChainDB.name == chain_name, ChainDB.user_id == self.user_id, @@ -407,13 +400,15 @@ def get_steps(self, chain_name): .first() ) if chain_db is None: + session.close() return [] chain_steps = ( - self.session.query(ChainStep) + session.query(ChainStep) .filter(ChainStep.chain_id == chain_db.id) .order_by(ChainStep.step_number) .all() ) + session.close() return chain_steps def get_step(self, chain_name, step_number): @@ -426,13 +421,14 @@ def get_step(self, chain_name, step_number): return chain_step def move_step(self, chain_name, current_step_number, new_step_number): + session = get_session() chain = ( - self.session.query(ChainDB) + session.query(ChainDB) .filter(ChainDB.name == chain_name, ChainDB.user_id == self.user_id) .first() ) chain_step = ( - self.session.query(ChainStep) + session.query(ChainStep) .filter( ChainStep.chain_id == chain.id, ChainStep.step_number == current_step_number, @@ -441,7 +437,7 @@ def move_step(self, chain_name, current_step_number, new_step_number): ) chain_step.step_number = new_step_number if new_step_number < current_step_number: - self.session.query(ChainStep).filter( + session.query(ChainStep).filter( ChainStep.chain_id == chain.id, ChainStep.step_number >= new_step_number, ChainStep.step_number < current_step_number, @@ -449,22 +445,24 @@ def move_step(self, chain_name, current_step_number, new_step_number): {"step_number": ChainStep.step_number + 1}, synchronize_session=False ) else: - self.session.query(ChainStep).filter( + session.query(ChainStep).filter( ChainStep.chain_id == chain.id, ChainStep.step_number > current_step_number, ChainStep.step_number <= new_step_number, ).update( {"step_number": ChainStep.step_number - 1}, synchronize_session=False ) - self.session.commit() + session.commit() + session.close() def get_step_response(self, chain_name, chain_run_id=None, step_number="all"): if chain_run_id is None: chain_run_id = self.get_last_chain_run_id(chain_name=chain_name) chain_data = self.get_chain(chain_name=chain_name) + session = get_session() if step_number == "all": chain_steps = ( - self.session.query(ChainStep) + session.query(ChainStep) .filter(ChainStep.chain_id == chain_data["id"]) .order_by(ChainStep.step_number) .all() @@ -473,7 +471,7 @@ def get_step_response(self, chain_name, chain_run_id=None, step_number="all"): responses = {} for step in chain_steps: chain_step_responses = ( - self.session.query(ChainStepResponse) + session.query(ChainStepResponse) .filter( ChainStepResponse.chain_step_id == step.id, ChainStepResponse.chain_run_id == chain_run_id, @@ -483,12 +481,12 @@ def get_step_response(self, chain_name, chain_run_id=None, step_number="all"): ) step_responses = [response.content for response in chain_step_responses] responses[str(step.step_number)] = step_responses - + session.close() return responses else: step_number = int(step_number) chain_step = ( - self.session.query(ChainStep) + session.query(ChainStep) .filter( ChainStep.chain_id == chain_data["id"], ChainStep.step_number == step_number, @@ -498,7 +496,7 @@ def get_step_response(self, chain_name, chain_run_id=None, step_number="all"): if chain_step: chain_step_responses = ( - self.session.query(ChainStepResponse) + session.query(ChainStepResponse) .filter( ChainStepResponse.chain_step_id == chain_step.id, ChainStepResponse.chain_run_id == chain_run_id, @@ -507,49 +505,53 @@ def get_step_response(self, chain_name, chain_run_id=None, step_number="all"): .all() ) step_responses = [response.content for response in chain_step_responses] + session.close() return step_responses else: + session.close() return None def get_chain_responses(self, chain_name): chain_steps = self.get_steps(chain_name=chain_name) responses = {} + session = get_session() for step in chain_steps: chain_step_responses = ( - self.session.query(ChainStepResponse) + session.query(ChainStepResponse) .filter(ChainStepResponse.chain_step_id == step.id) .order_by(ChainStepResponse.timestamp) .all() ) step_responses = [response.content for response in chain_step_responses] responses[str(step.step_number)] = step_responses + session.close() return responses def import_chain(self, chain_name: str, steps: dict): + session = get_session() chain = ( - self.session.query(ChainDB) + session.query(ChainDB) .filter(ChainDB.name == chain_name, ChainDB.user_id == self.user_id) .first() ) if chain: + session.close() return None chain = ChainDB(name=chain_name, user_id=self.user_id) - self.session.add(chain) - self.session.commit() + session.add(chain) + session.commit() steps = steps["steps"] if "steps" in steps else steps for step_data in steps: agent_name = step_data["agent_name"] agent = ( - self.session.query(Agent) + session.query(Agent) .filter(Agent.name == agent_name, Agent.user_id == self.user_id) .first() ) if not agent: # Use the first agent in the database agent = ( - self.session.query(Agent) - .filter(Agent.user_id == self.user_id) - .first() + session.query(Agent).filter(Agent.user_id == self.user_id).first() ) prompt = step_data["prompt"] if "prompt_type" not in step_data: @@ -559,7 +561,7 @@ def import_chain(self, chain_name: str, steps: dict): argument_key = "prompt_name" prompt_category = prompt.get("prompt_category", "Default") target_id = ( - self.session.query(Prompt) + session.query(Prompt) .filter( Prompt.name == prompt[argument_key], Prompt.user_id == self.user_id, @@ -574,7 +576,7 @@ def import_chain(self, chain_name: str, steps: dict): if "chain" in prompt: argument_key = "chain" target_id = ( - self.session.query(ChainDB) + session.query(ChainDB) .filter( ChainDB.name == prompt[argument_key], ChainDB.user_id == self.user_id, @@ -586,7 +588,7 @@ def import_chain(self, chain_name: str, steps: dict): elif prompt_type == "command": argument_key = "command_name" target_id = ( - self.session.query(Command) + session.query(Command) .filter(Command.name == prompt[argument_key]) .first() .id @@ -609,11 +611,11 @@ def import_chain(self, chain_name: str, steps: dict): target_command_id=target_id if target_type == "command" else None, target_prompt_id=target_id if target_type == "prompt" else None, ) - self.session.add(chain_step) - self.session.commit() + session.add(chain_step) + session.commit() for argument_name, argument_value in prompt_arguments.items(): argument = ( - self.session.query(Argument) + session.query(Argument) .filter(Argument.name == argument_name) .first() ) @@ -627,8 +629,9 @@ def import_chain(self, chain_name: str, steps: dict): argument_id=argument.id, value=argument_value, ) - self.session.add(chain_step_argument) - self.session.commit() + session.add(chain_step_argument) + session.commit() + session.close() return f"Imported chain: {chain_name}" def get_chain_step_dependencies(self, chain_name): @@ -771,8 +774,9 @@ async def update_step_response( ): chain_step = self.get_step(chain_name=chain_name, step_number=step_number) if chain_step: + session = get_session() existing_response = ( - self.session.query(ChainStepResponse) + session.query(ChainStepResponse) .filter( ChainStepResponse.chain_step_id == chain_step.id, ChainStepResponse.chain_run_id == chain_run_id, @@ -785,48 +789,55 @@ async def update_step_response( response, dict ): existing_response.content.update(response) - self.session.commit() + session.commit() elif isinstance(existing_response.content, list) and isinstance( response, list ): existing_response.content.extend(response) - self.session.commit() + session.commit() else: chain_step_response = ChainStepResponse( chain_step_id=chain_step.id, chain_run_id=chain_run_id, content=response, ) - self.session.add(chain_step_response) - self.session.commit() + session.add(chain_step_response) + session.commit() else: chain_step_response = ChainStepResponse( chain_step_id=chain_step.id, chain_run_id=chain_run_id, content=response, ) - self.session.add(chain_step_response) - self.session.commit() + session.add(chain_step_response) + session.commit() + session.close() async def get_chain_run_id(self, chain_name): + session = get_session() chain_run = ChainRun( chain_id=self.get_chain(chain_name=chain_name)["id"], user_id=self.user_id, ) - self.session.add(chain_run) - self.session.commit() - return chain_run.id + session.add(chain_run) + session.commit() + chain_id = chain_run.id + session.close() + return chain_id async def get_last_chain_run_id(self, chain_name): chain_data = self.get_chain(chain_name=chain_name) + session = get_session() chain_run = ( - self.session.query(ChainRun) + session.query(ChainRun) .filter(ChainRun.chain_id == chain_data["id"]) .order_by(ChainRun.timestamp.desc()) .first() ) if chain_run: - return chain_run.id + chain_run_id = chain_run.id + session.close() + return chain_run_id else: return await self.get_chain_run_id(chain_name=chain_name) @@ -876,8 +887,9 @@ def new_task( task_description, estimated_hours, ): + session = get_session() task_category = ( - self.session.query(TaskCategory) + session.query(TaskCategory) .filter( TaskCategory.name == task_category, TaskCategory.user_id == self.user_id ) @@ -885,8 +897,8 @@ def new_task( ) if not task_category: task_category = TaskCategory(name=task_category, user_id=self.user_id) - self.session.add(task_category) - self.session.commit() + session.add(task_category) + session.commit() task = TaskItem( user_id=self.user_id, category_id=task_category.id, @@ -895,6 +907,8 @@ def new_task( estimated_hours=estimated_hours, memory_collection=str(conversation_id), ) - self.session.add(task) - self.session.commit() - return task.id + session.add(task) + session.commit() + task_id = task.id + session.close() + return task_id diff --git a/agixt/MagicalAuth.py b/agixt/MagicalAuth.py index b32ca85176eb..a5e01891766e 100644 --- a/agixt/MagicalAuth.py +++ b/agixt/MagicalAuth.py @@ -11,8 +11,6 @@ from fastapi import Header, HTTPException from Globals import getenv from datetime import datetime, timedelta -from Agent import add_agent -from agixtsdk import AGiXTSDK from fastapi import HTTPException from sendgrid import SendGridAPIClient from sendgrid.helpers.mail import ( @@ -58,51 +56,6 @@ def is_agixt_admin(email: str = "", api_key: str = ""): return False -def webhook_create_user( - api_key: str, - email: str, - role: str = "user", - agent_name: str = "", - settings: dict = {}, - commands: dict = {}, - training_urls: list = [], - github_repos: list = [], - ApiClient: AGiXTSDK = AGiXTSDK(), -): - if not is_agixt_admin(email=email, api_key=api_key): - return {"error": "Access Denied"}, 403 - session = get_session() - email = email.lower() - user_exists = session.query(User).filter_by(email=email).first() - if user_exists: - session.close() - return {"error": "User already exists"}, 400 - admin = True if role.lower() == "admin" else False - user = User( - email=email, - admin=admin, - first_name="", - last_name="", - ) - session.add(user) - session.commit() - session.close() - if agent_name != "" and agent_name is not None: - add_agent( - agent_name=agent_name, - provider_settings=settings, - commands=commands, - user=email, - ) - if training_urls != []: - for url in training_urls: - ApiClient.learn_url(agent_name=agent_name, url=url) - if github_repos != []: - for repo in github_repos: - ApiClient.learn_github_repo(agent_name=agent_name, github_repo=repo) - return {"status": "Success"}, 200 - - def verify_api_key(authorization: str = Header(None)): AGIXT_API_KEY = getenv("AGIXT_API_KEY") if getenv("AUTH_PROVIDER") == "magicalauth": @@ -133,6 +86,21 @@ def verify_api_key(authorization: str = Header(None)): return authorization +def get_user_id(user: str): + session = get_session() + user_data = session.query(User).filter(User.email == user).first() + if user_data is None: + session.close() + raise HTTPException(status_code=404, detail=f"User {user} not found.") + try: + user_id = user_data.id + except Exception as e: + session.close() + raise HTTPException(status_code=404, detail=f"User {user} not found.") + session.close() + return user_id + + def send_email( email: str, subject: str, diff --git a/agixt/Models.py b/agixt/Models.py index 224295d7c5bd..7b8a2dd2a2b6 100644 --- a/agixt/Models.py +++ b/agixt/Models.py @@ -316,6 +316,7 @@ class WebhookUser(BaseModel): commands: Optional[Dict[str, Any]] = {} training_urls: Optional[List[str]] = [] github_repos: Optional[List[str]] = [] + zip_file_content: Optional[str] = "" # Auth user models diff --git a/agixt/Prompts.py b/agixt/Prompts.py index 85782efdc26e..eb97a9f93a31 100644 --- a/agixt/Prompts.py +++ b/agixt/Prompts.py @@ -1,21 +1,20 @@ from DB import Prompt, PromptCategory, Argument, User, get_session from Globals import DEFAULT_USER +from MagicalAuth import get_user_id import os class Prompts: def __init__(self, user=DEFAULT_USER): - self.session = get_session() self.user = user - user_data = self.session.query(User).filter(User.email == self.user).first() - self.user_id = user_data.id + self.user_id = get_user_id(user) def add_prompt(self, prompt_name, prompt, prompt_category="Default"): + session = get_session() if not prompt_category: prompt_category = "Default" - prompt_category = ( - self.session.query(PromptCategory) + session.query(PromptCategory) .filter( PromptCategory.name == prompt_category, PromptCategory.user_id == self.user_id, @@ -28,8 +27,8 @@ def add_prompt(self, prompt_name, prompt, prompt_category="Default"): description=f"{prompt_category} category", user_id=self.user_id, ) - self.session.add(prompt_category) - self.session.commit() + session.add(prompt_category) + session.commit() prompt_obj = Prompt( name=prompt_name, @@ -38,8 +37,8 @@ def add_prompt(self, prompt_name, prompt, prompt_category="Default"): prompt_category=prompt_category, user_id=self.user_id, ) - self.session.add(prompt_obj) - self.session.commit() + session.add(prompt_obj) + session.commit() # Populate prompt arguments prompt_args = self.get_prompt_args(prompt) @@ -48,13 +47,15 @@ def add_prompt(self, prompt_name, prompt, prompt_category="Default"): prompt_id=prompt_obj.id, name=arg, ) - self.session.add(argument) - self.session.commit() + session.add(argument) + session.commit() + session.close() def get_prompt(self, prompt_name: str, prompt_category: str = "Default"): - user_data = self.session.query(User).filter(User.email == DEFAULT_USER).first() + session = get_session() + user_data = session.query(User).filter(User.email == DEFAULT_USER).first() prompt = ( - self.session.query(Prompt) + session.query(Prompt) .filter( Prompt.name == prompt_name, Prompt.user_id == user_data.id, @@ -66,7 +67,7 @@ def get_prompt(self, prompt_name: str, prompt_category: str = "Default"): ) if not prompt: prompt = ( - self.session.query(Prompt) + session.query(Prompt) .filter( Prompt.name == prompt_name, Prompt.user_id == self.user_id, @@ -82,7 +83,7 @@ def get_prompt(self, prompt_name: str, prompt_category: str = "Default"): if not prompt and prompt_category != "Default": # Prompt not found in specified category, try the default category prompt = ( - self.session.query(Prompt) + session.query(Prompt) .filter( Prompt.name == prompt_name, Prompt.user_id == self.user_id, @@ -110,7 +111,7 @@ def get_prompt(self, prompt_name: str, prompt_category: str = "Default"): prompt_category="Default", ) prompt = ( - self.session.query(Prompt) + session.query(Prompt) .filter( Prompt.name == prompt_name, Prompt.user_id == self.user_id, @@ -124,13 +125,17 @@ def get_prompt(self, prompt_name: str, prompt_category: str = "Default"): .first() ) if prompt: - return prompt.content + prompt_content = prompt.content + session.close() + return prompt_content + session.close() return None def get_prompts(self, prompt_category="Default"): - user_data = self.session.query(User).filter(User.email == DEFAULT_USER).first() + session = get_session() + user_data = session.query(User).filter(User.email == DEFAULT_USER).first() global_prompts = ( - self.session.query(Prompt) + session.query(Prompt) .filter( Prompt.user_id == user_data.id, Prompt.prompt_category.has(name=prompt_category), @@ -142,7 +147,7 @@ def get_prompts(self, prompt_category="Default"): .all() ) user_prompts = ( - self.session.query(Prompt) + session.query(Prompt) .join(PromptCategory) .filter( PromptCategory.name == prompt_category, Prompt.user_id == self.user_id @@ -154,6 +159,7 @@ def get_prompts(self, prompt_category="Default"): prompts.append(prompt.name) for prompt in user_prompts: prompts.append(prompt.name) + session.close() return prompts def get_prompt_args(self, prompt_text): @@ -169,8 +175,9 @@ def get_prompt_args(self, prompt_text): return prompt_args def delete_prompt(self, prompt_name, prompt_category="Default"): + session = get_session() prompt = ( - self.session.query(Prompt) + session.query(Prompt) .filter_by(name=prompt_name) .join(PromptCategory) .filter( @@ -179,12 +186,14 @@ def delete_prompt(self, prompt_name, prompt_category="Default"): .first() ) if prompt: - self.session.delete(prompt) - self.session.commit() + session.delete(prompt) + session.commit() + session.close() def update_prompt(self, prompt_name, prompt, prompt_category="Default"): + session = get_session() prompt_obj = ( - self.session.query(Prompt) + session.query(Prompt) .filter( Prompt.name == prompt_name, Prompt.user_id == self.user_id, @@ -195,7 +204,7 @@ def update_prompt(self, prompt_name, prompt, prompt_category="Default"): if prompt_obj: if prompt_category: prompt_category = ( - self.session.query(PromptCategory) + session.query(PromptCategory) .filter( PromptCategory.name == prompt_category, PromptCategory.user_id == self.user_id, @@ -208,25 +217,21 @@ def update_prompt(self, prompt_name, prompt, prompt_category="Default"): description=f"{prompt_category} category", user_id=self.user_id, ) - self.session.add(prompt_category) - self.session.commit() + session.add(prompt_category) + session.commit() prompt_obj.prompt_category = prompt_category - prompt_obj.content = prompt - self.session.commit() - + session.commit() # Update prompt arguments prompt_args = self.get_prompt_args(prompt) existing_args = ( - self.session.query(Argument).filter_by(prompt_id=prompt_obj.id).all() + session.query(Argument).filter_by(prompt_id=prompt_obj.id).all() ) existing_arg_names = {arg.name for arg in existing_args} - # Delete removed arguments for arg in existing_args: if arg.name not in prompt_args: - self.session.delete(arg) - + session.delete(arg) # Add new arguments for arg in prompt_args: if arg not in existing_arg_names: @@ -234,13 +239,14 @@ def update_prompt(self, prompt_name, prompt, prompt_category="Default"): prompt_id=prompt_obj.id, name=arg, ) - self.session.add(argument) - - self.session.commit() + session.add(argument) + session.commit() + session.close() def rename_prompt(self, prompt_name, new_prompt_name, prompt_category="Default"): + session = get_session() prompt = ( - self.session.query(Prompt) + session.query(Prompt) .filter( Prompt.name == prompt_name, Prompt.user_id == self.user_id, @@ -254,17 +260,19 @@ def rename_prompt(self, prompt_name, new_prompt_name, prompt_category="Default") ) if prompt: prompt.name = new_prompt_name - self.session.commit() + session.commit() + session.close() def get_prompt_categories(self): - user_data = self.session.query(User).filter(User.email == DEFAULT_USER).first() + session = get_session() + user_data = session.query(User).filter(User.email == DEFAULT_USER).first() global_prompt_categories = ( - self.session.query(PromptCategory) + session.query(PromptCategory) .filter(PromptCategory.user_id == user_data.id) .all() ) user_prompt_categories = ( - self.session.query(PromptCategory) + session.query(PromptCategory) .filter(PromptCategory.user_id == self.user_id) .all() ) @@ -273,4 +281,5 @@ def get_prompt_categories(self): prompt_categories.append(prompt_category.name) for prompt_category in user_prompt_categories: prompt_categories.append(prompt_category.name) + session.close() return prompt_categories diff --git a/agixt/XT.py b/agixt/XT.py index 3f2ef71ac2e9..33401ba0c16d 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -630,10 +630,14 @@ async def learn_from_file( stderr=subprocess.PIPE, ) file_path = pdf_file_path - if conversation_name != "" and conversation_name != None: + if ( + conversation_name != "" + and conversation_name != None + and file_type not in ["jpg", "jpeg", "png", "gif"] + ): c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY] Reading file `{file_name}` into memory.", + message=f"[ACTIVITY] Reading `{file_name}` into memory.", ) if user_input == "": user_input = "Describe each stage of this image." @@ -785,7 +789,7 @@ async def learn_from_file( if conversation_name != "" and conversation_name != None: c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY] Viewing image at {file_url} .", + message=f"[ACTIVITY] [Uploaded {file_name}]({file_url}) .", ) try: vision_prompt = f"The assistant has an image in context\nThe user's last message was: {user_input}\nThe uploaded image is `{file_name}`.\n\nAnswer anything relevant to the image that the user is questioning if anything, additionally, describe the image in detail." @@ -1347,7 +1351,9 @@ async def chat_completions(self, prompt: ChatCompletions): self.agent_workspace, audio_file_info["file_name"], ) - if url.startswith(self.agent_workspace): + if os.path.normpath(audio_file_path).startswith( + self.agent_workspace + ): wav_file = os.path.join( self.agent_workspace, f"{uuid.uuid4().hex}.wav", diff --git a/agixt/endpoints/Auth.py b/agixt/endpoints/Auth.py index 224e71b510aa..91aa3059be01 100644 --- a/agixt/endpoints/Auth.py +++ b/agixt/endpoints/Auth.py @@ -1,6 +1,8 @@ from fastapi import APIRouter, Request, Header, Depends, HTTPException from Models import Detail, Login, UserInfo, Register -from MagicalAuth import MagicalAuth, verify_api_key, webhook_create_user +from MagicalAuth import MagicalAuth, verify_api_key, is_agixt_admin +from DB import get_session, User +from Agent import add_agent from ApiClient import get_api_client, is_admin from Models import WebhookUser from Globals import getenv @@ -99,18 +101,50 @@ async def createuser( account: WebhookUser, authorization: str = Header(None), ): + if not is_agixt_admin(email=email, api_key=authorization): + raise HTTPException(status_code=403, detail="Unauthorized") ApiClient = get_api_client(authorization=authorization) - return webhook_create_user( - api_key=authorization, - email=account.email, - role="user", - agent_name=account.agent_name, - settings=account.settings, - commands=account.commands, - training_urls=account.training_urls, - github_repos=account.github_repos, - ApiClient=ApiClient, + session = get_session() + email = account.email.lower() + agent_name = account.agent_name + settings = account.settings + commands = account.commands + training_urls = account.training_urls + github_repos = account.github_repos + zip_file_content = account.zip_file_content + user_exists = session.query(User).filter_by(email=email).first() + if user_exists: + session.close() + return {"status": "User already exists"}, 200 + user = User( + email=email, + admin=False, + first_name="", + last_name="", ) + session.add(user) + session.commit() + session.close() + if agent_name != "" and agent_name is not None: + add_agent( + agent_name=agent_name, + provider_settings=settings, + commands=commands, + user=email, + ) + if training_urls != []: + for url in training_urls: + ApiClient.learn_url(agent_name=agent_name, url=url) + if github_repos != []: + for repo in github_repos: + ApiClient.learn_github_repo(agent_name=agent_name, github_repo=repo) + if zip_file_content != "": + ApiClient.learn_file( + agent_name=agent_name, + file_name="training_data.zip", + file_content=zip_file_content, + ) + return {"status": "Success"}, 200 @app.post( From 4710406122f685a568bb999613d16559ee671bb4 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Thu, 20 Jun 2024 12:00:23 -0400 Subject: [PATCH 0239/1256] add openpyxl --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 8b9a87a582b8..a2222e3f4685 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,4 +19,5 @@ python-multipart==0.0.9 nest_asyncio g4f==0.3.2.0 pyotp -pytz \ No newline at end of file +pytz +openpyxl==3.1.4 \ No newline at end of file From 395ebd5519d63c968f51094178f8b79b254cd6f3 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Thu, 20 Jun 2024 12:19:42 -0400 Subject: [PATCH 0240/1256] convert tts to wav --- agixt/providers/default.py | 232 ++++++++++++++++++++++++++++++++++++- agixt/providers/google.py | 10 +- 2 files changed, 235 insertions(+), 7 deletions(-) diff --git a/agixt/providers/default.py b/agixt/providers/default.py index da6d103cce6d..e2fe9a50b4b0 100644 --- a/agixt/providers/default.py +++ b/agixt/providers/default.py @@ -5,10 +5,14 @@ import os import logging import numpy as np +import requests +import base64 +from pydub import AudioSegment +import uuid # Default provider uses: # llm: gpt4free -# tts: gTTS +# tts: Streamlabs TTS # transcription: faster-whisper # translation: faster-whisper @@ -17,9 +21,11 @@ class DefaultProvider: def __init__( self, AI_MODEL: str = "mixtral-8x7b", + VOICE: str = "Brian", **kwargs, ): self.AI_MODEL = AI_MODEL if AI_MODEL else "mixtral-8x7b" + self.VOICE = VOICE if VOICE else "Brian" self.AI_TEMPERATURE = 0.7 self.AI_TOP_P = 0.7 self.MAX_TOKENS = 16000 @@ -49,7 +55,229 @@ async def inference(self, prompt, tokens: int = 0, images: list = []): ).inference(prompt=prompt, tokens=tokens, images=images) async def text_to_speech(self, text: str): - return await GoogleProvider().text_to_speech(text=text) + voices = [ + "Filiz", + "Astrid", + "Tatyana", + "Maxim", + "Carmen", + "Ines", + "Cristiano", + "Vitoria", + "Ricardo", + "Maja", + "Jan", + "Jacek", + "Ewa", + "Ruben", + "Lotte", + "Liv", + "Seoyeon", + "Takumi", + "Mizuki", + "Giorgio", + "Carla", + "Bianca", + "Karl", + "Dora", + "Mathieu", + "Celine", + "Chantal", + "Penelope", + "Miguel", + "Mia", + "Enrique", + "Conchita", + "Geraint", + "Salli", + "Matthew", + "Kimberly", + "Kendra", + "Justin", + "Joey", + "Joanna", + "Ivy", + "Raveena", + "Aditi", + "Emma", + "Brian", + "Amy", + "Russell", + "Nicole", + "Vicki", + "Marlene", + "Hans", + "Naja", + "Mads", + "Gwyneth", + "Zhiyu", + "es-ES-Standard-A", + "it-IT-Standard-A", + "it-IT-Wavenet-A", + "ja-JP-Standard-A", + "ja-JP-Wavenet-A", + "ko-KR-Standard-A", + "ko-KR-Wavenet-A", + "pt-BR-Standard-A", + "tr-TR-Standard-A", + "sv-SE-Standard-A", + "nl-NL-Standard-A", + "nl-NL-Wavenet-A", + "en-US-Wavenet-A", + "en-US-Wavenet-B", + "en-US-Wavenet-C", + "en-US-Wavenet-D", + "en-US-Wavenet-E", + "en-US-Wavenet-F", + "en-GB-Standard-A", + "en-GB-Standard-B", + "en-GB-Standard-C", + "en-GB-Standard-D", + "en-GB-Wavenet-A", + "en-GB-Wavenet-B", + "en-GB-Wavenet-C", + "en-GB-Wavenet-D", + "en-US-Standard-B", + "en-US-Standard-C", + "en-US-Standard-D", + "en-US-Standard-E", + "de-DE-Standard-A", + "de-DE-Standard-B", + "de-DE-Wavenet-A", + "de-DE-Wavenet-B", + "de-DE-Wavenet-C", + "de-DE-Wavenet-D", + "en-AU-Standard-A", + "en-AU-Standard-B", + "en-AU-Wavenet-A", + "en-AU-Wavenet-B", + "en-AU-Wavenet-C", + "en-AU-Wavenet-D", + "en-AU-Standard-C", + "en-AU-Standard-D", + "fr-CA-Standard-A", + "fr-CA-Standard-B", + "fr-CA-Standard-C", + "fr-CA-Standard-D", + "fr-FR-Standard-C", + "fr-FR-Standard-D", + "fr-FR-Wavenet-A", + "fr-FR-Wavenet-B", + "fr-FR-Wavenet-C", + "fr-FR-Wavenet-D", + "da-DK-Wavenet-A", + "pl-PL-Wavenet-A", + "pl-PL-Wavenet-B", + "pl-PL-Wavenet-C", + "pl-PL-Wavenet-D", + "pt-PT-Wavenet-A", + "pt-PT-Wavenet-B", + "pt-PT-Wavenet-C", + "pt-PT-Wavenet-D", + "ru-RU-Wavenet-A", + "ru-RU-Wavenet-B", + "ru-RU-Wavenet-C", + "ru-RU-Wavenet-D", + "sk-SK-Wavenet-A", + "tr-TR-Wavenet-A", + "tr-TR-Wavenet-B", + "tr-TR-Wavenet-C", + "tr-TR-Wavenet-D", + "tr-TR-Wavenet-E", + "uk-UA-Wavenet-A", + "ar-XA-Wavenet-A", + "ar-XA-Wavenet-B", + "ar-XA-Wavenet-C", + "cs-CZ-Wavenet-A", + "nl-NL-Wavenet-B", + "nl-NL-Wavenet-C", + "nl-NL-Wavenet-D", + "nl-NL-Wavenet-E", + "en-IN-Wavenet-A", + "en-IN-Wavenet-B", + "en-IN-Wavenet-C", + "fil-PH-Wavenet-A", + "fi-FI-Wavenet-A", + "el-GR-Wavenet-A", + "hi-IN-Wavenet-A", + "hi-IN-Wavenet-B", + "hi-IN-Wavenet-C", + "hu-HU-Wavenet-A", + "id-ID-Wavenet-A", + "id-ID-Wavenet-B", + "id-ID-Wavenet-C", + "it-IT-Wavenet-B", + "it-IT-Wavenet-C", + "it-IT-Wavenet-D", + "ja-JP-Wavenet-B", + "ja-JP-Wavenet-C", + "ja-JP-Wavenet-D", + "cmn-CN-Wavenet-A", + "cmn-CN-Wavenet-B", + "cmn-CN-Wavenet-C", + "cmn-CN-Wavenet-D", + "nb-no-Wavenet-E", + "nb-no-Wavenet-A", + "nb-no-Wavenet-B", + "nb-no-Wavenet-C", + "nb-no-Wavenet-D", + "vi-VN-Wavenet-A", + "vi-VN-Wavenet-B", + "vi-VN-Wavenet-C", + "vi-VN-Wavenet-D", + "sr-rs-Standard-A", + "lv-lv-Standard-A", + "is-is-Standard-A", + "bg-bg-Standard-A", + "af-ZA-Standard-A", + "Tracy", + "Danny", + "Huihui", + "Yaoyao", + "Kangkang", + "HanHan", + "Zhiwei", + "Asaf", + "An", + "Stefanos", + "Filip", + "Ivan", + "Heidi", + "Herena", + "Kalpana", + "Hemant", + "Matej", + "Andika", + "Rizwan", + "Lado", + "Valluvar", + "Linda", + "Heather", + "Sean", + "Michael", + "Karsten", + "Guillaume", + "Pattara", + "Jakub", + "Szabolcs", + "Hoda", + "Naayf", + ] + if self.VOICE not in voices: + self.VOICE = "Brian" + response = requests.get( + f"https://api.streamelements.com/kappa/v2/speech?voice={self.VOICE}&text={text}" + ) + file_content = base64.b64encode(response.content).decode("utf-8") + # It is an mp3, convert to 16k wav + audio = AudioSegment.from_mp3(base64.b64decode(file_content)) + file_path = os.path.join(os.getcwd(), "WORKSPACE", f"{uuid.uuid4()}.wav") + audio.export(file_path, format="wav") + # Get content of the wav file to return base64 + with open(file_path, "rb") as f: + file_content = base64.b64encode(f.read()).decode("utf-8") + os.remove(file_path) + return file_content def embeddings(self, input) -> np.ndarray: return self.embedder.__call__(input=[input])[0] diff --git a/agixt/providers/google.py b/agixt/providers/google.py index 25320be05f62..a42593b4ee90 100644 --- a/agixt/providers/google.py +++ b/agixt/providers/google.py @@ -23,6 +23,7 @@ import gtts as ts from pydub import AudioSegment +import uuid class GoogleProvider: @@ -86,12 +87,11 @@ async def inference(self, prompt, tokens: int = 0, images: list = []): async def text_to_speech(self, text: str): tts = ts.gTTS(text) - mp3_path = "speech.mp3" + mp3_path = os.path.join(os.getcwd(), "WORKSPACE", f"{uuid.uuid4()}.mp3") tts.save(mp3_path) - wav_path = "output_speech.wav" - AudioSegment.from_mp3(mp3_path).set_frame_rate(16000).export( - wav_path, format="wav" - ) + wav_path = os.path.join(os.getcwd(), "WORKSPACE", f"{uuid.uuid4()}.wav") + audio = AudioSegment.from_mp3(mp3_path) + audio.export(wav_path, format="wav") os.remove(mp3_path) with open(wav_path, "rb") as f: audio_content = f.read() From 937047a8ac2b2b34daa0c726a026acda844f0d05 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Thu, 20 Jun 2024 12:35:08 -0400 Subject: [PATCH 0241/1256] clear err --- agixt/providers/default.py | 230 +------------------------------------ 1 file changed, 1 insertion(+), 229 deletions(-) diff --git a/agixt/providers/default.py b/agixt/providers/default.py index e2fe9a50b4b0..da043a71fcd3 100644 --- a/agixt/providers/default.py +++ b/agixt/providers/default.py @@ -5,10 +5,6 @@ import os import logging import numpy as np -import requests -import base64 -from pydub import AudioSegment -import uuid # Default provider uses: # llm: gpt4free @@ -21,11 +17,9 @@ class DefaultProvider: def __init__( self, AI_MODEL: str = "mixtral-8x7b", - VOICE: str = "Brian", **kwargs, ): self.AI_MODEL = AI_MODEL if AI_MODEL else "mixtral-8x7b" - self.VOICE = VOICE if VOICE else "Brian" self.AI_TEMPERATURE = 0.7 self.AI_TOP_P = 0.7 self.MAX_TOKENS = 16000 @@ -55,229 +49,7 @@ async def inference(self, prompt, tokens: int = 0, images: list = []): ).inference(prompt=prompt, tokens=tokens, images=images) async def text_to_speech(self, text: str): - voices = [ - "Filiz", - "Astrid", - "Tatyana", - "Maxim", - "Carmen", - "Ines", - "Cristiano", - "Vitoria", - "Ricardo", - "Maja", - "Jan", - "Jacek", - "Ewa", - "Ruben", - "Lotte", - "Liv", - "Seoyeon", - "Takumi", - "Mizuki", - "Giorgio", - "Carla", - "Bianca", - "Karl", - "Dora", - "Mathieu", - "Celine", - "Chantal", - "Penelope", - "Miguel", - "Mia", - "Enrique", - "Conchita", - "Geraint", - "Salli", - "Matthew", - "Kimberly", - "Kendra", - "Justin", - "Joey", - "Joanna", - "Ivy", - "Raveena", - "Aditi", - "Emma", - "Brian", - "Amy", - "Russell", - "Nicole", - "Vicki", - "Marlene", - "Hans", - "Naja", - "Mads", - "Gwyneth", - "Zhiyu", - "es-ES-Standard-A", - "it-IT-Standard-A", - "it-IT-Wavenet-A", - "ja-JP-Standard-A", - "ja-JP-Wavenet-A", - "ko-KR-Standard-A", - "ko-KR-Wavenet-A", - "pt-BR-Standard-A", - "tr-TR-Standard-A", - "sv-SE-Standard-A", - "nl-NL-Standard-A", - "nl-NL-Wavenet-A", - "en-US-Wavenet-A", - "en-US-Wavenet-B", - "en-US-Wavenet-C", - "en-US-Wavenet-D", - "en-US-Wavenet-E", - "en-US-Wavenet-F", - "en-GB-Standard-A", - "en-GB-Standard-B", - "en-GB-Standard-C", - "en-GB-Standard-D", - "en-GB-Wavenet-A", - "en-GB-Wavenet-B", - "en-GB-Wavenet-C", - "en-GB-Wavenet-D", - "en-US-Standard-B", - "en-US-Standard-C", - "en-US-Standard-D", - "en-US-Standard-E", - "de-DE-Standard-A", - "de-DE-Standard-B", - "de-DE-Wavenet-A", - "de-DE-Wavenet-B", - "de-DE-Wavenet-C", - "de-DE-Wavenet-D", - "en-AU-Standard-A", - "en-AU-Standard-B", - "en-AU-Wavenet-A", - "en-AU-Wavenet-B", - "en-AU-Wavenet-C", - "en-AU-Wavenet-D", - "en-AU-Standard-C", - "en-AU-Standard-D", - "fr-CA-Standard-A", - "fr-CA-Standard-B", - "fr-CA-Standard-C", - "fr-CA-Standard-D", - "fr-FR-Standard-C", - "fr-FR-Standard-D", - "fr-FR-Wavenet-A", - "fr-FR-Wavenet-B", - "fr-FR-Wavenet-C", - "fr-FR-Wavenet-D", - "da-DK-Wavenet-A", - "pl-PL-Wavenet-A", - "pl-PL-Wavenet-B", - "pl-PL-Wavenet-C", - "pl-PL-Wavenet-D", - "pt-PT-Wavenet-A", - "pt-PT-Wavenet-B", - "pt-PT-Wavenet-C", - "pt-PT-Wavenet-D", - "ru-RU-Wavenet-A", - "ru-RU-Wavenet-B", - "ru-RU-Wavenet-C", - "ru-RU-Wavenet-D", - "sk-SK-Wavenet-A", - "tr-TR-Wavenet-A", - "tr-TR-Wavenet-B", - "tr-TR-Wavenet-C", - "tr-TR-Wavenet-D", - "tr-TR-Wavenet-E", - "uk-UA-Wavenet-A", - "ar-XA-Wavenet-A", - "ar-XA-Wavenet-B", - "ar-XA-Wavenet-C", - "cs-CZ-Wavenet-A", - "nl-NL-Wavenet-B", - "nl-NL-Wavenet-C", - "nl-NL-Wavenet-D", - "nl-NL-Wavenet-E", - "en-IN-Wavenet-A", - "en-IN-Wavenet-B", - "en-IN-Wavenet-C", - "fil-PH-Wavenet-A", - "fi-FI-Wavenet-A", - "el-GR-Wavenet-A", - "hi-IN-Wavenet-A", - "hi-IN-Wavenet-B", - "hi-IN-Wavenet-C", - "hu-HU-Wavenet-A", - "id-ID-Wavenet-A", - "id-ID-Wavenet-B", - "id-ID-Wavenet-C", - "it-IT-Wavenet-B", - "it-IT-Wavenet-C", - "it-IT-Wavenet-D", - "ja-JP-Wavenet-B", - "ja-JP-Wavenet-C", - "ja-JP-Wavenet-D", - "cmn-CN-Wavenet-A", - "cmn-CN-Wavenet-B", - "cmn-CN-Wavenet-C", - "cmn-CN-Wavenet-D", - "nb-no-Wavenet-E", - "nb-no-Wavenet-A", - "nb-no-Wavenet-B", - "nb-no-Wavenet-C", - "nb-no-Wavenet-D", - "vi-VN-Wavenet-A", - "vi-VN-Wavenet-B", - "vi-VN-Wavenet-C", - "vi-VN-Wavenet-D", - "sr-rs-Standard-A", - "lv-lv-Standard-A", - "is-is-Standard-A", - "bg-bg-Standard-A", - "af-ZA-Standard-A", - "Tracy", - "Danny", - "Huihui", - "Yaoyao", - "Kangkang", - "HanHan", - "Zhiwei", - "Asaf", - "An", - "Stefanos", - "Filip", - "Ivan", - "Heidi", - "Herena", - "Kalpana", - "Hemant", - "Matej", - "Andika", - "Rizwan", - "Lado", - "Valluvar", - "Linda", - "Heather", - "Sean", - "Michael", - "Karsten", - "Guillaume", - "Pattara", - "Jakub", - "Szabolcs", - "Hoda", - "Naayf", - ] - if self.VOICE not in voices: - self.VOICE = "Brian" - response = requests.get( - f"https://api.streamelements.com/kappa/v2/speech?voice={self.VOICE}&text={text}" - ) - file_content = base64.b64encode(response.content).decode("utf-8") - # It is an mp3, convert to 16k wav - audio = AudioSegment.from_mp3(base64.b64decode(file_content)) - file_path = os.path.join(os.getcwd(), "WORKSPACE", f"{uuid.uuid4()}.wav") - audio.export(file_path, format="wav") - # Get content of the wav file to return base64 - with open(file_path, "rb") as f: - file_content = base64.b64encode(f.read()).decode("utf-8") - os.remove(file_path) - return file_content + return await GoogleProvider().text_to_speech(text=text) def embeddings(self, input) -> np.ndarray: return self.embedder.__call__(input=[input])[0] From 67dd0f6b3a9a312a7f13c5140ce75cbe2c047547 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Thu, 20 Jun 2024 12:58:44 -0400 Subject: [PATCH 0242/1256] Improve logging message --- agixt/XT.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agixt/XT.py b/agixt/XT.py index 33401ba0c16d..8aef23f288d4 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -251,7 +251,7 @@ async def audio_to_text(self, audio_path: str, conversation_name: str = ""): c = Conversations(conversation_name=conversation_name, user=self.user_email) c.log_interaction( role=self.agent_name, - message=f"Transcribing audio.", + message=f"[ACTIVITY] Transcribing recorded audio.", ) response = await self.agent.transcribe_audio(audio_path=audio_path) return response From c0a8e85de24a02ebb2a8dafc9272606c6ee05432 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Thu, 20 Jun 2024 13:07:04 -0400 Subject: [PATCH 0243/1256] add timer --- agixt/XT.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/agixt/XT.py b/agixt/XT.py index 8aef23f288d4..d41ad1b35987 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -253,7 +253,17 @@ async def audio_to_text(self, audio_path: str, conversation_name: str = ""): role=self.agent_name, message=f"[ACTIVITY] Transcribing recorded audio.", ) + # Start a timer + start = time.time() response = await self.agent.transcribe_audio(audio_path=audio_path) + if conversation_name != "" and conversation_name != None: + # End the timer + end = time.time() + elapsed_time = end - start + c.log_interaction( + role=self.agent_name, + message=f"Transcribed audio in {elapsed_time} seconds.", + ) return response async def translate_audio(self, audio_path: str, conversation_name: str = ""): From ea9bccdc5182455747ea9cd3a110b20cdcb21435 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Thu, 20 Jun 2024 13:49:47 -0400 Subject: [PATCH 0244/1256] move ref --- agixt/endpoints/Auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agixt/endpoints/Auth.py b/agixt/endpoints/Auth.py index 91aa3059be01..818b09567c94 100644 --- a/agixt/endpoints/Auth.py +++ b/agixt/endpoints/Auth.py @@ -101,11 +101,11 @@ async def createuser( account: WebhookUser, authorization: str = Header(None), ): + email = account.email.lower() if not is_agixt_admin(email=email, api_key=authorization): raise HTTPException(status_code=403, detail="Unauthorized") ApiClient = get_api_client(authorization=authorization) session = get_session() - email = account.email.lower() agent_name = account.agent_name settings = account.settings commands = account.commands From d96e590f45a266f2679425f7d6d885eac2ec54fa Mon Sep 17 00:00:00 2001 From: Josh XT Date: Thu, 20 Jun 2024 13:52:28 -0400 Subject: [PATCH 0245/1256] close sessions --- agixt/MagicalAuth.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/agixt/MagicalAuth.py b/agixt/MagicalAuth.py index a5e01891766e..6c87a1cb8e5d 100644 --- a/agixt/MagicalAuth.py +++ b/agixt/MagicalAuth.py @@ -48,11 +48,18 @@ def is_agixt_admin(email: str = "", api_key: str = ""): if api_key == getenv("AGIXT_API_KEY"): return True session = get_session() - user = session.query(User).filter_by(email=email).first() + try: + user = session.query(User).filter_by(email=email).first() + except: + session.close() + return False if not user: + session.close() return False if user.admin is True: + session.close() return True + session.close() return False From c69033fbec0d0ede2535876a636fda613ca95c98 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Thu, 20 Jun 2024 16:40:09 -0400 Subject: [PATCH 0246/1256] show uploaded image --- agixt/XT.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/agixt/XT.py b/agixt/XT.py index d41ad1b35987..60be34a25665 100644 --- a/agixt/XT.py +++ b/agixt/XT.py @@ -797,9 +797,10 @@ async def learn_from_file( and self.agent.VISION_PROVIDER != None ): if conversation_name != "" and conversation_name != None: + # f"[ACTIVITY] [Uploaded {file_name}]({file_url}) ." c.log_interaction( role=self.agent_name, - message=f"[ACTIVITY] [Uploaded {file_name}]({file_url}) .", + message=f"[ACTIVITY] Uploaded `{file_name}` ![Uploaded {file_name}]({file_url}) .", ) try: vision_prompt = f"The assistant has an image in context\nThe user's last message was: {user_input}\nThe uploaded image is `{file_name}`.\n\nAnswer anything relevant to the image that the user is questioning if anything, additionally, describe the image in detail." From 998e61b12febecc54261aa664b532250e04350cf Mon Sep 17 00:00:00 2001 From: Josh XT Date: Thu, 20 Jun 2024 17:06:38 -0400 Subject: [PATCH 0247/1256] add 402 --- agixt/MagicalAuth.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/agixt/MagicalAuth.py b/agixt/MagicalAuth.py index 6c87a1cb8e5d..6ae96f5c7c07 100644 --- a/agixt/MagicalAuth.py +++ b/agixt/MagicalAuth.py @@ -550,8 +550,17 @@ def get_user_preferences(self): .all() ) user_preferences = {x.pref_key: x.pref_value for x in user_preferences} + user_requirements = self.registration_requirements() if not user_preferences: return {} + if "subscription" in user_requirements: + if "subscription" not in user_preferences: + user_preferences["subscription"] = "none" + if str(user_preferences["subscription"]).lower() != "none": + if user.is_active is False: + raise HTTPException( + status_code=402, detail=user_preferences["subscription"] + ) session.close() if "email" in user_preferences: del user_preferences["email"] @@ -561,15 +570,10 @@ def get_user_preferences(self): del user_preferences["last_name"] if "missing_requirements" in user_preferences: del user_preferences["missing_requirements"] - user_requirements = self.registration_requirements() missing_requirements = [] for key, value in user_requirements.items(): if key not in user_preferences: - if key == "subscription": - if str(value).lower() != "none": - if str(value).lower() == "false": - raise HTTPException(status_code=402, detail=str(value)) - else: + if key != "subscription": missing_requirements.append({key: value}) if missing_requirements: user_preferences["missing_requirements"] = missing_requirements From 36ad0cf595891cb4d7e43c222ac9545d22750381 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Thu, 20 Jun 2024 17:22:12 -0400 Subject: [PATCH 0248/1256] add webhook --- agixt/Models.py | 4 +++ agixt/endpoints/Auth.py | 70 +++++++++++++++++++++++++++++++++++++++-- requirements.txt | 1 + 3 files changed, 72 insertions(+), 3 deletions(-) diff --git a/agixt/Models.py b/agixt/Models.py index 7b8a2dd2a2b6..94c94458d6ed 100644 --- a/agixt/Models.py +++ b/agixt/Models.py @@ -251,6 +251,10 @@ class TTSInput(BaseModel): text: str +class WebhookModel(BaseModel): + success: str + + class ConversationHistoryMessageModel(BaseModel): agent_name: str conversation_name: str diff --git a/agixt/endpoints/Auth.py b/agixt/endpoints/Auth.py index 818b09567c94..846c73af9b93 100644 --- a/agixt/endpoints/Auth.py +++ b/agixt/endpoints/Auth.py @@ -1,15 +1,22 @@ from fastapi import APIRouter, Request, Header, Depends, HTTPException from Models import Detail, Login, UserInfo, Register from MagicalAuth import MagicalAuth, verify_api_key, is_agixt_admin -from DB import get_session, User +from DB import get_session, User, UserPreferences from Agent import add_agent -from ApiClient import get_api_client, is_admin -from Models import WebhookUser +from ApiClient import get_api_client +from Models import WebhookUser, WebhookModel from Globals import getenv import pyotp +import stripe +import logging app = APIRouter() +logging.basicConfig( + level=getenv("LOG_LEVEL"), + format=getenv("LOG_FORMAT"), +) + @app.post("/v1/user") def register(register: Register): @@ -147,6 +154,63 @@ async def createuser( return {"status": "Success"}, 200 +@app.post( + "/webhook", + summary="Webhook endpoint for events.", + response_model=WebhookModel, + tags=["Webhook"], +) +async def webhook(request: Request): + event = None + data = None + try: + event = stripe.Webhook.construct_event( + payload=(await request.body()).decode("utf-8"), + sig_header=request.headers.get("stripe-signature"), + secret=getenv("STRIPE_WEBHOOK_SECRET"), + ) + data = event["data"]["object"] + except stripe.error.SignatureVerificationError as e: + logging.debug(f"Webhook signature verification failed: {str(e)}.") + raise HTTPException( + status_code=400, detail="Webhook signature verification failed." + ) + logging.debug(f"Stripe Webhook Event of type {event['type']} received") + if event and event["type"] == "checkout.session.completed": + session = get_session() + logging.debug("Checkout session completed.") + email = data["customer_details"]["email"] + user = session.query(User).filter_by(email=email).first() + stripe_id = data["customer"] + name = data["customer_details"]["name"] + status = data["payment_status"] + if not user: + logging.debug("User not found.") + return {"success": "false"} + user_preferences = ( + session.query(UserPreferences) + .filter_by(user_id=user.id, pref_key="subscription") + .first() + ) + if not user_preferences: + user_preferences = UserPreferences( + user_id=user.id, pref_key="subscription", pref_value=stripe_id + ) + session.add(user_preferences) + session.commit() + name = name.split(" ") + user.first_name = name[0] + user.last_name = name[1] + session.commit() + session.close() + return {"success": "true"} + elif event and event["type"] == "customer.subscription.updated": + logging.debug("Customer Subscription Update session completed.") + else: + logging.debug("Unhandled Stripe event type {}".format(event["type"])) + return {"success": "true"} + + @app.post( "/v1/oauth2/{provider}", response_model=Detail, diff --git a/requirements.txt b/requirements.txt index a2222e3f4685..e36abee8d5c8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,6 +12,7 @@ ngrok==1.2.0 faster-whisper==1.0.2 youtube-transcript-api==0.6.2 O365==2.0.34 +stripe==9.12.0 google-auth==2.29.0 google-api-python-client==2.125.0 google-auth-oauthlib From f67c5e912d40083be7eab0106cbf2fce7e1d975d Mon Sep 17 00:00:00 2001 From: Josh XT Date: Thu, 20 Jun 2024 18:44:18 -0400 Subject: [PATCH 0249/1256] add active flag --- agixt/endpoints/Auth.py | 107 +++++++++++++++++++++------------------- 1 file changed, 55 insertions(+), 52 deletions(-) diff --git a/agixt/endpoints/Auth.py b/agixt/endpoints/Auth.py index 846c73af9b93..31386a930b9f 100644 --- a/agixt/endpoints/Auth.py +++ b/agixt/endpoints/Auth.py @@ -154,61 +154,64 @@ async def createuser( return {"status": "Success"}, 200 -@app.post( - "/webhook", - summary="Webhook endpoint for events.", - response_model=WebhookModel, - tags=["Webhook"], -) -async def webhook(request: Request): - event = None - data = None - try: - event = stripe.Webhook.construct_event( - payload=(await request.body()).decode("utf-8"), - sig_header=request.headers.get("stripe-signature"), - secret=getenv("STRIPE_WEBHOOK_SECRET"), - ) - data = event["data"]["object"] - except stripe.error.SignatureVerificationError as e: - logging.debug(f"Webhook signature verification failed: {str(e)}.") - raise HTTPException( - status_code=400, detail="Webhook signature verification failed." - ) - logging.debug(f"Stripe Webhook Event of type {event['type']} received") - if event and event["type"] == "checkout.session.completed": - session = get_session() - logging.debug("Checkout session completed.") - email = data["customer_details"]["email"] - user = session.query(User).filter_by(email=email).first() - stripe_id = data["customer"] - name = data["customer_details"]["name"] - status = data["payment_status"] - if not user: - logging.debug("User not found.") - return {"success": "false"} - user_preferences = ( - session.query(UserPreferences) - .filter_by(user_id=user.id, pref_key="subscription") - .first() - ) - if not user_preferences: - user_preferences = UserPreferences( - user_id=user.id, pref_key="subscription", pref_value=stripe_id +if getenv("STRIPE_WEBHOOK_SECRET") != "": + + @app.post( + "/v1/webhook", + summary="Webhook endpoint for events.", + response_model=WebhookModel, + tags=["Webhook"], + ) + async def webhook(request: Request): + event = None + data = None + try: + event = stripe.Webhook.construct_event( + payload=(await request.body()).decode("utf-8"), + sig_header=request.headers.get("stripe-signature"), + secret=getenv("STRIPE_WEBHOOK_SECRET"), + ) + data = event["data"]["object"] + except stripe.error.SignatureVerificationError as e: + logging.debug(f"Webhook signature verification failed: {str(e)}.") + raise HTTPException( + status_code=400, detail="Webhook signature verification failed." ) - session.add(user_preferences) + logging.debug(f"Stripe Webhook Event of type {event['type']} received") + if event and event["type"] == "checkout.session.completed": + session = get_session() + logging.debug("Checkout session completed.") + email = data["customer_details"]["email"] + user = session.query(User).filter_by(email=email).first() + stripe_id = data["customer"] + name = data["customer_details"]["name"] + status = data["payment_status"] + if not user: + logging.debug("User not found.") + return {"success": "false"} + user_preferences = ( + session.query(UserPreferences) + .filter_by(user_id=user.id, pref_key="subscription") + .first() + ) + if not user_preferences: + user_preferences = UserPreferences( + user_id=user.id, pref_key="subscription", pref_value=stripe_id + ) + session.add(user_preferences) + session.commit() + name = name.split(" ") + user.is_active = True + user.first_name = name[0] + user.last_name = name[1] session.commit() - name = name.split(" ") - user.first_name = name[0] - user.last_name = name[1] - session.commit() - session.close() + session.close() + return {"success": "true"} + elif event and event["type"] == "customer.subscription.updated": + logging.debug("Customer Subscription Update session completed.") + else: + logging.debug("Unhandled Stripe event type {}".format(event["type"])) return {"success": "true"} - elif event and event["type"] == "customer.subscription.updated": - logging.debug("Customer Subscription Update session completed.") - else: - logging.debug("Unhandled Stripe event type {}".format(event["type"])) - return {"success": "true"} @app.post( From 2f60f8fa03d5739cd98d28171c69b150cf597850 Mon Sep 17 00:00:00 2001 From: Josh XT Date: Thu, 20 Jun 2024 19:16:29 -0400 Subject: [PATCH 0250/1256] incrase default injected interactions --- agixt/Interactions.py | 19 +++++++++++++++---- agixt/XT.py | 5 ++++- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/agixt/Interactions.py b/agixt/Interactions.py index e9267d216d78..ebfb7350c1db 100644 --- a/agixt/Interactions.py +++ b/agixt/Interactions.py @@ -280,21 +280,32 @@ async def format_prompt( activity_history = [ interaction for interaction in conversation["interactions"] - if interaction["message"].startswith("[ACTIVITY]") + if str(interaction["message"]).startswith("[ACTIVITY]") ] + activities = [] + for activity in activity_history: + if "audio response" not in activity["message"]: + activity["message"] = activity["message"].replace( + "[ACTIVITY]", "" + ) + activities.append(activity) if len(activity_history) > 5: activity_history = activity_history[-5:] conversation["interactions"] = [ interaction for interaction in conversation["interactions"] - if not interaction["message"].startswith("[ACTIVITY]") + if not str(interaction["message"]).startswith("[ACTIVITY]") ] + interactions = [] + for interaction in conversation["interactions"]: + if not str(interaction["message"]).startswith("