diff --git a/.do/app.yaml b/.do/app.yaml new file mode 100644 index 000000000..19fe1f4ad --- /dev/null +++ b/.do/app.yaml @@ -0,0 +1,76 @@ +alerts: + - rule: DEPLOYMENT_FAILED + - rule: DOMAIN_FAILED +databases: + - engine: PG + name: super-agi-main + num_nodes: 1 + size: basic-xs + version: "12" +ingress: + rules: + - component: + name: superagi-backend + match: + path: + prefix: /api +name: superagi +services: + - dockerfile_path: DockerfileRedis + github: + branch: main + deploy_on_push: true + repo: TransformerOptimus/SuperAGI + internal_ports: + - 6379 + instance_count: 1 + instance_size_slug: basic-xs + source_dir: / + name: superagi-redis + - dockerfile_path: Dockerfile + envs: + - key: REDIS_URL + scope: RUN_TIME + value: superagi-redis:6379 + - key: DB_URL + scope: RUN_TIME + value: ${super-agi-main.DATABASE_URL} + github: + branch: main + deploy_on_push: true + repo: TransformerOptimus/SuperAGI + http_port: 8001 + instance_count: 1 + instance_size_slug: basic-xs + run_command: /app/entrypoint.sh + source_dir: / + name: superagi-backend + - dockerfile_path: ./gui/DockerfileProd + github: + branch: main + deploy_on_push: true + repo: TransformerOptimus/SuperAGI + http_port: 3000 + instance_count: 1 + instance_size_slug: basic-xs + source_dir: ./gui + name: superagi-gui +workers: + - dockerfile_path: Dockerfile + envs: + - key: REDIS_URL + scope: RUN_TIME + value: superagi-redis:6379 + - key: DB_URL + scope: RUN_TIME + value: ${super-agi-main.DATABASE_URL} + github: + branch: main + deploy_on_push: true + repo: TransformerOptimus/SuperAGI + instance_count: 1 + instance_size_slug: basic-xs + run_command: celery -A superagi.worker worker --beat --loglevel=info + source_dir: / + name: superagi-celery + diff --git a/.do/deploy.template.yaml b/.do/deploy.template.yaml new file mode 100644 index 000000000..18ffa82dd --- /dev/null +++ b/.do/deploy.template.yaml @@ -0,0 +1,73 @@ +spec: + alerts: + - rule: DEPLOYMENT_FAILED + - rule: DOMAIN_FAILED + databases: + - engine: PG + name: super-agi-main + num_nodes: 1 + size: basic-xs + version: "12" + ingress: + rules: + - component: + name: superagi-backend + match: + path: + prefix: /api + name: superagi + services: + - dockerfile_path: DockerfileRedis + git: + branch: main + repo_clone_url: https://github.com/TransformerOptimus/SuperAGI.git + internal_ports: + - 6379 + instance_count: 1 + instance_size_slug: basic-xs + source_dir: / + name: superagi-redis + - dockerfile_path: Dockerfile + envs: + - key: REDIS_URL + scope: RUN_TIME + value: superagi-redis:6379 + - key: DB_URL + scope: RUN_TIME + value: ${super-agi-main.DATABASE_URL} + git: + branch: main + repo_clone_url: https://github.com/TransformerOptimus/SuperAGI.git + http_port: 8001 + instance_count: 1 + instance_size_slug: basic-xs + run_command: /app/entrypoint.sh + source_dir: / + name: superagi-backend + - dockerfile_path: ./gui/DockerfileProd + git: + branch: main + repo_clone_url: https://github.com/TransformerOptimus/SuperAGI.git + http_port: 3000 + instance_count: 1 + instance_size_slug: basic-xs + source_dir: ./gui + name: superagi-gui + workers: + - dockerfile_path: Dockerfile + envs: + - key: REDIS_URL + scope: RUN_TIME + value: superagi-redis:6379 + - key: DB_URL + scope: RUN_TIME + value: ${super-agi-main.DATABASE_URL} + git: + branch: main + repo_clone_url: https://github.com/TransformerOptimus/SuperAGI.git + instance_count: 1 + instance_size_slug: basic-xs + run_command: celery -A superagi.worker worker --beat --loglevel=info + source_dir: / + name: superagi-celery + diff --git a/DockerfileRedis b/DockerfileRedis new file mode 100644 index 000000000..04624e797 --- /dev/null +++ b/DockerfileRedis @@ -0,0 +1 @@ +FROM redis/redis-stack-server:latest \ No newline at end of file diff --git a/README.MD b/README.MD index adbe1ce5f..31fac79b9 100644 --- a/README.MD +++ b/README.MD @@ -85,6 +85,10 @@
Not sure how to setup? Learn here

+
+

+
Deploy SuperAGI to DigitalOcean with one click. +

### 💡 Features @@ -129,7 +133,6 @@ - If you wish to change the port it's running on, open the `docker-compose.yml` file and update the `proxy` container port forwarding, for example: `"3000:80"` - ### 🌐 Architecture diff --git a/config_template.yaml b/config_template.yaml index 51c090dec..c5e83df19 100644 --- a/config_template.yaml +++ b/config_template.yaml @@ -22,9 +22,10 @@ MAX_MODEL_TOKEN_LIMIT: 4032 # set to 2048 for llama #DATABASE INFO # redis details DB_NAME: super_agi_main -POSTGRES_URL: super__postgres +DB_HOST: super__postgres DB_USERNAME: superagi DB_PASSWORD: password +DB_URL: postgresql://superagi:password@super__postgres:5432/super_agi_main REDIS_URL: "super__redis:6379" #STORAGE TYPE ("FILE" or "S3") diff --git a/gui/pages/Content/APM/ApmDashboard.js b/gui/pages/Content/APM/ApmDashboard.js index d56f12466..c82b9b820 100644 --- a/gui/pages/Content/APM/ApmDashboard.js +++ b/gui/pages/Content/APM/ApmDashboard.js @@ -250,25 +250,19 @@ export default function ApmDashboard() { Agent Name - Model arrow_down + Model - Tokens Consumed arrow_down + Tokens Consumed - Runs arrow_down + Runs - Avg tokens per run arrow_down - Tools arrow_down + Avg tokens per run - Calls arrow_down + Tools - Avg Run Time arrow_down + Calls + + Avg Run Time diff --git a/gui/pages/Content/Agents/AgentCreate.js b/gui/pages/Content/Agents/AgentCreate.js index b3d136373..1f5afa3d1 100644 --- a/gui/pages/Content/Agents/AgentCreate.js +++ b/gui/pages/Content/Agents/AgentCreate.js @@ -11,7 +11,7 @@ import { updateExecution, uploadFile, getAgentDetails, addAgentRun, fetchModels, - getAgentWorkflows, validateOrAddModels + getAgentWorkflows, validateOrAddModels, publishTemplateToMarketplace } from "@/pages/api/DashboardService"; import { formatBytes, @@ -119,6 +119,9 @@ export default function AgentCreate({ const [editModal, setEditModal] = useState(false) const [editButtonClicked, setEditButtonClicked] = useState(false); + const [dropdown, setDropdown] = useState(false); + const [publishModal, setPublishModal] = useState(false); + useEffect(() => { getOrganisationConfig(organisationId, "model_api_key") @@ -502,30 +505,7 @@ export default function AgentCreate({ setCreateClickable(false); - let permission_type = permission; - if (permission.includes("RESTRICTED")) { - permission_type = "RESTRICTED"; - } - - const agentData = { - "name": agentName, - "project_id": selectedProjectId, - "description": agentDescription, - "goal": goals, - "instruction": instructions, - "agent_workflow": agentWorkflow, - "constraints": constraints, - "toolkits": [], - "tools": selectedTools, - "exit": exitCriterion, - "iteration_interval": stepTime, - "model": model, - "max_iterations": maxIterations, - "permission_type": permission_type, - "LTM_DB": longTermMemory ? database : null, - "user_timezone": getUserTimezone(), - "knowledge": toolNames.includes('Knowledge Search') ? selectedKnowledgeId : null, - }; + const agentData = setAgentData() const scheduleAgentData = { "agent_config": agentData, @@ -564,7 +544,34 @@ export default function AgentCreate({ }); } }; + const setAgentData= () => { + let permission_type = permission; + if (permission.includes("RESTRICTED")) { + permission_type = "RESTRICTED"; + } + + const agentData = { + "name": agentName, + "project_id": selectedProjectId, + "description": agentDescription, + "goal": goals, + "instruction": instructions, + "agent_workflow": agentWorkflow, + "constraints": constraints, + "toolkits": [], + "tools": selectedTools, + "exit": exitCriterion, + "iteration_interval": stepTime, + "model": model, + "max_iterations": maxIterations, + "permission_type": permission_type, + "LTM_DB": longTermMemory ? database : null, + "user_timezone": getUserTimezone(), + "knowledge": toolNames.includes('Knowledge Search') ? selectedKnowledgeId : null, + }; + return agentData + } const uploadResources = (agentId, name, executionId) => { if (addResources && input.length > 0) { const uploadPromises = input.map(fileData => { @@ -877,6 +884,20 @@ export default function AgentCreate({ localStorage.setItem('marketplace_tab', 'market_models'); } + const handleAddToMarketplace = () => { + const agentData = setAgentData() + agentData.agent_template_id = template.id + publishTemplateToMarketplace(agentData) + .then((response) => { + setDropdown(false) + setPublishModal(true) + }) + .catch((error) => { + toast.error("Error Publishing to marketplace") + console.error('Error Publishing to marketplace:', error); + }); + } + return (<>
@@ -1158,7 +1179,7 @@ export default function AgentCreate({

-
setAgentDropdown(!agentDropdown)} +
{setAgentDropdown(!edit ? !agentDropdown : false)}} style={{width: '100%'}}> {agentWorkflow} +
+
+ {dropdown && (
setDropdown(true)} onMouseOut={() => setDropdown(false)}> +
    +
  • updateTemplate()}>Update template
  • + {env === 'PROD' &&
  • handleAddToMarketplace()}>Publish to Marketplace
  • } +
