Skip to content

Commit

Permalink
Fix: oppia#3515 - StoreJobResult controller (oppia#3559)
Browse files Browse the repository at this point in the history
* implemented handler

* files added

* fixes

* updated return codes

* completed version

* addressed changes

* end-to-end implementation

* sorting message dicts

* addressed changes

* updated controller utilising mapping

* addressed final review comments

* addressed review changes

* moved func to exp_domain

* addressed review comments

* changed decorator to verify function

* checked for validity of vm_id

* addressed all changes, major upheaval

* addressed review comments

* changed save_job into create and update job methods

* addressed review changes

* added status change mapping
  • Loading branch information
pranavsid98 authored and seanlip committed Jul 16, 2017
1 parent 45df731 commit 1533383
Show file tree
Hide file tree
Showing 12 changed files with 475 additions and 167 deletions.
125 changes: 125 additions & 0 deletions core/controllers/classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright 2017 The Oppia Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Controllers for communicating with the VM for training classifiers."""

import hashlib
import hmac
import json

from core.controllers import base
from core.domain import classifier_services
from core.domain import config_domain

import feconf


# NOTE TO DEVELOPERS: This function should be kept in sync with its counterpart
# in Oppia-ml.
def generate_signature(secret, message):
"""Generates digital signature for given data.
Args:
secret: str. The secret used to communicate with Oppia-ml.
message: dict. The message payload data.
Returns:
str. The signature of the payload data.
"""
message_json = json.dumps(message, sort_keys=True)
return hmac.new(secret, message_json, digestmod=hashlib.sha256).hexdigest()


def validate_job_result_message_dict(message):
"""Validates the data-type of the message payload data.
Args:
message: dict. The message payload data.
Returns:
bool. Whether the payload dict is valid.
"""
job_id = message.get('job_id')
classifier_data = message.get('classifier_data')

if not isinstance(job_id, basestring):
return False
if not isinstance(classifier_data, dict):
return False
return True


def verify_signature(message, vm_id, received_signature):
"""Function that checks if the signature received from the VM is valid.
Args:
message: dict. The message payload data.
vm_id: str. The ID of the VM instance.
received_signature: str. The signature received from the VM.
Returns:
bool. Whether the incoming request is valid.
"""
secret = None
for val in config_domain.VMID_SHARED_SECRET_KEY_MAPPING.value:
if val['vm_id'] == vm_id:
secret = str(val['shared_secret_key'])
break
if secret is None:
return False

generated_signature = generate_signature(secret, message)
if generated_signature != received_signature:
return False
return True


class TrainedClassifierHandler(base.BaseHandler):
"""This handler stores the result of the training job in datastore and
updates the status of the job.
"""

REQUIRE_PAYLOAD_CSRF_CHECK = False

def post(self):
"""Handles POST requests."""
signature = self.payload.get('signature')
message = self.payload.get('message')
vm_id = self.payload.get('vm_id')
if vm_id == feconf.DEFAULT_VM_ID and not feconf.DEV_MODE:
raise self.UnauthorizedUserException

if not validate_job_result_message_dict(message):
raise self.InvalidInputException
if not verify_signature(message, vm_id, signature):
raise self.UnauthorizedUserException

job_id = message['job_id']
classifier_data = message['classifier_data']
classifier_training_job = (
classifier_services.get_classifier_training_job_by_id(job_id))
if classifier_training_job.status == (
feconf.TRAINING_JOB_STATUS_FAILED):
raise self.InternalErrorException(
'The current status of the job cannot transition to COMPLETE.')

try:
classifier_services.create_classifier(job_id, classifier_data)
except Exception as e:
raise self.InternalErrorException(e)

# Update status of the training job to 'COMPLETE'.
classifier_services.mark_training_job_complete(job_id)

return self.render_json({})
121 changes: 121 additions & 0 deletions core/controllers/classifier_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright 2017 The Oppia Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for the controllers that communicate with VM for training
classifiers."""

import os

from core.controllers import classifier
from core.domain import classifier_services
from core.domain import exp_services
from core.tests import test_utils
import feconf


