Skip to content

Commit

Permalink
Merge pull request #245 from wtsi-hgi/feature/accept-best-match
Browse files Browse the repository at this point in the history
Fixes content negotiation to respect q-values
  • Loading branch information
joshfriend committed Jan 20, 2015
2 parents bf983cf + a9667e2 commit f59fbdf
Show file tree
Hide file tree
Showing 2 changed files with 283 additions and 17 deletions.
63 changes: 46 additions & 17 deletions flask_restful/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@
import re
from flask import request, url_for, current_app
from flask import abort as original_flask_abort
from flask import make_response as original_flask_make_response
from flask.views import MethodView
from flask.signals import got_request_exception
from werkzeug.exceptions import HTTPException, MethodNotAllowed, NotFound
from werkzeug.exceptions import HTTPException, MethodNotAllowed, NotFound, NotAcceptable, InternalServerError
from werkzeug.http import HTTP_STATUS_CODES
from werkzeug.wrappers import Response as ResponseBase
from flask.ext.restful.utils import error_data, unpack, OrderedDict
from flask.ext.restful.representations.json import output_json
import sys
from flask.helpers import _endpoint_from_view_func
from types import MethodType
import operator


__all__ = ('Api', 'Resource', 'marshal', 'marshal_with', 'marshal_with_field', 'abort')
Expand Down Expand Up @@ -278,7 +280,6 @@ def handle_error(self, e):
raise
else:
raise e

code = getattr(e, 'code', 500)
data = getattr(e, 'data', error_data(code))
headers = {}
Expand Down Expand Up @@ -320,7 +321,21 @@ def handle_error(self, e):
code = custom_data.get('status', 500)
data.update(custom_data)

resp = self.make_response(data, code, headers)
if code == 406 and self.default_mediatype is None:
# if we are handling NotAcceptable (406), make sure that
# make_response uses a representation we support as the
# default mediatype (so that make_response doesn't throw
# another NotAcceptable error).
supported_mediatypes = list(self.representations.keys())
fallback_mediatype = supported_mediatypes[0] if supported_mediatypes else "text/plain"
resp = self.make_response(
data,
code,
headers,
fallback_mediatype = fallback_mediatype
)
else:
resp = self.make_response(data, code, headers)

if code == 401:
resp = self.unauthorized(resp)
Expand Down Expand Up @@ -442,20 +457,34 @@ def url_for(self, resource, **values):
def make_response(self, data, *args, **kwargs):
"""Looks up the representation transformer for the requested media
type, invoking the transformer to create a response object. This
defaults to (application/json) if no transformer is found for the
requested mediatype.
defaults to default_mediatype if no transformer is found for the
requested mediatype. If default_mediatype is None, a 406 Not
Acceptable response will be sent as per RFC 2616 section 14.1
:param data: Python object containing response data to be transformed
"""
for mediatype in self.mediatypes() + [self.default_mediatype]:
if mediatype in self.representations:
resp = self.representations[mediatype](data, *args, **kwargs)
resp.headers['Content-Type'] = mediatype
return resp
default_mediatype = kwargs.pop('fallback_mediatype', None) or self.default_mediatype
mediatype = request.accept_mimetypes.best_match(
self.representations,
default=default_mediatype,
)
if mediatype is None:
raise NotAcceptable()
if mediatype in self.representations:
resp = self.representations[mediatype](data, *args, **kwargs)
resp.headers['Content-Type'] = mediatype
return resp
elif mediatype == 'text/plain':
resp = original_flask_make_response(str(data), *args, **kwargs)
resp.headers['Content-Type'] = 'text/plain'
return resp
else:
raise InternalServerError()

def mediatypes(self):
"""Returns a list of requested mediatypes sent in the Accept header"""
return [h for h, q in request.accept_mimetypes]
return [h for h, q in sorted(request.accept_mimetypes,
key=operator.itemgetter(1), reverse=True)]

