Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixed error in test_call functions relating to tools #379

Merged
merged 9 commits into from
Sep 9, 2024
24 changes: 15 additions & 9 deletions swarmauri/community/tools/concrete/EntityRecognitionTool.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
import json
from typing import List, Literal, Dict

from transformers import pipeline, logging as hf_logging
from swarmauri.standard.tools.base.ToolBase import ToolBase
from swarmauri.standard.tools.concrete.Parameter import Parameter
from typing_extensions import Required

hf_logging.set_verbosity_error()

class EntityRecognitionTool(ToolBase):
def __init__(self):
parameters = [
Parameter("text","string","The text for entity recognition",True)
]
super().__init__(name="EntityRecognitionTool",
description="Extracts named entities from text",
parameters=parameters)

name: str = "EntityRecognitionTool"
description: str = "Extracts named entities from text"
parameters = [
Parameter(
name="text",
type="string",
description="The text for entity recognition",
required=True
)
]
type: Literal['EntityRecognitionTool'] = 'EntityRecognitionTool'

def __call__(self, text: str) -> dict:
def __call__(self, text: str) -> Dict[str, str]:
try:
self.nlp = pipeline("ner")
entities = self.nlp(text)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ def __call__(self, input_text: str) -> Dict[str, str]:
num_words = self.count_words(text)
num_characters = self.count_characters(text)
if num_sentences == 0 or num_words == 0:
return "0.0"
return {"ari_score": 0.0}
characters_per_word = num_characters / num_words
words_per_sentence = num_words / num_sentences
ari_score = 4.71 * characters_per_word + 0.5 * words_per_sentence - 21.43
return {"ari_score": str(ari_score)}
return {"ari_score": ari_score}
else:
raise ValueError("Invalid input for AutomatedReadabilityIndexTool.")

Expand Down
6 changes: 6 additions & 0 deletions swarmauri/standard/tools/concrete/TemperatureConverterTool.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,22 @@ def __call__(self, from_unit: str, to_unit: str, value: float) -> Dict[str, str]
result = (value * 9 / 5) + 32
elif to_unit == "kelvin":
result = value + 273.15
else:
return "Error: Unknown temperature unit."
elif from_unit == "fahrenheit":
if to_unit == "celsius":
result = (value - 32) * 5 / 9
elif to_unit == "kelvin":
result = (value - 32) * 5 / 9 + 273.15
else:
return "Error: Unknown temperature unit."
elif from_unit == "kelvin":
if to_unit == "celsius":
result = value - 273.15
elif to_unit == "fahrenheit":
result = (value - 273.15) * 9 / 5 + 32
else:
return "Error: Unknown temperature unit."
else:
return "Error: Unknown temperature unit."

Expand Down
2 changes: 1 addition & 1 deletion swarmauri/standard/tools/concrete/TextLengthTool.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __call__(self, text: str) -> Dict[str, int]:
Returns:
dict: A dictionary containing the number of characters, words, and sentences.
"""
return {"text_length": self.calculate_text_length(text)}
return self.calculate_text_length(text)

def calculate_text_length(self, text: str) -> Dict[str, int]:
"""
Expand Down
5 changes: 2 additions & 3 deletions tests/tools/AdditionTool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,13 @@ def test_serialization():
(2, 3, {'sum'}, str, "5"), # Test case 1: positive integers
(-2, -3, {'sum'}, str, "-5"), # Test case 2: negative integers
(0, 0, {'sum'}, str, "0"), # Test case 3: zero values
(2.5, 3.5, {'sum'}, str, "6.0"),# Test case 4: floating-point numbers
("2", "3", {'sum'}, str, "5") # Test case 5: string inputs (assuming the Tool can handle this)
(2.5, 3.5, {'sum'}, str, "6.0"), # Test case 4: floating-point numbers
]
)
def test_call(input_a, input_b, expected_keys, expected_type, expected_sum):
tool = Tool()

result = tool(int(input_a), int(input_b))
result = tool(input_a, input_b)

assert isinstance(result, dict), f"Expected dict, but got {type(result).__name__}"
assert expected_keys.issubset(result.keys()), f"Expected keys {expected_keys} but got {result.keys()}"
Expand Down
12 changes: 6 additions & 6 deletions tests/tools/AutomatedReadabilityIndex_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,19 @@ def test_serialization():
),
(
'Lorem ipsum dolor sit amet, consectetur adipiscing elit.',
2.50 # Replace with the expected ARI score for this input
11.41 # Replace with the expected ARI score for this input
),
(
'A short sentence.',
5.00 # Replace with the expected ARI score for this input
3.62 # Replace with the expected ARI score for this input
),
(
'',
0.00 # Replace with the expected ARI score for this input
),
(
'Some longer text to test the algorithm with more complexity and variability in sentence length and word choice.',
3.25 # Replace with the expected ARI score for this input
12.16 # Replace with the expected ARI score for this input
)
]
)
Expand All @@ -57,7 +57,7 @@ def test_call(input_text, expected_score):
assert isinstance(result, dict), f"Expected dict, but got {type(result).__name__}"
assert expected_keys.issubset(result.keys()), f"Expected keys {expected_keys} but got {result.keys()}"
assert isinstance(result.get("ari_score"),
str), f"Expected str, but got {type(result.get('ari_score')).__name__}"
float), f"Expected float, but got {type(result.get('ari_score')).__name__}"