class TrainedClassifierHandlerTest(test_utils.GenericTestBase):
"""Test the handler for storing job result of training job."""

def setUp(self):
super(TrainedClassifierHandlerTest, self).setUp()

self.exp_id = 'exp_id1'
self.title = 'Testing Classifier storing'
self.category = 'Test'
yaml_path = os.path.join(
feconf.TESTS_DATA_DIR, 'string_classifier_test.yaml')
with open(yaml_path, 'r') as yaml_file:
self.yaml_content = yaml_file.read()

assets_list = []
exp_services.save_new_exploration_from_yaml_and_assets(
feconf.SYSTEM_COMMITTER_ID, self.yaml_content, self.exp_id,
assets_list)
self.exploration = exp_services.get_exploration_by_id(self.exp_id)

state = self.exploration.states['Home']
algorithm_id = feconf.INTERACTION_CLASSIFIER_MAPPING[
state.interaction.id]['algorithm_id']
interaction_id = 'TextInput'
training_data = state.get_training_data()
self.classifier_data = {
'_alpha': 0.1,
'_beta': 0.001,
'_prediction_threshold': 0.5,
'_training_iterations': 25,
'_prediction_iterations': 5,
'_num_labels': 10,
'_num_docs': 12,
'_num_words': 20,
'_label_to_id': {'text': 1},
'_word_to_id': {'hello': 2},
'_w_dp': [],
'_b_dl': [],
'_l_dp': [],
'_c_dl': [],
'_c_lw': [],
'_c_l': []
}
self.job_id = classifier_services.create_classifier_training_job(
algorithm_id, interaction_id, self.exp_id, self.exploration.version,
'Home', training_data, feconf.TRAINING_JOB_STATUS_PENDING)

self.job_result_dict = {
'job_id' : self.job_id,
'classifier_data' : self.classifier_data,
}

self.payload = {}
self.payload['vm_id'] = feconf.DEFAULT_VM_ID
self.payload['message'] = self.job_result_dict
secret = feconf.DEFAULT_VM_SHARED_SECRET
self.payload['signature'] = classifier.generate_signature(
secret, self.payload['message'])

def test_trained_classifier_handler(self):
# Normal end-to-end test.
self.post_json('/ml/trainedclassifierhandler', self.payload,
expect_errors=False, expected_status_int=200)
classifier_obj = (
classifier_services.get_classifier_from_exploration_attributes(
self.exp_id, self.exploration.version, 'Home'))
self.assertEqual(classifier_obj.id, self.job_id)
self.assertEqual(classifier_obj.exp_id, self.exp_id)
self.assertEqual(classifier_obj.state_name, 'Home')
self.assertEqual(classifier_obj.algorithm_id, 'LDAStringClassifier')
self.assertEqual(classifier_obj.classifier_data, self.classifier_data)

def test_error_on_prod_mode_and_default_vm_id(self):
# Turn off DEV_MODE.
with self.swap(feconf, 'DEV_MODE', False):
self.post_json('/ml/trainedclassifierhandler', self.payload,
expect_errors=True, expected_status_int=401)

def test_error_on_different_signatures(self):
# Altering data to result in different signatures.
self.payload['message']['job_id'] = 'different_job_id'
self.post_json('/ml/trainedclassifierhandler', self.payload,
expect_errors=True, expected_status_int=401)

def test_error_on_invalid_message(self):
# Altering message dict to result in invalid dict.
self.payload['message']['job_id'] = 1
self.post_json('/ml/trainedclassifierhandler', self.payload,
expect_errors=True, expected_status_int=400)

def test_error_on_existing_classifier(self):
# Create ClassifierDataModel before the controller is called.
classifier_services.create_classifier(self.job_id, self.classifier_data)
self.post_json('/ml/trainedclassifierhandler', self.payload,
expect_errors=True, expected_status_int=500)
33 changes: 21 additions & 12 deletions core/domain/classifier_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,6 @@ def classifier_data(self):
def data_schema_version(self):
return self._data_schema_version