+
)} +
+ {showButton &&
+ +
} +
- {showButton && ( - - )} {!edit ?
{createDropdown && (
{ setCreateModal(true); @@ -1370,6 +1398,23 @@ export default function AgentCreate({
)} + {publishModal &&
} +
diff --git a/gui/pages/Content/Agents/AgentWorkspace.js b/gui/pages/Content/Agents/AgentWorkspace.js index 736ac6f50..212cd651f 100644 --- a/gui/pages/Content/Agents/AgentWorkspace.js +++ b/gui/pages/Content/Agents/AgentWorkspace.js @@ -19,7 +19,7 @@ import { saveAgentAsTemplate, stopSchedule, getDateTime, - deleteAgent + deleteAgent, publishToMarketplace } from "@/pages/api/DashboardService"; import {EventBus} from "@/utils/eventBus"; import 'moment-timezone'; @@ -50,6 +50,9 @@ export default function AgentWorkspace({env, agentId, agentName, selectedView, a const [createStopModal, setCreateStopModal] = useState(false); const [agentScheduleDetails, setAgentScheduleDetails] = useState(null) + const [publishModal, setPublishModal] = useState(false); + const [publishModalState, setPublishModalState] = useState(false); + const closeCreateModal = () => { setCreateModal(false); setCreateEditModal(false); @@ -61,6 +64,24 @@ export default function AgentWorkspace({env, agentId, agentName, selectedView, a setDropdown(false); }; + const handlePublishToMarketplace =() => { + if(agent?.is_scheduled && agentExecutions?.length < 1){ + setDropdown(false) + setPublishModalState(true) + setPublishModal(true) + return; + } + publishToMarketplace(selectedRun?.id) + .then((response) => { + setDropdown(false) + setPublishModalState(false) + setPublishModal(true) + }) + .catch((error) => { + console.error('Error publishing to marketplace:', error); + }); + }; + const handleStopScheduleClick = () => { setCreateStopModal(true); setCreateModal(false); @@ -369,12 +390,11 @@ export default function AgentWorkspace({env, agentId, agentName, selectedView, a
{dropdown &&
setDropdown(true)} onMouseLeave={() => setDropdown(false)}> -
    -
  • saveAgentTemplate()}>Save as Template
  • +
      {selectedRun && selectedRun.status === 'RUNNING' &&
    • { updateRunStatus("PAUSED") }}>Pause
    • } @@ -385,17 +405,13 @@ export default function AgentWorkspace({env, agentId, agentName, selectedView, a {agentExecutions && agentExecutions.length > 1 &&
    • { updateRunStatus("TERMINATED") }}>Delete Run
    • } - - {agent?.is_scheduled ? (
      -
    • Edit Schedule
    • -
    • Stop Schedule
    • -
      ) : (
      - {agent && !agent?.is_running && !agent?.is_scheduled && -
    • { - setDropdown(false); - setCreateModal(true) - }}>Schedule Run
    • } -
      )} + {agentExecutions && selectedRun && (selectedRun.status === 'CREATED' || selectedRun.status === 'PAUSED' || selectedRun.status === 'RUNNING' || agentExecutions.length > 1) &&
      } +
    • saveAgentTemplate()}>Save as Template
    • + {agent && env === 'PROD' && +
    • { + handlePublishToMarketplace() + }}>Publish to marketplace
    • } +
    • sendAgentData({ id: agentId, name: "Edit Agent", @@ -407,6 +423,17 @@ export default function AgentWorkspace({env, agentId, agentName, selectedView, a setDeleteModal(true) }}>Delete Agent
    • +
      + {agent?.is_scheduled ? (
      +
    • Edit Schedule
    • +
    • Stop Schedule
    • +
      ) : (
      + {agent && !agent?.is_running && !agent?.is_scheduled && +
    • { + setDropdown(false); + setCreateModal(true) + }}>Schedule Run
    • } +
      )}
} @@ -587,6 +614,23 @@ export default function AgentWorkspace({env, agentId, agentName, selectedView, a
)} + {publishModal && (
{setPublishModal(false)}}> +
+ {publishModalState ?
Run the agent at least once to publish!
:
Template submitted successfully!
} +
+ {!publishModalState ? + #agent-templates-submission + channel. : } +
+
+ +
+
+
)} + ); diff --git a/gui/pages/Content/Agents/Agents.module.css b/gui/pages/Content/Agents/Agents.module.css index c6553a8f9..14204edd4 100644 --- a/gui/pages/Content/Agents/Agents.module.css +++ b/gui/pages/Content/Agents/Agents.module.css @@ -460,4 +460,29 @@ .button_margin{ margin-top: -10px; +} + +.dropdown_separator{ + border-bottom: 1px solid rgba(255, 255, 255, 0.08); + margin-left:-5px; + width: 180px +} + +.dropdown_container_agent{ + bottom: 40px; + z-index: 9999; + padding: 0px; + width: fit-content; + height: fit-content; + margin-right: 40px; + background: #3B3B49; + border-radius: 8px; + position: absolute; + box-shadow: 0 2px 7px rgba(0,0,0,.4), 0 0 2px rgba(0,0,0,.22); +} + +.dropdown_item_agent{ + height:30px; + paddingTop: 2px; + paddingBottom: 2px; } \ No newline at end of file diff --git a/gui/pages/Content/Agents/Details.js b/gui/pages/Content/Agents/Details.js index a3c05dc17..72c325163 100644 --- a/gui/pages/Content/Agents/Details.js +++ b/gui/pages/Content/Agents/Details.js @@ -1,4 +1,4 @@ -import React, {useEffect, useState} from 'react'; +import React, {useEffect, useState, useRef} from 'react'; import styles from './Agents.module.css'; import Image from "next/image"; import {formatNumber} from "@/utils/utils"; @@ -11,6 +11,10 @@ export default function Details({agentDetails1, runCount, agentScheduleDetails, const [filteredInstructions, setFilteredInstructions] = useState([]); const [scheduleText, setScheduleText] = useState(''); const [agentDetails, setAgentDetails] = useState(null) + const goalBoxRef = useRef(null); + const instructionBoxRef = useRef(null); + const constrainBoxRef = useRef(null); + const [isOverflowing, setIsOverflowing] = useState([false, false, false]); const info_text = { marginLeft: '7px', }; @@ -23,6 +27,22 @@ export default function Details({agentDetails1, runCount, agentScheduleDetails, fontSize: '11px' }; + useEffect(() => { + const newOverflowing = [...isOverflowing]; + + if (goalBoxRef.current) { + newOverflowing[0] = goalBoxRef.current.scrollHeight > goalBoxRef.current.clientHeight; + } + if (instructionBoxRef.current) { + newOverflowing[1] = instructionBoxRef.current.scrollHeight > instructionBoxRef.current.clientHeight; + } + if (constrainBoxRef.current) { + newOverflowing[2] = constrainBoxRef.current.scrollHeight > constrainBoxRef.current.clientHeight; + } + + setIsOverflowing(newOverflowing); + }, [agentDetails?.goal, filteredInstructions, agentDetails?.constraints]); + const openToolkitTab = (toolId) => { EventBus.emit('openToolkitTab', {toolId: toolId}); } @@ -80,18 +100,18 @@ export default function Details({agentDetails1, runCount, agentScheduleDetails, style={{display: 'flex', marginBottom: '5px', alignItems: 'center', justifyContent: 'flex-start', gap: '7.5%'}}>
-
calls-icon
-
Total Calls
+
runs-icon
+
Total Runs
-
{formatNumber(agentDetails?.calls || 0)}
+
{runCount || 0}
-
runs-icon
-
Total Runs
+
calls-icon
+
Total Calls
-
{runCount || 0}
+
{formatNumber(agentDetails?.calls || 0)}
@@ -108,15 +128,15 @@ export default function Details({agentDetails1, runCount, agentScheduleDetails,
{agentDetails?.goal?.length || 0} Goals
{agentDetails?.goal && agentDetails?.goal?.length > 0 &&
-
+
{agentDetails?.goal?.map((goal, index) => (
{index + 1}. {goal || ''}
{index !== agentDetails?.goal?.length - 1 &&
}
))}
-
setShowGoals(!showGoals)}> + {isOverflowing[0] &&
setShowGoals(!showGoals)}> {showGoals ? 'Show Less' : 'Show More'} -
+
}
} {filteredInstructions && filteredInstructions.length > 0 &&
@@ -125,15 +145,15 @@ export default function Details({agentDetails1, runCount, agentScheduleDetails,
{filteredInstructions.length || 0} Instructions
-
{filteredInstructions.map((instruction, index) => (
{index + 1}. {instruction || ''}
{index !== filteredInstructions.length - 1 &&
}
))}
-
setShowInstructions(!showInstructions)}>{showInstructions ? 'Show Less' : 'Show More'}
+ {isOverflowing[1] &&
setShowInstructions(!showInstructions)}>{showInstructions ? 'Show Less' : 'Show More'}
}
} {agentDetails &&
{agentDetails.tools && agentDetails.tools.length > 0 &&
@@ -156,15 +176,15 @@ export default function Details({agentDetails1, runCount, agentScheduleDetails,
constraint-icon
{agentDetails?.constraints.length || 0} Constraints
-
{agentDetails?.constraints?.map((constraint, index) => (
{index + 1}. {constraint || ''}
{index !== agentDetails.constraints.length - 1 &&
}
))}
-
setShowConstraints(!showConstraints)}>{showConstraints ? 'Show Less' : 'Show More'}
+ {isOverflowing[2] &&
setShowConstraints(!showConstraints)}>{showConstraints ? 'Show Less' : 'Show More'}
}
}
}
diff --git a/gui/pages/Content/Knowledge/Knowledge.js b/gui/pages/Content/Knowledge/Knowledge.js index 7074488cd..a90a50b93 100644 --- a/gui/pages/Content/Knowledge/Knowledge.js +++ b/gui/pages/Content/Knowledge/Knowledge.js @@ -37,7 +37,7 @@ export default function Knowledge({sendKnowledgeData, knowledge}) {
{item.name} {item.is_marketplace && markteplace-icon}
-
by {item.contributed_by}
+
by {item.contributed_by}
diff --git a/gui/pages/Content/Knowledge/Knowledge.module.css b/gui/pages/Content/Knowledge/Knowledge.module.css index 7768fb5c3..f433338cb 100644 --- a/gui/pages/Content/Knowledge/Knowledge.module.css +++ b/gui/pages/Content/Knowledge/Knowledge.module.css @@ -73,4 +73,9 @@ padding: 4px 8px; align-items: center; gap: 6px; +} + +.knowledge_options_dropdown{ + right: 25px; + width: 165px; } \ No newline at end of file diff --git a/gui/pages/Content/Knowledge/KnowledgeDetails.js b/gui/pages/Content/Knowledge/KnowledgeDetails.js index de4d45703..857c9ab7c 100644 --- a/gui/pages/Content/Knowledge/KnowledgeDetails.js +++ b/gui/pages/Content/Knowledge/KnowledgeDetails.js @@ -5,8 +5,9 @@ import styles from "@/pages/Content/Toolkits/Tool.module.css"; import Image from "next/image"; import KnowledgeForm from "@/pages/Content/Knowledge/KnowledgeForm"; import {deleteCustomKnowledge, deleteMarketplaceKnowledge, getKnowledgeDetails} from "@/pages/api/DashboardService"; -import {removeTab} from "@/utils/utils"; +import {removeTab, returnToolkitIcon, setLocalStorageValue} from "@/utils/utils"; import {EventBus} from "@/utils/eventBus"; +import Metrics from "@/pages/Content/Toolkits/Metrics"; export default function KnowledgeDetails({internalId, knowledgeId}) { const [showDescription, setShowDescription] = useState(false); @@ -24,6 +25,8 @@ export default function KnowledgeDetails({internalId, knowledgeId}) { const [chunkOverlap, setChunkOverlap] = useState(''); const [dimension, setDimension] = useState(''); const [vectorDBIndex, setVectorDBIndex] = useState(''); + const [activeTab, setActiveTab] = useState('metrics'); + const uninstallKnowledge = () => { setDropdown(false); @@ -89,8 +92,45 @@ export default function KnowledgeDetails({internalId, knowledgeId}) { return (<>
-
-
+
+
+
+
{knowledgeName}
+
+ {knowledgeDescription} +
+
+
+ + {dropdown &&
setDropdown(true)} onMouseLeave={() => setDropdown(false)}> +
    + {installationType !== 'Marketplace' && + //
  • View in marketplace
  • : +
  • Edit details
  • } +
  • Uninstall knowledge
  • +
+
} +
+
+
+
setActiveTab('metrics')}> +
Metrics
+
+
setActiveTab('configuration')}> +
Configuration
+
+
+ {activeTab === 'metrics' &&
+ +
} + { activeTab === "configuration" &&
+
+
{isEditing ? :
-
-
-
-
-
{knowledgeName}
-
- {`${showDescription ? knowledgeDescription : knowledgeDescription.slice(0, 70)}`} - {knowledgeDescription.length > 70 && - setShowDescription(!showDescription)}> - {showDescription ? '...less' : '...more'} - } -
-
-
-
- - {dropdown &&
setDropdown(true)} onMouseLeave={() => setDropdown(false)}> -
    - {installationType !== 'Marketplace' && - //
  • View in marketplace
  • : -
  • Edit details
  • } -
  • Uninstall knowledge
  • -
-
} -
-
-
- {installationType === 'Marketplace' &&
-
+ {installationType === 'Marketplace' &&
+
{installationType}
@@ -201,8 +210,10 @@ export default function KnowledgeDetails({internalId, knowledgeId}) {
}
} +
+
+
}
-
); diff --git a/gui/pages/Content/Knowledge/KnowledgeForm.js b/gui/pages/Content/Knowledge/KnowledgeForm.js index 7ba785812..3362c0c42 100644 --- a/gui/pages/Content/Knowledge/KnowledgeForm.js +++ b/gui/pages/Content/Knowledge/KnowledgeForm.js @@ -137,8 +137,13 @@ export default function KnowledgeForm({ } const handleIndexSelect = (index) => { - setLocalStorageArray("knowledge_index_" + String(internalId), index, setSelectedIndex); - setIndexDropdown(false); + if(index.is_valid_state) { + setLocalStorageArray("knowledge_index_" + String(internalId), index, setSelectedIndex); + setIndexDropdown(false); + } + else{ + toast.error('Select valid index', {autoClose: 1800}) + } } const checkIndexValidity = (validState) => { diff --git a/gui/pages/Content/Marketplace/KnowledgeTemplate.js b/gui/pages/Content/Marketplace/KnowledgeTemplate.js index 2ccbebe4a..6727e63fe 100644 --- a/gui/pages/Content/Marketplace/KnowledgeTemplate.js +++ b/gui/pages/Content/Marketplace/KnowledgeTemplate.js @@ -88,7 +88,16 @@ export default function KnowledgeTemplate({template, env}) { } }, []); - const handleInstallClick = (indexId) => { + const handleInstallClick = (index) => { + const indexId = index.id + if(!index.is_valid_state){ + toast.error("Select valid index", {autoClose : 1800}) + return + } + if (template && template.is_installed) { + toast.error("Template is already installed", {autoClose: 1800}); + return; + } setInstalled("Installing"); if (window.location.href.toLowerCase().includes('marketplace')) { @@ -103,11 +112,6 @@ export default function KnowledgeTemplate({template, env}) { return; } - if (template && template.is_installed) { - toast.error("Template is already installed", {autoClose: 1800}); - return; - } - setIndexDropdown(false); installKnowledgeTemplate(template.name, indexId) @@ -201,7 +205,7 @@ export default function KnowledgeTemplate({template, env}) {
Pinecone
{pinconeIndices.map((index) => (
handleInstallClick(index.id)} style={{ + onClick={() => handleInstallClick(index)} style={{ padding: '12px 14px', maxWidth: '100%', display: 'flex', @@ -223,7 +227,7 @@ export default function KnowledgeTemplate({template, env}) {
Qdrant
{qdrantIndices.map((index) => (
handleInstallClick(index.id)} style={{ + onClick={() => handleInstallClick(index)} style={{ padding: '12px 14px', maxWidth: '100%', display: 'flex', @@ -245,7 +249,7 @@ export default function KnowledgeTemplate({template, env}) {
Weaviate
{weaviateIndices.map((index) => (
handleInstallClick(index.id)} style={{ + onClick={() => handleInstallClick(index)} style={{ padding: '12px 14px', maxWidth: '100%', display: 'flex', @@ -356,4 +360,4 @@ export default function KnowledgeTemplate({template, env}) { ); -} \ No newline at end of file +} diff --git a/gui/pages/Content/Marketplace/Market.module.css b/gui/pages/Content/Marketplace/Market.module.css index c309f41cd..5d9f775a2 100644 --- a/gui/pages/Content/Marketplace/Market.module.css +++ b/gui/pages/Content/Marketplace/Market.module.css @@ -529,3 +529,24 @@ .settings_tab_img{ margin-top: -1px; } + +.checkboxGroup { + display: flex; + justify-content: space-between; + flex-wrap: wrap; + height: 15vh; +} + +.checkboxLabel { + display: flex; + align-items: center; + width: 15vw; + cursor:pointer +} + +.checkboxText { + font-weight: 400; + font-size: 12px; + color: #FFF; + margin-left:5px; +} \ No newline at end of file diff --git a/gui/pages/Content/Marketplace/ToolkitTemplate.js b/gui/pages/Content/Marketplace/ToolkitTemplate.js index b705c802e..7e416bd7a 100644 --- a/gui/pages/Content/Marketplace/ToolkitTemplate.js +++ b/gui/pages/Content/Marketplace/ToolkitTemplate.js @@ -50,27 +50,27 @@ export default function ToolkitTemplate({template, env}) { if(installed === "Update"){ updateMarketplaceToolTemplate(template.name) .then((response) => { - toast.success("Template Updated", {autoClose: 1800}); + toast.success("Toolkit Updated", {autoClose: 1800}); setInstalled('Installed'); }) .catch((error) => { - console.error('Error installing template:', error); + console.error('Error installing Toolkit:', error); }); return; } if (template && template.is_installed) { - toast.error("Template is already installed", {autoClose: 1800}); + toast.error("Toolkit is already installed", {autoClose: 1800}); return; } installToolkitTemplate(template.name) .then((response) => { - toast.success("Template installed", {autoClose: 1800}); + toast.success("Toolkit installed", {autoClose: 1800}); setInstalled('Installed'); }) .catch((error) => { - console.error('Error installing template:', error); + console.error('Error installing Toolkit', error); }); } diff --git a/gui/pages/Content/Models/ModelDetails.js b/gui/pages/Content/Models/ModelDetails.js index 751b50568..1b65ee7ac 100644 --- a/gui/pages/Content/Models/ModelDetails.js +++ b/gui/pages/Content/Models/ModelDetails.js @@ -3,17 +3,21 @@ import Image from "next/image"; import ModelMetrics from "./ModelMetrics"; import ModelInfo from "./ModelInfo"; import {fetchModel} from "@/pages/api/DashboardService"; +import {loadingTextEffect} from "@/utils/utils"; export default function ModelDetails({modelId, modelName}){ const [modelDetails, setModelDetails] = useState([]) const [selectedOption, setSelectedOption] = useState('metrics') + const [isLoading, setIsLoading] = useState(true) + const [loadingText, setLoadingText] = useState("Loading Models"); useEffect(() => { + loadingTextEffect('Loading Models', setLoadingText, 500); const fetchModelDetails = async () => { try { const response = await fetchModel(modelId); - console.log(response.data) setModelDetails(response.data) + setIsLoading(false) } catch(error) { console.log(`Error Fetching the Details of the Model ${modelName}`, error) } @@ -23,8 +27,8 @@ export default function ModelDetails({modelId, modelName}){ },[]) return( -
-
+
+ {!isLoading &&
{ modelDetails.name ? (modelDetails.name.split('/')[1] || modelDetails.name) : ""} {modelDetails.description}
@@ -33,9 +37,10 @@ export default function ModelDetails({modelId, modelName}){
-
- {selectedOption === 'metrics' && } - {selectedOption === 'details' && } +
} + {selectedOption === 'metrics' && !isLoading && } + {selectedOption === 'details' && !isLoading && } + {isLoading &&
{loadingText}
}
) } \ No newline at end of file diff --git a/gui/pages/Content/Models/ModelMetrics.js b/gui/pages/Content/Models/ModelMetrics.js index 45708616b..377ab0e65 100644 --- a/gui/pages/Content/Models/ModelMetrics.js +++ b/gui/pages/Content/Models/ModelMetrics.js @@ -69,7 +69,7 @@ export default function ModelMetrics(modelDetails) { {modelRunData.map((data, index) => ( - {data.created_at} + {data.created_at.slice(0, 19).replace('T', ' ')} {data.agent_name} {data.agent_execution_name} diff --git a/gui/pages/Content/Toolkits/Metrics.js b/gui/pages/Content/Toolkits/Metrics.js new file mode 100644 index 000000000..2887d5d10 --- /dev/null +++ b/gui/pages/Content/Toolkits/Metrics.js @@ -0,0 +1,145 @@ +import React, {useState, useEffect, useRef} from 'react'; +import {ToastContainer, toast} from 'react-toastify'; +import 'react-toastify/dist/ReactToastify.css'; +import { + getApiKeys, getToolMetrics, getToolLogs, getKnowledgeMetrics, getKnowledgeLogs +} from "@/pages/api/DashboardService"; +import { + loadingTextEffect, +} from "@/utils/utils"; + +export default function Metrics({toolName, knowledgeName}) { + const [apiKeys, setApiKeys] = useState([]); + const [totalTokens, setTotalTokens] = useState(0) + const [totalAgentsUsing, setTotalAgentsUsing] = useState(0) + const [totalCalls, setTotalCalls] = useState(0) + const [callLogs, setCallLogs] = useState([]) + const [isLoading, setIsLoading] = useState(true) + const [loadingText, setLoadingText] = useState("Loading Metrics"); + const metricsData = [ + { label: 'Total Calls', value: totalCalls }, + { label: 'Total Agents Using', value: totalAgentsUsing } + ]; + + useEffect(() => { + loadingTextEffect('Loading Metrics', setLoadingText, 500); + }, []); + + useEffect(() => { + if(toolName && !knowledgeName){ + fetchToolMetrics() + fetchToolLogs() + return; + } + if(!toolName && knowledgeName){ + fetchKnowledgeMetrics() + fetchKnowledgeLogs() + return; + } + }, [toolName, knowledgeName]); + + const fetchToolMetrics = () => { + getToolMetrics(toolName) + .then((response) => { + setTotalAgentsUsing(response.data.tool_unique_agents ? response.data.tool_unique_agents : 0) + setTotalCalls(response.data.tool_calls ? response.data.tool_calls : 0) + setIsLoading(false) + }) + .catch((error) => { + console.error('Error fetching Metrics', error); + }); + } + + const fetchToolLogs = () => { + getToolLogs(toolName) + .then((response) => { + setCallLogs(response.data ? response.data : []) + setIsLoading(false) + }) + .catch((error) => { + console.error('Error fetching Metrics', error); + }); + } + + const fetchKnowledgeMetrics = () => { + getKnowledgeMetrics(knowledgeName) + .then((response) => { + setTotalAgentsUsing(response.data.knowledge_unique_agents ? response.data.knowledge_unique_agents : 0) + setTotalCalls(response.data.knowledge_calls ? response.data.knowledge_calls : 0) + setIsLoading(false) + }) + .catch((error) => { + console.error('Error fetching Metrics', error); + }); + } + + const fetchKnowledgeLogs = () => { + getKnowledgeLogs(knowledgeName) + .then((response) => { + setCallLogs(response.data ? response.data : []) + setIsLoading(false) + }) + .catch((error) => { + console.error('Error fetching Metrics', error); + }); + } + + return (<> +
+
+ {!isLoading ? +
+
+ {metricsData.map((metric, index) => ( +
+ {metric.label} +
+ {metric.value} +
+
+ ))} +
+
+ Call Logs + {callLogs.length > 0 ?
+ + + + + + + + + + +
Log TimestampAgent NameRun NameModelTokens Used
+
+ + + {callLogs.map((item, index) => ( + + + + + + + + ))} + +
{item.created_at}{item.agent_name}{item.agent_execution_name}{item.model}{item.tokens_consumed}
+
+
: +
+ No Data + No logs to show! +
} +
+
+ :
+
{loadingText}
+
} +
+
+ + ) +} \ No newline at end of file diff --git a/gui/pages/Content/Toolkits/ToolkitWorkspace.js b/gui/pages/Content/Toolkits/ToolkitWorkspace.js index 8356a1eb6..942a280ce 100644 --- a/gui/pages/Content/Toolkits/ToolkitWorkspace.js +++ b/gui/pages/Content/Toolkits/ToolkitWorkspace.js @@ -1,4 +1,4 @@ -import React, {useEffect, useState} from 'react'; +import React, {useEffect, useRef, useState} from 'react'; import Image from 'next/image'; import {ToastContainer, toast} from 'react-toastify'; import { @@ -9,14 +9,19 @@ import { } from "@/pages/api/DashboardService"; import styles from './Tool.module.css'; import {setLocalStorageValue, setLocalStorageArray, returnToolkitIcon, convertToTitleCase} from "@/utils/utils"; +import Metrics from "@/pages/Content/Toolkits/Metrics"; export default function ToolkitWorkspace({env, toolkitDetails, internalId}) { - const [activeTab, setActiveTab] = useState('configuration') + const [activeTab, setActiveTab] = useState('metrics') const [showDescription, setShowDescription] = useState(false) const [apiConfigs, setApiConfigs] = useState([]); const [toolsIncluded, setToolsIncluded] = useState([]); const [loading, setLoading] = useState(true); const authenticateToolkits = ['Google Calendar Toolkit', 'Twitter Toolkit']; + const [toolDropdown, setToolDropdown] = useState(false); + const toolRef = useRef(null); + const [currTool, setCurrTool] = useState(false); + let handleKeyChange = (event, index) => { const updatedData = [...apiConfigs]; @@ -44,6 +49,7 @@ export default function ToolkitWorkspace({env, toolkitDetails, internalId}) { if (toolkitDetails !== null) { if (toolkitDetails.tools) { setToolsIncluded(toolkitDetails.tools); + setCurrTool(toolkitDetails.tools[0].name) } getToolConfig(toolkitDetails.name) @@ -115,10 +121,20 @@ export default function ToolkitWorkspace({env, toolkitDetails, internalId}) { } }, [internalId]); + useEffect(() => { + function handleClickOutside(event) { + if (toolRef.current && !toolRef.current.contains(event.target)) { + setToolDropdown(false) + } + } + document.addEventListener('mousedown', handleClickOutside); + return () => { + document.removeEventListener('mousedown', handleClickOutside); + }; + }, []); return (<>
-
-
+
toolkit-icon @@ -133,7 +149,12 @@ export default function ToolkitWorkspace({env, toolkitDetails, internalId}) {
-
+
+
+
setLocalStorageValue('toolkit_tab_' + String(internalId), 'metrics', setActiveTab)}> +
Metrics
+
setLocalStorageValue('toolkit_tab_' + String(internalId), 'configuration', setActiveTab)}>
Configuration
@@ -142,7 +163,28 @@ export default function ToolkitWorkspace({env, toolkitDetails, internalId}) { onClick={() => setLocalStorageValue('toolkit_tab_' + String(internalId), 'tools_included', setActiveTab)}>
Tools Included
+
+ {!loading && activeTab === 'metrics' &&
+
+
setToolDropdown(!toolDropdown)}> + {currTool}expand-icon +
+
+ {toolDropdown &&
+ {toolsIncluded.map((tool, index) => ( +
{setCurrTool(tool.name); setToolDropdown(false)}}> + {tool.name} +
))} +
} +
+
+
}
+
+
+
{!loading && activeTab === 'configuration' &&
{apiConfigs.length > 0 ? (apiConfigs.map((config, index) => (
@@ -176,9 +218,14 @@ export default function ToolkitWorkspace({env, toolkitDetails, internalId}) {
))}
} +
+
+
+ {activeTab === 'metrics' &&
+ +
}
-
); diff --git a/gui/pages/Content/Toolkits/Toolkits.js b/gui/pages/Content/Toolkits/Toolkits.js index 0a0bfe337..b56ce12f6 100644 --- a/gui/pages/Content/Toolkits/Toolkits.js +++ b/gui/pages/Content/Toolkits/Toolkits.js @@ -18,7 +18,7 @@ export default function Toolkits({sendToolkitData, toolkits, env}) {
} {toolkits && toolkits.length > 0 ? ( -
+
{toolkits.map((tool, index) => tool.name !== null && !excludedToolkits().includes(tool.name) && (
sendToolkitData(tool)}> diff --git a/gui/pages/Dashboard/Settings/Settings.js b/gui/pages/Dashboard/Settings/Settings.js index 7d1263383..66570367e 100644 --- a/gui/pages/Dashboard/Settings/Settings.js +++ b/gui/pages/Dashboard/Settings/Settings.js @@ -5,6 +5,7 @@ import Image from "next/image"; import Model from "@/pages/Dashboard/Settings/Model"; import Database from "@/pages/Dashboard/Settings/Database"; import ApiKeys from "@/pages/Dashboard/Settings/ApiKeys"; +import Webhooks from "@/pages/Dashboard/Settings/Webhooks"; export default function Settings({organisationId, sendDatabaseData}) { const [activeTab, setActiveTab] = useState('model'); @@ -38,12 +39,17 @@ export default function Settings({organisationId, sendDatabaseData}) { api-key-icon API Keys +
{activeTab === 'model' && } {activeTab === 'database' && } {activeTab === 'apikeys' && } + {activeTab === 'webhooks' && }
diff --git a/gui/pages/Dashboard/Settings/Webhooks.js b/gui/pages/Dashboard/Settings/Webhooks.js new file mode 100644 index 000000000..749266c6e --- /dev/null +++ b/gui/pages/Dashboard/Settings/Webhooks.js @@ -0,0 +1,148 @@ +import React, {useState, useEffect} from 'react'; +import {ToastContainer, toast} from 'react-toastify'; +import 'react-toastify/dist/ReactToastify.css'; +import agentStyles from "@/pages/Content/Agents/Agents.module.css"; +import { + editWebhook, + getWebhook, saveWebhook, +} from "@/pages/api/DashboardService"; +import {loadingTextEffect, removeTab} from "@/utils/utils"; +import styles from "@/pages/Content/Marketplace/Market.module.css"; +export default function Webhooks() { + const [webhookUrl, setWebhookUrl] = useState(''); + const [webhookId, setWebhookId] = useState(-1); + const [isLoading, setIsLoading] = useState(true) + const [existingWebhook, setExistingWebhook] = useState(false) + const [isEdtiting, setIsEdtiting] = useState(false) + const [loadingText, setLoadingText] = useState("Loading Webhooks"); + const [selectedCheckboxes, setSelectedCheckboxes] = useState([]); + const checkboxes = [ + { label: 'Agent is running', value: 'RUNNING' }, + { label: 'Agent run is paused', value: 'PAUSED' }, + { label: 'Agent run is completed', value: 'COMPLETED' }, + { label: 'Agent is terminated ', value: 'TERMINATED' }, + { label: 'Agent run max iteration reached', value: 'MAX ITERATION REACHED' }, + ]; + + + useEffect(() => { + loadingTextEffect('Loading Webhooks', setLoadingText, 500); + fetchWebhooks(); + }, []); + + const handleWebhookChange = (event) => { + setWebhookUrl(event.target.value); + }; + + const handleSaveWebhook = () => { + if(!webhookUrl || webhookUrl.trim() === ""){ + toast.error("Enter valid webhook", {autoClose: 1800}); + return; + } + if(isEdtiting){ + editWebhook(webhookId, { url: webhookUrl, filters: {status: selectedCheckboxes}}) + .then((response) => { + setIsEdtiting(false) + fetchWebhooks() + toast.success("Webhook edited successfully", {autoClose: 1800}); + }) + .catch((error) => { + console.error('Error fetching webhook', error); + }); + return; + } + saveWebhook({name : "Webhook 1", url: webhookUrl, headers: {}, filters: {status: selectedCheckboxes}}) + .then((response) => { + setExistingWebhook(true) + setWebhookId(response.data.id) + toast.success("Webhook created successfully", {autoClose: 1800}); + }) + .catch((error) => { + toast.error("Unable to create webhook", {autoClose: 1800}); + console.error('Error saving webhook', error); + }); + } + + const fetchWebhooks = () => { + getWebhook() + .then((response) => { + setIsLoading(false) + if(response.data){ + setWebhookUrl(response.data.url) + setExistingWebhook(true) + setWebhookId(response.data.id) + setSelectedCheckboxes(response.data.filters.status) + } + else{ + setWebhookUrl('') + setExistingWebhook(false) + setWebhookId(-1) + } + }) + .catch((error) => { + console.error('Error fetching webhook', error); + }); + } + + const toggleCheckbox = (value) => { + if (selectedCheckboxes.includes(value)) { + setSelectedCheckboxes(selectedCheckboxes.filter((item) => item !== value)); + } else { + setSelectedCheckboxes([...selectedCheckboxes, value]); + } + }; + + return (<> +
+
+
+ {!isLoading ?
+
+
Webhooks
+ {existingWebhook && + } +
+ +
+ + +
+ +
+ {checkboxes.map((checkbox) => ( + + ))} +
+
+ + {!existingWebhook &&
+ + +
} + +
:
+
{loadingText}
+
} +
+
+
+ + ) +} \ No newline at end of file diff --git a/gui/pages/_app.css b/gui/pages/_app.css index fc813e8ec..8698f03a3 100644 --- a/gui/pages/_app.css +++ b/gui/pages/_app.css @@ -770,6 +770,7 @@ p { .mt_70{margin-top: 70px;} .mt_74{margin-top: 74px;} .mt_80{margin-top: 80px;} +.mt_90{margin-top: 90px;} .mb_1{margin-bottom: 1px;} .mb_2{margin-bottom: 2px;} @@ -815,6 +816,7 @@ p { .mb_60{margin-bottom: 60px;} .mb_70{margin-bottom: 70px;} .mb_74{margin-bottom: 74px;} +.mb_90{margin-bottom: 90px;} .ml_1{margin-left: 1px;} @@ -968,6 +970,14 @@ p { line-height: normal; } +.text_20 { + color: #FFF; + font-size: 20px; + font-style: normal; + font-weight: 400; + line-height: normal; +} + .text_20_bold{ color: #FFF; font-size: 20px; @@ -1051,10 +1061,13 @@ p { .r_0{right: 0} .w_120p{width: 120px} +.w_3{width: 3%} +.w_180p{width: 180px} .w_4{width: 4%} .w_6{width: 6%} .w_10{width: 10%} .w_12{width: 12%} +.w_15{width: 15%} .w_18{width: 18%} .w_20{width: 20%} .w_22{width: 22%} @@ -1062,6 +1075,8 @@ p { .w_50{width: 50%} .w_56{width: 56%} .w_60{width: 60%} +.w_73{width: 73%} +.w_97{width: 97%} .w_100{width: 100%} .w_inherit{width: inherit} .w_fit_content{width:fit-content} @@ -1070,12 +1085,16 @@ p { .mxw_100{max-width: 100%} .mxw_360{max-width: 360px} +.h_31p{height: 31px} .h_32p{height: 32px} .h_44p{height: 44px} .h_100{height: 100%} .h_auto{height: auto} .h_60vh{height: 60vh} .h_75vh{height: 75vh} +.h_80vh{height: 80vh} +.h_calc92{height: calc(100vh - 92px)} +.h_calc_add40{height: calc(80vh + 40px)} .mxh_78vh{max-height: 78vh} @@ -1087,6 +1106,7 @@ p { .justify_space_between{justify-content: space-between} .display_flex{display: inline-flex} +.display_flex_container{display: flex} .align_center{align-items: center} .align_start{align-items: flex-start} @@ -1103,6 +1123,7 @@ p { .cursor_pointer{cursor: pointer} .cursor_default{cursor: default} +.cursor_not_allowed{cursor: not-allowed} .overflow_auto{overflow: auto} .overflowY_scroll{overflow-y: scroll} @@ -1111,12 +1132,17 @@ p { .overflowX_auto{overflow-x: auto} .gap_4{gap:4px;} +.gap_5{gap:5px;} .gap_6{gap:6px;} .gap_8{gap:8px;} .gap_16{gap:16px;} .gap_20{gap:20px;} +.border_gray{border: 1px solid rgba(255, 255, 255, 0.08)} +.border_left_none{border-left: none;} .border_top_none{border-top: none;} +.border_bottom_none{border-bottom: none;} +.border_bottom_grey{border-bottom: 1px solid rgba(255, 255, 255, 0.08)} .border_radius_8{border-radius: 8px;} .border_radius_25{border-radius: 25px;} @@ -1143,6 +1169,8 @@ p { .padding_12_14{padding: 12px 14px;} .padding_0_15{padding: 0px 15px;} +.pd_bottom_5{padding-bottom: 5px} + .flex_1{flex: 1} .flex_wrap{flex-wrap: wrap;} @@ -1452,6 +1480,7 @@ tr{ .bg_black{background: black} .bg_white{background: white} +.bg_none{background: none} .container { height: 100%; @@ -1490,7 +1519,7 @@ tr{ .item_publisher { font-style: normal; font-weight: 500; - font-size: 9px; + font-size: 11px; line-height: 12px; display: flex; align-items: center; diff --git a/gui/pages/api/DashboardService.js b/gui/pages/api/DashboardService.js index 52782be6f..c26330d25 100644 --- a/gui/pages/api/DashboardService.js +++ b/gui/pages/api/DashboardService.js @@ -324,6 +324,21 @@ export const deleteApiKey = (apiId) => { return api.delete(`/api-keys/${apiId}`); }; +export const saveWebhook = (webhook) => { + return api.post(`/webhook/add`, webhook); +}; + +export const getWebhook = () => { + return api.get(`/webhook/get`); +}; + +export const editWebhook = (webhook_id, webook_data) => { + return api.post(`/webhook/edit/${webhook_id}`, webook_data); +}; + +export const publishToMarketplace = (executionId) => { + return api.post(`/agent_templates/publish_template/agent_execution_id/${executionId}`); +}; export const storeApiKey = (model_provider, model_api_key) => { return api.post(`/models_controller/store_api_keys`, {model_provider, model_api_key}); @@ -363,3 +378,21 @@ export const fetchMarketPlaceModel = () => { return api.get(`/models_controller/get/list`) } +export const getToolMetrics = (toolName) => { + return api.get(`analytics/tools/${toolName}/usage`) +} + +export const getToolLogs = (toolName) => { + return api.get(`analytics/tools/${toolName}/logs`) +} + +export const publishTemplateToMarketplace = (agentData) => { + return api.post(`/agent_templates/publish_template`, agentData); +}; +export const getKnowledgeMetrics = (knowledgeName) => { + return api.get(`analytics/knowledge/${knowledgeName}/usage`) +} + +export const getKnowledgeLogs = (knowledgeName) => { + return api.get(`analytics/knowledge/${knowledgeName}/logs`) +} \ No newline at end of file diff --git a/gui/public/images/twitter_icon.svg b/gui/public/images/twitter_icon.svg index cbd05ea9d..2c6e581da 100644 --- a/gui/public/images/twitter_icon.svg +++ b/gui/public/images/twitter_icon.svg @@ -2,8 +2,8 @@ - + - + diff --git a/gui/public/images/webhook_icon.svg b/gui/public/images/webhook_icon.svg new file mode 100644 index 000000000..0e316e430 --- /dev/null +++ b/gui/public/images/webhook_icon.svg @@ -0,0 +1,3 @@ + + + diff --git a/main.py b/main.py index 8a349e2d5..a0cd36223 100644 --- a/main.py +++ b/main.py @@ -45,6 +45,7 @@ from superagi.helper.tool_helper import register_toolkits, register_marketplace_toolkits from superagi.lib.logger import logger from superagi.llms.google_palm import GooglePalm +from superagi.llms.llm_model_factory import build_model_with_api_key from superagi.llms.openai import OpenAi from superagi.llms.replicate import Replicate from superagi.llms.hugging_face import HuggingFace @@ -56,18 +57,24 @@ from superagi.models.workflows.agent_workflow import AgentWorkflow from superagi.models.workflows.iteration_workflow import IterationWorkflow from superagi.models.workflows.iteration_workflow_step import IterationWorkflowStep +from urllib.parse import urlparse app = FastAPI() -database_url = get_config('POSTGRES_URL') +db_host = get_config('DB_HOST', 'super__postgres') +db_url = get_config('DB_URL', None) db_username = get_config('DB_USERNAME') db_password = get_config('DB_PASSWORD') db_name = get_config('DB_NAME') env = get_config('ENV', "DEV") -if db_username is None: - db_url = f'postgresql://{database_url}/{db_name}' +if db_url is None: + if db_username is None: + db_url = f'postgresql://{db_host}/{db_name}' + else: + db_url = f'postgresql://{db_username}:{db_password}@{db_host}/{db_name}' else: - db_url = f'postgresql://{db_username}:{db_password}@{database_url}/{db_name}' + db_url = urlparse(db_url) + db_url = db_url.scheme + "://" + db_url.netloc + db_url.path engine = create_engine(db_url, pool_size=20, # Maximum number of database connections in the pool @@ -344,15 +351,8 @@ async def validate_llm_api_key(request: ValidateAPIKeyRequest, Authorize: AuthJW """API to validate LLM API Key""" source = request.model_source api_key = request.model_api_key - valid_api_key = False - if source == "OpenAi": - valid_api_key = OpenAi(api_key=api_key).verify_access_key() - elif source == "Google Palm": - valid_api_key = GooglePalm(api_key=api_key).verify_access_key() - elif source == "Replicate": - valid_api_key = Replicate(api_key=api_key).verify_access_key() - elif source == "Hugging Face": - valid_api_key = HuggingFace(api_key=api_key).verify_access_key() + model = build_model_with_api_key(source, api_key) + valid_api_key = model.verify_access_key() if model is not None else False if valid_api_key: return {"message": "Valid API Key", "status": "success"} else: diff --git a/migrations/env.py b/migrations/env.py index f6a48ea87..a46ae289c 100644 --- a/migrations/env.py +++ b/migrations/env.py @@ -2,10 +2,8 @@ from sqlalchemy import engine_from_config from sqlalchemy import pool - from alembic import context - -from superagi.config.config import get_config +from urllib.parse import urlparse # this is the Alembic Config object, which provides # access to the values within the .ini file in use. @@ -23,16 +21,18 @@ from superagi.models.base_model import DBBaseModel target_metadata = DBBaseModel.metadata from superagi.models import * +from superagi.config.config import get_config # other values from the config, defined by the needs of env.py, # can be acquired: # my_important_option = config.get_main_option("my_important_option") # ... etc. -database_url = get_config('POSTGRES_URL') +db_host = get_config('DB_HOST', 'super__postgres') db_username = get_config('DB_USERNAME') db_password = get_config('DB_PASSWORD') db_name = get_config('DB_NAME') +database_url = get_config('DB_URL', None) def run_migrations_offline() -> None: """Run migrations in 'offline' mode. @@ -47,10 +47,15 @@ def run_migrations_offline() -> None: """ - if db_username is None: - db_url = f'postgresql://{database_url}/{db_name}' + db_url = database_url + if db_url is None: + if db_username is None: + db_url = f'postgresql://{db_host}/{db_name}' + else: + db_url = f'postgresql://{db_username}:{db_password}@{db_host}/{db_name}' else: - db_url = f'postgresql://{db_username}:{db_password}@{database_url}/{db_name}' + db_url = urlparse(db_url) + db_url = db_url.scheme + "://" + db_url.netloc + db_url.path config.set_main_option("sqlalchemy.url", db_url) @@ -73,6 +78,23 @@ def run_migrations_online() -> None: and associate a connection with the context. """ + + db_host = get_config('DB_HOST', 'super__postgres') + db_username = get_config('DB_USERNAME') + db_password = get_config('DB_PASSWORD') + db_name = get_config('DB_NAME') + db_url = get_config('DB_URL', None) + + if db_url is None: + if db_username is None: + db_url = f'postgresql://{db_host}/{db_name}' + else: + db_url = f'postgresql://{db_username}:{db_password}@{db_host}/{db_name}' + else: + db_url = urlparse(db_url) + db_url = db_url.scheme + "://" + db_url.netloc + db_url.path + + config.set_main_option('sqlalchemy.url', db_url) connectable = engine_from_config( config.get_section(config.config_ini_section, {}), prefix="sqlalchemy.", diff --git a/migrations/versions/40affbf3022b_add_filter_colume_in_webhooks.py b/migrations/versions/40affbf3022b_add_filter_colume_in_webhooks.py new file mode 100644 index 000000000..32f39be18 --- /dev/null +++ b/migrations/versions/40affbf3022b_add_filter_colume_in_webhooks.py @@ -0,0 +1,28 @@ +"""add filter colume in webhooks + +Revision ID: 40affbf3022b +Revises: 5d5f801f28e7 +Create Date: 2023-08-28 12:30:35.171176 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '40affbf3022b' +down_revision = '5d5f801f28e7' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('webhooks', sa.Column('filters', sa.JSON(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('webhooks', 'filters') + # ### end Alembic commands ### diff --git a/requirements.txt b/requirements.txt index 554fd54ea..ab45bb1c7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -156,3 +156,5 @@ html2text==2020.1.16 duckduckgo-search==3.8.3 google-generativeai==0.1.0 unstructured==0.8.1 +ai21==1.2.6 +typing-extensions==4.5.0 diff --git a/superagi/agent/output_handler.py b/superagi/agent/output_handler.py index 2a0aa950f..1fdeb7531 100644 --- a/superagi/agent/output_handler.py +++ b/superagi/agent/output_handler.py @@ -100,7 +100,7 @@ def handle_tool_response(self, session, assistant_reply): action = self.output_parser.parse(assistant_reply) agent = session.query(Agent).filter(Agent.id == self.agent_config["agent_id"]).first() organisation = agent.get_agent_organisation(session) - tool_executor = ToolExecutor(organisation_id=organisation.id, agent_id=agent.id, tools=self.tools) + tool_executor = ToolExecutor(organisation_id=organisation.id, agent_id=agent.id, tools=self.tools, agent_execution_id=self.agent_execution_id) return tool_executor.execute(session, action.name, action.args) def _check_permission_in_restricted_mode(self, session, assistant_reply: str): diff --git a/superagi/agent/tool_executor.py b/superagi/agent/tool_executor.py index 7aed83673..017094164 100644 --- a/superagi/agent/tool_executor.py +++ b/superagi/agent/tool_executor.py @@ -9,10 +9,11 @@ class ToolExecutor: """Executes the tool with the given args.""" FINISH = "finish" - def __init__(self, organisation_id: int, agent_id: int, tools: list): + def __init__(self, organisation_id: int, agent_id: int, tools: list, agent_execution_id: int): self.organisation_id = organisation_id self.agent_id = agent_id self.tools = tools + self.agent_execution_id = agent_execution_id def execute(self, session, tool_name, tool_args): """Executes the tool with the given args. @@ -31,7 +32,7 @@ def execute(self, session, tool_name, tool_args): status = "SUCCESS" tool = tools[tool_name] retry = False - EventHandler(session=session).create_event('tool_used', {'tool_name': tool_name}, self.agent_id, + EventHandler(session=session).create_event('tool_used', {'tool_name': tool_name, 'agent_execution_id': self.agent_execution_id}, self.agent_id, self.organisation_id), try: parsed_args = self.clean_tool_args(tool_args) diff --git a/superagi/apm/knowledge_handler.py b/superagi/apm/knowledge_handler.py new file mode 100644 index 000000000..098791d4c --- /dev/null +++ b/superagi/apm/knowledge_handler.py @@ -0,0 +1,117 @@ +from sqlalchemy.orm import Session +from superagi.models.events import Event +from superagi.models.knowledges import Knowledges +from sqlalchemy import Integer, or_, label, case, and_ +from fastapi import HTTPException +from typing import List, Dict, Union, Any +from sqlalchemy.sql import func +from sqlalchemy.orm import aliased +from superagi.models.agent_config import AgentConfiguration +import pytz +from datetime import datetime + + +class KnowledgeHandler: + def __init__(self, session: Session, organisation_id: int): + self.session = session + self.organisation_id = organisation_id + + + def get_knowledge_usage_by_name(self, knowledge_name: str) -> Dict[str, Dict[str, int]]: + + is_knowledge_valid = self.session.query(Knowledges.id).filter_by(name=knowledge_name).filter(Knowledges.organisation_id == self.organisation_id).first() + if not is_knowledge_valid: + raise HTTPException(status_code=404, detail="Knowledge not found") + EventAlias = aliased(Event) + + knowledge_used_event = self.session.query( + Event.event_property['knowledge_name'].label('knowledge_name'), + func.count(Event.agent_id.distinct()).label('knowledge_unique_agents') + ).filter( + Event.event_name == 'knowledge_picked', + Event.org_id == self.organisation_id, + Event.event_property['knowledge_name'].astext == knowledge_name + ).group_by( + Event.event_property['knowledge_name'] + ).first() + + if knowledge_used_event is None: + return {} + + knowledge_data = { + 'knowledge_unique_agents': knowledge_used_event.knowledge_unique_agents, + 'knowledge_calls': self.session.query( + EventAlias + ).filter( + EventAlias.event_property['tool_name'].astext == 'knowledgesearch', + EventAlias.event_name == 'tool_used', + EventAlias.org_id == self.organisation_id, + EventAlias.agent_id.in_(self.session.query(Event.agent_id).filter( + Event.event_name == 'knowledge_picked', + Event.org_id == self.organisation_id, + Event.event_property['knowledge_name'].astext == knowledge_name + )) + ).count() + } + + return knowledge_data + + + def get_knowledge_events_by_name(self, knowledge_name: str) -> List[Dict[str, Union[str, int, List[str]]]]: + + is_knowledge_valid = self.session.query(Knowledges.id).filter_by(name=knowledge_name).filter(Knowledges.organisation_id == self.organisation_id).first() + + if not is_knowledge_valid: + raise HTTPException(status_code=404, detail="Knowledge not found") + + knowledge_events = self.session.query(Event).filter( + Event.org_id == self.organisation_id, + Event.event_name == 'knowledge_picked', + Event.event_property['knowledge_name'].astext == knowledge_name + ).all() + + knowledge_events = [ke for ke in knowledge_events if 'agent_execution_id' in ke.event_property] + + event_runs = self.session.query(Event).filter( + Event.org_id == self.organisation_id, + or_(Event.event_name == 'run_completed', Event.event_name == 'run_iteration_limit_crossed') + ).all() + + agent_created_events = self.session.query(Event).filter( + Event.org_id == self.organisation_id, + Event.event_name == 'agent_created' + ).all() + + results = [] + + for knowledge_event in knowledge_events: + agent_execution_id = knowledge_event.event_property['agent_execution_id'] + + event_run = next((er for er in event_runs if er.agent_id == knowledge_event.agent_id and er.event_property['agent_execution_id'] == agent_execution_id), None) + agent_created_event = next((ace for ace in agent_created_events if ace.agent_id == knowledge_event.agent_id), None) + try: + user_timezone = AgentConfiguration.get_agent_config_by_key_and_agent_id(session=self.session, key='user_timezone', agent_id=knowledge_event.agent_id) + if user_timezone and user_timezone.value != 'None': + tz = pytz.timezone(user_timezone.value) + else: + tz = pytz.timezone('GMT') + except AttributeError: + tz = pytz.timezone('GMT') + + if event_run and agent_created_event: + actual_time = knowledge_event.created_at.astimezone(tz).strftime("%d %B %Y %H:%M") + + result_dict = { + 'agent_execution_id': agent_execution_id, + 'created_at': actual_time, + 'tokens_consumed': event_run.event_property['tokens_consumed'], + 'calls': event_run.event_property['calls'], + 'agent_execution_name': event_run.event_property['name'], + 'agent_name': agent_created_event.event_property['agent_name'], + 'model': agent_created_event.event_property['model'] + } + if agent_execution_id not in [i['agent_execution_id'] for i in results]: + results.append(result_dict) + + results = sorted(results, key=lambda x: datetime.strptime(x['created_at'], '%d %B %Y %H:%M'), reverse=True) + return results \ No newline at end of file diff --git a/superagi/apm/tools_handler.py b/superagi/apm/tools_handler.py index 72048d529..da3f97cc6 100644 --- a/superagi/apm/tools_handler.py +++ b/superagi/apm/tools_handler.py @@ -1,12 +1,16 @@ -from typing import List, Dict - -from sqlalchemy import func +from typing import List, Dict, Union +from sqlalchemy import func, distinct, and_ from sqlalchemy.orm import Session - +from sqlalchemy import Integer +from fastapi import HTTPException from superagi.models.events import Event from superagi.models.tool import Tool from superagi.models.toolkit import Toolkit - +from sqlalchemy import or_ +from sqlalchemy.sql import label +from datetime import datetime +from superagi.models.agent_config import AgentConfiguration +import pytz class ToolsHandler: def __init__(self, session: Session, organisation_id: int): @@ -54,4 +58,109 @@ def calculate_tool_usage(self) -> List[Dict[str, int]]: 'toolkit': tool_and_toolkit.get(row.tool_name, None) } for row in result] - return tool_usage \ No newline at end of file + return tool_usage + + def get_tool_usage_by_name(self, tool_name: str) -> Dict[str, Dict[str, int]]: + is_tool_name_valid = self.session.query(Tool).filter_by(name=tool_name).first() + + if not is_tool_name_valid: + raise HTTPException(status_code=404, detail="Tool not found") + formatted_tool_name = tool_name.lower().replace(" ", "") + + tool_used_event = self.session.query( + Event.event_property['tool_name'].label('tool_name'), + func.count(Event.id).label('tool_calls'), + func.count(distinct(Event.agent_id)).label('tool_unique_agents') + ).filter( + Event.event_name == 'tool_used', + Event.org_id == self.organisation_id, + Event.event_property['tool_name'].astext == formatted_tool_name + ).group_by( + Event.event_property['tool_name'] + ).first() + + if tool_used_event is None: + return {} + + tool_data = { + 'tool_calls': tool_used_event.tool_calls, + 'tool_unique_agents': tool_used_event.tool_unique_agents + } + + return tool_data + + + def get_tool_events_by_name(self, tool_name: str) -> List[Dict[str, Union[str, int, List[str]]]]: + is_tool_name_valid = self.session.query(Tool).filter_by(name=tool_name).first() + + if not is_tool_name_valid: + raise HTTPException(status_code=404, detail="Tool not found") + + formatted_tool_name = tool_name.lower().replace(" ", "") + + tool_events = self.session.query(Event).filter( + Event.org_id == self.organisation_id, + Event.event_name == 'tool_used', + Event.event_property['tool_name'].astext == formatted_tool_name + ).all() + + tool_events = [te for te in tool_events if 'agent_execution_id' in te.event_property] + + event_runs = self.session.query(Event).filter( + Event.org_id == self.organisation_id, + or_(Event.event_name == 'run_completed', Event.event_name == 'run_iteration_limit_crossed') + ).all() + + agent_created_events = self.session.query(Event).filter( + Event.org_id == self.organisation_id, + Event.event_name == 'agent_created' + ).all() + + results = [] + + for tool_event in tool_events: + agent_execution_id = tool_event.event_property['agent_execution_id'] + + event_run = next((er for er in event_runs if er.agent_id == tool_event.agent_id and er.event_property['agent_execution_id'] == agent_execution_id), None) + agent_created_event = next((ace for ace in agent_created_events if ace.agent_id == tool_event.agent_id), None) + try: + user_timezone = AgentConfiguration.get_agent_config_by_key_and_agent_id(session=self.session, key='user_timezone', agent_id=tool_event.agent_id) + if user_timezone and user_timezone.value != 'None': + tz = pytz.timezone(user_timezone.value) + else: + tz = pytz.timezone('GMT') + except AttributeError: + tz = pytz.timezone('GMT') + + if event_run and agent_created_event: + actual_time = tool_event.created_at.astimezone(tz).strftime("%d %B %Y %H:%M") + other_tools_events = self.session.query( + Event + ).filter( + Event.org_id == self.organisation_id, + Event.event_name == 'tool_used', + Event.event_property['tool_name'].astext != formatted_tool_name, + Event.agent_id == tool_event.agent_id, + Event.id.between(tool_event.id, event_run.id) + ).all() + + other_tools = [ote.event_property['tool_name'] for ote in other_tools_events] + + result_dict = { + 'created_at': actual_time, + 'agent_execution_id': agent_execution_id, + 'tokens_consumed': event_run.event_property['tokens_consumed'], + 'calls': event_run.event_property['calls'], + 'agent_execution_name': event_run.event_property['name'], + 'other_tools': other_tools, + 'agent_name': agent_created_event.event_property['agent_name'], + 'model': agent_created_event.event_property['model'] + } + + if agent_execution_id not in [i['agent_execution_id'] for i in results]: + results.append(result_dict) + + results = sorted(results, key=lambda x: datetime.strptime(x['created_at'], '%d %B %Y %H:%M'), reverse=True) + + return results + \ No newline at end of file diff --git a/superagi/controllers/agent.py b/superagi/controllers/agent.py index a912a722b..2d108d7d9 100644 --- a/superagi/controllers/agent.py +++ b/superagi/controllers/agent.py @@ -132,18 +132,27 @@ def create_agent_with_config(agent_with_config: AgentConfigInput, agent = db.session.query(Agent).filter(Agent.id == db_agent.id, ).first() organisation = agent.get_agent_organisation(db.session) - EventHandler(session=db.session).create_event('run_created', {'agent_execution_id': execution.id, - 'agent_execution_name': execution.name}, db_agent.id, - - organisation.id if organisation else 0), - EventHandler(session=db.session).create_event('agent_created', {'agent_name': agent_with_config.name, - - 'model': agent_with_config.model}, db_agent.id, - + + EventHandler(session=db.session).create_event('run_created', + {'agent_execution_id': execution.id, + 'agent_execution_name': execution.name}, + db_agent.id, + organisation.id if organisation else 0), + + if agent_with_config.knowledge: + knowledge_name = db.session.query(Knowledges.name).filter(Knowledges.id == agent_with_config.knowledge).first()[0] + EventHandler(session=db.session).create_event('knowledge_picked', + {'knowledge_name': knowledge_name, + 'agent_execution_id': execution.id}, + db_agent.id, + organisation.id if organisation else 0) + + EventHandler(session=db.session).create_event('agent_created', + {'agent_name': agent_with_config.name, + 'model': agent_with_config.model}, + db_agent.id, organisation.id if organisation else 0) - # execute_agent.delay(execution.id, datetime.now()) - db.session.commit() return { diff --git a/superagi/controllers/agent_execution.py b/superagi/controllers/agent_execution.py index fca7def57..9658712ec 100644 --- a/superagi/controllers/agent_execution.py +++ b/superagi/controllers/agent_execution.py @@ -24,6 +24,7 @@ from superagi.apm.event_handler import EventHandler from superagi.controllers.tool import ToolOut from superagi.models.agent_config import AgentConfiguration +from superagi.models.knowledges import Knowledges router = APIRouter() @@ -86,12 +87,12 @@ def create_agent_execution(agent_execution: AgentExecutionIn, iteration_step_id = IterationWorkflow.fetch_trigger_step_id(db.session, start_step.action_reference_id).id if start_step.action_type == "ITERATION_WORKFLOW" else -1 - db_agent_execution = AgentExecution(status="RUNNING", last_execution_time=datetime.now(), + db_agent_execution = AgentExecution(status="CREATED", last_execution_time=datetime.now(), agent_id=agent_execution.agent_id, name=agent_execution.name, num_of_calls=0, num_of_tokens=0, current_agent_step_id=start_step.id, iteration_workflow_step_id=iteration_step_id) - + agent_execution_configs = { "goal": agent_execution.goal, "instruction": agent_execution.instruction @@ -118,15 +119,31 @@ def create_agent_execution(agent_execution: AgentExecutionIn, db.session.add(db_agent_execution) db.session.commit() db.session.flush() + + #update status from CREATED to RUNNING + db_agent_execution.status = "RUNNING" + db.session.commit() + AgentExecutionConfiguration.add_or_update_agent_execution_config(session=db.session, execution=db_agent_execution, agent_execution_configs=agent_execution_configs) organisation = agent.get_agent_organisation(db.session) - EventHandler(session=db.session).create_event('run_created', {'agent_execution_id': db_agent_execution.id,'agent_execution_name':db_agent_execution.name}, - agent_execution.agent_id, organisation.id if organisation else 0) - + agent_execution_knowledge = AgentConfiguration.get_agent_config_by_key_and_agent_id(session= db.session, key= 'knowledge', agent_id= agent_execution.agent_id) + + EventHandler(session=db.session).create_event('run_created', + {'agent_execution_id': db_agent_execution.id, + 'agent_execution_name':db_agent_execution.name}, + agent_execution.agent_id, + organisation.id if organisation else 0) + if agent_execution_knowledge and agent_execution_knowledge.value != 'None': + knowledge_name = Knowledges.get_knowledge_from_id(db.session, int(agent_execution_knowledge.value)).name + if knowledge_name is not None: + EventHandler(session=db.session).create_event('knowledge_picked', + {'knowledge_name': knowledge_name, + 'agent_execution_id': db_agent_execution.id}, + agent_execution.agent_id, + organisation.id if organisation else 0) Models.api_key_from_configurations(session=db.session, organisation_id=organisation.id) - if db_agent_execution.status == "RUNNING": execute_agent.delay(db_agent_execution.id, datetime.now()) @@ -147,6 +164,7 @@ def create_agent_run(agent_execution: AgentRunIn, Authorize: AuthJWT = Depends(c Raises: HTTPException (Status Code=404): If the agent is not found. """ + agent = db.session.query(Agent).filter(Agent.id == agent_execution.agent_id, Agent.is_deleted == False).first() if not agent: raise HTTPException(status_code = 404, detail = "Agent not found") @@ -159,7 +177,7 @@ def create_agent_run(agent_execution: AgentRunIn, Authorize: AuthJWT = Depends(c iteration_step_id = IterationWorkflow.fetch_trigger_step_id(db.session, start_step.action_reference_id).id if start_step.action_type == "ITERATION_WORKFLOW" else -1 - db_agent_execution = AgentExecution(status="RUNNING", last_execution_time=datetime.now(), + db_agent_execution = AgentExecution(status="CREATED", last_execution_time=datetime.now(), agent_id=agent_execution.agent_id, name=agent_execution.name, num_of_calls=0, num_of_tokens=0, current_agent_step_id=start_step.id, @@ -183,13 +201,29 @@ def create_agent_run(agent_execution: AgentRunIn, Authorize: AuthJWT = Depends(c db.session.add(db_agent_execution) db.session.commit() db.session.flush() - + + #update status from CREATED to RUNNING + db_agent_execution.status = "RUNNING" + db.session.commit() + AgentExecutionConfiguration.add_or_update_agent_execution_config(session = db.session, execution = db_agent_execution, agent_execution_configs = agent_execution_configs) organisation = agent.get_agent_organisation(db.session) - EventHandler(session=db.session).create_event('run_created', {'agent_execution_id': db_agent_execution.id,'agent_execution_name':db_agent_execution.name}, - agent_execution.agent_id, organisation.id if organisation else 0) + EventHandler(session=db.session).create_event('run_created', + {'agent_execution_id': db_agent_execution.id, + 'agent_execution_name':db_agent_execution.name}, + agent_execution.agent_id, + organisation.id if organisation else 0) + agent_execution_knowledge = AgentConfiguration.get_agent_config_by_key_and_agent_id(session= db.session, key= 'knowledge', agent_id= agent_execution.agent_id) + if agent_execution_knowledge and agent_execution_knowledge.value != 'None': + knowledge_name = Knowledges.get_knowledge_from_id(db.session, int(agent_execution_knowledge.value)).name + if knowledge_name is not None: + EventHandler(session=db.session).create_event('knowledge_picked', + {'knowledge_name': knowledge_name, + 'agent_execution_id': db_agent_execution.id}, + agent_execution.agent_id, + organisation.id if organisation else 0) if db_agent_execution.status == "RUNNING": execute_agent.delay(db_agent_execution.id, datetime.now()) diff --git a/superagi/controllers/agent_template.py b/superagi/controllers/agent_template.py index 7e9abb0fe..799df391b 100644 --- a/superagi/controllers/agent_template.py +++ b/superagi/controllers/agent_template.py @@ -6,7 +6,10 @@ from pydantic import BaseModel from main import get_config +from superagi.controllers.types.agent_execution_config import AgentRunIn +from superagi.controllers.types.agent_publish_config import AgentPublish from superagi.helper.auth import get_user_organisation +from superagi.helper.auth import get_current_user from superagi.models.agent import Agent from superagi.models.agent_config import AgentConfiguration from superagi.models.agent_execution import AgentExecution @@ -138,58 +141,58 @@ def edit_agent_template(agent_template_id: int, db.session.commit() db.session.flush() -@router.put("/update_agent_template/{agent_template_id}", status_code=200) -def edit_agent_template(agent_template_id: int, - updated_agent_configs: dict, - organisation=Depends(get_user_organisation)): +# @router.put("/update_agent_template/{agent_template_id}", status_code=200) +# def edit_agent_template(agent_template_id: int, +# updated_agent_configs: dict, +# organisation=Depends(get_user_organisation)): - """ - Update the details of an agent template. +# """ +# Update the details of an agent template. - Args: - agent_template_id (int): The ID of the agent template to update. - edited_agent_configs (dict): The updated agent configurations. - organisation (Depends): Dependency to get the user organisation. +# Args: +# agent_template_id (int): The ID of the agent template to update. +# edited_agent_configs (dict): The updated agent configurations. +# organisation (Depends): Dependency to get the user organisation. - Returns: - HTTPException (status_code=200): If the agent gets successfully edited. +# Returns: +# HTTPException (status_code=200): If the agent gets successfully edited. - Raises: - HTTPException (status_code=404): If the agent template is not found. - """ +# Raises: +# HTTPException (status_code=404): If the agent template is not found. +# """ - db_agent_template = db.session.query(AgentTemplate).filter(AgentTemplate.organisation_id == organisation.id, - AgentTemplate.id == agent_template_id).first() - if db_agent_template is None: - raise HTTPException(status_code=404, detail="Agent Template not found") +# db_agent_template = db.session.query(AgentTemplate).filter(AgentTemplate.organisation_id == organisation.id, +# AgentTemplate.id == agent_template_id).first() +# if db_agent_template is None: +# raise HTTPException(status_code=404, detail="Agent Template not found") - db_agent_template.name = updated_agent_configs["name"] - db_agent_template.description = updated_agent_configs["description"] +# db_agent_template.name = updated_agent_configs["name"] +# db_agent_template.description = updated_agent_configs["description"] - db.session.commit() +# db.session.commit() - agent_config_values = updated_agent_configs.get('agent_configs', {}) +# agent_config_values = updated_agent_configs.get('agent_configs', {}) - for key, value in agent_config_values.items(): - if isinstance(value, (list, dict)): - value = json.dumps(value) - config = db.session.query(AgentTemplateConfig).filter( - AgentTemplateConfig.agent_template_id == agent_template_id, - AgentTemplateConfig.key == key - ).first() +# for key, value in agent_config_values.items(): +# if isinstance(value, (list, dict)): +# value = json.dumps(value) +# config = db.session.query(AgentTemplateConfig).filter( +# AgentTemplateConfig.agent_template_id == agent_template_id, +# AgentTemplateConfig.key == key +# ).first() - if config is not None: - config.value = value - else: - new_config = AgentTemplateConfig( - agent_template_id=agent_template_id, - key=key, - value= value - ) - db.session.add(new_config) +# if config is not None: +# config.value = value +# else: +# new_config = AgentTemplateConfig( +# agent_template_id=agent_template_id, +# key=key, +# value= value +# ) +# db.session.add(new_config) - db.session.commit() - db.session.flush() +# db.session.commit() +# db.session.flush() @router.post("/save_agent_as_template/agent_id/{agent_id}/agent_execution_id/{agent_execution_id}") @@ -411,4 +414,129 @@ def fetch_agent_config_from_template(agent_template_id: int, agent_workflow = AgentWorkflow.find_by_id(db.session, agent_template.agent_workflow_id) template_config_dict["agent_workflow"] = agent_workflow.name - return template_config_dict \ No newline at end of file + return template_config_dict + + +@router.post("/publish_template/agent_execution_id/{agent_execution_id}", status_code=201) +def publish_template(agent_execution_id: str, organisation=Depends(get_user_organisation), user=Depends(get_current_user)): + + """ + Publish an agent execution as a template. + + Args: + agent_execution_id (str): The ID of the agent execution to save as a template. + organisation (Depends): Dependency to get the user organisation. + user (Depends): Dependency to get the user. + + Returns: + dict: The saved agent template. + + Raises: + HTTPException (status_code=404): If the agent or agent execution configurations are not found. + """ + + if agent_execution_id == 'undefined': + raise HTTPException(status_code = 404, detail = "Agent Execution Id undefined") + + agent_executions = AgentExecution.get_agent_execution_from_id(db.session, agent_execution_id) + if agent_executions is None: + raise HTTPException(status_code = 404, detail = "Agent Execution not found") + agent_id = agent_executions.agent_id + + agent = db.session.query(Agent).filter(Agent.id == agent_id).first() + if agent is None: + raise HTTPException(status_code=404, detail="Agent not found") + + agent_execution_configurations = db.session.query(AgentExecutionConfiguration).filter(AgentExecutionConfiguration.agent_execution_id == agent_execution_id).all() + if not agent_execution_configurations: + raise HTTPException(status_code=404, detail="Agent execution configurations not found") + + agent_template = AgentTemplate(name=agent.name, description=agent.description, + agent_workflow_id=agent.agent_workflow_id, + organisation_id=organisation.id) + db.session.add(agent_template) + db.session.commit() + + main_keys = AgentTemplate.main_keys() + for agent_execution_configuration in agent_execution_configurations: + config_value = agent_execution_configuration.value + if agent_execution_configuration.key not in main_keys: + continue + if agent_execution_configuration.key == "tools": + config_value = str(Tool.convert_tool_ids_to_names(db, eval(agent_execution_configuration.value))) + agent_template_config = AgentTemplateConfig(agent_template_id=agent_template.id, key=agent_execution_configuration.key, + value=config_value) + db.session.add(agent_template_config) + + agent_template_configs = [ + AgentTemplateConfig(agent_template_id=agent_template.id, key="status", value="UNDER REVIEW"), + AgentTemplateConfig(agent_template_id=agent_template.id, key="Contributor Name", value=user.name), + AgentTemplateConfig(agent_template_id=agent_template.id, key="Contributor Email", value=user.email)] + db.session.add_all(agent_template_configs) + + db.session.commit() + db.session.flush() + return agent_template.to_dict() + +@router.post("/publish_template", status_code=201) +def handle_publish_template(updated_details: AgentPublish, organisation=Depends(get_user_organisation), user=Depends(get_current_user)): + + """ + Publish a template from edit template page. + + Args: + organisation (Depends): Dependency to get the user organisation. + user (Depends): Dependency to get the user. + + Returns: + dict: The saved agent template. + + Raises: + HTTPException (status_code=404): If the agent template or workflow are not found. + """ + + old_template_id = updated_details.agent_template_id + old_agent_template = db.session.query(AgentTemplate).filter(AgentTemplate.id==old_template_id, AgentTemplate.organisation_id==organisation.id).first() + if old_agent_template is None: + raise HTTPException(status_code = 404, detail = "Agent Template not found") + agent_workflow_id = old_agent_template.agent_workflow_id + if agent_workflow_id is None: + raise HTTPException(status_code = 404, detail = "Agent Workflow not found") + + agent_template = AgentTemplate(name=updated_details.name, description=updated_details.description, + agent_workflow_id=agent_workflow_id, + organisation_id=organisation.id) + db.session.add(agent_template) + db.session.commit() + + agent_template_configs = { + "goal": updated_details.goal, + "instruction": updated_details.instruction, + "constraints": updated_details.constraints, + "toolkits": updated_details.toolkits, + "exit": updated_details.exit, + "tools": updated_details.tools, + "iteration_interval": updated_details.iteration_interval, + "model": updated_details.model, + "permission_type": updated_details.permission_type, + "LTM_DB": updated_details.LTM_DB, + "max_iterations": updated_details.max_iterations, + "user_timezone": updated_details.user_timezone, + "knowledge": updated_details.knowledge + } + + for key, value in agent_template_configs.items(): + if key == "tools": + value = Tool.convert_tool_ids_to_names(db, value) + agent_template_config = AgentTemplateConfig(agent_template_id=agent_template.id, key=key, value=str(value)) + db.session.add(agent_template_config) + + agent_template_configs = [ + AgentTemplateConfig(agent_template_id=agent_template.id, key="status", value="UNDER REVIEW"), + AgentTemplateConfig(agent_template_id=agent_template.id, key="Contributor Name", value=user.name), + AgentTemplateConfig(agent_template_id=agent_template.id, key="Contributor Email", value=user.email)] + db.session.add_all(agent_template_configs) + + db.session.commit() + db.session.flush() + return agent_template.to_dict() \ No newline at end of file diff --git a/superagi/controllers/analytics.py b/superagi/controllers/analytics.py index 5dbfec7d9..c2c78f764 100644 --- a/superagi/controllers/analytics.py +++ b/superagi/controllers/analytics.py @@ -3,6 +3,7 @@ from superagi.apm.analytics_helper import AnalyticsHelper from superagi.apm.event_handler import EventHandler from superagi.apm.tools_handler import ToolsHandler +from superagi.apm.knowledge_handler import KnowledgeHandler from fastapi_jwt_auth import AuthJWT from fastapi_sqlalchemy import db import logging @@ -59,3 +60,48 @@ def get_tools_used(organisation=Depends(get_user_organisation)): except Exception as e: logging.error(f"Error while calculating tool usage: {str(e)}") raise HTTPException(status_code=500, detail="Internal Server Error") + + +@router.get("/tools/{tool_name}/usage", status_code=200) +def get_tool_usage(tool_name: str, organisation=Depends(get_user_organisation)): + try: + return ToolsHandler(session=db.session, organisation_id=organisation.id).get_tool_usage_by_name(tool_name) + except Exception as e: + if hasattr(e, 'status_code'): + raise HTTPException(status_code=e.status_code, detail=e.detail) + else: + raise HTTPException(status_code=500, detail="Internal Server Error") + + +@router.get("/knowledge/{knowledge_name}/usage", status_code=200) +def get_knowledge_usage(knowledge_name:str, organisation=Depends(get_user_organisation)): + try: + return KnowledgeHandler(session=db.session, organisation_id=organisation.id).get_knowledge_usage_by_name(knowledge_name) + except Exception as e: + if hasattr(e, 'status_code'): + raise HTTPException(status_code=e.status_code, detail=e.detail) + else: + raise HTTPException(status_code=500, detail="Internal Server Error") + + +@router.get("/tools/{tool_name}/logs", status_code=200) +def get_tool_logs(tool_name: str, organisation=Depends(get_user_organisation)): + try: + return ToolsHandler(session=db.session, organisation_id=organisation.id).get_tool_events_by_name(tool_name) + except Exception as e: + logging.error(f"Error while getting tool event details: {str(e)}") + if hasattr(e, 'status_code'): + raise HTTPException(status_code=e.status_code, detail=e.detail) + else: + raise HTTPException(status_code=500, detail="Internal Server Error") + +@router.get("/knowledge/{knowledge_name}/logs", status_code=200) +def get_knowledge_logs(knowledge_name: str, organisation=Depends(get_user_organisation)): + try: + return KnowledgeHandler(session=db.session, organisation_id=organisation.id).get_knowledge_events_by_name(knowledge_name) + except Exception as e: + logging.error(f"Error while getting knowledge event details: {str(e)}") + if hasattr(e, 'status_code'): + raise HTTPException(status_code=e.status_code, detail=e.detail) + else: + raise HTTPException(status_code=500, detail="Internal Server Error") \ No newline at end of file diff --git a/superagi/controllers/api/agent.py b/superagi/controllers/api/agent.py index 95b35c4f1..d057e7f0b 100644 --- a/superagi/controllers/api/agent.py +++ b/superagi/controllers/api/agent.py @@ -14,6 +14,7 @@ from superagi.models.workflows.agent_workflow import AgentWorkflow from superagi.models.agent_execution import AgentExecution from superagi.models.organisation import Organisation +from superagi.models.knowledges import Knowledges from superagi.models.resource import Resource from superagi.controllers.types.agent_with_config import AgentConfigExtInput,AgentConfigUpdateExtInput from superagi.models.workflows.iteration_workflow import IterationWorkflow @@ -117,14 +118,14 @@ def create_run(agent_id:int,agent_execution: AgentExecutionIn,api_key: str = Sec db_schedule=AgentSchedule.find_by_agent_id(db.session, agent_id) if db_schedule is not None: raise HTTPException(status_code=409, detail="Agent is already scheduled,cannot run") - start_step_id = AgentWorkflow.fetch_trigger_step_id(db.session, agent.agent_workflow_id) + start_step = AgentWorkflow.fetch_trigger_step_id(db.session, agent.agent_workflow_id) db_agent_execution=AgentExecution.get_execution_by_agent_id_and_status(db.session, agent_id, "CREATED") if db_agent_execution is None: db_agent_execution = AgentExecution(status="RUNNING", last_execution_time=datetime.now(), agent_id=agent_id, name=agent_execution.name, num_of_calls=0, num_of_tokens=0, - current_step_id=start_step_id) + current_agent_step_id=start_step.id) db.session.add(db_agent_execution) else: db_agent_execution.status = "RUNNING" @@ -144,8 +145,23 @@ def create_run(agent_id:int,agent_execution: AgentExecutionIn,api_key: str = Sec if agent_execution_configs != {}: AgentExecutionConfiguration.add_or_update_agent_execution_config(session=db.session, execution=db_agent_execution, agent_execution_configs=agent_execution_configs) - EventHandler(session=db.session).create_event('run_created', {'agent_execution_id': db_agent_execution.id,'agent_execution_name':db_agent_execution.name}, - agent_id, organisation.id if organisation else 0) + EventHandler(session=db.session).create_event('run_created', + {'agent_execution_id': db_agent_execution.id, + 'agent_execution_name':db_agent_execution.name + }, + agent_id, + organisation.id if organisation else 0) + + agent_execution_knowledge = AgentConfiguration.get_agent_config_by_key_and_agent_id(session= db.session, key= 'knowledge', agent_id= agent_id) + if agent_execution_knowledge and agent_execution_knowledge.value != 'None': + knowledge_name = Knowledges.get_knowledge_from_id(db.session, int(agent_execution_knowledge.value)).name + if knowledge_name is not None: + EventHandler(session=db.session).create_event('knowledge_picked', + {'knowledge_name': knowledge_name, + 'agent_execution_id': db_agent_execution.id}, + agent_id, + organisation.id if organisation else 0 + ) if db_agent_execution.status == "RUNNING": execute_agent.delay(db_agent_execution.id, datetime.now()) @@ -269,7 +285,8 @@ def pause_agent_runs(agent_id:int,execution_state_change_input:ExecutionStateCha db_execution_arr=AgentExecution.get_all_executions_by_status_and_agent_id(db.session, agent.id, execution_state_change_input, "RUNNING") - if len(db_execution_arr) != len(execution_state_change_input.run_ids): + if db_execution_arr is not None and execution_state_change_input.run_ids is not None \ + and len(db_execution_arr) != len(execution_state_change_input.run_ids): raise HTTPException(status_code=404, detail="One or more run id(s) not found") for ind_execution in db_execution_arr: @@ -298,7 +315,8 @@ def resume_agent_runs(agent_id:int,execution_state_change_input:ExecutionStateCh db_execution_arr=AgentExecution.get_all_executions_by_status_and_agent_id(db.session, agent.id, execution_state_change_input, "PAUSED") - if len(db_execution_arr) != len(execution_state_change_input.run_ids): + if db_execution_arr is not None and execution_state_change_input.run_ids is not None\ + and len(db_execution_arr) != len(execution_state_change_input.run_ids): raise HTTPException(status_code=404, detail="One or more run id(s) not found") for ind_execution in db_execution_arr: @@ -312,7 +330,7 @@ def resume_agent_runs(agent_id:int,execution_state_change_input:ExecutionStateCh "result":"success" } -@router.post("/resources/output",status_code=201) +@router.post("/resources/output",status_code=200) def get_run_resources(run_id_config:RunIDConfig,api_key: str = Security(validate_api_key),organisation:Organisation = Depends(get_organisation_from_api_key)): if get_config('STORAGE_TYPE') != "S3": raise HTTPException(status_code=400,detail="This endpoint only works when S3 is configured") diff --git a/superagi/controllers/api_key.py b/superagi/controllers/api_key.py index 57e5c739b..458e68db5 100644 --- a/superagi/controllers/api_key.py +++ b/superagi/controllers/api_key.py @@ -5,51 +5,64 @@ from fastapi_jwt_auth import AuthJWT from fastapi_sqlalchemy import db from pydantic import BaseModel -from superagi.helper.auth import get_user_organisation +from superagi.helper.auth import get_user_organisation, validate_api_key from superagi.helper.auth import check_auth from superagi.models.api_key import ApiKey from typing import Optional, Annotated + router = APIRouter() + class ApiKeyIn(BaseModel): - id:int + id: int name: str + class Config: orm_mode = True + class ApiKeyDeleteIn(BaseModel): - id:int + id: int + class Config: orm_mode = True + @router.post("") -def create_api_key(name: Annotated[str,Body(embed=True)], Authorize: AuthJWT = Depends(check_auth), organisation=Depends(get_user_organisation)): - api_key=str(uuid.uuid4()) - obj=ApiKey(key=api_key,name=name,org_id=organisation.id) +def create_api_key(name: Annotated[str, Body(embed=True)], Authorize: AuthJWT = Depends(check_auth), + organisation=Depends(get_user_organisation)): + api_key = str(uuid.uuid4()) + obj = ApiKey(key=api_key, name=name, org_id=organisation.id) db.session.add(obj) db.session.commit() db.session.flush() return {"api_key": api_key} + +@router.get("/validate") +def get_api_key(api_key: str = Depends(validate_api_key)): + return {"success": True} + + @router.get("") def get_all(Authorize: AuthJWT = Depends(check_auth), organisation=Depends(get_user_organisation)): - api_keys=ApiKey.get_by_org_id(db.session, organisation.id) + api_keys = ApiKey.get_by_org_id(db.session, organisation.id) return api_keys + @router.delete("/{api_key_id}") -def delete_api_key(api_key_id:int, Authorize: AuthJWT = Depends(check_auth)): - api_key=ApiKey.get_by_id(db.session, api_key_id) +def delete_api_key(api_key_id: int, Authorize: AuthJWT = Depends(check_auth)): + api_key = ApiKey.get_by_id(db.session, api_key_id) if api_key is None: raise HTTPException(status_code=404, detail="API key not found") ApiKey.delete_by_id(db.session, api_key_id) return {"success": True} + @router.put("") -def edit_api_key(api_key_in:ApiKeyIn,Authorize: AuthJWT = Depends(check_auth)): - api_key=ApiKey.get_by_id(db.session, api_key_in.id) +def edit_api_key(api_key_in: ApiKeyIn, Authorize: AuthJWT = Depends(check_auth)): + api_key = ApiKey.get_by_id(db.session, api_key_in.id) if api_key is None: raise HTTPException(status_code=404, detail="API key not found") ApiKey.update_api_key(db.session, api_key_in.id, api_key_in.name) return {"success": True} - - diff --git a/superagi/controllers/models_controller.py b/superagi/controllers/models_controller.py index e82e6ad60..abb771060 100644 --- a/superagi/controllers/models_controller.py +++ b/superagi/controllers/models_controller.py @@ -102,7 +102,7 @@ async def fetch_data(request: ModelName, organisation=Depends(get_user_organisat @router.get("/get/list", status_code=200) -def get_models_list(page: int = 0, organisation=Depends(get_user_organisation)): +def get_knowledge_list(page: int = 0, organisation=Depends(get_user_organisation)): """ Get Marketplace Model list. @@ -121,7 +121,7 @@ def get_models_list(page: int = 0, organisation=Depends(get_user_organisation)): @router.get("/marketplace/list/{page}", status_code=200) -def get_marketplace_models_list(page: int = 0): +def get_marketplace_knowledge_list(page: int = 0): organisation_id = get_config("MARKETPLACE_ORGANISATION_ID") if organisation_id is not None: organisation_id = int(organisation_id) diff --git a/superagi/controllers/organisation.py b/superagi/controllers/organisation.py index aa738d64b..e366c5966 100644 --- a/superagi/controllers/organisation.py +++ b/superagi/controllers/organisation.py @@ -11,6 +11,7 @@ from superagi.helper.encyption_helper import decrypt_data from superagi.helper.tool_helper import register_toolkits from superagi.llms.google_palm import GooglePalm +from superagi.llms.llm_model_factory import build_model_with_api_key from superagi.llms.openai import OpenAi from superagi.models.configuration import Configuration from superagi.models.organisation import Organisation @@ -170,11 +171,8 @@ def get_llm_models(organisation=Depends(get_user_organisation)): detail="Organisation not found") decrypted_api_key = decrypt_data(model_api_key.value) - models = [] - if model_source.value == "OpenAi": - models = OpenAi(api_key=decrypted_api_key).get_models() - elif model_source.value == "Google Palm": - models = GooglePalm(api_key=decrypted_api_key).get_models() + model = build_model_with_api_key(model_source.value, decrypted_api_key) + models = model.get_models() if model is not None else [] return models diff --git a/superagi/controllers/toolkit.py b/superagi/controllers/toolkit.py index 1b92e824c..8fb175e01 100644 --- a/superagi/controllers/toolkit.py +++ b/superagi/controllers/toolkit.py @@ -157,8 +157,8 @@ def install_toolkit_from_marketplace(toolkit_name: str, folder_name=tool['folder_name'], class_name=tool['class_name'], file_name=tool['file_name'], toolkit_id=db_toolkit.id) for config in toolkit['configs']: - ToolConfig.add_or_update(session=db.session, toolkit_id=db_toolkit.id, key=config['key'], value=config['value'], key_type = config['key_type'], is_secret = config['is_secret'], is_required = config['is_required']) - + ToolConfig.add_or_update(session=db.session, toolkit_id=db_toolkit.id, key=config['key'], value=config['value'], key_type = config['key_type'], is_secret = config['is_secret'], is_required = config['is_required']) + return {"message": "ToolKit installed successfully"} diff --git a/superagi/controllers/types/agent_publish_config.py b/superagi/controllers/types/agent_publish_config.py new file mode 100644 index 000000000..47f83f040 --- /dev/null +++ b/superagi/controllers/types/agent_publish_config.py @@ -0,0 +1,23 @@ +from typing import List, Optional +from pydantic import BaseModel + +class AgentPublish(BaseModel): + name: str + description: str + agent_template_id: int + goal: Optional[List[str]] + instruction: Optional[List[str]] + constraints: List[str] + toolkits: List[int] + tools: List[int] + exit: str + iteration_interval: int + model: str + permission_type: str + LTM_DB: str + max_iterations: int + user_timezone: Optional[str] + knowledge: Optional[int] + + class Config: + orm_mode = True \ No newline at end of file diff --git a/superagi/controllers/webhook.py b/superagi/controllers/webhook.py index 0a55bd216..0f49bdb0f 100644 --- a/superagi/controllers/webhook.py +++ b/superagi/controllers/webhook.py @@ -1,6 +1,6 @@ from datetime import datetime - -from fastapi import APIRouter +from typing import Optional +from fastapi import APIRouter, HTTPException from fastapi import Depends from fastapi_jwt_auth import AuthJWT from fastapi_sqlalchemy import db @@ -17,6 +17,7 @@ class WebHookIn(BaseModel): name: str url: str headers: dict + filters: dict class Config: orm_mode = True @@ -31,12 +32,21 @@ class WebHookOut(BaseModel): is_deleted: bool created_at: datetime updated_at: datetime + filters: dict + + class Config: + orm_mode = True + +class WebHookEdit(BaseModel): + url: str + filters: dict class Config: orm_mode = True -# CRUD Operations + +# CRUD Operations` @router.post("/add", response_model=WebHookOut, status_code=201) def create_webhook(webhook: WebHookIn, Authorize: AuthJWT = Depends(check_auth), organisation=Depends(get_user_organisation)): @@ -52,9 +62,54 @@ def create_webhook(webhook: WebHookIn, Authorize: AuthJWT = Depends(check_auth), HTTPException (Status Code=404): If the associated project is not found. """ db_webhook = Webhooks(name=webhook.name, url=webhook.url, headers=webhook.headers, org_id=organisation.id, - is_deleted=False) + is_deleted=False, filters=webhook.filters) db.session.add(db_webhook) db.session.commit() db.session.flush() - return db_webhook + +@router.get("/get", response_model=Optional[WebHookOut]) +def get_all_webhooks( + Authorize: AuthJWT = Depends(check_auth), + organisation=Depends(get_user_organisation), +): + """ + Retrieves a single webhook for the authenticated user's organisation. + + Returns: + JSONResponse: A JSON response containing the retrieved webhook. + + Raises: + """ + webhook = db.session.query(Webhooks).filter(Webhooks.org_id == organisation.id, Webhooks.is_deleted == False).first() + return webhook + +@router.post("/edit/{webhook_id}", response_model=WebHookOut) +def edit_webhook( + updated_webhook: WebHookEdit, + webhook_id: int, + Authorize: AuthJWT = Depends(check_auth), + organisation=Depends(get_user_organisation), +): + """ + Soft-deletes a webhook by setting the value of is_deleted to True. + + Args: + webhook_id (int): The ID of the webhook to delete. + + Returns: + WebHookOut: The deleted webhook. + + Raises: + HTTPException (Status Code=404): If the webhook is not found. + """ + webhook = db.session.query(Webhooks).filter(Webhooks.org_id == organisation.id, Webhooks.id == webhook_id, Webhooks.is_deleted == False).first() + if webhook is None: + raise HTTPException(status_code=404, detail="Webhook not found") + + webhook.url = updated_webhook.url + webhook.filters = updated_webhook.filters + + db.session.commit() + + return webhook \ No newline at end of file diff --git a/superagi/helper/github_helper.py b/superagi/helper/github_helper.py index 381e7392a..bb34eaf56 100644 --- a/superagi/helper/github_helper.py +++ b/superagi/helper/github_helper.py @@ -9,7 +9,8 @@ from superagi.types.storage_types import StorageType from superagi.config.config import get_config from superagi.helper.s3_helper import S3Helper - +from datetime import timedelta, datetime +import json class GithubHelper: def __init__(self, github_access_token, github_username): @@ -238,6 +239,7 @@ def add_file(self, repository_owner, repository_name, file_name, folder_path, he logger.info('Failed to upload file content:', file_response.json()['message']) return file_response.status_code + def create_pull_request(self, repository_owner, repository_name, head_branch, base_branch, headers): """ Creates a pull request in the given repository. @@ -335,3 +337,135 @@ def _get_file_contents(self, file_name, agent_id, agent_execution_id, session): with open(final_path, "r") as file: attachment_data = file.read().decode('utf-8') return attachment_data + + + def get_pull_request_content(self, repository_owner, repository_name, pull_request_number): + """ + Gets the content of a specific pull request from a GitHub repository. + + Args: + repository_owner (str): Owner of the repository. + repository_name (str): Name of the repository. + pull_request_number (int): pull request id. + headers (dict): Dictionary containing the headers, usually including the Authorization token. + + Returns: + dict: Dictionary containing the pull request content or None if not found. + """ + pull_request_url = f'https://api.github.com/repos/{repository_owner}/{repository_name}/pulls/{pull_request_number}' + headers = { + "Authorization": f"token {self.github_access_token}" if self.github_access_token else None, + "Content-Type": "application/vnd.github+json", + "Accept": "application/vnd.github.v3.diff", + } + + response = requests.get(pull_request_url, headers=headers) + + if response.status_code == 200: + logger.info('Successfully fetched pull request content.') + return response.text + elif response.status_code == 404: + logger.warning('Pull request not found.') + else: + logger.warning('Failed to fetch pull request content: ', response.text) + + return None + + def get_latest_commit_id_of_pull_request(self, repository_owner, repository_name, pull_request_number): + """ + Gets the latest commit id of a specific pull request from a GitHub repository. + :param repository_owner: owner + :param repository_name: repository name + :param pull_request_number: pull request id + + :return: + latest commit id of the pull request + """ + url = f'https://api.github.com/repos/{repository_owner}/{repository_name}/pulls/{pull_request_number}/commits' + headers = { + "Authorization": f"token {self.github_access_token}" if self.github_access_token else None, + "Content-Type": "application/json", + } + response = requests.get(url, headers=headers) + if response.status_code == 200: + commits = response.json() + latest_commit = commits[-1] # Assuming the last commit is the latest + return latest_commit.get('sha') + else: + logger.warning(f'Failed to fetch commits for pull request: {response.json()["message"]}') + return None + + + def add_line_comment_to_pull_request(self, repository_owner, repository_name, pull_request_number, + commit_id, file_path, position, comment_body): + """ + Adds a line comment to a specific pull request from a GitHub repository. + + :param repository_owner: owner + :param repository_name: repository name + :param pull_request_number: pull request id + :param commit_id: commit id + :param file_path: file path + :param position: position + :param comment_body: comment body + + :return: + dict: Dictionary containing the comment content or None if not found. + """ + comments_url = f'https://api.github.com/repos/{repository_owner}/{repository_name}/pulls/{pull_request_number}/comments' + headers = { + "Authorization": f"token {self.github_access_token}", + "Content-Type": "application/json", + "Accept": "application/vnd.github.v3+json" + } + data = { + "commit_id": commit_id, + "path": file_path, + "position": position, + "body": comment_body + } + response = requests.post(comments_url, headers=headers, json=data) + if response.status_code == 201: + logger.info('Successfully added line comment to pull request.') + return response.json() + else: + logger.warning(f'Failed to add line comment: {response.json()["message"]}') + return None + + def get_pull_requests_created_in_last_x_seconds(self, repository_owner, repository_name, x_seconds): + """ + Gets the pull requests created in the last x seconds. + + Args: + repository_owner (str): Owner of the repository + repository_name (str): Repository name + x_seconds (int): The number of seconds in the past to look for PRs + + Returns: + list: List of pull request objects that were created in the last x seconds + """ + # Calculate the time x seconds ago + time_x_seconds_ago = datetime.utcnow() - timedelta(seconds=x_seconds) + + # Convert to the ISO8601 format GitHub expects, remove milliseconds + time_x_seconds_ago_str = time_x_seconds_ago.strftime('%Y-%m-%dT%H:%M:%SZ') + + # Search query + query = f'repo:{repository_owner}/{repository_name} type:pr created:>{time_x_seconds_ago_str}' + + url = f'https://api.github.com/search/issues?q={query}' + headers = { + "Authorization": f"token {self.github_access_token}", + "Content-Type": "application/json", + } + + response = requests.get(url, headers=headers) + + if response.status_code == 200: + pull_request_urls = [] + for pull_request in response.json()['items']: + pull_request_urls.append(pull_request['html_url']) + return pull_request_urls + else: + logger.warning(f'Failed to fetch PRs: {response.json()["message"]}') + return [] diff --git a/superagi/helper/webhook_manager.py b/superagi/helper/webhook_manager.py index cf5e988d2..aa7a1fb19 100644 --- a/superagi/helper/webhook_manager.py +++ b/superagi/helper/webhook_manager.py @@ -5,6 +5,7 @@ import requests import json from superagi.lib.logger import logger + class WebHookManager: def __init__(self,session): self.session=session @@ -18,20 +19,21 @@ def agent_status_change_callback(self, agent_execution_id, curr_status, old_stat org_webhooks=self.session.query(Webhooks).filter(Webhooks.org_id == org.id).all() for webhook_obj in org_webhooks: - webhook_obj_body={"agent_id":agent_id,"org_id":org.id,"event":f"{old_status} to {curr_status}"} - error=None - request=None - status='sent' - try: - request = requests.post(webhook_obj.url.strip(), data=json.dumps(webhook_obj_body), headers=webhook_obj.headers) - except Exception as e: - logger.error(f"Exception occured in webhooks {e}") - error=str(e) - if request is not None and request.status_code not in [200,201] and error is None: - error=request.text - if error is not None: - status='Error' - webhook_event=WebhookEvents(agent_id=agent_id, run_id=agent_execution_id, event=f"{old_status} to {curr_status}", status=status, errors=error) - self.session.add(webhook_event) - self.session.commit() + if "status" in webhook_obj.filters and curr_status in webhook_obj.filters["status"]: + webhook_obj_body={"agent_id":agent_id,"org_id":org.id,"event":f"{old_status} to {curr_status}"} + error=None + request=None + status='sent' + try: + request = requests.post(webhook_obj.url.strip(), data=json.dumps(webhook_obj_body), headers=webhook_obj.headers) + except Exception as e: + logger.error(f"Exception occured in webhooks {e}") + error=str(e) + if request is not None and request.status_code not in [200,201] and error is None: + error=request.text + if error is not None: + status='Error' + webhook_event=WebhookEvents(agent_id=agent_id, run_id=agent_execution_id, event=f"{old_status} to {curr_status}", status=status, errors=error) + self.session.add(webhook_event) + self.session.commit() diff --git a/superagi/jobs/scheduling_executor.py b/superagi/jobs/scheduling_executor.py index b7e62db4e..f0d5a542a 100644 --- a/superagi/jobs/scheduling_executor.py +++ b/superagi/jobs/scheduling_executor.py @@ -13,7 +13,7 @@ from superagi.models.agent_execution import AgentExecution from superagi.models.agent_execution_config import AgentExecutionConfiguration from superagi.apm.event_handler import EventHandler - +from superagi.models.knowledges import Knowledges from superagi.models.db import connect_db @@ -42,7 +42,7 @@ def execute_scheduled_agent(self, agent_id: int, name: str): iteration_step_id = IterationWorkflow.fetch_trigger_step_id(session, start_step.action_reference_id).id if start_step.action_type == "ITERATION_WORKFLOW" else -1 - db_agent_execution = AgentExecution(status="RUNNING", last_execution_time=datetime.now(), + db_agent_execution = AgentExecution(status="CREATED", last_execution_time=datetime.now(), agent_id=agent_id, name=name, num_of_calls=0, num_of_tokens=0, current_agent_step_id=start_step.id, @@ -51,17 +51,32 @@ def execute_scheduled_agent(self, agent_id: int, name: str): session.add(db_agent_execution) session.commit() + #update status from CREATED to RUNNING + db_agent_execution.status = "RUNNING" + session.commit() + agent_execution_id = db_agent_execution.id agent_configurations = session.query(AgentConfiguration).filter(AgentConfiguration.agent_id == agent_id).all() for agent_config in agent_configurations: agent_execution_config = AgentExecutionConfiguration(agent_execution_id=agent_execution_id, key=agent_config.key, value=agent_config.value) session.add(agent_execution_config) - - organisation = agent.get_agent_organisation(session) model = session.query(AgentConfiguration.value).filter(AgentConfiguration.agent_id == agent_id).filter(AgentConfiguration.key == 'model').first()[0] - EventHandler(session=session).create_event('run_created', {'agent_execution_id': db_agent_execution.id,'agent_execution_name':db_agent_execution.name}, agent_id, organisation.id if organisation else 0), + EventHandler(session=session).create_event('run_created', + {'agent_execution_id': db_agent_execution.id, + 'agent_execution_name':db_agent_execution.name}, + agent_id, + organisation.id if organisation else 0) + agent_execution_knowledge = AgentConfiguration.get_agent_config_by_key_and_agent_id(session= session, key= 'knowledge', agent_id= agent_id) + if agent_execution_knowledge and agent_execution_knowledge.value != 'None': + knowledge_name = Knowledges.get_knowledge_from_id(session, int(agent_execution_knowledge.value)).name + if knowledge_name is not None: + EventHandler(session=session).create_event('knowledge_picked', + {'knowledge_name': knowledge_name, + 'agent_execution_id': db_agent_execution.id}, + agent_id, + organisation.id if organisation else 0) session.commit() if db_agent_execution.status == "RUNNING": diff --git a/superagi/llms/llm_model_factory.py b/superagi/llms/llm_model_factory.py index 9ba8c8892..251c64a71 100644 --- a/superagi/llms/llm_model_factory.py +++ b/superagi/llms/llm_model_factory.py @@ -33,5 +33,17 @@ def get_model(organisation_id, api_key, model="gpt-3.5-turbo", **kwargs): elif provider_name == 'Hugging Face': print("Provider is Hugging Face") return HuggingFace(model=model_instance.model_name, end_point=model_instance.end_point, api_key=api_key, **kwargs) + else: + print('Unknown provider.') + +def build_model_with_api_key(provider_name, api_key): + if provider_name.lower() == 'openai': + return OpenAi(api_key=api_key) + elif provider_name.lower() == 'replicate': + return Replicate(api_key=api_key) + elif provider_name.lower() == 'google palm': + return GooglePalm(api_key=api_key) + elif provider_name.lower() == 'hugging face': + return HuggingFace(api_key=api_key) else: print('Unknown provider.') \ No newline at end of file diff --git a/superagi/models/agent_config.py b/superagi/models/agent_config.py index 67b377da3..3b44eee9e 100644 --- a/superagi/models/agent_config.py +++ b/superagi/models/agent_config.py @@ -79,6 +79,7 @@ def update_agent_configurations_table(cls, session, agent_id: Union[int, None], # Fetch agent configurations agent_configs = session.query(AgentConfiguration).filter(AgentConfiguration.agent_id == agent_id).all() + for agent_config in agent_configs: if agent_config.key in updated_details_dict: agent_config.value = str(updated_details_dict[agent_config.key]) @@ -115,3 +116,12 @@ def get_model_api_key(cls, session, agent_id: int, model: str): # if selected_model_source == ModelSourceType.Replicate: # return get_config("REPLICATE_API_TOKEN") # return get_config("OPENAI_API_KEY") + + @classmethod + def get_agent_config_by_key_and_agent_id(cls, session, key: str, agent_id: int): + agent_config = session.query(AgentConfiguration).filter( + AgentConfiguration.agent_id == agent_id, + AgentConfiguration.key == key + ).first() + + return agent_config \ No newline at end of file diff --git a/superagi/models/api_key.py b/superagi/models/api_key.py index 1cc3e310a..8c61ac885 100644 --- a/superagi/models/api_key.py +++ b/superagi/models/api_key.py @@ -1,8 +1,6 @@ -from sqlalchemy import Column, Integer, Text, String, Boolean, ForeignKey -from sqlalchemy.orm import relationship +from sqlalchemy import Column, Integer, String, Boolean from superagi.models.base_model import DBBaseModel -from superagi.models.agent_execution import AgentExecution -from sqlalchemy import func, or_ +from sqlalchemy import or_ class ApiKey(DBBaseModel): """ diff --git a/superagi/models/db.py b/superagi/models/db.py index e711280ff..c6844cfd4 100644 --- a/superagi/models/db.py +++ b/superagi/models/db.py @@ -1,12 +1,8 @@ from sqlalchemy import create_engine from superagi.config.config import get_config +from urllib.parse import urlparse from superagi.lib.logger import logger -database_url = get_config('POSTGRES_URL') -db_username = get_config('DB_USERNAME') -db_password = get_config('DB_PASSWORD') -db_name = get_config('DB_NAME') - engine = None @@ -23,11 +19,20 @@ def connect_db(): return engine # Create the connection URL - if db_username is None: - db_url = f'postgresql://{database_url}/{db_name}' + db_host = get_config('DB_HOST', 'super__postgres') + db_username = get_config('DB_USERNAME') + db_password = get_config('DB_PASSWORD') + db_name = get_config('DB_NAME') + db_url = get_config('DB_URL', None) + + if db_url is None: + if db_username is None: + db_url = f'postgresql://{db_host}/{db_name}' + else: + db_url = f'postgresql://{db_username}:{db_password}@{db_host}/{db_name}' else: - db_url = f'postgresql://{db_username}:{db_password}@{database_url}/{db_name}' - + db_url = urlparse(db_url) + db_url = db_url.scheme + "://" + db_url.netloc + db_url.path # Create the SQLAlchemy engine engine = create_engine(db_url, pool_size=20, # Maximum number of database connections in the pool diff --git a/superagi/models/webhooks.py b/superagi/models/webhooks.py index 14d683472..9b47c6105 100644 --- a/superagi/models/webhooks.py +++ b/superagi/models/webhooks.py @@ -20,3 +20,4 @@ class Webhooks(DBBaseModel): url = Column(String) headers=Column(JSON) is_deleted=Column(Boolean) + filters=Column(JSON) diff --git a/superagi/resource_manager/llama_document_summary.py b/superagi/resource_manager/llama_document_summary.py index 5eca38913..767d2be27 100644 --- a/superagi/resource_manager/llama_document_summary.py +++ b/superagi/resource_manager/llama_document_summary.py @@ -1,12 +1,9 @@ import os -from langchain.chat_models import ChatGooglePalm from llama_index.indices.response import ResponseMode from llama_index.schema import Document from superagi.config.config import get_config -from superagi.lib.logger import logger -from superagi.types.model_source_types import ModelSourceType class LlamaDocumentSummary: diff --git a/superagi/tools/github/fetch_pull_request.py b/superagi/tools/github/fetch_pull_request.py new file mode 100644 index 000000000..09bf17f3d --- /dev/null +++ b/superagi/tools/github/fetch_pull_request.py @@ -0,0 +1,66 @@ +from typing import Type, Optional + +from pydantic import BaseModel, Field + +from superagi.helper.github_helper import GithubHelper +from superagi.llms.base_llm import BaseLlm +from superagi.tools.base_tool import BaseTool + + +class GithubFetchPullRequestSchema(BaseModel): + repository_name: str = Field( + ..., + description="Repository name in which file hase to be added", + ) + repository_owner: str = Field( + ..., + description="Owner of the github repository", + ) + time_in_seconds: int = Field( + ..., + description="Gets pull requests from last `time_in_seconds` seconds", + ) + + +class GithubFetchPullRequest(BaseTool): + """ + Fetch pull request tool + + Attributes: + name : The name. + description : The description. + args_schema : The args schema. + agent_id: The agent id. + agent_execution_id: The agent execution id. + """ + llm: Optional[BaseLlm] = None + name: str = "Github Fetch Pull Requests" + args_schema: Type[BaseModel] = GithubFetchPullRequestSchema + description: str = "Fetch pull requests from github" + agent_id: int = None + agent_execution_id: int = None + + def _execute(self, repository_name: str, repository_owner: str, time_in_seconds: int = 86400) -> str: + """ + Execute the add file tool. + + Args: + repository_name: The name of the repository to add file to. + repository_owner: Owner of the GitHub repository. + time_in_seconds: Gets pull requests from last `time_in_seconds` seconds + + Returns: + List of all pull request ids + """ + try: + github_access_token = self.get_tool_config("GITHUB_ACCESS_TOKEN") + github_username = self.get_tool_config("GITHUB_USERNAME") + github_helper = GithubHelper(github_access_token, github_username) + + pull_request_urls = github_helper.get_pull_requests_created_in_last_x_seconds(repository_owner, + repository_name, + time_in_seconds) + + return "Pull requests: " + str(pull_request_urls) + except Exception as err: + return f"Error: Unable to fetch pull requests {err}" diff --git a/superagi/tools/github/github_toolkit.py b/superagi/tools/github/github_toolkit.py index bb9ec2657..5321f979b 100644 --- a/superagi/tools/github/github_toolkit.py +++ b/superagi/tools/github/github_toolkit.py @@ -3,7 +3,9 @@ from superagi.tools.base_tool import BaseTool, BaseToolkit, ToolConfiguration from superagi.tools.github.add_file import GithubAddFileTool from superagi.tools.github.delete_file import GithubDeleteFileTool +from superagi.tools.github.fetch_pull_request import GithubFetchPullRequest from superagi.tools.github.search_repo import GithubRepoSearchTool +from superagi.tools.github.review_pull_request import GithubReviewPullRequest from superagi.types.key_type import ToolConfigKeyType @@ -12,7 +14,8 @@ class GitHubToolkit(BaseToolkit, ABC): description: str = "GitHub Tool Kit contains all github related to tool" def get_tools(self) -> List[BaseTool]: - return [GithubAddFileTool(), GithubDeleteFileTool(), GithubRepoSearchTool()] + return [GithubAddFileTool(), GithubDeleteFileTool(), GithubRepoSearchTool(), GithubReviewPullRequest(), + GithubFetchPullRequest()] def get_env_keys(self) -> List[ToolConfiguration]: return [ diff --git a/superagi/tools/github/prompts/code_review.txt b/superagi/tools/github/prompts/code_review.txt new file mode 100644 index 000000000..1aff59ea6 --- /dev/null +++ b/superagi/tools/github/prompts/code_review.txt @@ -0,0 +1,50 @@ +Your purpose is to act as a highly experienced software engineer and provide a thorough review of the code chunks and suggest code snippets to improve key areas such as: +- Logic +- Modularity +- Maintainability +- Complexity + +Do not comment on minor code style issues, missing comments/documentation. Identify and resolve significant concerns to improve overall code quality while deliberately disregarding minor issues + +Following is the github pull request diff content: +``` +{{DIFF_CONTENT}} +``` + +Instructions: +1. Do not comment on existing lines and deleted lines. +2. Ignore the lines start with '-'. +3. Only consider lines starting with '+' for review. +4. Do not comment on frontend and graphql code. + +Respond with only valid JSON conforming to the following schema: +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": { + "comments": { + "type": "array", + "items": { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "The path to the file where the comment should be added." + }, + "line": { + "type": "integer", + "description": "The line number where the comment should be added. " + }, + "comment": { + "type": "string", + "description": "The content of the comment." + } + }, + "required": ["file_name", "line", "comment"] + } + } + }, + "required": ["comments"] +} + +Ensure response is valid JSON conforming to the following schema. \ No newline at end of file diff --git a/superagi/tools/github/review_pull_request.py b/superagi/tools/github/review_pull_request.py new file mode 100644 index 000000000..a230c0701 --- /dev/null +++ b/superagi/tools/github/review_pull_request.py @@ -0,0 +1,143 @@ +import ast +from typing import Type, Optional + +from pydantic import BaseModel, Field + +from superagi.helper.github_helper import GithubHelper +from superagi.helper.json_cleaner import JsonCleaner +from superagi.helper.prompt_reader import PromptReader +from superagi.helper.token_counter import TokenCounter +from superagi.llms.base_llm import BaseLlm +from superagi.models.agent import Agent +from superagi.tools.base_tool import BaseTool + + +class GithubReviewPullRequestSchema(BaseModel): + repository_name: str = Field( + ..., + description="Repository name in which file hase to be added", + ) + repository_owner: str = Field( + ..., + description="Owner of the github repository", + ) + pull_request_number: int = Field( + ..., + description="Pull request number", + ) + + +class GithubReviewPullRequest(BaseTool): + """ + Reviews the github pull request and adds comments inline + + Attributes: + name : The name. + description : The description. + args_schema : The args schema. + """ + llm: Optional[BaseLlm] = None + name: str = "Github Review Pull Request" + args_schema: Type[BaseModel] = GithubReviewPullRequestSchema + description: str = "Add pull request for the github repository" + agent_id: int = None + agent_execution_id: int = None + + def _execute(self, repository_name: str, repository_owner: str, pull_request_number: int) -> str: + """ + Execute the add file tool. + + Args: + repository_name: The name of the repository to add file to. + repository_owner: Owner of the GitHub repository. + pull_request_number: pull request number + + Returns: + Pull request success message if pull request is created successfully else error message. + """ + try: + github_access_token = self.get_tool_config("GITHUB_ACCESS_TOKEN") + github_username = self.get_tool_config("GITHUB_USERNAME") + github_helper = GithubHelper(github_access_token, github_username) + + pull_request_content = github_helper.get_pull_request_content(repository_owner, repository_name, + pull_request_number) + latest_commit_id = github_helper.get_latest_commit_id_of_pull_request(repository_owner, repository_name, + pull_request_number) + + pull_request_arr = pull_request_content.split("diff --git") + organisation = Agent.find_org_by_agent_id(session=self.toolkit_config.session, agent_id=self.agent_id) + + model_token_limit = TokenCounter(session=self.toolkit_config.session, + organisation_id=organisation.id).token_limit(self.llm.get_model()) + pull_request_arr_parts = self.split_pull_request_content_into_multiple_parts(model_token_limit, pull_request_arr) + for content in pull_request_arr_parts: + self.run_code_review(github_helper, content, latest_commit_id, organisation, pull_request_number, + repository_name, repository_owner) + return "Added comments to the pull request:" + str(pull_request_number) + except Exception as err: + return f"Error: Unable to add comments to the pull request {err}" + + def run_code_review(self, github_helper, content, latest_commit_id, organisation, pull_request_number, + repository_name, repository_owner): + prompt = PromptReader.read_tools_prompt(__file__, "code_review.txt") + prompt = prompt.replace("{{DIFF_CONTENT}}", content) + messages = [{"role": "system", "content": prompt}] + total_tokens = TokenCounter.count_message_tokens(messages, self.llm.get_model()) + token_limit = TokenCounter(session=self.toolkit_config.session, + organisation_id=organisation.id).token_limit(self.llm.get_model()) + result = self.llm.chat_completion(messages, max_tokens=(token_limit - total_tokens - 100)) + response = result["content"] + if response.startswith("```") and response.endswith("```"): + response = "```".join(response.split("```")[1:-1]) + response = JsonCleaner.extract_json_section(response) + comments = ast.literal_eval(response) + + # Add comments in the pull request + for comment in comments['comments']: + line_number = self.get_exact_line_number(content, comment["file_path"], comment["line"]) + github_helper.add_line_comment_to_pull_request(repository_owner, repository_name, pull_request_number, + latest_commit_id, comment["file_path"], line_number, + comment["comment"]) + + def split_pull_request_content_into_multiple_parts(self, model_token_limit: int, pull_request_arr): + pull_request_arr_parts = [] + current_part = "" + for part in pull_request_arr: + total_tokens = TokenCounter.count_message_tokens([{"role": "user", "content": current_part}], + self.llm.get_model()) + # we are using 60% of the model token limit + if total_tokens >= model_token_limit * 0.6: + # Add the current part to pull_request_arr_parts and reset current_sum and current_part + pull_request_arr_parts.append(current_part) + current_part = "diff --git" + part + else: + current_part += "diff --git" + part + + pull_request_arr_parts.append(current_part) + return pull_request_arr_parts + + def get_exact_line_number(self, diff_content, file_path, line_number): + last_content = diff_content[diff_content.index(file_path):] + last_content = last_content[last_content.index('@@'):] + return self.find_position_in_diff(last_content, line_number) + + def find_position_in_diff(self, diff_content, target_line): + # Split the diff by lines and initialize variables + diff_lines = diff_content.split('\n') + position = 0 + current_file_line_number = 0 + + # Loop through each line in the diff + for line in diff_lines: + position += 1 # Increment position for each line + if line.startswith('@@'): + # Reset the current file line number when encountering a new hunk + current_file_line_number = int(line.split('+')[1].split(',')[0]) - 1 + elif not line.startswith('-'): + # Increment the current file line number for lines that are not deletions + current_file_line_number += 1 + if current_file_line_number >= target_line: + # Return the position when the target line number is reached + return position + return position diff --git a/superagi/tools/image_generation/image_generation_toolkit.py b/superagi/tools/image_generation/image_generation_toolkit.py index bebb94411..8b2361e06 100644 --- a/superagi/tools/image_generation/image_generation_toolkit.py +++ b/superagi/tools/image_generation/image_generation_toolkit.py @@ -15,7 +15,7 @@ def get_tools(self) -> List[BaseTool]: def get_env_keys(self) -> List[ToolConfiguration]: return [ - ToolConfiguration(key="STABILITY_API_KEY", key_type=ToolConfigKeyType.STRING, is_required= False, is_secret = True), + ToolConfiguration(key="STABILITY_API_KEY", key_type=ToolConfigKeyType.STRING, is_required=False, is_secret = True), ToolConfiguration(key="ENGINE_ID", key_type=ToolConfigKeyType.STRING, is_required=False, is_secret=False), - ToolConfiguration(key= "OPENAI_API_KEY", key_type=ToolConfigKeyType.STRING, is_required=False, is_secret=True), + ToolConfiguration(key="OPENAI_API_KEY", key_type=ToolConfigKeyType.STRING, is_required=False, is_secret=True) ] diff --git a/superagi/tools/knowledge_search/knowledge_search.py b/superagi/tools/knowledge_search/knowledge_search.py index b649000bb..272d8865b 100644 --- a/superagi/tools/knowledge_search/knowledge_search.py +++ b/superagi/tools/knowledge_search/knowledge_search.py @@ -45,7 +45,7 @@ def _execute(self, query: str): vector_db = Vectordbs.get_vector_db_from_id(session, vector_db_index.vector_db_id) db_creds = VectordbConfigs.get_vector_db_config_from_db_id(session, vector_db.id) model_api_key = self.get_tool_config('OPENAI_API_KEY') - model_source = "OpenAI" + model_source = 'OpenAI' embedding_model = AgentExecutor.get_embedding(model_source, model_api_key) try: if vector_db_index.state == "Custom": diff --git a/superagi/tools/knowledge_search/knowledge_search_toolkit.py b/superagi/tools/knowledge_search/knowledge_search_toolkit.py index ae3a937f3..6e0cdf0c5 100644 --- a/superagi/tools/knowledge_search/knowledge_search_toolkit.py +++ b/superagi/tools/knowledge_search/knowledge_search_toolkit.py @@ -13,5 +13,5 @@ def get_tools(self) -> List[BaseTool]: def get_env_keys(self) -> List[ToolConfiguration]: return [ - ToolConfiguration(key= "OPENAI_API_KEY", key_type=ToolConfigKeyType.STRING, is_required=False, is_secret=True) + ToolConfiguration(key="OPENAI_API_KEY", key_type=ToolConfigKeyType.STRING, is_required=False, is_secret=True) ] \ No newline at end of file diff --git a/superagi/tools/tool_response_query_manager.py b/superagi/tools/tool_response_query_manager.py index c4da9b9a6..0bf708b39 100644 --- a/superagi/tools/tool_response_query_manager.py +++ b/superagi/tools/tool_response_query_manager.py @@ -11,11 +11,12 @@ def __init__(self, session: Session, agent_execution_id: int,memory:VectorStore) def get_last_response(self, tool_name: str = None): return AgentExecutionFeed.get_last_tool_response(self.session, self.agent_execution_id, tool_name) - + def get_relevant_response(self, query: str,metadata:dict, top_k: int = 5): if self.memory is None: return "" - documents = self.memory.get_matching_text(query, metadata=metadata) + documents = self.memory.get_matching_text(query, metadata=metadata, + top_k=top_k) relevant_responses = "" for document in documents["documents"]: relevant_responses += document.text_content diff --git a/superagi/worker.py b/superagi/worker.py index 83afa7253..e0bace7a1 100644 --- a/superagi/worker.py +++ b/superagi/worker.py @@ -1,4 +1,5 @@ from __future__ import absolute_import +import sys from sqlalchemy.orm import sessionmaker @@ -36,12 +37,11 @@ } app.conf.beat_schedule = beat_schedule -# @event.listens_for(AgentExecution.status, "set") -# def agent_status_change(target, val,old_val,initiator): -# if not get_config("IN_TESTING",False): -# webhook_callback.delay(target.id,val,old_val) - - +@event.listens_for(AgentExecution.status, "set") +def agent_status_change(target, val,old_val,initiator): + if not hasattr(sys, '_called_from_test'): + webhook_callback.delay(target.id,val,old_val) + @app.task(name="initialize-schedule-agent", autoretry_for=(Exception,), retry_backoff=2, max_retries=5) def initialize_schedule_agent_task(): """Executing agent scheduling in the background.""" diff --git a/tests/unit_tests/agent/test_tool_executor.py b/tests/unit_tests/agent/test_tool_executor.py index 246db63fd..35e1735f3 100644 --- a/tests/unit_tests/agent/test_tool_executor.py +++ b/tests/unit_tests/agent/test_tool_executor.py @@ -19,7 +19,7 @@ def mock_tools(): @pytest.fixture def executor(mock_tools): - return ToolExecutor(organisation_id=1, agent_id=1, tools=mock_tools) + return ToolExecutor(organisation_id=1, agent_id=1, tools=mock_tools, agent_execution_id=1) def test_tool_executor_finish(executor): res = executor.execute(None, 'finish', {}) @@ -29,7 +29,7 @@ def test_tool_executor_finish(executor): @patch('superagi.agent.tool_executor.EventHandler') def test_tool_executor_success(mock_event_handler, executor, mock_tools): for i, tool in enumerate(mock_tools): - res = executor.execute(None, f'tool{i}', {}) + res = executor.execute(None, f'tool{i}', {'agent_execution_id': 1}) assert res.status == 'SUCCESS' assert res.result == f'Tool {tool.name} returned: {tool.name}' assert res.retry == False diff --git a/tests/unit_tests/apm/test_knowledge_handler.py b/tests/unit_tests/apm/test_knowledge_handler.py new file mode 100644 index 000000000..acb810ff4 --- /dev/null +++ b/tests/unit_tests/apm/test_knowledge_handler.py @@ -0,0 +1,100 @@ +import pytest +from unittest.mock import MagicMock +from superagi.apm.knowledge_handler import KnowledgeHandler +from fastapi import HTTPException +from datetime import datetime +import pytz + +@pytest.fixture +def organisation_id(): + return 1 + +@pytest.fixture +def mock_session(): + return MagicMock() + +@pytest.fixture +def knowledge_handler(mock_session, organisation_id): + return KnowledgeHandler(mock_session, organisation_id) + +def test_get_knowledge_usage_by_name(knowledge_handler, mock_session): + knowledge_handler.session = mock_session + knowledge_name = 'Knowledge1' + mock_knowledge_event = MagicMock() + mock_knowledge_event.knowledge_unique_agents = 5 + mock_knowledge_event.knowledge_name = knowledge_name + mock_knowledge_event.id = 1 + + mock_session.query.return_value.filter_by.return_value.filter.return_value.first.return_value = mock_knowledge_event + mock_session.query.return_value.filter.return_value.group_by.return_value.first.return_value = mock_knowledge_event + mock_session.query.return_value.filter.return_value.count.return_value = 10 + + result = knowledge_handler.get_knowledge_usage_by_name(knowledge_name) + + assert isinstance(result, dict) + assert result == { + 'knowledge_unique_agents': 5, + 'knowledge_calls': 10 + } + + mock_session.query.return_value.filter_by.return_value.filter.return_value.first.return_value = None + + with pytest.raises(HTTPException): + knowledge_handler.get_knowledge_usage_by_name('NonexistentKnowledge') + +def test_get_knowledge_events_by_name(knowledge_handler, mock_session): + knowledge_name = 'knowledge1' + knowledge_handler.session = mock_session + knowledge_handler.organisation_id = 1 + + mock_knowledge = MagicMock() + mock_knowledge.id = 1 + mock_session.query().filter_by().filter().first.return_value = mock_knowledge + + result_obj = MagicMock() + result_obj.agent_id = 1 + result_obj.created_at = datetime.now() + result_obj.event_name = 'knowledge_picked' + result_obj.event_property = {'knowledge_name': 'knowledge1', 'agent_execution_id': '1'} + result_obj2 = MagicMock() + result_obj2.agent_id = 1 + result_obj2.event_name = 'run_completed' + result_obj2.event_property = {'tokens_consumed': 10, 'calls': 5, 'name': 'Runner', 'agent_execution_id': '1'} + result_obj3 = MagicMock() + result_obj3.agent_id = 1 + result_obj3.event_name = 'agent_created' + result_obj3.event_property = {'agent_name': 'A1', 'model': 'M1'} + + mock_session.query().filter().all.side_effect = [[result_obj], [result_obj2], [result_obj3]] + + user_timezone = MagicMock() + user_timezone.value = 'America/New_York' + mock_session.query().filter().first.return_value = user_timezone + + result = knowledge_handler.get_knowledge_events_by_name(knowledge_name) + + assert isinstance(result, list) + assert len(result) == 1 + for item in result: + assert 'agent_execution_id' in item + assert 'created_at' in item + assert 'tokens_consumed' in item + assert 'calls' in item + assert 'agent_execution_name' in item + assert 'agent_name' in item + assert 'model' in item + + +def test_get_knowledge_events_by_name_knowledge_not_found(knowledge_handler, mock_session): + knowledge_name = "knowledge1" + not_found_message = 'Knowledge not found' + + mock_session.query().filter_by().filter().first.return_value = None + + try: + knowledge_handler.get_knowledge_events_by_name(knowledge_name) + assert False, "Expected HTTPException has not been raised" + except HTTPException as e: + assert str(e.detail) == not_found_message, f"Expected {not_found_message}, got {e.detail}" + finally: + assert mock_session.query().filter_by().filter().first.called, "first() function not called" \ No newline at end of file diff --git a/tests/unit_tests/apm/test_tools_handler.py b/tests/unit_tests/apm/test_tools_handler.py index f58650cc7..f805bbdf1 100644 --- a/tests/unit_tests/apm/test_tools_handler.py +++ b/tests/unit_tests/apm/test_tools_handler.py @@ -1,8 +1,12 @@ import pytest -from unittest.mock import MagicMock - +from unittest.mock import MagicMock, patch +from fastapi import HTTPException from superagi.apm.tools_handler import ToolsHandler from sqlalchemy.orm import Session +from superagi.models.agent_config import AgentConfiguration + +from datetime import datetime +import pytz @pytest.fixture def organisation_id(): @@ -17,6 +21,129 @@ def tools_handler(mock_session, organisation_id): return ToolsHandler(mock_session, organisation_id) def test_calculate_tool_usage(tools_handler, mock_session): - mock_session.query().all.return_value = [MagicMock()] + tool_used_subquery = MagicMock() + agent_count_subquery = MagicMock() + total_usage_subquery = MagicMock() + + tool_used_subquery.c.tool_name = 'Tool1' + tool_used_subquery.c.agent_id = 1 + + agent_count_subquery.c.tool_name = 'Tool1' + agent_count_subquery.c.unique_agents = 1 + + total_usage_subquery.c.tool_name = 'Tool1' + total_usage_subquery.c.total_usage = 5 + + tools_handler.get_tool_and_toolkit = MagicMock() + tools_handler.get_tool_and_toolkit.return_value = {'Tool1': 'Toolkit1'} + + mock_session.query().filter_by().subquery.return_value = tool_used_subquery + mock_session.query().group_by().subquery.return_value = agent_count_subquery + mock_session.query().group_by().subquery.return_value = total_usage_subquery + + result_obj = MagicMock() + result_obj.tool_name = 'Tool1' + result_obj.unique_agents = 1 + result_obj.total_usage = 5 + mock_session.query().join().all.return_value = [result_obj] + result = tools_handler.calculate_tool_usage() - assert isinstance(result, list) \ No newline at end of file + + assert isinstance(result, list) + + expected_output = [{'tool_name': 'Tool1', 'unique_agents': 1, 'total_usage': 5, 'toolkit': 'Toolkit1'}] + assert result == expected_output + +def test_get_tool_and_toolkit(tools_handler, mock_session): + result_obj = MagicMock() + result_obj.tool_name = 'tool 1' + result_obj.toolkit_name = 'toolkit 1' + + mock_session.query().join().all.return_value = [result_obj] + + output = tools_handler.get_tool_and_toolkit() + + assert isinstance(output, dict) + assert output == {'tool 1': 'toolkit 1'} + +def test_get_tool_usage_by_name(tools_handler, mock_session): + tools_handler.session = mock_session + tool_name = 'Tool1' + formatted_tool_name = tool_name.lower().replace(" ", "") + + mock_tool = MagicMock() + mock_tool.name = tool_name + + mock_tool_event = MagicMock() + mock_tool_event.tool_name = formatted_tool_name + mock_tool_event.tool_calls = 10 + mock_tool_event.tool_unique_agents = 5 + + mock_session.query.return_value.filter_by.return_value.first.return_value = mock_tool + mock_session.query.return_value.filter.return_value.group_by.return_value.first.return_value = mock_tool_event + + result = tools_handler.get_tool_usage_by_name(tool_name=tool_name) + + assert isinstance(result, dict) + assert result == { + 'tool_calls': 10, + 'tool_unique_agents': 5 + } + + mock_session.query.return_value.filter_by.return_value.first.return_value = None + + with pytest.raises(HTTPException): + tools_handler.get_tool_usage_by_name(tool_name="NonexistentTool") + +def test_get_tool_events_by_name(tools_handler, mock_session): + tool_name = 'Tool1' + tools_handler.session = mock_session + tools_handler.organisation_id = 1 + + mock_tool = MagicMock() + mock_tool.id = 1 + mock_session.query().filter_by().first.return_value = mock_tool + + result_obj = MagicMock() + result_obj.agent_id = 1 + result_obj.id = 1 + result_obj.created_at = datetime.now() + result_obj.event_name = 'tool_used' + result_obj.event_property = {'tool_name': 'tool1', 'agent_execution_id': '1'} + result_obj2 = MagicMock() + result_obj2.agent_id = 1 + result_obj2.id = 2 + result_obj2.event_name = 'run_completed' + result_obj2.event_property = {'tokens_consumed': 10, 'calls': 5, 'name': 'Runner', 'agent_execution_id': '1'} + result_obj3 = MagicMock() + result_obj3.agent_id = 1 + result_obj3.event_name = 'agent_created' + result_obj3.event_property = {'agent_name': 'A1', 'model': 'M1'} + + mock_session.query().filter().all.side_effect = [[result_obj], [result_obj2], [result_obj3], []] + + user_timezone = MagicMock() + user_timezone.value = 'America/New_York' + mock_session.query().filter().first.return_value = user_timezone + + result = tools_handler.get_tool_events_by_name(tool_name) + + assert isinstance(result, list) + assert len(result) == 1 + for item in result: + assert 'agent_execution_id' in item + assert 'created_at' in item + assert 'tokens_consumed' in item + assert 'calls' in item + assert 'agent_execution_name' in item + assert 'agent_name' in item + assert 'model' in item + +def test_get_tool_events_by_name_tool_not_found(tools_handler, mock_session): + tool_name = "tool1" + + mock_session.query().filter_by().first.return_value = None + with pytest.raises(HTTPException): + tools_handler.get_tool_events_by_name(tool_name) + + assert mock_session.query().filter_by().first.called \ No newline at end of file diff --git a/tests/unit_tests/controllers/test_publish_agent.py b/tests/unit_tests/controllers/test_publish_agent.py new file mode 100644 index 000000000..615a8a26a --- /dev/null +++ b/tests/unit_tests/controllers/test_publish_agent.py @@ -0,0 +1,41 @@ +import pytest +from fastapi.testclient import TestClient +from unittest.mock import create_autospec, patch +from main import app +from superagi.models.agent import Agent +from superagi.models.agent_config import AgentConfiguration +from superagi.models.agent_execution import AgentExecution +from superagi.models.agent_execution_config import AgentExecutionConfiguration +from superagi.models.organisation import Organisation +from superagi.models.user import User +from sqlalchemy.orm import Session + +client = TestClient(app) + +@pytest.fixture +def mocks(): + # Mock tool kit data for testing + mock_agent = Agent(id=1, name="test_agent", project_id=1, description="testing", agent_workflow_id=1, is_deleted=False) + mock_agent_config = AgentConfiguration(id=1, agent_id=1, key="test_key", value="['test']") + mock_execution = AgentExecution(id=1, agent_id=1, name="test_execution") + mock_execution_config = [AgentExecutionConfiguration(id=1, agent_execution_id=1, key="test_key", value="['test']")] + return mock_agent,mock_agent_config,mock_execution,mock_execution_config + +def test_publish_template(mocks): + with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \ + patch('superagi.helper.auth.get_current_user') as mock_get_user, \ + patch('superagi.helper.auth.db') as mock_auth_db,\ + patch('superagi.controllers.agent_template.db') as mock_db: + + mock_session = create_autospec(Session) + mock_agent, mock_agent_config, mock_execution, mock_execution_config = mocks + + mock_session.query.return_value.filter.return_value.first.return_value = mock_agent + mock_session.query.return_value.filter.return_value.all.return_value = [mock_agent_config] + mock_session.query.return_value.filter.return_value.order_by.return_value.first.return_value = mock_execution + mock_session.query.return_value.filter.return_value.all.return_value = mock_execution_config + + with patch('superagi.controllers.agent_execution_config.AgentExecution.get_agent_execution_from_id') as mock_get_exec: + mock_get_exec.return_value = mock_execution + response = client.post("/agent_templates/publish_template/agent_execution_id/1") + assert response.status_code == 201 \ No newline at end of file diff --git a/tests/unit_tests/helper/test_github_helper.py b/tests/unit_tests/helper/test_github_helper.py index 04105c59a..e947e4703 100644 --- a/tests/unit_tests/helper/test_github_helper.py +++ b/tests/unit_tests/helper/test_github_helper.py @@ -155,7 +155,47 @@ def test_create_pull_request(self, mock_post): headers={'header': 'value'} ) - # ... more tests for other methods + @patch('requests.get') + def test_get_pull_request_content_success(self, mock_get): + mock_get.return_value.status_code = 200 + mock_get.return_value.text = "some_content" + + github_api = GithubHelper('access_token', 'username') + result = github_api.get_pull_request_content("owner", "repo", 1) + + self.assertEqual(result, "some_content") + + @patch('requests.get') + def test_get_pull_request_content_not_found(self, mock_get): + mock_get.return_value.status_code = 404 + + github_api = GithubHelper('access_token', 'username') + result = github_api.get_pull_request_content("owner", "repo", 1) + + self.assertIsNone(result) + + @patch('requests.get') + def test_get_latest_commit_id_of_pull_request(self, mock_get): + mock_get.return_value.status_code = 200 + mock_get.return_value.json.return_value = [{"sha": "123"}, {"sha": "456"}] + + github_api = GithubHelper('access_token', 'username') + result = github_api.get_latest_commit_id_of_pull_request("owner", "repo", 1) + + self.assertEqual(result, "456") + + @patch('requests.post') + def test_add_line_comment_to_pull_request(self, mock_post): + mock_post.return_value.status_code = 201 + mock_post.return_value.json.return_value = {"id": 1, "body": "comment"} + + github_api = GithubHelper('access_token', 'username') + result = github_api.add_line_comment_to_pull_request("owner", "repo", 1, "commit_id", "file_path", 1, "comment") + + self.assertEqual(result, {"id": 1, "body": "comment"}) + + +# ... more tests for other methods if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/tests/unit_tests/helper/test_webhooks.py b/tests/unit_tests/helper/test_webhooks.py new file mode 100644 index 000000000..b52edf507 --- /dev/null +++ b/tests/unit_tests/helper/test_webhooks.py @@ -0,0 +1,75 @@ +import json +from unittest.mock import Mock, patch +import pytest +from superagi.helper.webhook_manager import WebHookManager +from superagi.models.webhooks import Webhooks + +@pytest.fixture +def mock_session(): + return Mock() + +@pytest.fixture +def mock_agent_execution(): + return Mock() + +@pytest.fixture +def mock_agent(): + return Mock() + +@pytest.fixture +def mock_webhook(): + return Mock() + +@pytest.fixture +def mock_org(): + org_mock = Mock() + org_mock.id = "mock_org_id" + return org_mock + +def test_agent_status_change_callback( + mock_session, mock_agent_execution, mock_agent, mock_org, mock_webhook +): + curr_status = "NEW_STATUS" + old_status = "OLD_STATUS" + mock_agent_id = "mock_agent_id" + mock_org_id = "mock_org_id" + + # Create a mock instance of AgentExecution and set its attributes + mock_agent_execution_instance = Mock() + mock_agent_execution_instance.agent_id = "mock_agent_id" + + # Create a mock instance of Agent and set its attributes + mock_agent_instance = Mock() + mock_agent_instance.get_agent_organisation.return_value = mock_org + + # Create a mock instance of Webhooks and set its attributes + mock_webhook_instance = Mock(spec=Webhooks) + mock_webhook_instance.filters = {"status": ["PAUSED", "RUNNING"]} + + # Set up session.query().filter().all() to return the mock_webhook_instance + mock_session.query.return_value.filter.return_value.all.return_value = [mock_webhook_instance] + + # Patch required functions/methods + with patch( + 'superagi.controllers.agent_execution_config.AgentExecution.get_agent_execution_from_id', + return_value=mock_agent_execution_instance + ), patch( + 'superagi.models.agent.Agent.get_agent_from_id', + return_value=mock_agent_instance + ), patch( + 'requests.post', + return_value=Mock(status_code=200) # Mock the status_code response + ) as mock_post, patch( + 'json.dumps' + ) as mock_json_dumps: + + # Create the WebHookManager instance + web_hook_manager = WebHookManager(mock_session) + + # Call the function + web_hook_manager.agent_status_change_callback( + mock_agent_execution_instance, curr_status, old_status + ) + + assert mock_agent_execution_instance.agent_status_change_callback + diff --git a/tests/unit_tests/jobs/conftest.py b/tests/unit_tests/jobs/conftest.py new file mode 100644 index 000000000..e1344072a --- /dev/null +++ b/tests/unit_tests/jobs/conftest.py @@ -0,0 +1,8 @@ +# content of conftest.py +def pytest_configure(config): + import sys + sys._called_from_test = True + +def pytest_unconfigure(config): + import sys # This was missing from the manual + del sys._called_from_test diff --git a/tests/unit_tests/llms/test_model_factory.py b/tests/unit_tests/llms/test_model_factory.py index 54240476f..2aa49a542 100644 --- a/tests/unit_tests/llms/test_model_factory.py +++ b/tests/unit_tests/llms/test_model_factory.py @@ -1,52 +1,82 @@ -# from unittest.mock import MagicMock, patch -# from superagi.llms.openai import OpenAi -# from superagi.llms.replicate import Replicate -# from superagi.models.models_config import ModelsConfig -# from superagi.models.models import Models -# from superagi.models.db import connect_db -# from sqlalchemy.orm import sessionmaker -# import pytest -# -# -# @pytest.fixture -# def mock_db_session(): -# db_session = MagicMock() -# db_session.query().filter().first().return_value = Models(model_name="gpt-3.5-turbo", org_id=1, model_provider_id=1) -# db_session.query().filter().first().return_value = ModelsConfig(provider="OpenAI") -# return db_session -# -# -# @patch("superagi.models.db.connect_db") -# @patch("sqlalchemy.orm.sessionmaker") -# def test_get_model_openai(mock_sessionmaker, mock_connect_db, mock_db_session): -# mock_sessionmaker.return_value = MagicMock(return_value=mock_db_session) -# mock_connect_db.return_value = MagicMock() -# -# from superagi.llms.openai import OpenAi -# result = OpenAi.get_model(organisation_id=1, api_key="TEST_KEY") -# -# assert isinstance(result, OpenAi) -# # -# # @patch('superagi.models.db.connect_db') -# # @patch('superagi.llms.replicate.Replicate') -# # def test_get_model_replicate(mock_replicate, mock_db, mock_session_maker): -# # mock_session_maker.query().filter().filter().first().return_value.provider = 'Replicate' -# # from superagi.models.db import get_model -# # result = get_model(organisation_id=1, api_key="TEST_KEY") -# # assert isinstance(result, superagi.llms.replicate.Replicate) -# # -# # @patch('superagi.models.db.connect_db') -# # @patch('superagi.llms.google_palm.GooglePalm') -# # def test_get_model_google_palm(mock_google_palm, mock_db, mock_session_maker): -# # mock_session_maker.query().filter().filter().first().return_value.provider = 'Google Palm' -# # from superagi.models.db import get_model -# # result = get_model(organisation_id=1, api_key="TEST_KEY") -# # assert isinstance(result, superagi.llms.google_palm.GooglePalm) -# # -# # @patch('superagi.models.db.connect_db') -# # @patch('superagi.llms.hugging_face.HuggingFace') -# # def test_get_model_hugging_face(mock_hugging_face, mock_db, mock_session_maker): -# # mock_session_maker.query().filter().filter().first().return_value.provider = 'Hugging Face' -# # from superagi.models.db import get_model -# # result = get_model(organisation_id=1, api_key="TEST_KEY") -# # assert isinstance(result, superagi.llms.hugging_face.HuggingFace) \ No newline at end of file +import pytest +from unittest.mock import Mock + +from superagi.llms.google_palm import GooglePalm +from superagi.llms.hugging_face import HuggingFace +from superagi.llms.llm_model_factory import get_model, build_model_with_api_key +from superagi.llms.openai import OpenAi +from superagi.llms.replicate import Replicate + + +# Fixtures for the mock objects +@pytest.fixture +def mock_openai(): + return Mock(spec=OpenAi) + +@pytest.fixture +def mock_replicate(): + return Mock(spec=Replicate) + +@pytest.fixture +def mock_google_palm(): + return Mock(spec=GooglePalm) + +@pytest.fixture +def mock_hugging_face(): + return Mock(spec=HuggingFace) + +@pytest.fixture +def mock_replicate(): + return Mock(spec=Replicate) + +@pytest.fixture +def mock_google_palm(): + return Mock(spec=GooglePalm) + +@pytest.fixture +def mock_hugging_face(): + return Mock(spec=HuggingFace) + +# Test build_model_with_api_key function +def test_build_model_with_openai(mock_openai, monkeypatch): + monkeypatch.setattr('superagi.llms.llm_model_factory.OpenAi', mock_openai) + model = build_model_with_api_key('OpenAi', 'fake_key') + mock_openai.assert_called_once_with(api_key='fake_key') + assert isinstance(model, Mock) + +def test_build_model_with_replicate(mock_replicate, monkeypatch): + monkeypatch.setattr('superagi.llms.llm_model_factory.Replicate', mock_replicate) + model = build_model_with_api_key('Replicate', 'fake_key') + mock_replicate.assert_called_once_with(api_key='fake_key') + assert isinstance(model, Mock) + + +def test_build_model_with_openai(mock_openai, monkeypatch): + monkeypatch.setattr('superagi.llms.llm_model_factory.OpenAi', mock_openai) # Replace 'your_module' with the actual module name + model = build_model_with_api_key('OpenAi', 'fake_key') + mock_openai.assert_called_once_with(api_key='fake_key') + assert isinstance(model, Mock) + +def test_build_model_with_replicate(mock_replicate, monkeypatch): + monkeypatch.setattr('superagi.llms.llm_model_factory.Replicate', mock_replicate) # Replace 'your_module' with the actual module name + model = build_model_with_api_key('Replicate', 'fake_key') + mock_replicate.assert_called_once_with(api_key='fake_key') + assert isinstance(model, Mock) + +def test_build_model_with_google_palm(mock_google_palm, monkeypatch): + monkeypatch.setattr('superagi.llms.llm_model_factory.GooglePalm', mock_google_palm) # Replace 'your_module' with the actual module name + model = build_model_with_api_key('Google Palm', 'fake_key') + mock_google_palm.assert_called_once_with(api_key='fake_key') + assert isinstance(model, Mock) + +def test_build_model_with_hugging_face(mock_hugging_face, monkeypatch): + monkeypatch.setattr('superagi.llms.llm_model_factory.HuggingFace', mock_hugging_face) # Replace 'your_module' with the actual module name + model = build_model_with_api_key('Hugging Face', 'fake_key') + mock_hugging_face.assert_called_once_with(api_key='fake_key') + assert isinstance(model, Mock) + +def test_build_model_with_unknown_provider(capsys): # capsys is a built-in pytest fixture for capturing print output + model = build_model_with_api_key('Unknown', 'fake_key') + assert model is None + captured = capsys.readouterr() + assert "Unknown provider." in captured.out \ No newline at end of file diff --git a/tests/unit_tests/tools/github/test_fetch_pull_request.py b/tests/unit_tests/tools/github/test_fetch_pull_request.py new file mode 100644 index 000000000..270ec2280 --- /dev/null +++ b/tests/unit_tests/tools/github/test_fetch_pull_request.py @@ -0,0 +1,58 @@ +import pytest +from unittest.mock import patch, Mock +from pydantic import ValidationError + +from superagi.tools.github.fetch_pull_request import GithubFetchPullRequest, GithubFetchPullRequestSchema + + +@pytest.fixture +def mock_github_helper(): + with patch('superagi.tools.github.fetch_pull_request.GithubHelper') as MockGithubHelper: + yield MockGithubHelper + + +@pytest.fixture +def tool(mock_github_helper): + tool = GithubFetchPullRequest() + tool.toolkit_config = Mock() + tool.toolkit_config.side_effect = ['dummy_token', 'dummy_username'] + mock_github_helper_instance = mock_github_helper.return_value + mock_github_helper_instance.get_pull_requests_created_in_last_x_seconds.return_value = ['url1', 'url2'] + return tool + + +def test_execute(tool, mock_github_helper): + mock_github_helper_instance = mock_github_helper.return_value + + # Execute the method + result = tool._execute('repo_name', 'repo_owner', 86400) + + # Verify results + assert result == "Pull requests: ['url1', 'url2']" + mock_github_helper_instance.get_pull_requests_created_in_last_x_seconds.assert_called_once_with('repo_owner', + 'repo_name', 86400) + + +def test_schema_validation(): + # Valid data + valid_data = {'repository_name': 'repo', 'repository_owner': 'owner', 'time_in_seconds': 86400} + GithubFetchPullRequestSchema(**valid_data) + + # Invalid data + invalid_data = {'repository_name': 'repo', 'repository_owner': 'owner', 'time_in_seconds': 'string'} + with pytest.raises(ValidationError): + GithubFetchPullRequestSchema(**invalid_data) + + +def test_execute_error(mock_github_helper): + tool = GithubFetchPullRequest() + tool.toolkit_config = Mock() + tool.toolkit_config.side_effect = ['dummy_token', 'dummy_username'] + mock_github_helper_instance = mock_github_helper.return_value + mock_github_helper_instance.get_pull_requests_created_in_last_x_seconds.side_effect = Exception('An error occurred') + + # Execute the method + result = tool._execute('repo_name', 'repo_owner', 86400) + + # Verify results + assert result == 'Error: Unable to fetch pull requests An error occurred' diff --git a/tests/unit_tests/tools/github/test_review_pull_request.py b/tests/unit_tests/tools/github/test_review_pull_request.py new file mode 100644 index 000000000..4e5c12903 --- /dev/null +++ b/tests/unit_tests/tools/github/test_review_pull_request.py @@ -0,0 +1,82 @@ +import pytest +from unittest.mock import patch, Mock +import pytest_mock +from pydantic import ValidationError + +from superagi.tools.github.review_pull_request import GithubReviewPullRequest + +class MockLLM: + def get_model(self): + return "some_model" + +class MockTokenCounter: + @staticmethod + def count_message_tokens(message, model): + # Mocking the token count based on the length of the content. + # Replace this logic as needed. + return len(message[0]['content']) + + +def test_split_pull_request_content_into_multiple_parts(): + tool = GithubReviewPullRequest() + tool.llm = MockLLM() + + # Mocking the pull_request_arr + pull_request_arr = ["part1", "part2", "part3"] + + # Calling the method to be tested + result = tool.split_pull_request_content_into_multiple_parts(4000,pull_request_arr) + + # Validate the result (this depends on what you expect the output to be) + # For instance, if you expect the result to be a list of all parts concatenated with 'diff --git' + expected = ["diff --gitpart1diff --gitpart2diff --gitpart3"] + assert result == expected + + +@pytest.mark.parametrize("diff_content, file_path, line_number, expected", [ + ("file_path_1\n@@ -1,3 +1,4 @@\n+ line1\n+ line2\n+ line3", "file_path_1", 3, 4), + ("file_path_2\n@@ -1,3 +1,3 @@\n+ line1\n- line2", "file_path_2", 1, 2), + ("file_path_3\n@@ -1,3 +1,4 @@\n+ line1\n+ line2\n- line3", "file_path_3", 2, 3) +]) +def test_get_exact_line_number(diff_content, file_path, line_number, expected): + tool = GithubReviewPullRequest() + + # Calling the method to be tested + result = tool.get_exact_line_number(diff_content, file_path, line_number) + + # Validate the result + assert result == expected + + +class MockGithubHelper: + def __init__(self, access_token, username): + pass + + def get_pull_request_content(self, owner, repo, pr_number): + return 'mock_content' + + def get_latest_commit_id_of_pull_request(self, owner, repo, pr_number): + return 'mock_commit_id' + + def add_line_comment_to_pull_request(self, *args, **kwargs): + return True + + +# Your test case +def test_execute(): + with patch('superagi.tools.github.review_pull_request.GithubHelper', MockGithubHelper), \ + patch('superagi.tools.github.review_pull_request.TokenCounter.count_message_tokens', return_value=3000), \ + patch('superagi.tools.github.review_pull_request.Agent.find_org_by_agent_id', return_value=Mock()), \ + patch.object(GithubReviewPullRequest, 'get_tool_config', return_value='mock_value'), \ + patch.object(GithubReviewPullRequest, 'run_code_review', return_value=None): + # Replace 'your_module' with the actual module name + + tool = GithubReviewPullRequest() + tool.llm = Mock() + tool.llm.get_model = Mock(return_value='mock_model') + tool.toolkit_config = Mock() + tool.toolkit_config.session = 'mock_session' + + result = tool._execute('mock_repo', 'mock_owner', 42) + + assert result == 'Added comments to the pull request:42' diff --git a/tools.json b/tools.json index 9091ede76..94de16ccc 100644 --- a/tools.json +++ b/tools.json @@ -1,4 +1,4 @@ { "tools": { } -} +} \ No newline at end of file