Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
luciferlinx101 authored Jun 29, 2023
1 parent b7fbc71 commit 7f1c1a9
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 9 deletions.
4 changes: 2 additions & 2 deletions superagi/controllers/tool_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


@router.post("/add/{toolkit_name}", status_code=201)
def update_tool_config(toolkit_name: str, configs: list):
def update_tool_config(toolkit_name: str, configs: list, organisation: Organisation = Depends(get_user_organisation)):
"""
Update tool configurations for a specific tool kit.
Expand All @@ -34,7 +34,7 @@ def update_tool_config(toolkit_name: str, configs: list):

try:
# Check if the tool kit exists
toolkit = Toolkit.get_toolkit_from_name(db.session, toolkit_name)
toolkit = Toolkit.get_toolkit_from_name(db.session, toolkit_name,organisation)
if toolkit is None:
raise HTTPException(status_code=404, detail="Tool kit not found")

Expand Down
4 changes: 2 additions & 2 deletions superagi/models/toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ def fetch_marketplace_detail(cls, search_str, toolkit_name):
return None

@staticmethod
def get_toolkit_from_name(session, toolkit_name):
toolkit = session.query(Toolkit).filter_by(name=toolkit_name).first()
def get_toolkit_from_name(session, toolkit_name, organisation):
toolkit = session.query(Toolkit).filter_by(name=toolkit_name, organisation_id=organisation.id).first()
if toolkit:
return toolkit
return None
Expand Down
12 changes: 7 additions & 5 deletions tests/unit_tests/models/test_toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,18 +154,19 @@ def test_fetch_marketplace_detail_error():
def test_get_toolkit_from_name_existing_toolkit(mock_session):
# Arrange
toolkit_name = "example_toolkit"
expected_toolkit = Toolkit(name=toolkit_name)
organisation = Organisation(id=1)
expected_toolkit = Toolkit(name=toolkit_name,organisation_id=organisation.id)

# Mock the session.query method
mock_session.query.return_value.filter_by.return_value.first.return_value = expected_toolkit

# Act
result = Toolkit.get_toolkit_from_name(mock_session, toolkit_name)
result = Toolkit.get_toolkit_from_name(mock_session, toolkit_name,organisation)

# Assert
assert result == expected_toolkit
mock_session.query.assert_called_once_with(Toolkit)
mock_session.query.return_value.filter_by.assert_called_once_with(name=toolkit_name)
mock_session.query.return_value.filter_by.assert_called_once_with(name=toolkit_name,organisation_id=organisation.id)
mock_session.query.return_value.filter_by.return_value.first.assert_called_once()

def test_get_toolkit_from_name_nonexistent_toolkit(mock_session):
Expand All @@ -174,14 +175,15 @@ def test_get_toolkit_from_name_nonexistent_toolkit(mock_session):

# Mock the session.query method to return None
mock_session.query.return_value.filter_by.return_value.first.return_value = None
organisation = Organisation(id=1)

# Act
result = Toolkit.get_toolkit_from_name(mock_session, toolkit_name)
result = Toolkit.get_toolkit_from_name(mock_session, toolkit_name,organisation)

# Assert
assert result is None
mock_session.query.assert_called_once_with(Toolkit)
mock_session.query.return_value.filter_by.assert_called_once_with(name=toolkit_name)
mock_session.query.return_value.filter_by.assert_called_once_with(name=toolkit_name,organisation_id=organisation.id)
mock_session.query.return_value.filter_by.return_value.first.assert_called_once()

def test_get_toolkit_installed_details(mock_session):
Expand Down

0 comments on commit 7f1c1a9

Please sign in to comment.