def representation(self, mediatype):
"""Allows additional representation transformers to be declared for the
Expand Down Expand Up @@ -526,12 +555,12 @@ def dispatch_request(self, *args, **kwargs):
representations = self.representations or {}

#noinspection PyUnresolvedReferences
for mediatype in self.mediatypes():
if mediatype in representations:
data, code, headers = unpack(resp)
resp = representations[mediatype](data, code, headers)
resp.headers['Content-Type'] = mediatype
return resp
mediatype = request.accept_mimetypes.best_match(representations, default=None)
if mediatype in representations:
data, code, headers = unpack(resp)
resp = representations[mediatype](data, code, headers)
resp.headers['Content-Type'] = mediatype
return resp

return resp

Expand Down
237 changes: 237 additions & 0 deletions tests/test_accept.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
import unittest
from flask import Flask
import flask_restful
from werkzeug import exceptions
from nose.tools import assert_equals
from nose import SkipTest
import functools


def expected_failure(test):
@functools.wraps(test)
def inner(*args, **kwargs):
try:
test(*args, **kwargs)
except Exception:
raise SkipTest
else:
raise AssertionError('Failure expected')
return inner


class AcceptTestCase(unittest.TestCase):

def test_accept_default_application_json(self):

class Foo(flask_restful.Resource):
def get(self):
return "data"

app = Flask(__name__)
api = flask_restful.Api(app)

api.add_resource(Foo, '/')

with app.test_client() as client:
res = client.get('/', headers=[('Accept', 'application/json')])
assert_equals(res.status_code, 200)
assert_equals(res.content_type, 'application/json')


def test_accept_no_default_match_acceptable(self):

class Foo(flask_restful.Resource):
def get(self):
return "data"

app = Flask(__name__)
api = flask_restful.Api(app, default_mediatype=None)

api.add_resource(Foo, '/')

with app.test_client() as client:
res = client.get('/', headers=[('Accept', 'application/json')])
assert_equals(res.status_code, 200)
assert_equals(res.content_type, 'application/json')


def test_accept_default_override_accept(self):

class Foo(flask_restful.Resource):
def get(self):
return "data"

app = Flask(__name__)
api = flask_restful.Api(app)

api.add_resource(Foo, '/')

with app.test_client() as client:
res = client.get('/', headers=[('Accept', 'text/plain')])
assert_equals(res.status_code, 200)
assert_equals(res.content_type, 'application/json')


def test_accept_no_default_no_match_not_acceptable(self):

class Foo(flask_restful.Resource):
def get(self):
return "data"

app = Flask(__name__)
api = flask_restful.Api(app, default_mediatype=None)

api.add_resource(Foo, '/')

with app.test_client() as client:
res = client.get('/', headers=[('Accept', 'text/plain')])
assert_equals(res.status_code, 406)
assert_equals(res.content_type, 'application/json')


def test_accept_no_default_custom_repr_match(self):

class Foo(flask_restful.Resource):
def get(self):
return "data"

app = Flask(__name__)
api = flask_restful.Api(app, default_mediatype=None)
api.representations = {}

@api.representation('text/plain')
def text_rep(data, status_code, headers=None):
resp = app.make_response((str(data), status_code, headers))
return resp

api.add_resource(Foo, '/')

with app.test_client() as client:
res = client.get('/', headers=[('Accept', 'text/plain')])
assert_equals(res.status_code, 200)
assert_equals(res.content_type, 'text/plain')


def test_accept_no_default_custom_repr_not_acceptable(self):

class Foo(flask_restful.Resource):
def get(self):
return "data"

app = Flask(__name__)
api = flask_restful.Api(app, default_mediatype=None)
api.representations = {}

@api.representation('text/plain')
def text_rep(data, status_code, headers=None):
resp = app.make_response((str(data), status_code, headers))
return resp

api.add_resource(Foo, '/')

with app.test_client() as client:
res = client.get('/', headers=[('Accept', 'application/json')])
assert_equals(res.status_code, 406)
assert_equals(res.content_type, 'text/plain')


@expected_failure
def test_accept_no_default_match_q0_not_acceptable(self):
"""
q=0 should be considered NotAcceptable,
but this depends on werkzeug >= 1.0 which is not yet released
so this test is expected to fail until we depend on werkzeug >= 1.0
"""
class Foo(flask_restful.Resource):
def get(self):
return "data"

app = Flask(__name__)
api = flask_restful.Api(app, default_mediatype=None)

api.add_resource(Foo, '/')

with app.test_client() as client:
res = client.get('/', headers=[('Accept', 'application/json; q=0')])
assert_equals(res.status_code, 406)
assert_equals(res.content_type, 'application/json')

def test_accept_no_default_accept_highest_quality_of_two(self):
class Foo(flask_restful.Resource):
def get(self):
return "data"

app = Flask(__name__)
api = flask_restful.Api(app, default_mediatype=None)

@api.representation('text/plain')
def text_rep(data, status_code, headers=None):
resp = app.make_response((str(data), status_code, headers))
return resp

api.add_resource(Foo, '/')

with app.test_client() as client:
res = client.get('/', headers=[('Accept', 'application/json; q=0.1, text/plain; q=1.0')])
assert_equals(res.status_code, 200)
assert_equals(res.content_type, 'text/plain')


def test_accept_no_default_accept_highest_quality_of_three(self):
class Foo(flask_restful.Resource):
def get(self):
return "data"

app = Flask(__name__)
api = flask_restful.Api(app, default_mediatype=None)

@api.representation('text/html')
@api.representation('text/plain')
def text_rep(data, status_code, headers=None):
resp = app.make_response((str(data), status_code, headers))
return resp

api.add_resource(Foo, '/')

with app.test_client() as client:
res = client.get('/', headers=[('Accept', 'application/json; q=0.1, text/plain; q=0.3, text/html; q=0.2')])
assert_equals(res.status_code, 200)
assert_equals(res.content_type, 'text/plain')


def test_accept_no_default_no_representations(self):

class Foo(flask_restful.Resource):
def get(self):
return "data"

app = Flask(__name__)
api = flask_restful.Api(app, default_mediatype=None)
api.representations = {}

api.add_resource(Foo, '/')

with app.test_client() as client:
res = client.get('/', headers=[('Accept', 'text/plain')])
assert_equals(res.status_code, 406)
assert_equals(res.content_type, 'text/plain')

def test_accept_invalid_default_no_representations(self):

class Foo(flask_restful.Resource):
def get(self):
return "data"

app = Flask(__name__)
api = flask_restful.Api(app, default_mediatype='nonexistant/mediatype')
api.representations = {}

api.add_resource(Foo, '/')

with app.test_client() as client:
res = client.get('/', headers=[('Accept', 'text/plain')])
assert_equals(res.status_code, 500)




0 comments on commit f59fbdf

Please sign in to comment.