-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
M-3.1 - Adding LesNet_labels file, adding test for inference and mode…
…l services
- Loading branch information
1 parent
de8661b
commit 6dedb8d
Showing
4 changed files
with
190 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
[ | ||
"acrochordon", | ||
"actinic keratosis", | ||
"AIMP", | ||
"angiofibroma or fibrous papule", | ||
"angiokeratoma", | ||
"angioma", | ||
"atypical melanocytic proliferation", | ||
"atypical spitz tumor", | ||
"basal cell carcinoma", | ||
"benign", | ||
"cafe-au-lait macule", | ||
"clear cell acanthoma", | ||
"dermatofibroma", | ||
"lentigo NOS", | ||
"lentigo simplex", | ||
"lichenoid keratosis", | ||
"malignant", | ||
"melanoma", | ||
"neurofibroma", | ||
"nevus", | ||
"pigmented benign keratosis", | ||
"scar", | ||
"sebaceous hyperplasia", | ||
"seborrheic keratosis", | ||
"solar lentigo", | ||
"squamous cell carcinoma", | ||
"vascular lesion", | ||
"verruca" | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from io import BytesIO | ||
from unittest.mock import MagicMock, patch | ||
|
||
import numpy as np | ||
from pyramid.response import Response | ||
import pytest | ||
from PIL import Image | ||
|
||
from skinvestigatorai.services.inference import Inference | ||
from skinvestigatorai.services.model import SVModel | ||
|
||
|
||
@pytest.fixture | ||
def mock_svmodel(): | ||
sv_model = MagicMock(SVModel) | ||
sv_model.load_model.return_value = (MagicMock(), ["class1", "class2"]) | ||
sv_model.preprocess_image_for_tflite = lambda x: x | ||
return sv_model | ||
|
||
|
||
@pytest.fixture | ||
def inference(mock_svmodel): | ||
with patch('skinvestigatorai.services.model.SVModel', return_value=mock_svmodel): | ||
return Inference() | ||
|
||
|
||
def create_mock_image(): | ||
image = Image.new('RGB', (100, 100)) | ||
img_byte_arr = BytesIO() | ||
image.save(img_byte_arr, format='PNG') | ||
img_byte_arr = BytesIO(img_byte_arr.getvalue()) | ||
return img_byte_arr | ||
|
||
|
||
def test_predict_success(inference): | ||
mock_image = create_mock_image() | ||
|
||
inference.model.predict = MagicMock(return_value=np.array([[0.1, 0.9]])) | ||
|
||
result = inference.predict(mock_image) | ||
|
||
assert isinstance(result, dict) | ||
assert 'prediction' in result | ||
assert 'confidence' in result | ||
|
||
|
||
def test_predict_failure(inference): | ||
mock_image = create_mock_image() | ||
|
||
inference.model.predict = MagicMock(return_value=np.array([[0.3, 0.2]])) | ||
|
||
result = inference.predict(mock_image) | ||
|
||
assert isinstance(result, Response) | ||
assert result.status_code == 400 | ||
|
||
|
||
def test_is_image_similar(inference): | ||
mock_image = np.random.rand(100, 100, 3) | ||
|
||
inference.dataset_embedding = np.random.rand(2048) | ||
inference._predict_similar = MagicMock(return_value=np.random.rand(2048)) | ||
|
||
result = inference.is_image_similar(mock_image, threshold=0.5) | ||
|
||
assert result in [True, False] | ||
|
||
|
||
def test__predict_similar_keras(inference): | ||
mock_image = np.random.rand(100, 100, 3) | ||
inference.model.predict = MagicMock(return_value=np.random.rand(1, 2048)) | ||
|
||
with patch('skinvestigatorai.config.model.ModelConfig.MODEL_TYPE', 'KERAS'): | ||
result = inference._predict_similar(mock_image) | ||
|
||
assert result is not None | ||
|
||
|
||
def test__predict_similar_tflite(inference): | ||
mock_image = np.random.rand(100, 100, 3) | ||
inference.model.get_input_details = MagicMock(return_value=[{'index': 0, 'dtype': np.float32}]) | ||
inference.model.get_output_details = MagicMock(return_value=[{'index': 1}]) | ||
inference.model.set_tensor = MagicMock() | ||
inference.model.invoke = MagicMock() | ||
inference.model.get_tensor = MagicMock(return_value=np.random.rand(1, 2048)) | ||
|
||
with patch('skinvestigatorai.config.model.ModelConfig.MODEL_TYPE', 'TFLITE'): | ||
result = inference._predict_similar(mock_image) | ||
|
||
assert result is not None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from unittest.mock import patch, MagicMock | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from skinvestigatorai.config.model import ModelConfig | ||
from skinvestigatorai.services.model import SVModel | ||
|
||
|
||
@pytest.fixture | ||
def sv_model(): | ||
return SVModel() | ||
|
||
|
||
def test_create_feature_extractor_tflite(sv_model): | ||
sv_model.model_type = 'TFLITE' | ||
mock_model = MagicMock() | ||
sv_model.model = mock_model | ||
sv_model.create_feature_extractor() | ||
assert sv_model.feature_extractor == mock_model | ||
|
||
|
||
def test_create_feature_extractor_invalid_model_type(sv_model): | ||
sv_model.model_type = 'INVALID' | ||
with pytest.raises(ValueError, match="Unsupported model type. Please use 'KERAS' or 'TFLITE'."): | ||
sv_model.create_feature_extractor() | ||
|
||
|
||
def test_preprocess_image_for_tflite(sv_model): | ||
img = np.random.rand(224, 224, 3).astype(np.float32) | ||
processed_img = sv_model.preprocess_image_for_tflite(img) | ||
assert processed_img.shape == (ModelConfig.IMG_SIZE[0], ModelConfig.IMG_SIZE[1], 3) | ||
assert np.max(processed_img) <= 1.0 | ||
assert np.min(processed_img) >= 0.0 | ||
|
||
|
||
def test_evaluate_model(sv_model): | ||
sv_model.model = MagicMock() | ||
sv_model.model.evaluate.return_value = [0.5, 0.8, 0.7, 0.6] | ||
test_datagen = MagicMock() | ||
test_loss, test_acc, test_precision, test_recall = sv_model.evaluate_model(test_datagen) | ||
assert test_loss == 0.5 | ||
assert test_acc == 0.8 | ||
assert test_precision == 0.7 | ||
assert test_recall == 0.6 | ||
|
||
|
||
@patch('tensorflow.summary.create_file_writer') | ||
def test_run_experiments(mock_create_file_writer, sv_model): | ||
sv_model.run_experiments = MagicMock() | ||
train_ds = MagicMock() | ||
val_ds = MagicMock() | ||
sv_model.run_experiments(train_ds, val_ds) | ||
sv_model.run_experiments.assert_called_once_with(train_ds, val_ds) | ||
|
||
|
||
def test_save_model(sv_model): | ||
sv_model.model = MagicMock() | ||
with patch('builtins.open', MagicMock()): | ||
with patch('tensorflow.keras.models.Model.save', MagicMock()): | ||
sv_model.save_model() | ||
sv_model.model.save.assert_called_once() | ||
|
||
|
||
def test_load_model(sv_model): | ||
with patch('os.path.exists', return_value=True): | ||
with patch('tensorflow.keras.models.load_model', return_value=MagicMock()): | ||
sv_model.load_model() | ||
assert isinstance(sv_model.model, MagicMock) |