Skip to content

Commit

Permalink
M-3.1 - Adding LesNet_labels file, adding test for inference and mode…
Browse files Browse the repository at this point in the history
…l services
  • Loading branch information
Thomasbehan committed Jun 3, 2024
1 parent de8661b commit 6dedb8d
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 1 deletion.
30 changes: 30 additions & 0 deletions models/LesNet_labels.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
[
"acrochordon",
"actinic keratosis",
"AIMP",
"angiofibroma or fibrous papule",
"angiokeratoma",
"angioma",
"atypical melanocytic proliferation",
"atypical spitz tumor",
"basal cell carcinoma",
"benign",
"cafe-au-lait macule",
"clear cell acanthoma",
"dermatofibroma",
"lentigo NOS",
"lentigo simplex",
"lichenoid keratosis",
"malignant",
"melanoma",
"neurofibroma",
"nevus",
"pigmented benign keratosis",
"scar",
"sebaceous hyperplasia",
"seborrheic keratosis",
"solar lentigo",
"squamous cell carcinoma",
"vascular lesion",
"verruca"
]
2 changes: 1 addition & 1 deletion skinvestigatorai/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,5 @@ class ModelConfig(object):
MAX_AUG_PER_IMAGE = 5000000
TRAIN_DIR = 'data/train'
MODEL_TYPE = "KERAS"
MODEL_NAME = "LesNetM31.keras"
MODEL_NAME = "LesNet.keras"
LABELS_NAME = "LesNet_labels.json"
90 changes: 90 additions & 0 deletions tests/test_inference_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from io import BytesIO
from unittest.mock import MagicMock, patch

import numpy as np
from pyramid.response import Response
import pytest
from PIL import Image

from skinvestigatorai.services.inference import Inference
from skinvestigatorai.services.model import SVModel


@pytest.fixture
def mock_svmodel():
sv_model = MagicMock(SVModel)
sv_model.load_model.return_value = (MagicMock(), ["class1", "class2"])
sv_model.preprocess_image_for_tflite = lambda x: x
return sv_model


@pytest.fixture
def inference(mock_svmodel):
with patch('skinvestigatorai.services.model.SVModel', return_value=mock_svmodel):
return Inference()


def create_mock_image():
image = Image.new('RGB', (100, 100))
img_byte_arr = BytesIO()
image.save(img_byte_arr, format='PNG')
img_byte_arr = BytesIO(img_byte_arr.getvalue())
return img_byte_arr


def test_predict_success(inference):
mock_image = create_mock_image()

inference.model.predict = MagicMock(return_value=np.array([[0.1, 0.9]]))

result = inference.predict(mock_image)

assert isinstance(result, dict)
assert 'prediction' in result
assert 'confidence' in result


def test_predict_failure(inference):
mock_image = create_mock_image()

inference.model.predict = MagicMock(return_value=np.array([[0.3, 0.2]]))

result = inference.predict(mock_image)

assert isinstance(result, Response)
assert result.status_code == 400


def test_is_image_similar(inference):
mock_image = np.random.rand(100, 100, 3)

inference.dataset_embedding = np.random.rand(2048)
inference._predict_similar = MagicMock(return_value=np.random.rand(2048))

result = inference.is_image_similar(mock_image, threshold=0.5)

assert result in [True, False]


def test__predict_similar_keras(inference):
mock_image = np.random.rand(100, 100, 3)
inference.model.predict = MagicMock(return_value=np.random.rand(1, 2048))

with patch('skinvestigatorai.config.model.ModelConfig.MODEL_TYPE', 'KERAS'):
result = inference._predict_similar(mock_image)

assert result is not None


def test__predict_similar_tflite(inference):
mock_image = np.random.rand(100, 100, 3)
inference.model.get_input_details = MagicMock(return_value=[{'index': 0, 'dtype': np.float32}])
inference.model.get_output_details = MagicMock(return_value=[{'index': 1}])
inference.model.set_tensor = MagicMock()
inference.model.invoke = MagicMock()
inference.model.get_tensor = MagicMock(return_value=np.random.rand(1, 2048))

with patch('skinvestigatorai.config.model.ModelConfig.MODEL_TYPE', 'TFLITE'):
result = inference._predict_similar(mock_image)

assert result is not None
69 changes: 69 additions & 0 deletions tests/test_model_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from unittest.mock import patch, MagicMock

import numpy as np
import pytest

from skinvestigatorai.config.model import ModelConfig
from skinvestigatorai.services.model import SVModel


@pytest.fixture
def sv_model():
return SVModel()


def test_create_feature_extractor_tflite(sv_model):
sv_model.model_type = 'TFLITE'
mock_model = MagicMock()
sv_model.model = mock_model
sv_model.create_feature_extractor()
assert sv_model.feature_extractor == mock_model


def test_create_feature_extractor_invalid_model_type(sv_model):
sv_model.model_type = 'INVALID'
with pytest.raises(ValueError, match="Unsupported model type. Please use 'KERAS' or 'TFLITE'."):
sv_model.create_feature_extractor()


def test_preprocess_image_for_tflite(sv_model):
img = np.random.rand(224, 224, 3).astype(np.float32)
processed_img = sv_model.preprocess_image_for_tflite(img)
assert processed_img.shape == (ModelConfig.IMG_SIZE[0], ModelConfig.IMG_SIZE[1], 3)
assert np.max(processed_img) <= 1.0
assert np.min(processed_img) >= 0.0


def test_evaluate_model(sv_model):
sv_model.model = MagicMock()
sv_model.model.evaluate.return_value = [0.5, 0.8, 0.7, 0.6]
test_datagen = MagicMock()
test_loss, test_acc, test_precision, test_recall = sv_model.evaluate_model(test_datagen)
assert test_loss == 0.5
assert test_acc == 0.8
assert test_precision == 0.7
assert test_recall == 0.6


@patch('tensorflow.summary.create_file_writer')
def test_run_experiments(mock_create_file_writer, sv_model):
sv_model.run_experiments = MagicMock()
train_ds = MagicMock()
val_ds = MagicMock()
sv_model.run_experiments(train_ds, val_ds)
sv_model.run_experiments.assert_called_once_with(train_ds, val_ds)


def test_save_model(sv_model):
sv_model.model = MagicMock()
with patch('builtins.open', MagicMock()):
with patch('tensorflow.keras.models.Model.save', MagicMock()):
sv_model.save_model()
sv_model.model.save.assert_called_once()


def test_load_model(sv_model):
with patch('os.path.exists', return_value=True):
with patch('tensorflow.keras.models.load_model', return_value=MagicMock()):
sv_model.load_model()
assert isinstance(sv_model.model, MagicMock)

0 comments on commit 6dedb8d

Please sign in to comment.