forked from oppia/oppia
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* implemented handler * files added * fixes * updated return codes * completed version * addressed changes * end-to-end implementation * sorting message dicts * addressed changes * updated controller utilising mapping * addressed final review comments * addressed review changes * moved func to exp_domain * addressed review comments * changed decorator to verify function * checked for validity of vm_id * addressed all changes, major upheaval * addressed review comments * changed save_job into create and update job methods * addressed review changes * added status change mapping
- Loading branch information
1 parent
45df731
commit 1533383
Showing
12 changed files
with
475 additions
and
167 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# Copyright 2017 The Oppia Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS-IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Controllers for communicating with the VM for training classifiers.""" | ||
|
||
import hashlib | ||
import hmac | ||
import json | ||
|
||
from core.controllers import base | ||
from core.domain import classifier_services | ||
from core.domain import config_domain | ||
|
||
import feconf | ||
|
||
|
||
# NOTE TO DEVELOPERS: This function should be kept in sync with its counterpart | ||
# in Oppia-ml. | ||
def generate_signature(secret, message): | ||
"""Generates digital signature for given data. | ||
Args: | ||
secret: str. The secret used to communicate with Oppia-ml. | ||
message: dict. The message payload data. | ||
Returns: | ||
str. The signature of the payload data. | ||
""" | ||
message_json = json.dumps(message, sort_keys=True) | ||
return hmac.new(secret, message_json, digestmod=hashlib.sha256).hexdigest() | ||
|
||
|
||
def validate_job_result_message_dict(message): | ||
"""Validates the data-type of the message payload data. | ||
Args: | ||
message: dict. The message payload data. | ||
Returns: | ||
bool. Whether the payload dict is valid. | ||
""" | ||
job_id = message.get('job_id') | ||
classifier_data = message.get('classifier_data') | ||
|
||
if not isinstance(job_id, basestring): | ||
return False | ||
if not isinstance(classifier_data, dict): | ||
return False | ||
return True | ||
|
||
|
||
def verify_signature(message, vm_id, received_signature): | ||
"""Function that checks if the signature received from the VM is valid. | ||
Args: | ||
message: dict. The message payload data. | ||
vm_id: str. The ID of the VM instance. | ||
received_signature: str. The signature received from the VM. | ||
Returns: | ||
bool. Whether the incoming request is valid. | ||
""" | ||
secret = None | ||
for val in config_domain.VMID_SHARED_SECRET_KEY_MAPPING.value: | ||
if val['vm_id'] == vm_id: | ||
secret = str(val['shared_secret_key']) | ||
break | ||
if secret is None: | ||
return False | ||
|
||
generated_signature = generate_signature(secret, message) | ||
if generated_signature != received_signature: | ||
return False | ||
return True | ||
|
||
|
||
class TrainedClassifierHandler(base.BaseHandler): | ||
"""This handler stores the result of the training job in datastore and | ||
updates the status of the job. | ||
""" | ||
|
||
REQUIRE_PAYLOAD_CSRF_CHECK = False | ||
|
||
def post(self): | ||
"""Handles POST requests.""" | ||
signature = self.payload.get('signature') | ||
message = self.payload.get('message') | ||
vm_id = self.payload.get('vm_id') | ||
if vm_id == feconf.DEFAULT_VM_ID and not feconf.DEV_MODE: | ||
raise self.UnauthorizedUserException | ||
|
||
if not validate_job_result_message_dict(message): | ||
raise self.InvalidInputException | ||
if not verify_signature(message, vm_id, signature): | ||
raise self.UnauthorizedUserException | ||
|
||
job_id = message['job_id'] | ||
classifier_data = message['classifier_data'] | ||
classifier_training_job = ( | ||
classifier_services.get_classifier_training_job_by_id(job_id)) | ||
if classifier_training_job.status == ( | ||
feconf.TRAINING_JOB_STATUS_FAILED): | ||
raise self.InternalErrorException( | ||
'The current status of the job cannot transition to COMPLETE.') | ||
|
||
try: | ||
classifier_services.create_classifier(job_id, classifier_data) | ||
except Exception as e: | ||
raise self.InternalErrorException(e) | ||
|
||
# Update status of the training job to 'COMPLETE'. | ||
classifier_services.mark_training_job_complete(job_id) | ||
|
||
return self.render_json({}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# Copyright 2017 The Oppia Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS-IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Tests for the controllers that communicate with VM for training | ||
classifiers.""" | ||
|
||
import os | ||
|
||
from core.controllers import classifier | ||
from core.domain import classifier_services | ||
from core.domain import exp_services | ||
from core.tests import test_utils | ||
import feconf | ||
|
||
|
||
class TrainedClassifierHandlerTest(test_utils.GenericTestBase): | ||
"""Test the handler for storing job result of training job.""" | ||
|
||
def setUp(self): | ||
super(TrainedClassifierHandlerTest, self).setUp() | ||
|
||
self.exp_id = 'exp_id1' | ||
self.title = 'Testing Classifier storing' | ||
self.category = 'Test' | ||
yaml_path = os.path.join( | ||
feconf.TESTS_DATA_DIR, 'string_classifier_test.yaml') | ||
with open(yaml_path, 'r') as yaml_file: | ||
self.yaml_content = yaml_file.read() | ||
|
||
assets_list = [] | ||
exp_services.save_new_exploration_from_yaml_and_assets( | ||
feconf.SYSTEM_COMMITTER_ID, self.yaml_content, self.exp_id, | ||
assets_list) | ||
self.exploration = exp_services.get_exploration_by_id(self.exp_id) | ||
|
||
state = self.exploration.states['Home'] | ||
algorithm_id = feconf.INTERACTION_CLASSIFIER_MAPPING[ | ||
state.interaction.id]['algorithm_id'] | ||
interaction_id = 'TextInput' | ||
training_data = state.get_training_data() | ||
self.classifier_data = { | ||
'_alpha': 0.1, | ||
'_beta': 0.001, | ||
'_prediction_threshold': 0.5, | ||
'_training_iterations': 25, | ||
'_prediction_iterations': 5, | ||
'_num_labels': 10, | ||
'_num_docs': 12, | ||
'_num_words': 20, | ||
'_label_to_id': {'text': 1}, | ||
'_word_to_id': {'hello': 2}, | ||
'_w_dp': [], | ||
'_b_dl': [], | ||
'_l_dp': [], | ||
'_c_dl': [], | ||
'_c_lw': [], | ||
'_c_l': [] | ||
} | ||
self.job_id = classifier_services.create_classifier_training_job( | ||
algorithm_id, interaction_id, self.exp_id, self.exploration.version, | ||
'Home', training_data, feconf.TRAINING_JOB_STATUS_PENDING) | ||
|
||
self.job_result_dict = { | ||
'job_id' : self.job_id, | ||
'classifier_data' : self.classifier_data, | ||
} | ||
|
||
self.payload = {} | ||
self.payload['vm_id'] = feconf.DEFAULT_VM_ID | ||
self.payload['message'] = self.job_result_dict | ||
secret = feconf.DEFAULT_VM_SHARED_SECRET | ||
self.payload['signature'] = classifier.generate_signature( | ||
secret, self.payload['message']) | ||
|
||
def test_trained_classifier_handler(self): | ||
# Normal end-to-end test. | ||
self.post_json('/ml/trainedclassifierhandler', self.payload, | ||
expect_errors=False, expected_status_int=200) | ||
classifier_obj = ( | ||
classifier_services.get_classifier_from_exploration_attributes( | ||
self.exp_id, self.exploration.version, 'Home')) | ||
self.assertEqual(classifier_obj.id, self.job_id) | ||
self.assertEqual(classifier_obj.exp_id, self.exp_id) | ||
self.assertEqual(classifier_obj.state_name, 'Home') | ||
self.assertEqual(classifier_obj.algorithm_id, 'LDAStringClassifier') | ||
self.assertEqual(classifier_obj.classifier_data, self.classifier_data) | ||
|
||
def test_error_on_prod_mode_and_default_vm_id(self): | ||
# Turn off DEV_MODE. | ||
with self.swap(feconf, 'DEV_MODE', False): | ||
self.post_json('/ml/trainedclassifierhandler', self.payload, | ||
expect_errors=True, expected_status_int=401) | ||
|
||
def test_error_on_different_signatures(self): | ||
# Altering data to result in different signatures. | ||
self.payload['message']['job_id'] = 'different_job_id' | ||
self.post_json('/ml/trainedclassifierhandler', self.payload, | ||
expect_errors=True, expected_status_int=401) | ||
|
||
def test_error_on_invalid_message(self): | ||
# Altering message dict to result in invalid dict. | ||
self.payload['message']['job_id'] = 1 | ||
self.post_json('/ml/trainedclassifierhandler', self.payload, | ||
expect_errors=True, expected_status_int=400) | ||
|
||
def test_error_on_existing_classifier(self): | ||
# Create ClassifierDataModel before the controller is called. | ||
classifier_services.create_classifier(self.job_id, self.classifier_data) | ||
self.post_json('/ml/trainedclassifierhandler', self.payload, | ||
expect_errors=True, expected_status_int=500) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.