Skip to content

Commit

Permalink
Fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
luciferlinx101 committed Jun 29, 2023
1 parent d085c2f commit 4366836
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 25 deletions.
13 changes: 5 additions & 8 deletions superagi/controllers/tool_config.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from fastapi import APIRouter, HTTPException, Depends, Path
from fastapi import APIRouter, HTTPException, Depends
from fastapi_jwt_auth import AuthJWT
from fastapi_sqlalchemy import db
from pydantic_sqlalchemy import sqlalchemy_to_pydantic

from superagi.helper.auth import check_auth
from superagi.helper.auth import get_user_organisation
from superagi.models.organisation import Organisation
from superagi.models.tool_config import ToolConfig
from superagi.models.toolkit import Toolkit
from fastapi_jwt_auth import AuthJWT
from superagi.helper.auth import check_auth
from superagi.helper.auth import get_user_organisation
from typing import List

router = APIRouter()

Expand Down Expand Up @@ -115,13 +115,10 @@ def get_all_tool_configs(toolkit_name: str, organisation: Organisation = Depends
HTTPException (status_code=403): If the user is not authorized to access the tool kit.
"""

user_toolkits = db.session.query(Toolkit).filter(Toolkit.organisation_id == organisation.id).all()
toolkit = db.session.query(Toolkit).filter(Toolkit.name == toolkit_name,
Toolkit.organisation_id == organisation.id).first()
if not toolkit:
raise HTTPException(status_code=404, detail='ToolKit not found')
if toolkit.name not in [user_toolkit.name for user_toolkit in user_toolkits]:
raise HTTPException(status_code=403, detail='Unauthorized')

tool_configs = db.session.query(ToolConfig).filter(ToolConfig.toolkit_id == toolkit.id).all()
return tool_configs
Expand Down
18 changes: 1 addition & 17 deletions tests/unit_tests/controllers/test_tool_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def test_get_all_tool_configs_success(mocks):
patch('superagi.helper.auth.db') as mock_auth_db:
mock_db.session.query.return_value.filter_by.return_value.first.return_value = toolkit_1
mock_db.session.query.return_value.filter.return_value.all.side_effect = [
[toolkit_1, toolkit_2],
[tool_config]
]
response = client.get(f"/tool_configs/get/toolkit/test_toolkit_1")
Expand All @@ -98,28 +97,13 @@ def test_get_all_tool_configs_toolkit_not_found(mocks):
with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \
patch('superagi.controllers.tool_config.db') as mock_db, \
patch('superagi.helper.auth.db') as mock_auth_db:
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
mock_db.session.query.return_value.filter.return_value.first.return_value = None
response = client.get(f"/tool_configs/get/toolkit/nonexistent_toolkit")

# Assertions
assert response.status_code == 404
assert response.json() == {'detail': 'ToolKit not found'}


def test_get_all_tool_configs_unauthorized_access(mocks):
user_organisation, _, _, toolkit_1, toolkit_2 = mocks

with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \
patch('superagi.controllers.tool_config.db') as mock_db, \
patch('superagi.helper.auth.db') as mock_auth_db:
mock_db.session.query.return_value.filter_by.return_value.first.return_value = toolkit_1
response = client.get(f"/tool_configs/get/toolkit/test_toolkit_3")

# Assertions
assert response.status_code == 403
assert response.json() == {'detail': 'Unauthorized'}


def test_get_tool_config_success(mocks):
# Unpack the fixture data
user_organisation, user_toolkits, tool_config, toolkit_1, toolkit_2 = mocks
Expand Down

0 comments on commit 4366836

Please sign in to comment.