assert result.get("ari_score") == pytest.approx(expected_score,
rel=1e-2), f"Expected ARI score {pytest.approx(expected_score, rel=1e-2)}, but got {result.get('ari_score')}"
assert result.get("ari_score") == pytest.approx(expected_score, rel=1e-2), \
f"Expected ARI score 22. {expected_score} ± {1e-2 * expected_score}, but got {result.get('ari_score')}"
2 changes: 1 addition & 1 deletion tests/tools/CalculatorTool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_serialization():
('add', 2, 3, '5'), # Addition
('subtract', 5, 3, '2'), # Subtraction
('multiply', 2, 3, '6'), # Multiplication
('divide', 6, 3, '2'), # Division
('divide', 6, 3, '2.0'), # Division
('divide', 5, 0, 'Error: Division by zero.'), # Division by zero, adjust based on your expected behavior
('unknown_ops', 5, 0, 'Error: Unknown operation.')
]
Expand Down
91 changes: 44 additions & 47 deletions tests/tools/CodeExtractorTool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,68 +158,68 @@ def test_stop_collecting_lines():
"file_contents, extract_documentation, to_be_ignored, expected_code",
[
(
'''
"""
This is a docstring.
"""
def foo():
pass
def bar():
pass
# non-essentials
''',
'''
"""
This is a docstring.
"""
def foo():
pass
def bar():
pass
# non-essentials
''',
True,
[],
'def foo():\n pass\n\ndef bar():\n pass\n'
),
(
'''
"""
This is a docstring.
"""
def foo():
pass
def bar():
pass
# non-essentials
''',
'''
"""
This is a docstring.
"""
def foo():
pass
def bar():
pass
# non-essentials
''',
False,
[],
'def foo():\n pass\n\ndef bar():\n pass\n'
),
(
'''
"""
This is a docstring.
"""
def foo():
pass
def bar():
pass
# non-essentials
''',
'''
"""
This is a docstring.
"""
def foo():
pass
def bar():
pass
# non-essentials
''',
True,
['foo'],
'def bar():\n pass\n'
),
(
'''
def foo():
pass
def bar():
pass
''',
'''
def foo():
pass
def bar():
pass
''',
True,
[],
'def foo():\n pass\n\ndef bar():\n pass\n'
),
(
'''
def foo():
pass
def bar():
pass
''',
'''
def foo():
pass
def bar():
pass
''',
False,
['foo'],
'def bar():\n pass\n'
Expand All @@ -234,10 +234,7 @@ def test_call(file_contents, extract_documentation, to_be_ignored, expected_code
with patch("builtins.open", mock_open(read_data=file_contents)):
result = tool(file_name="fake_file.py", extract_documentation=extract_documentation, to_be_ignored=to_be_ignored)

assert result == expected_code

assert isinstance(result, dict), f"Expected dict, but got {type(result).__name__}"
assert expected_keys.issubset(result.keys()), f"Expected keys {expected_keys} but got {result.keys()}"
assert isinstance(result.get("code"),
str), f"Expected str, but got {type(result.get('code')).__name__}"
assert result.get("code") == expected_code, f"Expected Extracted Code {expected_code}, but got {result.get('code')}"
assert isinstance(result.get("code"), str), f"Expected str, but got {type(result.get('code')).__name__}"
assert result.get("code") == expected_code, f"Expected Extracted Code {expected_code}, but got {result.get('code')}"
2 changes: 1 addition & 1 deletion tests/tools/GunningFogTool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_serialization():
@pytest.mark.parametrize(
"input_text, num_of_major_punctuations, num_of_words, num_of_three_plus_syllable_words, expected_score",
[
("This is a sample sentence. It is used to test the Gunning-Fog tool.", 2, 13, 1, 1.24), # Test case 1
("This is a sample sentence. It is used to test the Gunning-Fog tool.", 2, 13, 1, 5.65), # Test case 1
("Another example with more complex sentences; used for testing.", 3, 10, 2, 16.93), # Test case 2
("Short sentence.", 1, 3, 0, 20.8), # Test case 3
("Punctuation-heavy text! Is it really? Yes, it is! 42", 5, 10, 1, 5.0), # Test case 4
Expand Down
1 change: 0 additions & 1 deletion tests/tools/TemperatureConverterTool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def test_ubc_type():
)
def test_call(from_unit, to_unit, value, expected_result):
tool = Tool()

expected_keys = {f"temperature_in_{to_unit}"}

result = tool(from_unit=from_unit, to_unit=to_unit, value=value)
Expand Down