def update_state_name(self, state_name):
"""Updates the state_name attribute of the ClassifierData domain object.
Args:
state_name: str. The name of the updated state to which the
classifier belongs.
"""

self._state_name = state_name

def to_dict(self):
"""Constructs a dict representation of ClassifierData domain object.
Expand Down Expand Up @@ -185,6 +175,8 @@ class ClassifierTrainingJob(object):
job_id: str. The unique id of the classifier training job.
algorithm_id: str. The id of the algorithm that will be used for
generating the classifier.
interaction_id: str. The id of the interaction to which the algorithm
belongs.
exp_id: str. The id of the exploration that contains the state
for which the classifier will be generated.
exp_version: str. The version of the exploration when
Expand All @@ -210,14 +202,16 @@ class ClassifierTrainingJob(object):
"""

def __init__(self, job_id, algorithm_id, exp_id, exp_version,
state_name, status, training_data):
def __init__(self, job_id, algorithm_id, interaction_id, exp_id,
exp_version, state_name, status, training_data):
"""Constructs a ClassifierTrainingJob domain object.
Args:
job_id: str. The unique id of the classifier training job.
algorithm_id: str. The id of the algorithm that will be used for
generating the classifier.
interaction_id: str. The id of the interaction to which the algorithm
belongs.
exp_id: str. The id of the exploration id that contains the state
for which classifier will be generated.
exp_version: str. The version of the exploration when
Expand All @@ -243,6 +237,7 @@ def __init__(self, job_id, algorithm_id, exp_id, exp_version,
"""
self._job_id = job_id
self._algorithm_id = algorithm_id
self._interaction_id = interaction_id
self._exp_id = exp_id
self._exp_version = exp_version
self._state_name = state_name
Expand All @@ -257,6 +252,10 @@ def job_id(self):
def algorithm_id(self):
return self._algorithm_id

@property
def interaction_id(self):
return self._interaction_id

@property
def exp_id(self):
return self._exp_id
Expand Down Expand Up @@ -297,6 +296,7 @@ def to_dict(self):
return {
'job_id': self._job_id,
'algorithm_id': self._algorithm_id,
'interaction_id': self._interaction_id,
'exp_id': self._exp_id,
'exp_version': self._exp_version,
'state_name': self._state_name,
Expand Down Expand Up @@ -332,6 +332,15 @@ def validate(self):
feconf.ALLOWED_TRAINING_JOB_STATUSES,
self.exp_version)

if not isinstance(self.interaction_id, basestring):
raise utils.ValidationError(
'Expected interaction_id to be a string, received %s' %
self.interaction_id)

if self.interaction_id not in feconf.INTERACTION_CLASSIFIER_MAPPING:
raise utils.ValidationError(
'Invalid interaction id: %s' % self.interaction_id)

if not isinstance(self.algorithm_id, basestring):
raise utils.ValidationError(
'Expected algorithm_id to be a string, received %s' %
Expand Down
3 changes: 3 additions & 0 deletions core/domain/classifier_domain_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def _get_training_job_from_dict(self, training_job_dict):
training_job = classifier_domain.ClassifierTrainingJob(
training_job_dict['job_id'],
training_job_dict['algorithm_id'],
training_job_dict['interaction_id'],
training_job_dict['exp_id'],
training_job_dict['exp_version'],
training_job_dict['state_name'],
Expand All @@ -184,6 +185,7 @@ def test_to_dict(self):
expected_training_job_dict = {
'job_id': 'exp_id1.SOME_RANDOM_STRING',
'algorithm_id': 'LDAStringClassifier',
'interaction_id': 'TextInput',
'exp_id': 'exp_id1',
'exp_version': 1,
'state_name': 'a state name',
Expand Down Expand Up @@ -224,6 +226,7 @@ def test_validation(self):
'exp_version': 1,
'state_name': 'some state',
'algorithm_id': 'LDAStringClassifier',
'interaction_id': 'TextInput',
'training_data': training_data,
'status': 'NEW'
}
Expand Down
Loading

0 comments on commit 1533383

Please sign in to comment.