- {data.created_at} |
+ {data.created_at.slice(0, 19).replace('T', ' ')} |
{data.agent_name} |
{data.agent_execution_name} |
diff --git a/gui/pages/Content/Toolkits/Metrics.js b/gui/pages/Content/Toolkits/Metrics.js
new file mode 100644
index 000000000..2887d5d10
--- /dev/null
+++ b/gui/pages/Content/Toolkits/Metrics.js
@@ -0,0 +1,145 @@
+import React, {useState, useEffect, useRef} from 'react';
+import {ToastContainer, toast} from 'react-toastify';
+import 'react-toastify/dist/ReactToastify.css';
+import {
+ getApiKeys, getToolMetrics, getToolLogs, getKnowledgeMetrics, getKnowledgeLogs
+} from "@/pages/api/DashboardService";
+import {
+ loadingTextEffect,
+} from "@/utils/utils";
+
+export default function Metrics({toolName, knowledgeName}) {
+ const [apiKeys, setApiKeys] = useState([]);
+ const [totalTokens, setTotalTokens] = useState(0)
+ const [totalAgentsUsing, setTotalAgentsUsing] = useState(0)
+ const [totalCalls, setTotalCalls] = useState(0)
+ const [callLogs, setCallLogs] = useState([])
+ const [isLoading, setIsLoading] = useState(true)
+ const [loadingText, setLoadingText] = useState("Loading Metrics");
+ const metricsData = [
+ { label: 'Total Calls', value: totalCalls },
+ { label: 'Total Agents Using', value: totalAgentsUsing }
+ ];
+
+ useEffect(() => {
+ loadingTextEffect('Loading Metrics', setLoadingText, 500);
+ }, []);
+
+ useEffect(() => {
+ if(toolName && !knowledgeName){
+ fetchToolMetrics()
+ fetchToolLogs()
+ return;
+ }
+ if(!toolName && knowledgeName){
+ fetchKnowledgeMetrics()
+ fetchKnowledgeLogs()
+ return;
+ }
+ }, [toolName, knowledgeName]);
+
+ const fetchToolMetrics = () => {
+ getToolMetrics(toolName)
+ .then((response) => {
+ setTotalAgentsUsing(response.data.tool_unique_agents ? response.data.tool_unique_agents : 0)
+ setTotalCalls(response.data.tool_calls ? response.data.tool_calls : 0)
+ setIsLoading(false)
+ })
+ .catch((error) => {
+ console.error('Error fetching Metrics', error);
+ });
+ }
+
+ const fetchToolLogs = () => {
+ getToolLogs(toolName)
+ .then((response) => {
+ setCallLogs(response.data ? response.data : [])
+ setIsLoading(false)
+ })
+ .catch((error) => {
+ console.error('Error fetching Metrics', error);
+ });
+ }
+
+ const fetchKnowledgeMetrics = () => {
+ getKnowledgeMetrics(knowledgeName)
+ .then((response) => {
+ setTotalAgentsUsing(response.data.knowledge_unique_agents ? response.data.knowledge_unique_agents : 0)
+ setTotalCalls(response.data.knowledge_calls ? response.data.knowledge_calls : 0)
+ setIsLoading(false)
+ })
+ .catch((error) => {
+ console.error('Error fetching Metrics', error);
+ });
+ }
+
+ const fetchKnowledgeLogs = () => {
+ getKnowledgeLogs(knowledgeName)
+ .then((response) => {
+ setCallLogs(response.data ? response.data : [])
+ setIsLoading(false)
+ })
+ .catch((error) => {
+ console.error('Error fetching Metrics', error);
+ });
+ }
+
+ return (<>
+
+
+ {!isLoading ?
+
+
+ {metricsData.map((metric, index) => (
+
+ {metric.label}
+
+ {metric.value}
+
+
+ ))}
+
+
+ Call Logs
+ {callLogs.length > 0 ?
+
+
+
+ Log Timestamp |
+ Agent Name |
+ Run Name |
+ Model |
+ Tokens Used |
+
+
+
+
+
+
+ {callLogs.map((item, index) => (
+
+ {item.created_at} |
+ {item.agent_name} |
+ {item.agent_execution_name} |
+ {item.model} |
+ {item.tokens_consumed} |
+
+ ))}
+
+
+
+ :
+
+
+ No logs to show!
+ }
+
+
+ : }
+
+
+
+ >)
+}
\ No newline at end of file
diff --git a/gui/pages/Content/Toolkits/ToolkitWorkspace.js b/gui/pages/Content/Toolkits/ToolkitWorkspace.js
index 8356a1eb6..942a280ce 100644
--- a/gui/pages/Content/Toolkits/ToolkitWorkspace.js
+++ b/gui/pages/Content/Toolkits/ToolkitWorkspace.js
@@ -1,4 +1,4 @@
-import React, {useEffect, useState} from 'react';
+import React, {useEffect, useRef, useState} from 'react';
import Image from 'next/image';
import {ToastContainer, toast} from 'react-toastify';
import {
@@ -9,14 +9,19 @@ import {
} from "@/pages/api/DashboardService";
import styles from './Tool.module.css';
import {setLocalStorageValue, setLocalStorageArray, returnToolkitIcon, convertToTitleCase} from "@/utils/utils";
+import Metrics from "@/pages/Content/Toolkits/Metrics";
export default function ToolkitWorkspace({env, toolkitDetails, internalId}) {
- const [activeTab, setActiveTab] = useState('configuration')
+ const [activeTab, setActiveTab] = useState('metrics')
const [showDescription, setShowDescription] = useState(false)
const [apiConfigs, setApiConfigs] = useState([]);
const [toolsIncluded, setToolsIncluded] = useState([]);
const [loading, setLoading] = useState(true);
const authenticateToolkits = ['Google Calendar Toolkit', 'Twitter Toolkit'];
+ const [toolDropdown, setToolDropdown] = useState(false);
+ const toolRef = useRef(null);
+ const [currTool, setCurrTool] = useState(false);
+
let handleKeyChange = (event, index) => {
const updatedData = [...apiConfigs];
@@ -44,6 +49,7 @@ export default function ToolkitWorkspace({env, toolkitDetails, internalId}) {
if (toolkitDetails !== null) {
if (toolkitDetails.tools) {
setToolsIncluded(toolkitDetails.tools);
+ setCurrTool(toolkitDetails.tools[0].name)
}
getToolConfig(toolkitDetails.name)
@@ -115,10 +121,20 @@ export default function ToolkitWorkspace({env, toolkitDetails, internalId}) {
}
}, [internalId]);
+ useEffect(() => {
+ function handleClickOutside(event) {
+ if (toolRef.current && !toolRef.current.contains(event.target)) {
+ setToolDropdown(false)
+ }
+ }
+ document.addEventListener('mousedown', handleClickOutside);
+ return () => {
+ document.removeEventListener('mousedown', handleClickOutside);
+ };
+ }, []);
return (<>
-
-
+
@@ -133,7 +149,12 @@ export default function ToolkitWorkspace({env, toolkitDetails, internalId}) {
-
+
+
+ setLocalStorageValue('toolkit_tab_' + String(internalId), 'metrics', setActiveTab)}>
+ Metrics
+
setLocalStorageValue('toolkit_tab_' + String(internalId), 'configuration', setActiveTab)}>
Configuration
@@ -142,7 +163,28 @@ export default function ToolkitWorkspace({env, toolkitDetails, internalId}) {
onClick={() => setLocalStorageValue('toolkit_tab_' + String(internalId), 'tools_included', setActiveTab)}>
Tools Included
+
+ {!loading && activeTab === 'metrics' &&
+
+ setToolDropdown(!toolDropdown)}>
+ {currTool}
+
+
+ {toolDropdown &&
+ {toolsIncluded.map((tool, index) => (
+ {setCurrTool(tool.name); setToolDropdown(false)}}>
+ {tool.name}
+ ))}
+ }
+
+
+ }
+
+
+
{!loading && activeTab === 'configuration' &&
{apiConfigs.length > 0 ? (apiConfigs.map((config, index) => (
@@ -176,9 +218,14 @@ export default function ToolkitWorkspace({env, toolkitDetails, internalId}) {
))}
}
+
+
+
+ {activeTab === 'metrics' &&
+
+ }
-
>);
diff --git a/gui/pages/Content/Toolkits/Toolkits.js b/gui/pages/Content/Toolkits/Toolkits.js
index 0a0bfe337..b56ce12f6 100644
--- a/gui/pages/Content/Toolkits/Toolkits.js
+++ b/gui/pages/Content/Toolkits/Toolkits.js
@@ -18,7 +18,7 @@ export default function Toolkits({sendToolkitData, toolkits, env}) {
}
{toolkits && toolkits.length > 0 ? (
-
+
{toolkits.map((tool, index) =>
tool.name !== null && !excludedToolkits().includes(tool.name) && (
sendToolkitData(tool)}>
diff --git a/gui/pages/Dashboard/Settings/Settings.js b/gui/pages/Dashboard/Settings/Settings.js
index 7d1263383..66570367e 100644
--- a/gui/pages/Dashboard/Settings/Settings.js
+++ b/gui/pages/Dashboard/Settings/Settings.js
@@ -5,6 +5,7 @@ import Image from "next/image";
import Model from "@/pages/Dashboard/Settings/Model";
import Database from "@/pages/Dashboard/Settings/Database";
import ApiKeys from "@/pages/Dashboard/Settings/ApiKeys";
+import Webhooks from "@/pages/Dashboard/Settings/Webhooks";
export default function Settings({organisationId, sendDatabaseData}) {
const [activeTab, setActiveTab] = useState('model');
@@ -38,12 +39,17 @@ export default function Settings({organisationId, sendDatabaseData}) {
API Keys
+
{activeTab === 'model' && }
{activeTab === 'database' && }
{activeTab === 'apikeys' && }
+ {activeTab === 'webhooks' && }
diff --git a/gui/pages/Dashboard/Settings/Webhooks.js b/gui/pages/Dashboard/Settings/Webhooks.js
new file mode 100644
index 000000000..749266c6e
--- /dev/null
+++ b/gui/pages/Dashboard/Settings/Webhooks.js
@@ -0,0 +1,148 @@
+import React, {useState, useEffect} from 'react';
+import {ToastContainer, toast} from 'react-toastify';
+import 'react-toastify/dist/ReactToastify.css';
+import agentStyles from "@/pages/Content/Agents/Agents.module.css";
+import {
+ editWebhook,
+ getWebhook, saveWebhook,
+} from "@/pages/api/DashboardService";
+import {loadingTextEffect, removeTab} from "@/utils/utils";
+import styles from "@/pages/Content/Marketplace/Market.module.css";
+export default function Webhooks() {
+ const [webhookUrl, setWebhookUrl] = useState('');
+ const [webhookId, setWebhookId] = useState(-1);
+ const [isLoading, setIsLoading] = useState(true)
+ const [existingWebhook, setExistingWebhook] = useState(false)
+ const [isEdtiting, setIsEdtiting] = useState(false)
+ const [loadingText, setLoadingText] = useState("Loading Webhooks");
+ const [selectedCheckboxes, setSelectedCheckboxes] = useState([]);
+ const checkboxes = [
+ { label: 'Agent is running', value: 'RUNNING' },
+ { label: 'Agent run is paused', value: 'PAUSED' },
+ { label: 'Agent run is completed', value: 'COMPLETED' },
+ { label: 'Agent is terminated ', value: 'TERMINATED' },
+ { label: 'Agent run max iteration reached', value: 'MAX ITERATION REACHED' },
+ ];
+
+
+ useEffect(() => {
+ loadingTextEffect('Loading Webhooks', setLoadingText, 500);
+ fetchWebhooks();
+ }, []);
+
+ const handleWebhookChange = (event) => {
+ setWebhookUrl(event.target.value);
+ };
+
+ const handleSaveWebhook = () => {
+ if(!webhookUrl || webhookUrl.trim() === ""){
+ toast.error("Enter valid webhook", {autoClose: 1800});
+ return;
+ }
+ if(isEdtiting){
+ editWebhook(webhookId, { url: webhookUrl, filters: {status: selectedCheckboxes}})
+ .then((response) => {
+ setIsEdtiting(false)
+ fetchWebhooks()
+ toast.success("Webhook edited successfully", {autoClose: 1800});
+ })
+ .catch((error) => {
+ console.error('Error fetching webhook', error);
+ });
+ return;
+ }
+ saveWebhook({name : "Webhook 1", url: webhookUrl, headers: {}, filters: {status: selectedCheckboxes}})
+ .then((response) => {
+ setExistingWebhook(true)
+ setWebhookId(response.data.id)
+ toast.success("Webhook created successfully", {autoClose: 1800});
+ })
+ .catch((error) => {
+ toast.error("Unable to create webhook", {autoClose: 1800});
+ console.error('Error saving webhook', error);
+ });
+ }
+
+ const fetchWebhooks = () => {
+ getWebhook()
+ .then((response) => {
+ setIsLoading(false)
+ if(response.data){
+ setWebhookUrl(response.data.url)
+ setExistingWebhook(true)
+ setWebhookId(response.data.id)
+ setSelectedCheckboxes(response.data.filters.status)
+ }
+ else{
+ setWebhookUrl('')
+ setExistingWebhook(false)
+ setWebhookId(-1)
+ }
+ })
+ .catch((error) => {
+ console.error('Error fetching webhook', error);
+ });
+ }
+
+ const toggleCheckbox = (value) => {
+ if (selectedCheckboxes.includes(value)) {
+ setSelectedCheckboxes(selectedCheckboxes.filter((item) => item !== value));
+ } else {
+ setSelectedCheckboxes([...selectedCheckboxes, value]);
+ }
+ };
+
+ return (<>
+
+
+
+ {!isLoading ?
+
+ Webhooks
+ {existingWebhook &&
+ }
+
+
+
+
+
+
+
+
+ {checkboxes.map((checkbox) => (
+
+ ))}
+
+
+
+ {!existingWebhook &&
+
+
+ }
+
+ : }
+
+
+
+
+ >)
+}
\ No newline at end of file
diff --git a/gui/pages/_app.css b/gui/pages/_app.css
index fc813e8ec..8698f03a3 100644
--- a/gui/pages/_app.css
+++ b/gui/pages/_app.css
@@ -770,6 +770,7 @@ p {
.mt_70{margin-top: 70px;}
.mt_74{margin-top: 74px;}
.mt_80{margin-top: 80px;}
+.mt_90{margin-top: 90px;}
.mb_1{margin-bottom: 1px;}
.mb_2{margin-bottom: 2px;}
@@ -815,6 +816,7 @@ p {
.mb_60{margin-bottom: 60px;}
.mb_70{margin-bottom: 70px;}
.mb_74{margin-bottom: 74px;}
+.mb_90{margin-bottom: 90px;}
.ml_1{margin-left: 1px;}
@@ -968,6 +970,14 @@ p {
line-height: normal;
}
+.text_20 {
+ color: #FFF;
+ font-size: 20px;
+ font-style: normal;
+ font-weight: 400;
+ line-height: normal;
+}
+
.text_20_bold{
color: #FFF;
font-size: 20px;
@@ -1051,10 +1061,13 @@ p {
.r_0{right: 0}
.w_120p{width: 120px}
+.w_3{width: 3%}
+.w_180p{width: 180px}
.w_4{width: 4%}
.w_6{width: 6%}
.w_10{width: 10%}
.w_12{width: 12%}
+.w_15{width: 15%}
.w_18{width: 18%}
.w_20{width: 20%}
.w_22{width: 22%}
@@ -1062,6 +1075,8 @@ p {
.w_50{width: 50%}
.w_56{width: 56%}
.w_60{width: 60%}
+.w_73{width: 73%}
+.w_97{width: 97%}
.w_100{width: 100%}
.w_inherit{width: inherit}
.w_fit_content{width:fit-content}
@@ -1070,12 +1085,16 @@ p {
.mxw_100{max-width: 100%}
.mxw_360{max-width: 360px}
+.h_31p{height: 31px}
.h_32p{height: 32px}
.h_44p{height: 44px}
.h_100{height: 100%}
.h_auto{height: auto}
.h_60vh{height: 60vh}
.h_75vh{height: 75vh}
+.h_80vh{height: 80vh}
+.h_calc92{height: calc(100vh - 92px)}
+.h_calc_add40{height: calc(80vh + 40px)}
.mxh_78vh{max-height: 78vh}
@@ -1087,6 +1106,7 @@ p {
.justify_space_between{justify-content: space-between}
.display_flex{display: inline-flex}
+.display_flex_container{display: flex}
.align_center{align-items: center}
.align_start{align-items: flex-start}
@@ -1103,6 +1123,7 @@ p {
.cursor_pointer{cursor: pointer}
.cursor_default{cursor: default}
+.cursor_not_allowed{cursor: not-allowed}
.overflow_auto{overflow: auto}
.overflowY_scroll{overflow-y: scroll}
@@ -1111,12 +1132,17 @@ p {
.overflowX_auto{overflow-x: auto}
.gap_4{gap:4px;}
+.gap_5{gap:5px;}
.gap_6{gap:6px;}
.gap_8{gap:8px;}
.gap_16{gap:16px;}
.gap_20{gap:20px;}
+.border_gray{border: 1px solid rgba(255, 255, 255, 0.08)}
+.border_left_none{border-left: none;}
.border_top_none{border-top: none;}
+.border_bottom_none{border-bottom: none;}
+.border_bottom_grey{border-bottom: 1px solid rgba(255, 255, 255, 0.08)}
.border_radius_8{border-radius: 8px;}
.border_radius_25{border-radius: 25px;}
@@ -1143,6 +1169,8 @@ p {
.padding_12_14{padding: 12px 14px;}
.padding_0_15{padding: 0px 15px;}
+.pd_bottom_5{padding-bottom: 5px}
+
.flex_1{flex: 1}
.flex_wrap{flex-wrap: wrap;}
@@ -1452,6 +1480,7 @@ tr{
.bg_black{background: black}
.bg_white{background: white}
+.bg_none{background: none}
.container {
height: 100%;
@@ -1490,7 +1519,7 @@ tr{
.item_publisher {
font-style: normal;
font-weight: 500;
- font-size: 9px;
+ font-size: 11px;
line-height: 12px;
display: flex;
align-items: center;
diff --git a/gui/pages/api/DashboardService.js b/gui/pages/api/DashboardService.js
index 52782be6f..c26330d25 100644
--- a/gui/pages/api/DashboardService.js
+++ b/gui/pages/api/DashboardService.js
@@ -324,6 +324,21 @@ export const deleteApiKey = (apiId) => {
return api.delete(`/api-keys/${apiId}`);
};
+export const saveWebhook = (webhook) => {
+ return api.post(`/webhook/add`, webhook);
+};
+
+export const getWebhook = () => {
+ return api.get(`/webhook/get`);
+};
+
+export const editWebhook = (webhook_id, webook_data) => {
+ return api.post(`/webhook/edit/${webhook_id}`, webook_data);
+};
+
+export const publishToMarketplace = (executionId) => {
+ return api.post(`/agent_templates/publish_template/agent_execution_id/${executionId}`);
+};
export const storeApiKey = (model_provider, model_api_key) => {
return api.post(`/models_controller/store_api_keys`, {model_provider, model_api_key});
@@ -363,3 +378,21 @@ export const fetchMarketPlaceModel = () => {
return api.get(`/models_controller/get/list`)
}
+export const getToolMetrics = (toolName) => {
+ return api.get(`analytics/tools/${toolName}/usage`)
+}
+
+export const getToolLogs = (toolName) => {
+ return api.get(`analytics/tools/${toolName}/logs`)
+}
+
+export const publishTemplateToMarketplace = (agentData) => {
+ return api.post(`/agent_templates/publish_template`, agentData);
+};
+export const getKnowledgeMetrics = (knowledgeName) => {
+ return api.get(`analytics/knowledge/${knowledgeName}/usage`)
+}
+
+export const getKnowledgeLogs = (knowledgeName) => {
+ return api.get(`analytics/knowledge/${knowledgeName}/logs`)
+}
\ No newline at end of file
diff --git a/gui/public/images/twitter_icon.svg b/gui/public/images/twitter_icon.svg
index cbd05ea9d..2c6e581da 100644
--- a/gui/public/images/twitter_icon.svg
+++ b/gui/public/images/twitter_icon.svg
@@ -2,8 +2,8 @@
-
+
-
+
diff --git a/gui/public/images/webhook_icon.svg b/gui/public/images/webhook_icon.svg
new file mode 100644
index 000000000..0e316e430
--- /dev/null
+++ b/gui/public/images/webhook_icon.svg
@@ -0,0 +1,3 @@
+
diff --git a/main.py b/main.py
index 8a349e2d5..a0cd36223 100644
--- a/main.py
+++ b/main.py
@@ -45,6 +45,7 @@
from superagi.helper.tool_helper import register_toolkits, register_marketplace_toolkits
from superagi.lib.logger import logger
from superagi.llms.google_palm import GooglePalm
+from superagi.llms.llm_model_factory import build_model_with_api_key
from superagi.llms.openai import OpenAi
from superagi.llms.replicate import Replicate
from superagi.llms.hugging_face import HuggingFace
@@ -56,18 +57,24 @@
from superagi.models.workflows.agent_workflow import AgentWorkflow
from superagi.models.workflows.iteration_workflow import IterationWorkflow
from superagi.models.workflows.iteration_workflow_step import IterationWorkflowStep
+from urllib.parse import urlparse
app = FastAPI()
-database_url = get_config('POSTGRES_URL')
+db_host = get_config('DB_HOST', 'super__postgres')
+db_url = get_config('DB_URL', None)
db_username = get_config('DB_USERNAME')
db_password = get_config('DB_PASSWORD')
db_name = get_config('DB_NAME')
env = get_config('ENV', "DEV")
-if db_username is None:
- db_url = f'postgresql://{database_url}/{db_name}'
+if db_url is None:
+ if db_username is None:
+ db_url = f'postgresql://{db_host}/{db_name}'
+ else:
+ db_url = f'postgresql://{db_username}:{db_password}@{db_host}/{db_name}'
else:
- db_url = f'postgresql://{db_username}:{db_password}@{database_url}/{db_name}'
+ db_url = urlparse(db_url)
+ db_url = db_url.scheme + "://" + db_url.netloc + db_url.path
engine = create_engine(db_url,
pool_size=20, # Maximum number of database connections in the pool
@@ -344,15 +351,8 @@ async def validate_llm_api_key(request: ValidateAPIKeyRequest, Authorize: AuthJW
"""API to validate LLM API Key"""
source = request.model_source
api_key = request.model_api_key
- valid_api_key = False
- if source == "OpenAi":
- valid_api_key = OpenAi(api_key=api_key).verify_access_key()
- elif source == "Google Palm":
- valid_api_key = GooglePalm(api_key=api_key).verify_access_key()
- elif source == "Replicate":
- valid_api_key = Replicate(api_key=api_key).verify_access_key()
- elif source == "Hugging Face":
- valid_api_key = HuggingFace(api_key=api_key).verify_access_key()
+ model = build_model_with_api_key(source, api_key)
+ valid_api_key = model.verify_access_key() if model is not None else False
if valid_api_key:
return {"message": "Valid API Key", "status": "success"}
else:
diff --git a/migrations/env.py b/migrations/env.py
index f6a48ea87..a46ae289c 100644
--- a/migrations/env.py
+++ b/migrations/env.py
@@ -2,10 +2,8 @@
from sqlalchemy import engine_from_config
from sqlalchemy import pool
-
from alembic import context
-
-from superagi.config.config import get_config
+from urllib.parse import urlparse
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
@@ -23,16 +21,18 @@
from superagi.models.base_model import DBBaseModel
target_metadata = DBBaseModel.metadata
from superagi.models import *
+from superagi.config.config import get_config
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
-database_url = get_config('POSTGRES_URL')
+db_host = get_config('DB_HOST', 'super__postgres')
db_username = get_config('DB_USERNAME')
db_password = get_config('DB_PASSWORD')
db_name = get_config('DB_NAME')
+database_url = get_config('DB_URL', None)
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
@@ -47,10 +47,15 @@ def run_migrations_offline() -> None:
"""
- if db_username is None:
- db_url = f'postgresql://{database_url}/{db_name}'
+ db_url = database_url
+ if db_url is None:
+ if db_username is None:
+ db_url = f'postgresql://{db_host}/{db_name}'
+ else:
+ db_url = f'postgresql://{db_username}:{db_password}@{db_host}/{db_name}'
else:
- db_url = f'postgresql://{db_username}:{db_password}@{database_url}/{db_name}'
+ db_url = urlparse(db_url)
+ db_url = db_url.scheme + "://" + db_url.netloc + db_url.path
config.set_main_option("sqlalchemy.url", db_url)
@@ -73,6 +78,23 @@ def run_migrations_online() -> None:
and associate a connection with the context.
"""
+
+ db_host = get_config('DB_HOST', 'super__postgres')
+ db_username = get_config('DB_USERNAME')
+ db_password = get_config('DB_PASSWORD')
+ db_name = get_config('DB_NAME')
+ db_url = get_config('DB_URL', None)
+
+ if db_url is None:
+ if db_username is None:
+ db_url = f'postgresql://{db_host}/{db_name}'
+ else:
+ db_url = f'postgresql://{db_username}:{db_password}@{db_host}/{db_name}'
+ else:
+ db_url = urlparse(db_url)
+ db_url = db_url.scheme + "://" + db_url.netloc + db_url.path
+
+ config.set_main_option('sqlalchemy.url', db_url)
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
diff --git a/migrations/versions/40affbf3022b_add_filter_colume_in_webhooks.py b/migrations/versions/40affbf3022b_add_filter_colume_in_webhooks.py
new file mode 100644
index 000000000..32f39be18
--- /dev/null
+++ b/migrations/versions/40affbf3022b_add_filter_colume_in_webhooks.py
@@ -0,0 +1,28 @@
+"""add filter colume in webhooks
+
+Revision ID: 40affbf3022b
+Revises: 5d5f801f28e7
+Create Date: 2023-08-28 12:30:35.171176
+
+"""
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = '40affbf3022b'
+down_revision = '5d5f801f28e7'
+branch_labels = None
+depends_on = None
+
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.add_column('webhooks', sa.Column('filters', sa.JSON(), nullable=True))
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_column('webhooks', 'filters')
+ # ### end Alembic commands ###
diff --git a/requirements.txt b/requirements.txt
index 554fd54ea..ab45bb1c7 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -156,3 +156,5 @@ html2text==2020.1.16
duckduckgo-search==3.8.3
google-generativeai==0.1.0
unstructured==0.8.1
+ai21==1.2.6
+typing-extensions==4.5.0
diff --git a/superagi/agent/output_handler.py b/superagi/agent/output_handler.py
index 2a0aa950f..1fdeb7531 100644
--- a/superagi/agent/output_handler.py
+++ b/superagi/agent/output_handler.py
@@ -100,7 +100,7 @@ def handle_tool_response(self, session, assistant_reply):
action = self.output_parser.parse(assistant_reply)
agent = session.query(Agent).filter(Agent.id == self.agent_config["agent_id"]).first()
organisation = agent.get_agent_organisation(session)
- tool_executor = ToolExecutor(organisation_id=organisation.id, agent_id=agent.id, tools=self.tools)
+ tool_executor = ToolExecutor(organisation_id=organisation.id, agent_id=agent.id, tools=self.tools, agent_execution_id=self.agent_execution_id)
return tool_executor.execute(session, action.name, action.args)
def _check_permission_in_restricted_mode(self, session, assistant_reply: str):
diff --git a/superagi/agent/tool_executor.py b/superagi/agent/tool_executor.py
index 7aed83673..017094164 100644
--- a/superagi/agent/tool_executor.py
+++ b/superagi/agent/tool_executor.py
@@ -9,10 +9,11 @@ class ToolExecutor:
"""Executes the tool with the given args."""
FINISH = "finish"
- def __init__(self, organisation_id: int, agent_id: int, tools: list):
+ def __init__(self, organisation_id: int, agent_id: int, tools: list, agent_execution_id: int):
self.organisation_id = organisation_id
self.agent_id = agent_id
self.tools = tools
+ self.agent_execution_id = agent_execution_id
def execute(self, session, tool_name, tool_args):
"""Executes the tool with the given args.
@@ -31,7 +32,7 @@ def execute(self, session, tool_name, tool_args):
status = "SUCCESS"
tool = tools[tool_name]
retry = False
- EventHandler(session=session).create_event('tool_used', {'tool_name': tool_name}, self.agent_id,
+ EventHandler(session=session).create_event('tool_used', {'tool_name': tool_name, 'agent_execution_id': self.agent_execution_id}, self.agent_id,
self.organisation_id),
try:
parsed_args = self.clean_tool_args(tool_args)
diff --git a/superagi/apm/knowledge_handler.py b/superagi/apm/knowledge_handler.py
new file mode 100644
index 000000000..098791d4c
--- /dev/null
+++ b/superagi/apm/knowledge_handler.py
@@ -0,0 +1,117 @@
+from sqlalchemy.orm import Session
+from superagi.models.events import Event
+from superagi.models.knowledges import Knowledges
+from sqlalchemy import Integer, or_, label, case, and_
+from fastapi import HTTPException
+from typing import List, Dict, Union, Any
+from sqlalchemy.sql import func
+from sqlalchemy.orm import aliased
+from superagi.models.agent_config import AgentConfiguration
+import pytz
+from datetime import datetime
+
+
+class KnowledgeHandler:
+ def __init__(self, session: Session, organisation_id: int):
+ self.session = session
+ self.organisation_id = organisation_id
+
+
+ def get_knowledge_usage_by_name(self, knowledge_name: str) -> Dict[str, Dict[str, int]]:
+
+ is_knowledge_valid = self.session.query(Knowledges.id).filter_by(name=knowledge_name).filter(Knowledges.organisation_id == self.organisation_id).first()
+ if not is_knowledge_valid:
+ raise HTTPException(status_code=404, detail="Knowledge not found")
+ EventAlias = aliased(Event)
+
+ knowledge_used_event = self.session.query(
+ Event.event_property['knowledge_name'].label('knowledge_name'),
+ func.count(Event.agent_id.distinct()).label('knowledge_unique_agents')
+ ).filter(
+ Event.event_name == 'knowledge_picked',
+ Event.org_id == self.organisation_id,
+ Event.event_property['knowledge_name'].astext == knowledge_name
+ ).group_by(
+ Event.event_property['knowledge_name']
+ ).first()
+
+ if knowledge_used_event is None:
+ return {}
+
+ knowledge_data = {
+ 'knowledge_unique_agents': knowledge_used_event.knowledge_unique_agents,
+ 'knowledge_calls': self.session.query(
+ EventAlias
+ ).filter(
+ EventAlias.event_property['tool_name'].astext == 'knowledgesearch',
+ EventAlias.event_name == 'tool_used',
+ EventAlias.org_id == self.organisation_id,
+ EventAlias.agent_id.in_(self.session.query(Event.agent_id).filter(
+ Event.event_name == 'knowledge_picked',
+ Event.org_id == self.organisation_id,
+ Event.event_property['knowledge_name'].astext == knowledge_name
+ ))
+ ).count()
+ }
+
+ return knowledge_data
+
+
+ def get_knowledge_events_by_name(self, knowledge_name: str) -> List[Dict[str, Union[str, int, List[str]]]]:
+
+ is_knowledge_valid = self.session.query(Knowledges.id).filter_by(name=knowledge_name).filter(Knowledges.organisation_id == self.organisation_id).first()
+
+ if not is_knowledge_valid:
+ raise HTTPException(status_code=404, detail="Knowledge not found")
+
+ knowledge_events = self.session.query(Event).filter(
+ Event.org_id == self.organisation_id,
+ Event.event_name == 'knowledge_picked',
+ Event.event_property['knowledge_name'].astext == knowledge_name
+ ).all()
+
+ knowledge_events = [ke for ke in knowledge_events if 'agent_execution_id' in ke.event_property]
+
+ event_runs = self.session.query(Event).filter(
+ Event.org_id == self.organisation_id,
+ or_(Event.event_name == 'run_completed', Event.event_name == 'run_iteration_limit_crossed')
+ ).all()
+
+ agent_created_events = self.session.query(Event).filter(
+ Event.org_id == self.organisation_id,
+ Event.event_name == 'agent_created'
+ ).all()
+
+ results = []
+
+ for knowledge_event in knowledge_events:
+ agent_execution_id = knowledge_event.event_property['agent_execution_id']
+
+ event_run = next((er for er in event_runs if er.agent_id == knowledge_event.agent_id and er.event_property['agent_execution_id'] == agent_execution_id), None)
+ agent_created_event = next((ace for ace in agent_created_events if ace.agent_id == knowledge_event.agent_id), None)
+ try:
+ user_timezone = AgentConfiguration.get_agent_config_by_key_and_agent_id(session=self.session, key='user_timezone', agent_id=knowledge_event.agent_id)
+ if user_timezone and user_timezone.value != 'None':
+ tz = pytz.timezone(user_timezone.value)
+ else:
+ tz = pytz.timezone('GMT')
+ except AttributeError:
+ tz = pytz.timezone('GMT')
+
+ if event_run and agent_created_event:
+ actual_time = knowledge_event.created_at.astimezone(tz).strftime("%d %B %Y %H:%M")
+
+ result_dict = {
+ 'agent_execution_id': agent_execution_id,
+ 'created_at': actual_time,
+ 'tokens_consumed': event_run.event_property['tokens_consumed'],
+ 'calls': event_run.event_property['calls'],
+ 'agent_execution_name': event_run.event_property['name'],
+ 'agent_name': agent_created_event.event_property['agent_name'],
+ 'model': agent_created_event.event_property['model']
+ }
+ if agent_execution_id not in [i['agent_execution_id'] for i in results]:
+ results.append(result_dict)
+
+ results = sorted(results, key=lambda x: datetime.strptime(x['created_at'], '%d %B %Y %H:%M'), reverse=True)
+ return results
\ No newline at end of file
diff --git a/superagi/apm/tools_handler.py b/superagi/apm/tools_handler.py
index 72048d529..da3f97cc6 100644
--- a/superagi/apm/tools_handler.py
+++ b/superagi/apm/tools_handler.py
@@ -1,12 +1,16 @@
-from typing import List, Dict
-
-from sqlalchemy import func
+from typing import List, Dict, Union
+from sqlalchemy import func, distinct, and_
from sqlalchemy.orm import Session
-
+from sqlalchemy import Integer
+from fastapi import HTTPException
from superagi.models.events import Event
from superagi.models.tool import Tool
from superagi.models.toolkit import Toolkit
-
+from sqlalchemy import or_
+from sqlalchemy.sql import label
+from datetime import datetime
+from superagi.models.agent_config import AgentConfiguration
+import pytz
class ToolsHandler:
def __init__(self, session: Session, organisation_id: int):
@@ -54,4 +58,109 @@ def calculate_tool_usage(self) -> List[Dict[str, int]]:
'toolkit': tool_and_toolkit.get(row.tool_name, None)
} for row in result]
- return tool_usage
\ No newline at end of file
+ return tool_usage
+
+ def get_tool_usage_by_name(self, tool_name: str) -> Dict[str, Dict[str, int]]:
+ is_tool_name_valid = self.session.query(Tool).filter_by(name=tool_name).first()
+
+ if not is_tool_name_valid:
+ raise HTTPException(status_code=404, detail="Tool not found")
+ formatted_tool_name = tool_name.lower().replace(" ", "")
+
+ tool_used_event = self.session.query(
+ Event.event_property['tool_name'].label('tool_name'),
+ func.count(Event.id).label('tool_calls'),
+ func.count(distinct(Event.agent_id)).label('tool_unique_agents')
+ ).filter(
+ Event.event_name == 'tool_used',
+ Event.org_id == self.organisation_id,
+ Event.event_property['tool_name'].astext == formatted_tool_name
+ ).group_by(
+ Event.event_property['tool_name']
+ ).first()
+
+ if tool_used_event is None:
+ return {}
+
+ tool_data = {
+ 'tool_calls': tool_used_event.tool_calls,
+ 'tool_unique_agents': tool_used_event.tool_unique_agents
+ }
+
+ return tool_data
+
+
+ def get_tool_events_by_name(self, tool_name: str) -> List[Dict[str, Union[str, int, List[str]]]]:
+ is_tool_name_valid = self.session.query(Tool).filter_by(name=tool_name).first()
+
+ if not is_tool_name_valid:
+ raise HTTPException(status_code=404, detail="Tool not found")
+
+ formatted_tool_name = tool_name.lower().replace(" ", "")
+
+ tool_events = self.session.query(Event).filter(
+ Event.org_id == self.organisation_id,
+ Event.event_name == 'tool_used',
+ Event.event_property['tool_name'].astext == formatted_tool_name
+ ).all()
+
+ tool_events = [te for te in tool_events if 'agent_execution_id' in te.event_property]
+
+ event_runs = self.session.query(Event).filter(
+ Event.org_id == self.organisation_id,
+ or_(Event.event_name == 'run_completed', Event.event_name == 'run_iteration_limit_crossed')
+ ).all()
+
+ agent_created_events = self.session.query(Event).filter(
+ Event.org_id == self.organisation_id,
+ Event.event_name == 'agent_created'
+ ).all()
+
+ results = []
+
+ for tool_event in tool_events:
+ agent_execution_id = tool_event.event_property['agent_execution_id']
+
+ event_run = next((er for er in event_runs if er.agent_id == tool_event.agent_id and er.event_property['agent_execution_id'] == agent_execution_id), None)
+ agent_created_event = next((ace for ace in agent_created_events if ace.agent_id == tool_event.agent_id), None)
+ try:
+ user_timezone = AgentConfiguration.get_agent_config_by_key_and_agent_id(session=self.session, key='user_timezone', agent_id=tool_event.agent_id)
+ if user_timezone and user_timezone.value != 'None':
+ tz = pytz.timezone(user_timezone.value)
+ else:
+ tz = pytz.timezone('GMT')
+ except AttributeError:
+ tz = pytz.timezone('GMT')
+
+ if event_run and agent_created_event:
+ actual_time = tool_event.created_at.astimezone(tz).strftime("%d %B %Y %H:%M")
+ other_tools_events = self.session.query(
+ Event
+ ).filter(
+ Event.org_id == self.organisation_id,
+ Event.event_name == 'tool_used',
+ Event.event_property['tool_name'].astext != formatted_tool_name,
+ Event.agent_id == tool_event.agent_id,
+ Event.id.between(tool_event.id, event_run.id)
+ ).all()
+
+ other_tools = [ote.event_property['tool_name'] for ote in other_tools_events]
+
+ result_dict = {
+ 'created_at': actual_time,
+ 'agent_execution_id': agent_execution_id,
+ 'tokens_consumed': event_run.event_property['tokens_consumed'],
+ 'calls': event_run.event_property['calls'],
+ 'agent_execution_name': event_run.event_property['name'],
+ 'other_tools': other_tools,
+ 'agent_name': agent_created_event.event_property['agent_name'],
+ 'model': agent_created_event.event_property['model']
+ }
+
+ if agent_execution_id not in [i['agent_execution_id'] for i in results]:
+ results.append(result_dict)
+
+ results = sorted(results, key=lambda x: datetime.strptime(x['created_at'], '%d %B %Y %H:%M'), reverse=True)
+
+ return results
+
\ No newline at end of file
diff --git a/superagi/controllers/agent.py b/superagi/controllers/agent.py
index a912a722b..2d108d7d9 100644
--- a/superagi/controllers/agent.py
+++ b/superagi/controllers/agent.py
@@ -132,18 +132,27 @@ def create_agent_with_config(agent_with_config: AgentConfigInput,
agent = db.session.query(Agent).filter(Agent.id == db_agent.id, ).first()
organisation = agent.get_agent_organisation(db.session)
- EventHandler(session=db.session).create_event('run_created', {'agent_execution_id': execution.id,
- 'agent_execution_name': execution.name}, db_agent.id,
-
- organisation.id if organisation else 0),
- EventHandler(session=db.session).create_event('agent_created', {'agent_name': agent_with_config.name,
-
- 'model': agent_with_config.model}, db_agent.id,
-
+
+ EventHandler(session=db.session).create_event('run_created',
+ {'agent_execution_id': execution.id,
+ 'agent_execution_name': execution.name},
+ db_agent.id,
+ organisation.id if organisation else 0),
+
+ if agent_with_config.knowledge:
+ knowledge_name = db.session.query(Knowledges.name).filter(Knowledges.id == agent_with_config.knowledge).first()[0]
+ EventHandler(session=db.session).create_event('knowledge_picked',
+ {'knowledge_name': knowledge_name,
+ 'agent_execution_id': execution.id},
+ db_agent.id,
+ organisation.id if organisation else 0)
+
+ EventHandler(session=db.session).create_event('agent_created',
+ {'agent_name': agent_with_config.name,
+ 'model': agent_with_config.model},
+ db_agent.id,
organisation.id if organisation else 0)
- # execute_agent.delay(execution.id, datetime.now())
-
db.session.commit()
return {
diff --git a/superagi/controllers/agent_execution.py b/superagi/controllers/agent_execution.py
index fca7def57..9658712ec 100644
--- a/superagi/controllers/agent_execution.py
+++ b/superagi/controllers/agent_execution.py
@@ -24,6 +24,7 @@
from superagi.apm.event_handler import EventHandler
from superagi.controllers.tool import ToolOut
from superagi.models.agent_config import AgentConfiguration
+from superagi.models.knowledges import Knowledges
router = APIRouter()
@@ -86,12 +87,12 @@ def create_agent_execution(agent_execution: AgentExecutionIn,
iteration_step_id = IterationWorkflow.fetch_trigger_step_id(db.session,
start_step.action_reference_id).id if start_step.action_type == "ITERATION_WORKFLOW" else -1
- db_agent_execution = AgentExecution(status="RUNNING", last_execution_time=datetime.now(),
+ db_agent_execution = AgentExecution(status="CREATED", last_execution_time=datetime.now(),
agent_id=agent_execution.agent_id, name=agent_execution.name, num_of_calls=0,
num_of_tokens=0,
current_agent_step_id=start_step.id,
iteration_workflow_step_id=iteration_step_id)
-
+
agent_execution_configs = {
"goal": agent_execution.goal,
"instruction": agent_execution.instruction
@@ -118,15 +119,31 @@ def create_agent_execution(agent_execution: AgentExecutionIn,
db.session.add(db_agent_execution)
db.session.commit()
db.session.flush()
+
+ #update status from CREATED to RUNNING
+ db_agent_execution.status = "RUNNING"
+ db.session.commit()
+
AgentExecutionConfiguration.add_or_update_agent_execution_config(session=db.session, execution=db_agent_execution,
agent_execution_configs=agent_execution_configs)
organisation = agent.get_agent_organisation(db.session)
- EventHandler(session=db.session).create_event('run_created', {'agent_execution_id': db_agent_execution.id,'agent_execution_name':db_agent_execution.name},
- agent_execution.agent_id, organisation.id if organisation else 0)
-
+ agent_execution_knowledge = AgentConfiguration.get_agent_config_by_key_and_agent_id(session= db.session, key= 'knowledge', agent_id= agent_execution.agent_id)
+
+ EventHandler(session=db.session).create_event('run_created',
+ {'agent_execution_id': db_agent_execution.id,
+ 'agent_execution_name':db_agent_execution.name},
+ agent_execution.agent_id,
+ organisation.id if organisation else 0)
+ if agent_execution_knowledge and agent_execution_knowledge.value != 'None':
+ knowledge_name = Knowledges.get_knowledge_from_id(db.session, int(agent_execution_knowledge.value)).name
+ if knowledge_name is not None:
+ EventHandler(session=db.session).create_event('knowledge_picked',
+ {'knowledge_name': knowledge_name,
+ 'agent_execution_id': db_agent_execution.id},
+ agent_execution.agent_id,
+ organisation.id if organisation else 0)
Models.api_key_from_configurations(session=db.session, organisation_id=organisation.id)
-
if db_agent_execution.status == "RUNNING":
execute_agent.delay(db_agent_execution.id, datetime.now())
@@ -147,6 +164,7 @@ def create_agent_run(agent_execution: AgentRunIn, Authorize: AuthJWT = Depends(c
Raises:
HTTPException (Status Code=404): If the agent is not found.
"""
+
agent = db.session.query(Agent).filter(Agent.id == agent_execution.agent_id, Agent.is_deleted == False).first()
if not agent:
raise HTTPException(status_code = 404, detail = "Agent not found")
@@ -159,7 +177,7 @@ def create_agent_run(agent_execution: AgentRunIn, Authorize: AuthJWT = Depends(c
iteration_step_id = IterationWorkflow.fetch_trigger_step_id(db.session,
start_step.action_reference_id).id if start_step.action_type == "ITERATION_WORKFLOW" else -1
- db_agent_execution = AgentExecution(status="RUNNING", last_execution_time=datetime.now(),
+ db_agent_execution = AgentExecution(status="CREATED", last_execution_time=datetime.now(),
agent_id=agent_execution.agent_id, name=agent_execution.name, num_of_calls=0,
num_of_tokens=0,
current_agent_step_id=start_step.id,
@@ -183,13 +201,29 @@ def create_agent_run(agent_execution: AgentRunIn, Authorize: AuthJWT = Depends(c
db.session.add(db_agent_execution)
db.session.commit()
db.session.flush()
-
+
+ #update status from CREATED to RUNNING
+ db_agent_execution.status = "RUNNING"
+ db.session.commit()
+
AgentExecutionConfiguration.add_or_update_agent_execution_config(session = db.session, execution = db_agent_execution,
agent_execution_configs = agent_execution_configs)
organisation = agent.get_agent_organisation(db.session)
- EventHandler(session=db.session).create_event('run_created', {'agent_execution_id': db_agent_execution.id,'agent_execution_name':db_agent_execution.name},
- agent_execution.agent_id, organisation.id if organisation else 0)
+ EventHandler(session=db.session).create_event('run_created',
+ {'agent_execution_id': db_agent_execution.id,
+ 'agent_execution_name':db_agent_execution.name},
+ agent_execution.agent_id,
+ organisation.id if organisation else 0)
+ agent_execution_knowledge = AgentConfiguration.get_agent_config_by_key_and_agent_id(session= db.session, key= 'knowledge', agent_id= agent_execution.agent_id)
+ if agent_execution_knowledge and agent_execution_knowledge.value != 'None':
+ knowledge_name = Knowledges.get_knowledge_from_id(db.session, int(agent_execution_knowledge.value)).name
+ if knowledge_name is not None:
+ EventHandler(session=db.session).create_event('knowledge_picked',
+ {'knowledge_name': knowledge_name,
+ 'agent_execution_id': db_agent_execution.id},
+ agent_execution.agent_id,
+ organisation.id if organisation else 0)
if db_agent_execution.status == "RUNNING":
execute_agent.delay(db_agent_execution.id, datetime.now())
diff --git a/superagi/controllers/agent_template.py b/superagi/controllers/agent_template.py
index 7e9abb0fe..799df391b 100644
--- a/superagi/controllers/agent_template.py
+++ b/superagi/controllers/agent_template.py
@@ -6,7 +6,10 @@
from pydantic import BaseModel
from main import get_config
+from superagi.controllers.types.agent_execution_config import AgentRunIn
+from superagi.controllers.types.agent_publish_config import AgentPublish
from superagi.helper.auth import get_user_organisation
+from superagi.helper.auth import get_current_user
from superagi.models.agent import Agent
from superagi.models.agent_config import AgentConfiguration
from superagi.models.agent_execution import AgentExecution
@@ -138,58 +141,58 @@ def edit_agent_template(agent_template_id: int,
db.session.commit()
db.session.flush()
-@router.put("/update_agent_template/{agent_template_id}", status_code=200)
-def edit_agent_template(agent_template_id: int,
- updated_agent_configs: dict,
- organisation=Depends(get_user_organisation)):
+# @router.put("/update_agent_template/{agent_template_id}", status_code=200)
+# def edit_agent_template(agent_template_id: int,
+# updated_agent_configs: dict,
+# organisation=Depends(get_user_organisation)):
- """
- Update the details of an agent template.
+# """
+# Update the details of an agent template.
- Args:
- agent_template_id (int): The ID of the agent template to update.
- edited_agent_configs (dict): The updated agent configurations.
- organisation (Depends): Dependency to get the user organisation.
+# Args:
+# agent_template_id (int): The ID of the agent template to update.
+# edited_agent_configs (dict): The updated agent configurations.
+# organisation (Depends): Dependency to get the user organisation.
- Returns:
- HTTPException (status_code=200): If the agent gets successfully edited.
+# Returns:
+# HTTPException (status_code=200): If the agent gets successfully edited.
- Raises:
- HTTPException (status_code=404): If the agent template is not found.
- """
+# Raises:
+# HTTPException (status_code=404): If the agent template is not found.
+# """
- db_agent_template = db.session.query(AgentTemplate).filter(AgentTemplate.organisation_id == organisation.id,
- AgentTemplate.id == agent_template_id).first()
- if db_agent_template is None:
- raise HTTPException(status_code=404, detail="Agent Template not found")
+# db_agent_template = db.session.query(AgentTemplate).filter(AgentTemplate.organisation_id == organisation.id,
+# AgentTemplate.id == agent_template_id).first()
+# if db_agent_template is None:
+# raise HTTPException(status_code=404, detail="Agent Template not found")
- db_agent_template.name = updated_agent_configs["name"]
- db_agent_template.description = updated_agent_configs["description"]
+# db_agent_template.name = updated_agent_configs["name"]
+# db_agent_template.description = updated_agent_configs["description"]
- db.session.commit()
+# db.session.commit()
- agent_config_values = updated_agent_configs.get('agent_configs', {})
+# agent_config_values = updated_agent_configs.get('agent_configs', {})
- for key, value in agent_config_values.items():
- if isinstance(value, (list, dict)):
- value = json.dumps(value)
- config = db.session.query(AgentTemplateConfig).filter(
- AgentTemplateConfig.agent_template_id == agent_template_id,
- AgentTemplateConfig.key == key
- ).first()
+# for key, value in agent_config_values.items():
+# if isinstance(value, (list, dict)):
+# value = json.dumps(value)
+# config = db.session.query(AgentTemplateConfig).filter(
+# AgentTemplateConfig.agent_template_id == agent_template_id,
+# AgentTemplateConfig.key == key
+# ).first()
- if config is not None:
- config.value = value
- else:
- new_config = AgentTemplateConfig(
- agent_template_id=agent_template_id,
- key=key,
- value= value
- )
- db.session.add(new_config)
+# if config is not None:
+# config.value = value
+# else:
+# new_config = AgentTemplateConfig(
+# agent_template_id=agent_template_id,
+# key=key,
+# value= value
+# )
+# db.session.add(new_config)
- db.session.commit()
- db.session.flush()
+# db.session.commit()
+# db.session.flush()
@router.post("/save_agent_as_template/agent_id/{agent_id}/agent_execution_id/{agent_execution_id}")
@@ -411,4 +414,129 @@ def fetch_agent_config_from_template(agent_template_id: int,
agent_workflow = AgentWorkflow.find_by_id(db.session, agent_template.agent_workflow_id)
template_config_dict["agent_workflow"] = agent_workflow.name
- return template_config_dict
\ No newline at end of file
+ return template_config_dict
+
+
+@router.post("/publish_template/agent_execution_id/{agent_execution_id}", status_code=201)
+def publish_template(agent_execution_id: str, organisation=Depends(get_user_organisation), user=Depends(get_current_user)):
+
+ """
+ Publish an agent execution as a template.
+
+ Args:
+ agent_execution_id (str): The ID of the agent execution to save as a template.
+ organisation (Depends): Dependency to get the user organisation.
+ user (Depends): Dependency to get the user.
+
+ Returns:
+ dict: The saved agent template.
+
+ Raises:
+ HTTPException (status_code=404): If the agent or agent execution configurations are not found.
+ """
+
+ if agent_execution_id == 'undefined':
+ raise HTTPException(status_code = 404, detail = "Agent Execution Id undefined")
+
+ agent_executions = AgentExecution.get_agent_execution_from_id(db.session, agent_execution_id)
+ if agent_executions is None:
+ raise HTTPException(status_code = 404, detail = "Agent Execution not found")
+ agent_id = agent_executions.agent_id
+
+ agent = db.session.query(Agent).filter(Agent.id == agent_id).first()
+ if agent is None:
+ raise HTTPException(status_code=404, detail="Agent not found")
+
+ agent_execution_configurations = db.session.query(AgentExecutionConfiguration).filter(AgentExecutionConfiguration.agent_execution_id == agent_execution_id).all()
+ if not agent_execution_configurations:
+ raise HTTPException(status_code=404, detail="Agent execution configurations not found")
+
+ agent_template = AgentTemplate(name=agent.name, description=agent.description,
+ agent_workflow_id=agent.agent_workflow_id,
+ organisation_id=organisation.id)
+ db.session.add(agent_template)
+ db.session.commit()
+
+ main_keys = AgentTemplate.main_keys()
+ for agent_execution_configuration in agent_execution_configurations:
+ config_value = agent_execution_configuration.value
+ if agent_execution_configuration.key not in main_keys:
+ continue
+ if agent_execution_configuration.key == "tools":
+ config_value = str(Tool.convert_tool_ids_to_names(db, eval(agent_execution_configuration.value)))
+ agent_template_config = AgentTemplateConfig(agent_template_id=agent_template.id, key=agent_execution_configuration.key,
+ value=config_value)
+ db.session.add(agent_template_config)
+
+ agent_template_configs = [
+ AgentTemplateConfig(agent_template_id=agent_template.id, key="status", value="UNDER REVIEW"),
+ AgentTemplateConfig(agent_template_id=agent_template.id, key="Contributor Name", value=user.name),
+ AgentTemplateConfig(agent_template_id=agent_template.id, key="Contributor Email", value=user.email)]
+ db.session.add_all(agent_template_configs)
+
+ db.session.commit()
+ db.session.flush()
+ return agent_template.to_dict()
+
+@router.post("/publish_template", status_code=201)
+def handle_publish_template(updated_details: AgentPublish, organisation=Depends(get_user_organisation), user=Depends(get_current_user)):
+
+ """
+ Publish a template from edit template page.
+
+ Args:
+ organisation (Depends): Dependency to get the user organisation.
+ user (Depends): Dependency to get the user.
+
+ Returns:
+ dict: The saved agent template.
+
+ Raises:
+ HTTPException (status_code=404): If the agent template or workflow are not found.
+ """
+
+ old_template_id = updated_details.agent_template_id
+ old_agent_template = db.session.query(AgentTemplate).filter(AgentTemplate.id==old_template_id, AgentTemplate.organisation_id==organisation.id).first()
+ if old_agent_template is None:
+ raise HTTPException(status_code = 404, detail = "Agent Template not found")
+ agent_workflow_id = old_agent_template.agent_workflow_id
+ if agent_workflow_id is None:
+ raise HTTPException(status_code = 404, detail = "Agent Workflow not found")
+
+ agent_template = AgentTemplate(name=updated_details.name, description=updated_details.description,
+ agent_workflow_id=agent_workflow_id,
+ organisation_id=organisation.id)
+ db.session.add(agent_template)
+ db.session.commit()
+
+ agent_template_configs = {
+ "goal": updated_details.goal,
+ "instruction": updated_details.instruction,
+ "constraints": updated_details.constraints,
+ "toolkits": updated_details.toolkits,
+ "exit": updated_details.exit,
+ "tools": updated_details.tools,
+ "iteration_interval": updated_details.iteration_interval,
+ "model": updated_details.model,
+ "permission_type": updated_details.permission_type,
+ "LTM_DB": updated_details.LTM_DB,
+ "max_iterations": updated_details.max_iterations,
+ "user_timezone": updated_details.user_timezone,
+ "knowledge": updated_details.knowledge
+ }
+
+ for key, value in agent_template_configs.items():
+ if key == "tools":
+ value = Tool.convert_tool_ids_to_names(db, value)
+ agent_template_config = AgentTemplateConfig(agent_template_id=agent_template.id, key=key, value=str(value))
+ db.session.add(agent_template_config)
+
+ agent_template_configs = [
+ AgentTemplateConfig(agent_template_id=agent_template.id, key="status", value="UNDER REVIEW"),
+ AgentTemplateConfig(agent_template_id=agent_template.id, key="Contributor Name", value=user.name),
+ AgentTemplateConfig(agent_template_id=agent_template.id, key="Contributor Email", value=user.email)]
+ db.session.add_all(agent_template_configs)
+
+ db.session.commit()
+ db.session.flush()
+ return agent_template.to_dict()
\ No newline at end of file
diff --git a/superagi/controllers/analytics.py b/superagi/controllers/analytics.py
index 5dbfec7d9..c2c78f764 100644
--- a/superagi/controllers/analytics.py
+++ b/superagi/controllers/analytics.py
@@ -3,6 +3,7 @@
from superagi.apm.analytics_helper import AnalyticsHelper
from superagi.apm.event_handler import EventHandler
from superagi.apm.tools_handler import ToolsHandler
+from superagi.apm.knowledge_handler import KnowledgeHandler
from fastapi_jwt_auth import AuthJWT
from fastapi_sqlalchemy import db
import logging
@@ -59,3 +60,48 @@ def get_tools_used(organisation=Depends(get_user_organisation)):
except Exception as e:
logging.error(f"Error while calculating tool usage: {str(e)}")
raise HTTPException(status_code=500, detail="Internal Server Error")
+
+
+@router.get("/tools/{tool_name}/usage", status_code=200)
+def get_tool_usage(tool_name: str, organisation=Depends(get_user_organisation)):
+ try:
+ return ToolsHandler(session=db.session, organisation_id=organisation.id).get_tool_usage_by_name(tool_name)
+ except Exception as e:
+ if hasattr(e, 'status_code'):
+ raise HTTPException(status_code=e.status_code, detail=e.detail)
+ else:
+ raise HTTPException(status_code=500, detail="Internal Server Error")
+
+
+@router.get("/knowledge/{knowledge_name}/usage", status_code=200)
+def get_knowledge_usage(knowledge_name:str, organisation=Depends(get_user_organisation)):
+ try:
+ return KnowledgeHandler(session=db.session, organisation_id=organisation.id).get_knowledge_usage_by_name(knowledge_name)
+ except Exception as e:
+ if hasattr(e, 'status_code'):
+ raise HTTPException(status_code=e.status_code, detail=e.detail)
+ else:
+ raise HTTPException(status_code=500, detail="Internal Server Error")
+
+
+@router.get("/tools/{tool_name}/logs", status_code=200)
+def get_tool_logs(tool_name: str, organisation=Depends(get_user_organisation)):
+ try:
+ return ToolsHandler(session=db.session, organisation_id=organisation.id).get_tool_events_by_name(tool_name)
+ except Exception as e:
+ logging.error(f"Error while getting tool event details: {str(e)}")
+ if hasattr(e, 'status_code'):
+ raise HTTPException(status_code=e.status_code, detail=e.detail)
+ else:
+ raise HTTPException(status_code=500, detail="Internal Server Error")
+
+@router.get("/knowledge/{knowledge_name}/logs", status_code=200)
+def get_knowledge_logs(knowledge_name: str, organisation=Depends(get_user_organisation)):
+ try:
+ return KnowledgeHandler(session=db.session, organisation_id=organisation.id).get_knowledge_events_by_name(knowledge_name)
+ except Exception as e:
+ logging.error(f"Error while getting knowledge event details: {str(e)}")
+ if hasattr(e, 'status_code'):
+ raise HTTPException(status_code=e.status_code, detail=e.detail)
+ else:
+ raise HTTPException(status_code=500, detail="Internal Server Error")
\ No newline at end of file
diff --git a/superagi/controllers/api/agent.py b/superagi/controllers/api/agent.py
index 95b35c4f1..d057e7f0b 100644
--- a/superagi/controllers/api/agent.py
+++ b/superagi/controllers/api/agent.py
@@ -14,6 +14,7 @@
from superagi.models.workflows.agent_workflow import AgentWorkflow
from superagi.models.agent_execution import AgentExecution
from superagi.models.organisation import Organisation
+from superagi.models.knowledges import Knowledges
from superagi.models.resource import Resource
from superagi.controllers.types.agent_with_config import AgentConfigExtInput,AgentConfigUpdateExtInput
from superagi.models.workflows.iteration_workflow import IterationWorkflow
@@ -117,14 +118,14 @@ def create_run(agent_id:int,agent_execution: AgentExecutionIn,api_key: str = Sec
db_schedule=AgentSchedule.find_by_agent_id(db.session, agent_id)
if db_schedule is not None:
raise HTTPException(status_code=409, detail="Agent is already scheduled,cannot run")
- start_step_id = AgentWorkflow.fetch_trigger_step_id(db.session, agent.agent_workflow_id)
+ start_step = AgentWorkflow.fetch_trigger_step_id(db.session, agent.agent_workflow_id)
db_agent_execution=AgentExecution.get_execution_by_agent_id_and_status(db.session, agent_id, "CREATED")
if db_agent_execution is None:
db_agent_execution = AgentExecution(status="RUNNING", last_execution_time=datetime.now(),
agent_id=agent_id, name=agent_execution.name, num_of_calls=0,
num_of_tokens=0,
- current_step_id=start_step_id)
+ current_agent_step_id=start_step.id)
db.session.add(db_agent_execution)
else:
db_agent_execution.status = "RUNNING"
@@ -144,8 +145,23 @@ def create_run(agent_id:int,agent_execution: AgentExecutionIn,api_key: str = Sec
if agent_execution_configs != {}:
AgentExecutionConfiguration.add_or_update_agent_execution_config(session=db.session, execution=db_agent_execution,
agent_execution_configs=agent_execution_configs)
- EventHandler(session=db.session).create_event('run_created', {'agent_execution_id': db_agent_execution.id,'agent_execution_name':db_agent_execution.name},
- agent_id, organisation.id if organisation else 0)
+ EventHandler(session=db.session).create_event('run_created',
+ {'agent_execution_id': db_agent_execution.id,
+ 'agent_execution_name':db_agent_execution.name
+ },
+ agent_id,
+ organisation.id if organisation else 0)
+
+ agent_execution_knowledge = AgentConfiguration.get_agent_config_by_key_and_agent_id(session= db.session, key= 'knowledge', agent_id= agent_id)
+ if agent_execution_knowledge and agent_execution_knowledge.value != 'None':
+ knowledge_name = Knowledges.get_knowledge_from_id(db.session, int(agent_execution_knowledge.value)).name
+ if knowledge_name is not None:
+ EventHandler(session=db.session).create_event('knowledge_picked',
+ {'knowledge_name': knowledge_name,
+ 'agent_execution_id': db_agent_execution.id},
+ agent_id,
+ organisation.id if organisation else 0
+ )
if db_agent_execution.status == "RUNNING":
execute_agent.delay(db_agent_execution.id, datetime.now())
@@ -269,7 +285,8 @@ def pause_agent_runs(agent_id:int,execution_state_change_input:ExecutionStateCha
db_execution_arr=AgentExecution.get_all_executions_by_status_and_agent_id(db.session, agent.id, execution_state_change_input, "RUNNING")
- if len(db_execution_arr) != len(execution_state_change_input.run_ids):
+ if db_execution_arr is not None and execution_state_change_input.run_ids is not None \
+ and len(db_execution_arr) != len(execution_state_change_input.run_ids):
raise HTTPException(status_code=404, detail="One or more run id(s) not found")
for ind_execution in db_execution_arr:
@@ -298,7 +315,8 @@ def resume_agent_runs(agent_id:int,execution_state_change_input:ExecutionStateCh
db_execution_arr=AgentExecution.get_all_executions_by_status_and_agent_id(db.session, agent.id, execution_state_change_input, "PAUSED")
- if len(db_execution_arr) != len(execution_state_change_input.run_ids):
+ if db_execution_arr is not None and execution_state_change_input.run_ids is not None\
+ and len(db_execution_arr) != len(execution_state_change_input.run_ids):
raise HTTPException(status_code=404, detail="One or more run id(s) not found")
for ind_execution in db_execution_arr:
@@ -312,7 +330,7 @@ def resume_agent_runs(agent_id:int,execution_state_change_input:ExecutionStateCh
"result":"success"
}
-@router.post("/resources/output",status_code=201)
+@router.post("/resources/output",status_code=200)
def get_run_resources(run_id_config:RunIDConfig,api_key: str = Security(validate_api_key),organisation:Organisation = Depends(get_organisation_from_api_key)):
if get_config('STORAGE_TYPE') != "S3":
raise HTTPException(status_code=400,detail="This endpoint only works when S3 is configured")
diff --git a/superagi/controllers/api_key.py b/superagi/controllers/api_key.py
index 57e5c739b..458e68db5 100644
--- a/superagi/controllers/api_key.py
+++ b/superagi/controllers/api_key.py
@@ -5,51 +5,64 @@
from fastapi_jwt_auth import AuthJWT
from fastapi_sqlalchemy import db
from pydantic import BaseModel
-from superagi.helper.auth import get_user_organisation
+from superagi.helper.auth import get_user_organisation, validate_api_key
from superagi.helper.auth import check_auth
from superagi.models.api_key import ApiKey
from typing import Optional, Annotated
+
router = APIRouter()
+
class ApiKeyIn(BaseModel):
- id:int
+ id: int
name: str
+
class Config:
orm_mode = True
+
class ApiKeyDeleteIn(BaseModel):
- id:int
+ id: int
+
class Config:
orm_mode = True
+
@router.post("")
-def create_api_key(name: Annotated[str,Body(embed=True)], Authorize: AuthJWT = Depends(check_auth), organisation=Depends(get_user_organisation)):
- api_key=str(uuid.uuid4())
- obj=ApiKey(key=api_key,name=name,org_id=organisation.id)
+def create_api_key(name: Annotated[str, Body(embed=True)], Authorize: AuthJWT = Depends(check_auth),
+ organisation=Depends(get_user_organisation)):
+ api_key = str(uuid.uuid4())
+ obj = ApiKey(key=api_key, name=name, org_id=organisation.id)
db.session.add(obj)
db.session.commit()
db.session.flush()
return {"api_key": api_key}
+
+@router.get("/validate")
+def get_api_key(api_key: str = Depends(validate_api_key)):
+ return {"success": True}
+
+
@router.get("")
def get_all(Authorize: AuthJWT = Depends(check_auth), organisation=Depends(get_user_organisation)):
- api_keys=ApiKey.get_by_org_id(db.session, organisation.id)
+ api_keys = ApiKey.get_by_org_id(db.session, organisation.id)
return api_keys
+
@router.delete("/{api_key_id}")
-def delete_api_key(api_key_id:int, Authorize: AuthJWT = Depends(check_auth)):
- api_key=ApiKey.get_by_id(db.session, api_key_id)
+def delete_api_key(api_key_id: int, Authorize: AuthJWT = Depends(check_auth)):
+ api_key = ApiKey.get_by_id(db.session, api_key_id)
if api_key is None:
raise HTTPException(status_code=404, detail="API key not found")
ApiKey.delete_by_id(db.session, api_key_id)
return {"success": True}
+
@router.put("")
-def edit_api_key(api_key_in:ApiKeyIn,Authorize: AuthJWT = Depends(check_auth)):
- api_key=ApiKey.get_by_id(db.session, api_key_in.id)
+def edit_api_key(api_key_in: ApiKeyIn, Authorize: AuthJWT = Depends(check_auth)):
+ api_key = ApiKey.get_by_id(db.session, api_key_in.id)
if api_key is None:
raise HTTPException(status_code=404, detail="API key not found")
ApiKey.update_api_key(db.session, api_key_in.id, api_key_in.name)
return {"success": True}
-
-
diff --git a/superagi/controllers/models_controller.py b/superagi/controllers/models_controller.py
index e82e6ad60..abb771060 100644
--- a/superagi/controllers/models_controller.py
+++ b/superagi/controllers/models_controller.py
@@ -102,7 +102,7 @@ async def fetch_data(request: ModelName, organisation=Depends(get_user_organisat
@router.get("/get/list", status_code=200)
-def get_models_list(page: int = 0, organisation=Depends(get_user_organisation)):
+def get_knowledge_list(page: int = 0, organisation=Depends(get_user_organisation)):
"""
Get Marketplace Model list.
@@ -121,7 +121,7 @@ def get_models_list(page: int = 0, organisation=Depends(get_user_organisation)):
@router.get("/marketplace/list/{page}", status_code=200)
-def get_marketplace_models_list(page: int = 0):
+def get_marketplace_knowledge_list(page: int = 0):
organisation_id = get_config("MARKETPLACE_ORGANISATION_ID")
if organisation_id is not None:
organisation_id = int(organisation_id)
diff --git a/superagi/controllers/organisation.py b/superagi/controllers/organisation.py
index aa738d64b..e366c5966 100644
--- a/superagi/controllers/organisation.py
+++ b/superagi/controllers/organisation.py
@@ -11,6 +11,7 @@
from superagi.helper.encyption_helper import decrypt_data
from superagi.helper.tool_helper import register_toolkits
from superagi.llms.google_palm import GooglePalm
+from superagi.llms.llm_model_factory import build_model_with_api_key
from superagi.llms.openai import OpenAi
from superagi.models.configuration import Configuration
from superagi.models.organisation import Organisation
@@ -170,11 +171,8 @@ def get_llm_models(organisation=Depends(get_user_organisation)):
detail="Organisation not found")
decrypted_api_key = decrypt_data(model_api_key.value)
- models = []
- if model_source.value == "OpenAi":
- models = OpenAi(api_key=decrypted_api_key).get_models()
- elif model_source.value == "Google Palm":
- models = GooglePalm(api_key=decrypted_api_key).get_models()
+ model = build_model_with_api_key(model_source.value, decrypted_api_key)
+ models = model.get_models() if model is not None else []
return models
diff --git a/superagi/controllers/toolkit.py b/superagi/controllers/toolkit.py
index 1b92e824c..8fb175e01 100644
--- a/superagi/controllers/toolkit.py
+++ b/superagi/controllers/toolkit.py
@@ -157,8 +157,8 @@ def install_toolkit_from_marketplace(toolkit_name: str,
folder_name=tool['folder_name'], class_name=tool['class_name'], file_name=tool['file_name'],
toolkit_id=db_toolkit.id)
for config in toolkit['configs']:
- ToolConfig.add_or_update(session=db.session, toolkit_id=db_toolkit.id, key=config['key'], value=config['value'], key_type = config['key_type'], is_secret = config['is_secret'], is_required = config['is_required'])
-
+ ToolConfig.add_or_update(session=db.session, toolkit_id=db_toolkit.id, key=config['key'], value=config['value'], key_type = config['key_type'], is_secret = config['is_secret'], is_required = config['is_required'])
+
return {"message": "ToolKit installed successfully"}
diff --git a/superagi/controllers/types/agent_publish_config.py b/superagi/controllers/types/agent_publish_config.py
new file mode 100644
index 000000000..47f83f040
--- /dev/null
+++ b/superagi/controllers/types/agent_publish_config.py
@@ -0,0 +1,23 @@
+from typing import List, Optional
+from pydantic import BaseModel
+
+class AgentPublish(BaseModel):
+ name: str
+ description: str
+ agent_template_id: int
+ goal: Optional[List[str]]
+ instruction: Optional[List[str]]
+ constraints: List[str]
+ toolkits: List[int]
+ tools: List[int]
+ exit: str
+ iteration_interval: int
+ model: str
+ permission_type: str
+ LTM_DB: str
+ max_iterations: int
+ user_timezone: Optional[str]
+ knowledge: Optional[int]
+
+ class Config:
+ orm_mode = True
\ No newline at end of file
diff --git a/superagi/controllers/webhook.py b/superagi/controllers/webhook.py
index 0a55bd216..0f49bdb0f 100644
--- a/superagi/controllers/webhook.py
+++ b/superagi/controllers/webhook.py
@@ -1,6 +1,6 @@
from datetime import datetime
-
-from fastapi import APIRouter
+from typing import Optional
+from fastapi import APIRouter, HTTPException
from fastapi import Depends
from fastapi_jwt_auth import AuthJWT
from fastapi_sqlalchemy import db
@@ -17,6 +17,7 @@ class WebHookIn(BaseModel):
name: str
url: str
headers: dict
+ filters: dict
class Config:
orm_mode = True
@@ -31,12 +32,21 @@ class WebHookOut(BaseModel):
is_deleted: bool
created_at: datetime
updated_at: datetime
+ filters: dict
+
+ class Config:
+ orm_mode = True
+
+class WebHookEdit(BaseModel):
+ url: str
+ filters: dict
class Config:
orm_mode = True
-# CRUD Operations
+
+# CRUD Operations`
@router.post("/add", response_model=WebHookOut, status_code=201)
def create_webhook(webhook: WebHookIn, Authorize: AuthJWT = Depends(check_auth),
organisation=Depends(get_user_organisation)):
@@ -52,9 +62,54 @@ def create_webhook(webhook: WebHookIn, Authorize: AuthJWT = Depends(check_auth),
HTTPException (Status Code=404): If the associated project is not found.
"""
db_webhook = Webhooks(name=webhook.name, url=webhook.url, headers=webhook.headers, org_id=organisation.id,
- is_deleted=False)
+ is_deleted=False, filters=webhook.filters)
db.session.add(db_webhook)
db.session.commit()
db.session.flush()
-
return db_webhook
+
+@router.get("/get", response_model=Optional[WebHookOut])
+def get_all_webhooks(
+ Authorize: AuthJWT = Depends(check_auth),
+ organisation=Depends(get_user_organisation),
+):
+ """
+ Retrieves a single webhook for the authenticated user's organisation.
+
+ Returns:
+ JSONResponse: A JSON response containing the retrieved webhook.
+
+ Raises:
+ """
+ webhook = db.session.query(Webhooks).filter(Webhooks.org_id == organisation.id, Webhooks.is_deleted == False).first()
+ return webhook
+
+@router.post("/edit/{webhook_id}", response_model=WebHookOut)
+def edit_webhook(
+ updated_webhook: WebHookEdit,
+ webhook_id: int,
+ Authorize: AuthJWT = Depends(check_auth),
+ organisation=Depends(get_user_organisation),
+):
+ """
+ Soft-deletes a webhook by setting the value of is_deleted to True.
+
+ Args:
+ webhook_id (int): The ID of the webhook to delete.
+
+ Returns:
+ WebHookOut: The deleted webhook.
+
+ Raises:
+ HTTPException (Status Code=404): If the webhook is not found.
+ """
+ webhook = db.session.query(Webhooks).filter(Webhooks.org_id == organisation.id, Webhooks.id == webhook_id, Webhooks.is_deleted == False).first()
+ if webhook is None:
+ raise HTTPException(status_code=404, detail="Webhook not found")
+
+ webhook.url = updated_webhook.url
+ webhook.filters = updated_webhook.filters
+
+ db.session.commit()
+
+ return webhook
\ No newline at end of file
diff --git a/superagi/helper/github_helper.py b/superagi/helper/github_helper.py
index 381e7392a..bb34eaf56 100644
--- a/superagi/helper/github_helper.py
+++ b/superagi/helper/github_helper.py
@@ -9,7 +9,8 @@
from superagi.types.storage_types import StorageType
from superagi.config.config import get_config
from superagi.helper.s3_helper import S3Helper
-
+from datetime import timedelta, datetime
+import json
class GithubHelper:
def __init__(self, github_access_token, github_username):
@@ -238,6 +239,7 @@ def add_file(self, repository_owner, repository_name, file_name, folder_path, he
logger.info('Failed to upload file content:', file_response.json()['message'])
return file_response.status_code
+
def create_pull_request(self, repository_owner, repository_name, head_branch, base_branch, headers):
"""
Creates a pull request in the given repository.
@@ -335,3 +337,135 @@ def _get_file_contents(self, file_name, agent_id, agent_execution_id, session):
with open(final_path, "r") as file:
attachment_data = file.read().decode('utf-8')
return attachment_data
+
+
+ def get_pull_request_content(self, repository_owner, repository_name, pull_request_number):
+ """
+ Gets the content of a specific pull request from a GitHub repository.
+
+ Args:
+ repository_owner (str): Owner of the repository.
+ repository_name (str): Name of the repository.
+ pull_request_number (int): pull request id.
+ headers (dict): Dictionary containing the headers, usually including the Authorization token.
+
+ Returns:
+ dict: Dictionary containing the pull request content or None if not found.
+ """
+ pull_request_url = f'https://api.github.com/repos/{repository_owner}/{repository_name}/pulls/{pull_request_number}'
+ headers = {
+ "Authorization": f"token {self.github_access_token}" if self.github_access_token else None,
+ "Content-Type": "application/vnd.github+json",
+ "Accept": "application/vnd.github.v3.diff",
+ }
+
+ response = requests.get(pull_request_url, headers=headers)
+
+ if response.status_code == 200:
+ logger.info('Successfully fetched pull request content.')
+ return response.text
+ elif response.status_code == 404:
+ logger.warning('Pull request not found.')
+ else:
+ logger.warning('Failed to fetch pull request content: ', response.text)
+
+ return None
+
+ def get_latest_commit_id_of_pull_request(self, repository_owner, repository_name, pull_request_number):
+ """
+ Gets the latest commit id of a specific pull request from a GitHub repository.
+ :param repository_owner: owner
+ :param repository_name: repository name
+ :param pull_request_number: pull request id
+
+ :return:
+ latest commit id of the pull request
+ """
+ url = f'https://api.github.com/repos/{repository_owner}/{repository_name}/pulls/{pull_request_number}/commits'
+ headers = {
+ "Authorization": f"token {self.github_access_token}" if self.github_access_token else None,
+ "Content-Type": "application/json",
+ }
+ response = requests.get(url, headers=headers)
+ if response.status_code == 200:
+ commits = response.json()
+ latest_commit = commits[-1] # Assuming the last commit is the latest
+ return latest_commit.get('sha')
+ else:
+ logger.warning(f'Failed to fetch commits for pull request: {response.json()["message"]}')
+ return None
+
+
+ def add_line_comment_to_pull_request(self, repository_owner, repository_name, pull_request_number,
+ commit_id, file_path, position, comment_body):
+ """
+ Adds a line comment to a specific pull request from a GitHub repository.
+
+ :param repository_owner: owner
+ :param repository_name: repository name
+ :param pull_request_number: pull request id
+ :param commit_id: commit id
+ :param file_path: file path
+ :param position: position
+ :param comment_body: comment body
+
+ :return:
+ dict: Dictionary containing the comment content or None if not found.
+ """
+ comments_url = f'https://api.github.com/repos/{repository_owner}/{repository_name}/pulls/{pull_request_number}/comments'
+ headers = {
+ "Authorization": f"token {self.github_access_token}",
+ "Content-Type": "application/json",
+ "Accept": "application/vnd.github.v3+json"
+ }
+ data = {
+ "commit_id": commit_id,
+ "path": file_path,
+ "position": position,
+ "body": comment_body
+ }
+ response = requests.post(comments_url, headers=headers, json=data)
+ if response.status_code == 201:
+ logger.info('Successfully added line comment to pull request.')
+ return response.json()
+ else:
+ logger.warning(f'Failed to add line comment: {response.json()["message"]}')
+ return None
+
+ def get_pull_requests_created_in_last_x_seconds(self, repository_owner, repository_name, x_seconds):
+ """
+ Gets the pull requests created in the last x seconds.
+
+ Args:
+ repository_owner (str): Owner of the repository
+ repository_name (str): Repository name
+ x_seconds (int): The number of seconds in the past to look for PRs
+
+ Returns:
+ list: List of pull request objects that were created in the last x seconds
+ """
+ # Calculate the time x seconds ago
+ time_x_seconds_ago = datetime.utcnow() - timedelta(seconds=x_seconds)
+
+ # Convert to the ISO8601 format GitHub expects, remove milliseconds
+ time_x_seconds_ago_str = time_x_seconds_ago.strftime('%Y-%m-%dT%H:%M:%SZ')
+
+ # Search query
+ query = f'repo:{repository_owner}/{repository_name} type:pr created:>{time_x_seconds_ago_str}'
+
+ url = f'https://api.github.com/search/issues?q={query}'
+ headers = {
+ "Authorization": f"token {self.github_access_token}",
+ "Content-Type": "application/json",
+ }
+
+ response = requests.get(url, headers=headers)
+
+ if response.status_code == 200:
+ pull_request_urls = []
+ for pull_request in response.json()['items']:
+ pull_request_urls.append(pull_request['html_url'])
+ return pull_request_urls
+ else:
+ logger.warning(f'Failed to fetch PRs: {response.json()["message"]}')
+ return []
diff --git a/superagi/helper/webhook_manager.py b/superagi/helper/webhook_manager.py
index cf5e988d2..aa7a1fb19 100644
--- a/superagi/helper/webhook_manager.py
+++ b/superagi/helper/webhook_manager.py
@@ -5,6 +5,7 @@
import requests
import json
from superagi.lib.logger import logger
+
class WebHookManager:
def __init__(self,session):
self.session=session
@@ -18,20 +19,21 @@ def agent_status_change_callback(self, agent_execution_id, curr_status, old_stat
org_webhooks=self.session.query(Webhooks).filter(Webhooks.org_id == org.id).all()
for webhook_obj in org_webhooks:
- webhook_obj_body={"agent_id":agent_id,"org_id":org.id,"event":f"{old_status} to {curr_status}"}
- error=None
- request=None
- status='sent'
- try:
- request = requests.post(webhook_obj.url.strip(), data=json.dumps(webhook_obj_body), headers=webhook_obj.headers)
- except Exception as e:
- logger.error(f"Exception occured in webhooks {e}")
- error=str(e)
- if request is not None and request.status_code not in [200,201] and error is None:
- error=request.text
- if error is not None:
- status='Error'
- webhook_event=WebhookEvents(agent_id=agent_id, run_id=agent_execution_id, event=f"{old_status} to {curr_status}", status=status, errors=error)
- self.session.add(webhook_event)
- self.session.commit()
+ if "status" in webhook_obj.filters and curr_status in webhook_obj.filters["status"]:
+ webhook_obj_body={"agent_id":agent_id,"org_id":org.id,"event":f"{old_status} to {curr_status}"}
+ error=None
+ request=None
+ status='sent'
+ try:
+ request = requests.post(webhook_obj.url.strip(), data=json.dumps(webhook_obj_body), headers=webhook_obj.headers)
+ except Exception as e:
+ logger.error(f"Exception occured in webhooks {e}")
+ error=str(e)
+ if request is not None and request.status_code not in [200,201] and error is None:
+ error=request.text
+ if error is not None:
+ status='Error'
+ webhook_event=WebhookEvents(agent_id=agent_id, run_id=agent_execution_id, event=f"{old_status} to {curr_status}", status=status, errors=error)
+ self.session.add(webhook_event)
+ self.session.commit()
diff --git a/superagi/jobs/scheduling_executor.py b/superagi/jobs/scheduling_executor.py
index b7e62db4e..f0d5a542a 100644
--- a/superagi/jobs/scheduling_executor.py
+++ b/superagi/jobs/scheduling_executor.py
@@ -13,7 +13,7 @@
from superagi.models.agent_execution import AgentExecution
from superagi.models.agent_execution_config import AgentExecutionConfiguration
from superagi.apm.event_handler import EventHandler
-
+from superagi.models.knowledges import Knowledges
from superagi.models.db import connect_db
@@ -42,7 +42,7 @@ def execute_scheduled_agent(self, agent_id: int, name: str):
iteration_step_id = IterationWorkflow.fetch_trigger_step_id(session,
start_step.action_reference_id).id if start_step.action_type == "ITERATION_WORKFLOW" else -1
- db_agent_execution = AgentExecution(status="RUNNING", last_execution_time=datetime.now(),
+ db_agent_execution = AgentExecution(status="CREATED", last_execution_time=datetime.now(),
agent_id=agent_id, name=name, num_of_calls=0,
num_of_tokens=0,
current_agent_step_id=start_step.id,
@@ -51,17 +51,32 @@ def execute_scheduled_agent(self, agent_id: int, name: str):
session.add(db_agent_execution)
session.commit()
+ #update status from CREATED to RUNNING
+ db_agent_execution.status = "RUNNING"
+ session.commit()
+
agent_execution_id = db_agent_execution.id
agent_configurations = session.query(AgentConfiguration).filter(AgentConfiguration.agent_id == agent_id).all()
for agent_config in agent_configurations:
agent_execution_config = AgentExecutionConfiguration(agent_execution_id=agent_execution_id, key=agent_config.key, value=agent_config.value)
session.add(agent_execution_config)
-
-
organisation = agent.get_agent_organisation(session)
model = session.query(AgentConfiguration.value).filter(AgentConfiguration.agent_id == agent_id).filter(AgentConfiguration.key == 'model').first()[0]
- EventHandler(session=session).create_event('run_created', {'agent_execution_id': db_agent_execution.id,'agent_execution_name':db_agent_execution.name}, agent_id, organisation.id if organisation else 0),
+ EventHandler(session=session).create_event('run_created',
+ {'agent_execution_id': db_agent_execution.id,
+ 'agent_execution_name':db_agent_execution.name},
+ agent_id,
+ organisation.id if organisation else 0)
+ agent_execution_knowledge = AgentConfiguration.get_agent_config_by_key_and_agent_id(session= session, key= 'knowledge', agent_id= agent_id)
+ if agent_execution_knowledge and agent_execution_knowledge.value != 'None':
+ knowledge_name = Knowledges.get_knowledge_from_id(session, int(agent_execution_knowledge.value)).name
+ if knowledge_name is not None:
+ EventHandler(session=session).create_event('knowledge_picked',
+ {'knowledge_name': knowledge_name,
+ 'agent_execution_id': db_agent_execution.id},
+ agent_id,
+ organisation.id if organisation else 0)
session.commit()
if db_agent_execution.status == "RUNNING":
diff --git a/superagi/llms/llm_model_factory.py b/superagi/llms/llm_model_factory.py
index 9ba8c8892..251c64a71 100644
--- a/superagi/llms/llm_model_factory.py
+++ b/superagi/llms/llm_model_factory.py
@@ -33,5 +33,17 @@ def get_model(organisation_id, api_key, model="gpt-3.5-turbo", **kwargs):
elif provider_name == 'Hugging Face':
print("Provider is Hugging Face")
return HuggingFace(model=model_instance.model_name, end_point=model_instance.end_point, api_key=api_key, **kwargs)
+ else:
+ print('Unknown provider.')
+
+def build_model_with_api_key(provider_name, api_key):
+ if provider_name.lower() == 'openai':
+ return OpenAi(api_key=api_key)
+ elif provider_name.lower() == 'replicate':
+ return Replicate(api_key=api_key)
+ elif provider_name.lower() == 'google palm':
+ return GooglePalm(api_key=api_key)
+ elif provider_name.lower() == 'hugging face':
+ return HuggingFace(api_key=api_key)
else:
print('Unknown provider.')
\ No newline at end of file
diff --git a/superagi/models/agent_config.py b/superagi/models/agent_config.py
index 67b377da3..3b44eee9e 100644
--- a/superagi/models/agent_config.py
+++ b/superagi/models/agent_config.py
@@ -79,6 +79,7 @@ def update_agent_configurations_table(cls, session, agent_id: Union[int, None],
# Fetch agent configurations
agent_configs = session.query(AgentConfiguration).filter(AgentConfiguration.agent_id == agent_id).all()
+
for agent_config in agent_configs:
if agent_config.key in updated_details_dict:
agent_config.value = str(updated_details_dict[agent_config.key])
@@ -115,3 +116,12 @@ def get_model_api_key(cls, session, agent_id: int, model: str):
# if selected_model_source == ModelSourceType.Replicate:
# return get_config("REPLICATE_API_TOKEN")
# return get_config("OPENAI_API_KEY")
+
+ @classmethod
+ def get_agent_config_by_key_and_agent_id(cls, session, key: str, agent_id: int):
+ agent_config = session.query(AgentConfiguration).filter(
+ AgentConfiguration.agent_id == agent_id,
+ AgentConfiguration.key == key
+ ).first()
+
+ return agent_config
\ No newline at end of file
diff --git a/superagi/models/api_key.py b/superagi/models/api_key.py
index 1cc3e310a..8c61ac885 100644
--- a/superagi/models/api_key.py
+++ b/superagi/models/api_key.py
@@ -1,8 +1,6 @@
-from sqlalchemy import Column, Integer, Text, String, Boolean, ForeignKey
-from sqlalchemy.orm import relationship
+from sqlalchemy import Column, Integer, String, Boolean
from superagi.models.base_model import DBBaseModel
-from superagi.models.agent_execution import AgentExecution
-from sqlalchemy import func, or_
+from sqlalchemy import or_
class ApiKey(DBBaseModel):
"""
diff --git a/superagi/models/db.py b/superagi/models/db.py
index e711280ff..c6844cfd4 100644
--- a/superagi/models/db.py
+++ b/superagi/models/db.py
@@ -1,12 +1,8 @@
from sqlalchemy import create_engine
from superagi.config.config import get_config
+from urllib.parse import urlparse
from superagi.lib.logger import logger
-database_url = get_config('POSTGRES_URL')
-db_username = get_config('DB_USERNAME')
-db_password = get_config('DB_PASSWORD')
-db_name = get_config('DB_NAME')
-
engine = None
@@ -23,11 +19,20 @@ def connect_db():
return engine
# Create the connection URL
- if db_username is None:
- db_url = f'postgresql://{database_url}/{db_name}'
+ db_host = get_config('DB_HOST', 'super__postgres')
+ db_username = get_config('DB_USERNAME')
+ db_password = get_config('DB_PASSWORD')
+ db_name = get_config('DB_NAME')
+ db_url = get_config('DB_URL', None)
+
+ if db_url is None:
+ if db_username is None:
+ db_url = f'postgresql://{db_host}/{db_name}'
+ else:
+ db_url = f'postgresql://{db_username}:{db_password}@{db_host}/{db_name}'
else:
- db_url = f'postgresql://{db_username}:{db_password}@{database_url}/{db_name}'
-
+ db_url = urlparse(db_url)
+ db_url = db_url.scheme + "://" + db_url.netloc + db_url.path
# Create the SQLAlchemy engine
engine = create_engine(db_url,
pool_size=20, # Maximum number of database connections in the pool
diff --git a/superagi/models/webhooks.py b/superagi/models/webhooks.py
index 14d683472..9b47c6105 100644
--- a/superagi/models/webhooks.py
+++ b/superagi/models/webhooks.py
@@ -20,3 +20,4 @@ class Webhooks(DBBaseModel):
url = Column(String)
headers=Column(JSON)
is_deleted=Column(Boolean)
+ filters=Column(JSON)
diff --git a/superagi/resource_manager/llama_document_summary.py b/superagi/resource_manager/llama_document_summary.py
index 5eca38913..767d2be27 100644
--- a/superagi/resource_manager/llama_document_summary.py
+++ b/superagi/resource_manager/llama_document_summary.py
@@ -1,12 +1,9 @@
import os
-from langchain.chat_models import ChatGooglePalm
from llama_index.indices.response import ResponseMode
from llama_index.schema import Document
from superagi.config.config import get_config
-from superagi.lib.logger import logger
-from superagi.types.model_source_types import ModelSourceType
class LlamaDocumentSummary:
diff --git a/superagi/tools/github/fetch_pull_request.py b/superagi/tools/github/fetch_pull_request.py
new file mode 100644
index 000000000..09bf17f3d
--- /dev/null
+++ b/superagi/tools/github/fetch_pull_request.py
@@ -0,0 +1,66 @@
+from typing import Type, Optional
+
+from pydantic import BaseModel, Field
+
+from superagi.helper.github_helper import GithubHelper
+from superagi.llms.base_llm import BaseLlm
+from superagi.tools.base_tool import BaseTool
+
+
+class GithubFetchPullRequestSchema(BaseModel):
+ repository_name: str = Field(
+ ...,
+ description="Repository name in which file hase to be added",
+ )
+ repository_owner: str = Field(
+ ...,
+ description="Owner of the github repository",
+ )
+ time_in_seconds: int = Field(
+ ...,
+ description="Gets pull requests from last `time_in_seconds` seconds",
+ )
+
+
+class GithubFetchPullRequest(BaseTool):
+ """
+ Fetch pull request tool
+
+ Attributes:
+ name : The name.
+ description : The description.
+ args_schema : The args schema.
+ agent_id: The agent id.
+ agent_execution_id: The agent execution id.
+ """
+ llm: Optional[BaseLlm] = None
+ name: str = "Github Fetch Pull Requests"
+ args_schema: Type[BaseModel] = GithubFetchPullRequestSchema
+ description: str = "Fetch pull requests from github"
+ agent_id: int = None
+ agent_execution_id: int = None
+
+ def _execute(self, repository_name: str, repository_owner: str, time_in_seconds: int = 86400) -> str:
+ """
+ Execute the add file tool.
+
+ Args:
+ repository_name: The name of the repository to add file to.
+ repository_owner: Owner of the GitHub repository.
+ time_in_seconds: Gets pull requests from last `time_in_seconds` seconds
+
+ Returns:
+ List of all pull request ids
+ """
+ try:
+ github_access_token = self.get_tool_config("GITHUB_ACCESS_TOKEN")
+ github_username = self.get_tool_config("GITHUB_USERNAME")
+ github_helper = GithubHelper(github_access_token, github_username)
+
+ pull_request_urls = github_helper.get_pull_requests_created_in_last_x_seconds(repository_owner,
+ repository_name,
+ time_in_seconds)
+
+ return "Pull requests: " + str(pull_request_urls)
+ except Exception as err:
+ return f"Error: Unable to fetch pull requests {err}"
diff --git a/superagi/tools/github/github_toolkit.py b/superagi/tools/github/github_toolkit.py
index bb9ec2657..5321f979b 100644
--- a/superagi/tools/github/github_toolkit.py
+++ b/superagi/tools/github/github_toolkit.py
@@ -3,7 +3,9 @@
from superagi.tools.base_tool import BaseTool, BaseToolkit, ToolConfiguration
from superagi.tools.github.add_file import GithubAddFileTool
from superagi.tools.github.delete_file import GithubDeleteFileTool
+from superagi.tools.github.fetch_pull_request import GithubFetchPullRequest
from superagi.tools.github.search_repo import GithubRepoSearchTool
+from superagi.tools.github.review_pull_request import GithubReviewPullRequest
from superagi.types.key_type import ToolConfigKeyType
@@ -12,7 +14,8 @@ class GitHubToolkit(BaseToolkit, ABC):
description: str = "GitHub Tool Kit contains all github related to tool"
def get_tools(self) -> List[BaseTool]:
- return [GithubAddFileTool(), GithubDeleteFileTool(), GithubRepoSearchTool()]
+ return [GithubAddFileTool(), GithubDeleteFileTool(), GithubRepoSearchTool(), GithubReviewPullRequest(),
+ GithubFetchPullRequest()]
def get_env_keys(self) -> List[ToolConfiguration]:
return [
diff --git a/superagi/tools/github/prompts/code_review.txt b/superagi/tools/github/prompts/code_review.txt
new file mode 100644
index 000000000..1aff59ea6
--- /dev/null
+++ b/superagi/tools/github/prompts/code_review.txt
@@ -0,0 +1,50 @@
+Your purpose is to act as a highly experienced software engineer and provide a thorough review of the code chunks and suggest code snippets to improve key areas such as:
+- Logic
+- Modularity
+- Maintainability
+- Complexity
+
+Do not comment on minor code style issues, missing comments/documentation. Identify and resolve significant concerns to improve overall code quality while deliberately disregarding minor issues
+
+Following is the github pull request diff content:
+```
+{{DIFF_CONTENT}}
+```
+
+Instructions:
+1. Do not comment on existing lines and deleted lines.
+2. Ignore the lines start with '-'.
+3. Only consider lines starting with '+' for review.
+4. Do not comment on frontend and graphql code.
+
+Respond with only valid JSON conforming to the following schema:
+{
+ "$schema": "http://json-schema.org/draft-07/schema#",
+ "type": "object",
+ "properties": {
+ "comments": {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {
+ "file_path": {
+ "type": "string",
+ "description": "The path to the file where the comment should be added."
+ },
+ "line": {
+ "type": "integer",
+ "description": "The line number where the comment should be added. "
+ },
+ "comment": {
+ "type": "string",
+ "description": "The content of the comment."
+ }
+ },
+ "required": ["file_name", "line", "comment"]
+ }
+ }
+ },
+ "required": ["comments"]
+}
+
+Ensure response is valid JSON conforming to the following schema.
\ No newline at end of file
diff --git a/superagi/tools/github/review_pull_request.py b/superagi/tools/github/review_pull_request.py
new file mode 100644
index 000000000..a230c0701
--- /dev/null
+++ b/superagi/tools/github/review_pull_request.py
@@ -0,0 +1,143 @@
+import ast
+from typing import Type, Optional
+
+from pydantic import BaseModel, Field
+
+from superagi.helper.github_helper import GithubHelper
+from superagi.helper.json_cleaner import JsonCleaner
+from superagi.helper.prompt_reader import PromptReader
+from superagi.helper.token_counter import TokenCounter
+from superagi.llms.base_llm import BaseLlm
+from superagi.models.agent import Agent
+from superagi.tools.base_tool import BaseTool
+
+
+class GithubReviewPullRequestSchema(BaseModel):
+ repository_name: str = Field(
+ ...,
+ description="Repository name in which file hase to be added",
+ )
+ repository_owner: str = Field(
+ ...,
+ description="Owner of the github repository",
+ )
+ pull_request_number: int = Field(
+ ...,
+ description="Pull request number",
+ )
+
+
+class GithubReviewPullRequest(BaseTool):
+ """
+ Reviews the github pull request and adds comments inline
+
+ Attributes:
+ name : The name.
+ description : The description.
+ args_schema : The args schema.
+ """
+ llm: Optional[BaseLlm] = None
+ name: str = "Github Review Pull Request"
+ args_schema: Type[BaseModel] = GithubReviewPullRequestSchema
+ description: str = "Add pull request for the github repository"
+ agent_id: int = None
+ agent_execution_id: int = None
+
+ def _execute(self, repository_name: str, repository_owner: str, pull_request_number: int) -> str:
+ """
+ Execute the add file tool.
+
+ Args:
+ repository_name: The name of the repository to add file to.
+ repository_owner: Owner of the GitHub repository.
+ pull_request_number: pull request number
+
+ Returns:
+ Pull request success message if pull request is created successfully else error message.
+ """
+ try:
+ github_access_token = self.get_tool_config("GITHUB_ACCESS_TOKEN")
+ github_username = self.get_tool_config("GITHUB_USERNAME")
+ github_helper = GithubHelper(github_access_token, github_username)
+
+ pull_request_content = github_helper.get_pull_request_content(repository_owner, repository_name,
+ pull_request_number)
+ latest_commit_id = github_helper.get_latest_commit_id_of_pull_request(repository_owner, repository_name,
+ pull_request_number)
+
+ pull_request_arr = pull_request_content.split("diff --git")
+ organisation = Agent.find_org_by_agent_id(session=self.toolkit_config.session, agent_id=self.agent_id)
+
+ model_token_limit = TokenCounter(session=self.toolkit_config.session,
+ organisation_id=organisation.id).token_limit(self.llm.get_model())
+ pull_request_arr_parts = self.split_pull_request_content_into_multiple_parts(model_token_limit, pull_request_arr)
+ for content in pull_request_arr_parts:
+ self.run_code_review(github_helper, content, latest_commit_id, organisation, pull_request_number,
+ repository_name, repository_owner)
+ return "Added comments to the pull request:" + str(pull_request_number)
+ except Exception as err:
+ return f"Error: Unable to add comments to the pull request {err}"
+
+ def run_code_review(self, github_helper, content, latest_commit_id, organisation, pull_request_number,
+ repository_name, repository_owner):
+ prompt = PromptReader.read_tools_prompt(__file__, "code_review.txt")
+ prompt = prompt.replace("{{DIFF_CONTENT}}", content)
+ messages = [{"role": "system", "content": prompt}]
+ total_tokens = TokenCounter.count_message_tokens(messages, self.llm.get_model())
+ token_limit = TokenCounter(session=self.toolkit_config.session,
+ organisation_id=organisation.id).token_limit(self.llm.get_model())
+ result = self.llm.chat_completion(messages, max_tokens=(token_limit - total_tokens - 100))
+ response = result["content"]
+ if response.startswith("```") and response.endswith("```"):
+ response = "```".join(response.split("```")[1:-1])
+ response = JsonCleaner.extract_json_section(response)
+ comments = ast.literal_eval(response)
+
+ # Add comments in the pull request
+ for comment in comments['comments']:
+ line_number = self.get_exact_line_number(content, comment["file_path"], comment["line"])
+ github_helper.add_line_comment_to_pull_request(repository_owner, repository_name, pull_request_number,
+ latest_commit_id, comment["file_path"], line_number,
+ comment["comment"])
+
+ def split_pull_request_content_into_multiple_parts(self, model_token_limit: int, pull_request_arr):
+ pull_request_arr_parts = []
+ current_part = ""
+ for part in pull_request_arr:
+ total_tokens = TokenCounter.count_message_tokens([{"role": "user", "content": current_part}],
+ self.llm.get_model())
+ # we are using 60% of the model token limit
+ if total_tokens >= model_token_limit * 0.6:
+ # Add the current part to pull_request_arr_parts and reset current_sum and current_part
+ pull_request_arr_parts.append(current_part)
+ current_part = "diff --git" + part
+ else:
+ current_part += "diff --git" + part
+
+ pull_request_arr_parts.append(current_part)
+ return pull_request_arr_parts
+
+ def get_exact_line_number(self, diff_content, file_path, line_number):
+ last_content = diff_content[diff_content.index(file_path):]
+ last_content = last_content[last_content.index('@@'):]
+ return self.find_position_in_diff(last_content, line_number)
+
+ def find_position_in_diff(self, diff_content, target_line):
+ # Split the diff by lines and initialize variables
+ diff_lines = diff_content.split('\n')
+ position = 0
+ current_file_line_number = 0
+
+ # Loop through each line in the diff
+ for line in diff_lines:
+ position += 1 # Increment position for each line
+ if line.startswith('@@'):
+ # Reset the current file line number when encountering a new hunk
+ current_file_line_number = int(line.split('+')[1].split(',')[0]) - 1
+ elif not line.startswith('-'):
+ # Increment the current file line number for lines that are not deletions
+ current_file_line_number += 1
+ if current_file_line_number >= target_line:
+ # Return the position when the target line number is reached
+ return position
+ return position
diff --git a/superagi/tools/image_generation/image_generation_toolkit.py b/superagi/tools/image_generation/image_generation_toolkit.py
index bebb94411..8b2361e06 100644
--- a/superagi/tools/image_generation/image_generation_toolkit.py
+++ b/superagi/tools/image_generation/image_generation_toolkit.py
@@ -15,7 +15,7 @@ def get_tools(self) -> List[BaseTool]:
def get_env_keys(self) -> List[ToolConfiguration]:
return [
- ToolConfiguration(key="STABILITY_API_KEY", key_type=ToolConfigKeyType.STRING, is_required= False, is_secret = True),
+ ToolConfiguration(key="STABILITY_API_KEY", key_type=ToolConfigKeyType.STRING, is_required=False, is_secret = True),
ToolConfiguration(key="ENGINE_ID", key_type=ToolConfigKeyType.STRING, is_required=False, is_secret=False),
- ToolConfiguration(key= "OPENAI_API_KEY", key_type=ToolConfigKeyType.STRING, is_required=False, is_secret=True),
+ ToolConfiguration(key="OPENAI_API_KEY", key_type=ToolConfigKeyType.STRING, is_required=False, is_secret=True)
]
diff --git a/superagi/tools/knowledge_search/knowledge_search.py b/superagi/tools/knowledge_search/knowledge_search.py
index b649000bb..272d8865b 100644
--- a/superagi/tools/knowledge_search/knowledge_search.py
+++ b/superagi/tools/knowledge_search/knowledge_search.py
@@ -45,7 +45,7 @@ def _execute(self, query: str):
vector_db = Vectordbs.get_vector_db_from_id(session, vector_db_index.vector_db_id)
db_creds = VectordbConfigs.get_vector_db_config_from_db_id(session, vector_db.id)
model_api_key = self.get_tool_config('OPENAI_API_KEY')
- model_source = "OpenAI"
+ model_source = 'OpenAI'
embedding_model = AgentExecutor.get_embedding(model_source, model_api_key)
try:
if vector_db_index.state == "Custom":
diff --git a/superagi/tools/knowledge_search/knowledge_search_toolkit.py b/superagi/tools/knowledge_search/knowledge_search_toolkit.py
index ae3a937f3..6e0cdf0c5 100644
--- a/superagi/tools/knowledge_search/knowledge_search_toolkit.py
+++ b/superagi/tools/knowledge_search/knowledge_search_toolkit.py
@@ -13,5 +13,5 @@ def get_tools(self) -> List[BaseTool]:
def get_env_keys(self) -> List[ToolConfiguration]:
return [
- ToolConfiguration(key= "OPENAI_API_KEY", key_type=ToolConfigKeyType.STRING, is_required=False, is_secret=True)
+ ToolConfiguration(key="OPENAI_API_KEY", key_type=ToolConfigKeyType.STRING, is_required=False, is_secret=True)
]
\ No newline at end of file
diff --git a/superagi/tools/tool_response_query_manager.py b/superagi/tools/tool_response_query_manager.py
index c4da9b9a6..0bf708b39 100644
--- a/superagi/tools/tool_response_query_manager.py
+++ b/superagi/tools/tool_response_query_manager.py
@@ -11,11 +11,12 @@ def __init__(self, session: Session, agent_execution_id: int,memory:VectorStore)
def get_last_response(self, tool_name: str = None):
return AgentExecutionFeed.get_last_tool_response(self.session, self.agent_execution_id, tool_name)
-
+
def get_relevant_response(self, query: str,metadata:dict, top_k: int = 5):
if self.memory is None:
return ""
- documents = self.memory.get_matching_text(query, metadata=metadata)
+ documents = self.memory.get_matching_text(query, metadata=metadata,
+ top_k=top_k)
relevant_responses = ""
for document in documents["documents"]:
relevant_responses += document.text_content
diff --git a/superagi/worker.py b/superagi/worker.py
index 83afa7253..e0bace7a1 100644
--- a/superagi/worker.py
+++ b/superagi/worker.py
@@ -1,4 +1,5 @@
from __future__ import absolute_import
+import sys
from sqlalchemy.orm import sessionmaker
@@ -36,12 +37,11 @@
}
app.conf.beat_schedule = beat_schedule
-# @event.listens_for(AgentExecution.status, "set")
-# def agent_status_change(target, val,old_val,initiator):
-# if not get_config("IN_TESTING",False):
-# webhook_callback.delay(target.id,val,old_val)
-
-
+@event.listens_for(AgentExecution.status, "set")
+def agent_status_change(target, val,old_val,initiator):
+ if not hasattr(sys, '_called_from_test'):
+ webhook_callback.delay(target.id,val,old_val)
+
@app.task(name="initialize-schedule-agent", autoretry_for=(Exception,), retry_backoff=2, max_retries=5)
def initialize_schedule_agent_task():
"""Executing agent scheduling in the background."""
diff --git a/tests/unit_tests/agent/test_tool_executor.py b/tests/unit_tests/agent/test_tool_executor.py
index 246db63fd..35e1735f3 100644
--- a/tests/unit_tests/agent/test_tool_executor.py
+++ b/tests/unit_tests/agent/test_tool_executor.py
@@ -19,7 +19,7 @@ def mock_tools():
@pytest.fixture
def executor(mock_tools):
- return ToolExecutor(organisation_id=1, agent_id=1, tools=mock_tools)
+ return ToolExecutor(organisation_id=1, agent_id=1, tools=mock_tools, agent_execution_id=1)
def test_tool_executor_finish(executor):
res = executor.execute(None, 'finish', {})
@@ -29,7 +29,7 @@ def test_tool_executor_finish(executor):
@patch('superagi.agent.tool_executor.EventHandler')
def test_tool_executor_success(mock_event_handler, executor, mock_tools):
for i, tool in enumerate(mock_tools):
- res = executor.execute(None, f'tool{i}', {})
+ res = executor.execute(None, f'tool{i}', {'agent_execution_id': 1})
assert res.status == 'SUCCESS'
assert res.result == f'Tool {tool.name} returned: {tool.name}'
assert res.retry == False
diff --git a/tests/unit_tests/apm/test_knowledge_handler.py b/tests/unit_tests/apm/test_knowledge_handler.py
new file mode 100644
index 000000000..acb810ff4
--- /dev/null
+++ b/tests/unit_tests/apm/test_knowledge_handler.py
@@ -0,0 +1,100 @@
+import pytest
+from unittest.mock import MagicMock
+from superagi.apm.knowledge_handler import KnowledgeHandler
+from fastapi import HTTPException
+from datetime import datetime
+import pytz
+
+@pytest.fixture
+def organisation_id():
+ return 1
+
+@pytest.fixture
+def mock_session():
+ return MagicMock()
+
+@pytest.fixture
+def knowledge_handler(mock_session, organisation_id):
+ return KnowledgeHandler(mock_session, organisation_id)
+
+def test_get_knowledge_usage_by_name(knowledge_handler, mock_session):
+ knowledge_handler.session = mock_session
+ knowledge_name = 'Knowledge1'
+ mock_knowledge_event = MagicMock()
+ mock_knowledge_event.knowledge_unique_agents = 5
+ mock_knowledge_event.knowledge_name = knowledge_name
+ mock_knowledge_event.id = 1
+
+ mock_session.query.return_value.filter_by.return_value.filter.return_value.first.return_value = mock_knowledge_event
+ mock_session.query.return_value.filter.return_value.group_by.return_value.first.return_value = mock_knowledge_event
+ mock_session.query.return_value.filter.return_value.count.return_value = 10
+
+ result = knowledge_handler.get_knowledge_usage_by_name(knowledge_name)
+
+ assert isinstance(result, dict)
+ assert result == {
+ 'knowledge_unique_agents': 5,
+ 'knowledge_calls': 10
+ }
+
+ mock_session.query.return_value.filter_by.return_value.filter.return_value.first.return_value = None
+
+ with pytest.raises(HTTPException):
+ knowledge_handler.get_knowledge_usage_by_name('NonexistentKnowledge')
+
+def test_get_knowledge_events_by_name(knowledge_handler, mock_session):
+ knowledge_name = 'knowledge1'
+ knowledge_handler.session = mock_session
+ knowledge_handler.organisation_id = 1
+
+ mock_knowledge = MagicMock()
+ mock_knowledge.id = 1
+ mock_session.query().filter_by().filter().first.return_value = mock_knowledge
+
+ result_obj = MagicMock()
+ result_obj.agent_id = 1
+ result_obj.created_at = datetime.now()
+ result_obj.event_name = 'knowledge_picked'
+ result_obj.event_property = {'knowledge_name': 'knowledge1', 'agent_execution_id': '1'}
+ result_obj2 = MagicMock()
+ result_obj2.agent_id = 1
+ result_obj2.event_name = 'run_completed'
+ result_obj2.event_property = {'tokens_consumed': 10, 'calls': 5, 'name': 'Runner', 'agent_execution_id': '1'}
+ result_obj3 = MagicMock()
+ result_obj3.agent_id = 1
+ result_obj3.event_name = 'agent_created'
+ result_obj3.event_property = {'agent_name': 'A1', 'model': 'M1'}
+
+ mock_session.query().filter().all.side_effect = [[result_obj], [result_obj2], [result_obj3]]
+
+ user_timezone = MagicMock()
+ user_timezone.value = 'America/New_York'
+ mock_session.query().filter().first.return_value = user_timezone
+
+ result = knowledge_handler.get_knowledge_events_by_name(knowledge_name)
+
+ assert isinstance(result, list)
+ assert len(result) == 1
+ for item in result:
+ assert 'agent_execution_id' in item
+ assert 'created_at' in item
+ assert 'tokens_consumed' in item
+ assert 'calls' in item
+ assert 'agent_execution_name' in item
+ assert 'agent_name' in item
+ assert 'model' in item
+
+
+def test_get_knowledge_events_by_name_knowledge_not_found(knowledge_handler, mock_session):
+ knowledge_name = "knowledge1"
+ not_found_message = 'Knowledge not found'
+
+ mock_session.query().filter_by().filter().first.return_value = None
+
+ try:
+ knowledge_handler.get_knowledge_events_by_name(knowledge_name)
+ assert False, "Expected HTTPException has not been raised"
+ except HTTPException as e:
+ assert str(e.detail) == not_found_message, f"Expected {not_found_message}, got {e.detail}"
+ finally:
+ assert mock_session.query().filter_by().filter().first.called, "first() function not called"
\ No newline at end of file
diff --git a/tests/unit_tests/apm/test_tools_handler.py b/tests/unit_tests/apm/test_tools_handler.py
index f58650cc7..f805bbdf1 100644
--- a/tests/unit_tests/apm/test_tools_handler.py
+++ b/tests/unit_tests/apm/test_tools_handler.py
@@ -1,8 +1,12 @@
import pytest
-from unittest.mock import MagicMock
-
+from unittest.mock import MagicMock, patch
+from fastapi import HTTPException
from superagi.apm.tools_handler import ToolsHandler
from sqlalchemy.orm import Session
+from superagi.models.agent_config import AgentConfiguration
+
+from datetime import datetime
+import pytz
@pytest.fixture
def organisation_id():
@@ -17,6 +21,129 @@ def tools_handler(mock_session, organisation_id):
return ToolsHandler(mock_session, organisation_id)
def test_calculate_tool_usage(tools_handler, mock_session):
- mock_session.query().all.return_value = [MagicMock()]
+ tool_used_subquery = MagicMock()
+ agent_count_subquery = MagicMock()
+ total_usage_subquery = MagicMock()
+
+ tool_used_subquery.c.tool_name = 'Tool1'
+ tool_used_subquery.c.agent_id = 1
+
+ agent_count_subquery.c.tool_name = 'Tool1'
+ agent_count_subquery.c.unique_agents = 1
+
+ total_usage_subquery.c.tool_name = 'Tool1'
+ total_usage_subquery.c.total_usage = 5
+
+ tools_handler.get_tool_and_toolkit = MagicMock()
+ tools_handler.get_tool_and_toolkit.return_value = {'Tool1': 'Toolkit1'}
+
+ mock_session.query().filter_by().subquery.return_value = tool_used_subquery
+ mock_session.query().group_by().subquery.return_value = agent_count_subquery
+ mock_session.query().group_by().subquery.return_value = total_usage_subquery
+
+ result_obj = MagicMock()
+ result_obj.tool_name = 'Tool1'
+ result_obj.unique_agents = 1
+ result_obj.total_usage = 5
+ mock_session.query().join().all.return_value = [result_obj]
+
result = tools_handler.calculate_tool_usage()
- assert isinstance(result, list)
\ No newline at end of file
+
+ assert isinstance(result, list)
+
+ expected_output = [{'tool_name': 'Tool1', 'unique_agents': 1, 'total_usage': 5, 'toolkit': 'Toolkit1'}]
+ assert result == expected_output
+
+def test_get_tool_and_toolkit(tools_handler, mock_session):
+ result_obj = MagicMock()
+ result_obj.tool_name = 'tool 1'
+ result_obj.toolkit_name = 'toolkit 1'
+
+ mock_session.query().join().all.return_value = [result_obj]
+
+ output = tools_handler.get_tool_and_toolkit()
+
+ assert isinstance(output, dict)
+ assert output == {'tool 1': 'toolkit 1'}
+
+def test_get_tool_usage_by_name(tools_handler, mock_session):
+ tools_handler.session = mock_session
+ tool_name = 'Tool1'
+ formatted_tool_name = tool_name.lower().replace(" ", "")
+
+ mock_tool = MagicMock()
+ mock_tool.name = tool_name
+
+ mock_tool_event = MagicMock()
+ mock_tool_event.tool_name = formatted_tool_name
+ mock_tool_event.tool_calls = 10
+ mock_tool_event.tool_unique_agents = 5
+
+ mock_session.query.return_value.filter_by.return_value.first.return_value = mock_tool
+ mock_session.query.return_value.filter.return_value.group_by.return_value.first.return_value = mock_tool_event
+
+ result = tools_handler.get_tool_usage_by_name(tool_name=tool_name)
+
+ assert isinstance(result, dict)
+ assert result == {
+ 'tool_calls': 10,
+ 'tool_unique_agents': 5
+ }
+
+ mock_session.query.return_value.filter_by.return_value.first.return_value = None
+
+ with pytest.raises(HTTPException):
+ tools_handler.get_tool_usage_by_name(tool_name="NonexistentTool")
+
+def test_get_tool_events_by_name(tools_handler, mock_session):
+ tool_name = 'Tool1'
+ tools_handler.session = mock_session
+ tools_handler.organisation_id = 1
+
+ mock_tool = MagicMock()
+ mock_tool.id = 1
+ mock_session.query().filter_by().first.return_value = mock_tool
+
+ result_obj = MagicMock()
+ result_obj.agent_id = 1
+ result_obj.id = 1
+ result_obj.created_at = datetime.now()
+ result_obj.event_name = 'tool_used'
+ result_obj.event_property = {'tool_name': 'tool1', 'agent_execution_id': '1'}
+ result_obj2 = MagicMock()
+ result_obj2.agent_id = 1
+ result_obj2.id = 2
+ result_obj2.event_name = 'run_completed'
+ result_obj2.event_property = {'tokens_consumed': 10, 'calls': 5, 'name': 'Runner', 'agent_execution_id': '1'}
+ result_obj3 = MagicMock()
+ result_obj3.agent_id = 1
+ result_obj3.event_name = 'agent_created'
+ result_obj3.event_property = {'agent_name': 'A1', 'model': 'M1'}
+
+ mock_session.query().filter().all.side_effect = [[result_obj], [result_obj2], [result_obj3], []]
+
+ user_timezone = MagicMock()
+ user_timezone.value = 'America/New_York'
+ mock_session.query().filter().first.return_value = user_timezone
+
+ result = tools_handler.get_tool_events_by_name(tool_name)
+
+ assert isinstance(result, list)
+ assert len(result) == 1
+ for item in result:
+ assert 'agent_execution_id' in item
+ assert 'created_at' in item
+ assert 'tokens_consumed' in item
+ assert 'calls' in item
+ assert 'agent_execution_name' in item
+ assert 'agent_name' in item
+ assert 'model' in item
+
+def test_get_tool_events_by_name_tool_not_found(tools_handler, mock_session):
+ tool_name = "tool1"
+
+ mock_session.query().filter_by().first.return_value = None
+ with pytest.raises(HTTPException):
+ tools_handler.get_tool_events_by_name(tool_name)
+
+ assert mock_session.query().filter_by().first.called
\ No newline at end of file
diff --git a/tests/unit_tests/controllers/test_publish_agent.py b/tests/unit_tests/controllers/test_publish_agent.py
new file mode 100644
index 000000000..615a8a26a
--- /dev/null
+++ b/tests/unit_tests/controllers/test_publish_agent.py
@@ -0,0 +1,41 @@
+import pytest
+from fastapi.testclient import TestClient
+from unittest.mock import create_autospec, patch
+from main import app
+from superagi.models.agent import Agent
+from superagi.models.agent_config import AgentConfiguration
+from superagi.models.agent_execution import AgentExecution
+from superagi.models.agent_execution_config import AgentExecutionConfiguration
+from superagi.models.organisation import Organisation
+from superagi.models.user import User
+from sqlalchemy.orm import Session
+
+client = TestClient(app)
+
+@pytest.fixture
+def mocks():
+ # Mock tool kit data for testing
+ mock_agent = Agent(id=1, name="test_agent", project_id=1, description="testing", agent_workflow_id=1, is_deleted=False)
+ mock_agent_config = AgentConfiguration(id=1, agent_id=1, key="test_key", value="['test']")
+ mock_execution = AgentExecution(id=1, agent_id=1, name="test_execution")
+ mock_execution_config = [AgentExecutionConfiguration(id=1, agent_execution_id=1, key="test_key", value="['test']")]
+ return mock_agent,mock_agent_config,mock_execution,mock_execution_config
+
+def test_publish_template(mocks):
+ with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \
+ patch('superagi.helper.auth.get_current_user') as mock_get_user, \
+ patch('superagi.helper.auth.db') as mock_auth_db,\
+ patch('superagi.controllers.agent_template.db') as mock_db:
+
+ mock_session = create_autospec(Session)
+ mock_agent, mock_agent_config, mock_execution, mock_execution_config = mocks
+
+ mock_session.query.return_value.filter.return_value.first.return_value = mock_agent
+ mock_session.query.return_value.filter.return_value.all.return_value = [mock_agent_config]
+ mock_session.query.return_value.filter.return_value.order_by.return_value.first.return_value = mock_execution
+ mock_session.query.return_value.filter.return_value.all.return_value = mock_execution_config
+
+ with patch('superagi.controllers.agent_execution_config.AgentExecution.get_agent_execution_from_id') as mock_get_exec:
+ mock_get_exec.return_value = mock_execution
+ response = client.post("/agent_templates/publish_template/agent_execution_id/1")
+ assert response.status_code == 201
\ No newline at end of file
diff --git a/tests/unit_tests/helper/test_github_helper.py b/tests/unit_tests/helper/test_github_helper.py
index 04105c59a..e947e4703 100644
--- a/tests/unit_tests/helper/test_github_helper.py
+++ b/tests/unit_tests/helper/test_github_helper.py
@@ -155,7 +155,47 @@ def test_create_pull_request(self, mock_post):
headers={'header': 'value'}
)
- # ... more tests for other methods
+ @patch('requests.get')
+ def test_get_pull_request_content_success(self, mock_get):
+ mock_get.return_value.status_code = 200
+ mock_get.return_value.text = "some_content"
+
+ github_api = GithubHelper('access_token', 'username')
+ result = github_api.get_pull_request_content("owner", "repo", 1)
+
+ self.assertEqual(result, "some_content")
+
+ @patch('requests.get')
+ def test_get_pull_request_content_not_found(self, mock_get):
+ mock_get.return_value.status_code = 404
+
+ github_api = GithubHelper('access_token', 'username')
+ result = github_api.get_pull_request_content("owner", "repo", 1)
+
+ self.assertIsNone(result)
+
+ @patch('requests.get')
+ def test_get_latest_commit_id_of_pull_request(self, mock_get):
+ mock_get.return_value.status_code = 200
+ mock_get.return_value.json.return_value = [{"sha": "123"}, {"sha": "456"}]
+
+ github_api = GithubHelper('access_token', 'username')
+ result = github_api.get_latest_commit_id_of_pull_request("owner", "repo", 1)
+
+ self.assertEqual(result, "456")
+
+ @patch('requests.post')
+ def test_add_line_comment_to_pull_request(self, mock_post):
+ mock_post.return_value.status_code = 201
+ mock_post.return_value.json.return_value = {"id": 1, "body": "comment"}
+
+ github_api = GithubHelper('access_token', 'username')
+ result = github_api.add_line_comment_to_pull_request("owner", "repo", 1, "commit_id", "file_path", 1, "comment")
+
+ self.assertEqual(result, {"id": 1, "body": "comment"})
+
+
+# ... more tests for other methods
if __name__ == '__main__':
unittest.main()
\ No newline at end of file
diff --git a/tests/unit_tests/helper/test_webhooks.py b/tests/unit_tests/helper/test_webhooks.py
new file mode 100644
index 000000000..b52edf507
--- /dev/null
+++ b/tests/unit_tests/helper/test_webhooks.py
@@ -0,0 +1,75 @@
+import json
+from unittest.mock import Mock, patch
+import pytest
+from superagi.helper.webhook_manager import WebHookManager
+from superagi.models.webhooks import Webhooks
+
+@pytest.fixture
+def mock_session():
+ return Mock()
+
+@pytest.fixture
+def mock_agent_execution():
+ return Mock()
+
+@pytest.fixture
+def mock_agent():
+ return Mock()
+
+@pytest.fixture
+def mock_webhook():
+ return Mock()
+
+@pytest.fixture
+def mock_org():
+ org_mock = Mock()
+ org_mock.id = "mock_org_id"
+ return org_mock
+
+def test_agent_status_change_callback(
+ mock_session, mock_agent_execution, mock_agent, mock_org, mock_webhook
+):
+ curr_status = "NEW_STATUS"
+ old_status = "OLD_STATUS"
+ mock_agent_id = "mock_agent_id"
+ mock_org_id = "mock_org_id"
+
+ # Create a mock instance of AgentExecution and set its attributes
+ mock_agent_execution_instance = Mock()
+ mock_agent_execution_instance.agent_id = "mock_agent_id"
+
+ # Create a mock instance of Agent and set its attributes
+ mock_agent_instance = Mock()
+ mock_agent_instance.get_agent_organisation.return_value = mock_org
+
+ # Create a mock instance of Webhooks and set its attributes
+ mock_webhook_instance = Mock(spec=Webhooks)
+ mock_webhook_instance.filters = {"status": ["PAUSED", "RUNNING"]}
+
+ # Set up session.query().filter().all() to return the mock_webhook_instance
+ mock_session.query.return_value.filter.return_value.all.return_value = [mock_webhook_instance]
+
+ # Patch required functions/methods
+ with patch(
+ 'superagi.controllers.agent_execution_config.AgentExecution.get_agent_execution_from_id',
+ return_value=mock_agent_execution_instance
+ ), patch(
+ 'superagi.models.agent.Agent.get_agent_from_id',
+ return_value=mock_agent_instance
+ ), patch(
+ 'requests.post',
+ return_value=Mock(status_code=200) # Mock the status_code response
+ ) as mock_post, patch(
+ 'json.dumps'
+ ) as mock_json_dumps:
+
+ # Create the WebHookManager instance
+ web_hook_manager = WebHookManager(mock_session)
+
+ # Call the function
+ web_hook_manager.agent_status_change_callback(
+ mock_agent_execution_instance, curr_status, old_status
+ )
+
+ assert mock_agent_execution_instance.agent_status_change_callback
+
diff --git a/tests/unit_tests/jobs/conftest.py b/tests/unit_tests/jobs/conftest.py
new file mode 100644
index 000000000..e1344072a
--- /dev/null
+++ b/tests/unit_tests/jobs/conftest.py
@@ -0,0 +1,8 @@
+# content of conftest.py
+def pytest_configure(config):
+ import sys
+ sys._called_from_test = True
+
+def pytest_unconfigure(config):
+ import sys # This was missing from the manual
+ del sys._called_from_test
diff --git a/tests/unit_tests/llms/test_model_factory.py b/tests/unit_tests/llms/test_model_factory.py
index 54240476f..2aa49a542 100644
--- a/tests/unit_tests/llms/test_model_factory.py
+++ b/tests/unit_tests/llms/test_model_factory.py
@@ -1,52 +1,82 @@
-# from unittest.mock import MagicMock, patch
-# from superagi.llms.openai import OpenAi
-# from superagi.llms.replicate import Replicate
-# from superagi.models.models_config import ModelsConfig
-# from superagi.models.models import Models
-# from superagi.models.db import connect_db
-# from sqlalchemy.orm import sessionmaker
-# import pytest
-#
-#
-# @pytest.fixture
-# def mock_db_session():
-# db_session = MagicMock()
-# db_session.query().filter().first().return_value = Models(model_name="gpt-3.5-turbo", org_id=1, model_provider_id=1)
-# db_session.query().filter().first().return_value = ModelsConfig(provider="OpenAI")
-# return db_session
-#
-#
-# @patch("superagi.models.db.connect_db")
-# @patch("sqlalchemy.orm.sessionmaker")
-# def test_get_model_openai(mock_sessionmaker, mock_connect_db, mock_db_session):
-# mock_sessionmaker.return_value = MagicMock(return_value=mock_db_session)
-# mock_connect_db.return_value = MagicMock()
-#
-# from superagi.llms.openai import OpenAi
-# result = OpenAi.get_model(organisation_id=1, api_key="TEST_KEY")
-#
-# assert isinstance(result, OpenAi)
-# #
-# # @patch('superagi.models.db.connect_db')
-# # @patch('superagi.llms.replicate.Replicate')
-# # def test_get_model_replicate(mock_replicate, mock_db, mock_session_maker):
-# # mock_session_maker.query().filter().filter().first().return_value.provider = 'Replicate'
-# # from superagi.models.db import get_model
-# # result = get_model(organisation_id=1, api_key="TEST_KEY")
-# # assert isinstance(result, superagi.llms.replicate.Replicate)
-# #
-# # @patch('superagi.models.db.connect_db')
-# # @patch('superagi.llms.google_palm.GooglePalm')
-# # def test_get_model_google_palm(mock_google_palm, mock_db, mock_session_maker):
-# # mock_session_maker.query().filter().filter().first().return_value.provider = 'Google Palm'
-# # from superagi.models.db import get_model
-# # result = get_model(organisation_id=1, api_key="TEST_KEY")
-# # assert isinstance(result, superagi.llms.google_palm.GooglePalm)
-# #
-# # @patch('superagi.models.db.connect_db')
-# # @patch('superagi.llms.hugging_face.HuggingFace')
-# # def test_get_model_hugging_face(mock_hugging_face, mock_db, mock_session_maker):
-# # mock_session_maker.query().filter().filter().first().return_value.provider = 'Hugging Face'
-# # from superagi.models.db import get_model
-# # result = get_model(organisation_id=1, api_key="TEST_KEY")
-# # assert isinstance(result, superagi.llms.hugging_face.HuggingFace)
\ No newline at end of file
+import pytest
+from unittest.mock import Mock
+
+from superagi.llms.google_palm import GooglePalm
+from superagi.llms.hugging_face import HuggingFace
+from superagi.llms.llm_model_factory import get_model, build_model_with_api_key
+from superagi.llms.openai import OpenAi
+from superagi.llms.replicate import Replicate
+
+
+# Fixtures for the mock objects
+@pytest.fixture
+def mock_openai():
+ return Mock(spec=OpenAi)
+
+@pytest.fixture
+def mock_replicate():
+ return Mock(spec=Replicate)
+
+@pytest.fixture
+def mock_google_palm():
+ return Mock(spec=GooglePalm)
+
+@pytest.fixture
+def mock_hugging_face():
+ return Mock(spec=HuggingFace)
+
+@pytest.fixture
+def mock_replicate():
+ return Mock(spec=Replicate)
+
+@pytest.fixture
+def mock_google_palm():
+ return Mock(spec=GooglePalm)
+
+@pytest.fixture
+def mock_hugging_face():
+ return Mock(spec=HuggingFace)
+
+# Test build_model_with_api_key function
+def test_build_model_with_openai(mock_openai, monkeypatch):
+ monkeypatch.setattr('superagi.llms.llm_model_factory.OpenAi', mock_openai)
+ model = build_model_with_api_key('OpenAi', 'fake_key')
+ mock_openai.assert_called_once_with(api_key='fake_key')
+ assert isinstance(model, Mock)
+
+def test_build_model_with_replicate(mock_replicate, monkeypatch):
+ monkeypatch.setattr('superagi.llms.llm_model_factory.Replicate', mock_replicate)
+ model = build_model_with_api_key('Replicate', 'fake_key')
+ mock_replicate.assert_called_once_with(api_key='fake_key')
+ assert isinstance(model, Mock)
+
+
+def test_build_model_with_openai(mock_openai, monkeypatch):
+ monkeypatch.setattr('superagi.llms.llm_model_factory.OpenAi', mock_openai) # Replace 'your_module' with the actual module name
+ model = build_model_with_api_key('OpenAi', 'fake_key')
+ mock_openai.assert_called_once_with(api_key='fake_key')
+ assert isinstance(model, Mock)
+
+def test_build_model_with_replicate(mock_replicate, monkeypatch):
+ monkeypatch.setattr('superagi.llms.llm_model_factory.Replicate', mock_replicate) # Replace 'your_module' with the actual module name
+ model = build_model_with_api_key('Replicate', 'fake_key')
+ mock_replicate.assert_called_once_with(api_key='fake_key')
+ assert isinstance(model, Mock)
+
+def test_build_model_with_google_palm(mock_google_palm, monkeypatch):
+ monkeypatch.setattr('superagi.llms.llm_model_factory.GooglePalm', mock_google_palm) # Replace 'your_module' with the actual module name
+ model = build_model_with_api_key('Google Palm', 'fake_key')
+ mock_google_palm.assert_called_once_with(api_key='fake_key')
+ assert isinstance(model, Mock)
+
+def test_build_model_with_hugging_face(mock_hugging_face, monkeypatch):
+ monkeypatch.setattr('superagi.llms.llm_model_factory.HuggingFace', mock_hugging_face) # Replace 'your_module' with the actual module name
+ model = build_model_with_api_key('Hugging Face', 'fake_key')
+ mock_hugging_face.assert_called_once_with(api_key='fake_key')
+ assert isinstance(model, Mock)
+
+def test_build_model_with_unknown_provider(capsys): # capsys is a built-in pytest fixture for capturing print output
+ model = build_model_with_api_key('Unknown', 'fake_key')
+ assert model is None
+ captured = capsys.readouterr()
+ assert "Unknown provider." in captured.out
\ No newline at end of file
diff --git a/tests/unit_tests/tools/github/test_fetch_pull_request.py b/tests/unit_tests/tools/github/test_fetch_pull_request.py
new file mode 100644
index 000000000..270ec2280
--- /dev/null
+++ b/tests/unit_tests/tools/github/test_fetch_pull_request.py
@@ -0,0 +1,58 @@
+import pytest
+from unittest.mock import patch, Mock
+from pydantic import ValidationError
+
+from superagi.tools.github.fetch_pull_request import GithubFetchPullRequest, GithubFetchPullRequestSchema
+
+
+@pytest.fixture
+def mock_github_helper():
+ with patch('superagi.tools.github.fetch_pull_request.GithubHelper') as MockGithubHelper:
+ yield MockGithubHelper
+
+
+@pytest.fixture
+def tool(mock_github_helper):
+ tool = GithubFetchPullRequest()
+ tool.toolkit_config = Mock()
+ tool.toolkit_config.side_effect = ['dummy_token', 'dummy_username']
+ mock_github_helper_instance = mock_github_helper.return_value
+ mock_github_helper_instance.get_pull_requests_created_in_last_x_seconds.return_value = ['url1', 'url2']
+ return tool
+
+
+def test_execute(tool, mock_github_helper):
+ mock_github_helper_instance = mock_github_helper.return_value
+
+ # Execute the method
+ result = tool._execute('repo_name', 'repo_owner', 86400)
+
+ # Verify results
+ assert result == "Pull requests: ['url1', 'url2']"
+ mock_github_helper_instance.get_pull_requests_created_in_last_x_seconds.assert_called_once_with('repo_owner',
+ 'repo_name', 86400)
+
+
+def test_schema_validation():
+ # Valid data
+ valid_data = {'repository_name': 'repo', 'repository_owner': 'owner', 'time_in_seconds': 86400}
+ GithubFetchPullRequestSchema(**valid_data)
+
+ # Invalid data
+ invalid_data = {'repository_name': 'repo', 'repository_owner': 'owner', 'time_in_seconds': 'string'}
+ with pytest.raises(ValidationError):
+ GithubFetchPullRequestSchema(**invalid_data)
+
+
+def test_execute_error(mock_github_helper):
+ tool = GithubFetchPullRequest()
+ tool.toolkit_config = Mock()
+ tool.toolkit_config.side_effect = ['dummy_token', 'dummy_username']
+ mock_github_helper_instance = mock_github_helper.return_value
+ mock_github_helper_instance.get_pull_requests_created_in_last_x_seconds.side_effect = Exception('An error occurred')
+
+ # Execute the method
+ result = tool._execute('repo_name', 'repo_owner', 86400)
+
+ # Verify results
+ assert result == 'Error: Unable to fetch pull requests An error occurred'
diff --git a/tests/unit_tests/tools/github/test_review_pull_request.py b/tests/unit_tests/tools/github/test_review_pull_request.py
new file mode 100644
index 000000000..4e5c12903
--- /dev/null
+++ b/tests/unit_tests/tools/github/test_review_pull_request.py
@@ -0,0 +1,82 @@
+import pytest
+from unittest.mock import patch, Mock
+import pytest_mock
+from pydantic import ValidationError
+
+from superagi.tools.github.review_pull_request import GithubReviewPullRequest
+
+class MockLLM:
+ def get_model(self):
+ return "some_model"
+
+class MockTokenCounter:
+ @staticmethod
+ def count_message_tokens(message, model):
+ # Mocking the token count based on the length of the content.
+ # Replace this logic as needed.
+ return len(message[0]['content'])
+
+
+def test_split_pull_request_content_into_multiple_parts():
+ tool = GithubReviewPullRequest()
+ tool.llm = MockLLM()
+
+ # Mocking the pull_request_arr
+ pull_request_arr = ["part1", "part2", "part3"]
+
+ # Calling the method to be tested
+ result = tool.split_pull_request_content_into_multiple_parts(4000,pull_request_arr)
+
+ # Validate the result (this depends on what you expect the output to be)
+ # For instance, if you expect the result to be a list of all parts concatenated with 'diff --git'
+ expected = ["diff --gitpart1diff --gitpart2diff --gitpart3"]
+ assert result == expected
+
+
+@pytest.mark.parametrize("diff_content, file_path, line_number, expected", [
+ ("file_path_1\n@@ -1,3 +1,4 @@\n+ line1\n+ line2\n+ line3", "file_path_1", 3, 4),
+ ("file_path_2\n@@ -1,3 +1,3 @@\n+ line1\n- line2", "file_path_2", 1, 2),
+ ("file_path_3\n@@ -1,3 +1,4 @@\n+ line1\n+ line2\n- line3", "file_path_3", 2, 3)
+])
+def test_get_exact_line_number(diff_content, file_path, line_number, expected):
+ tool = GithubReviewPullRequest()
+
+ # Calling the method to be tested
+ result = tool.get_exact_line_number(diff_content, file_path, line_number)
+
+ # Validate the result
+ assert result == expected
+
+
+class MockGithubHelper:
+ def __init__(self, access_token, username):
+ pass
+
+ def get_pull_request_content(self, owner, repo, pr_number):
+ return 'mock_content'
+
+ def get_latest_commit_id_of_pull_request(self, owner, repo, pr_number):
+ return 'mock_commit_id'
+
+ def add_line_comment_to_pull_request(self, *args, **kwargs):
+ return True
+
+
+# Your test case
+def test_execute():
+ with patch('superagi.tools.github.review_pull_request.GithubHelper', MockGithubHelper), \
+ patch('superagi.tools.github.review_pull_request.TokenCounter.count_message_tokens', return_value=3000), \
+ patch('superagi.tools.github.review_pull_request.Agent.find_org_by_agent_id', return_value=Mock()), \
+ patch.object(GithubReviewPullRequest, 'get_tool_config', return_value='mock_value'), \
+ patch.object(GithubReviewPullRequest, 'run_code_review', return_value=None):
+ # Replace 'your_module' with the actual module name
+
+ tool = GithubReviewPullRequest()
+ tool.llm = Mock()
+ tool.llm.get_model = Mock(return_value='mock_model')
+ tool.toolkit_config = Mock()
+ tool.toolkit_config.session = 'mock_session'
+
+ result = tool._execute('mock_repo', 'mock_owner', 42)
+
+ assert result == 'Added comments to the pull request:42'
diff --git a/tools.json b/tools.json
index 9091ede76..94de16ccc 100644
--- a/tools.json
+++ b/tools.json
@@ -1,4 +1,4 @@
{
"tools": {
}
-}
+}
\ No newline at end of file
|