Skip to content

Commit

Permalink
Merge httpchannel-proxy-8193-3: Add proxy HTTPChannel object to allow…
Browse files Browse the repository at this point in the history
… swapping of channel after connection establishment

Author: lukasa
Reviewers: glyph, adiroiban, hawkowl
Fixes: twisted#8193

git-svn-id: svn://svn.twistedmatrix.com/svn/Twisted/trunk@46999 bbbe8e31-12d6-0310-92fd-ac37d47ddeeb
  • Loading branch information
hawkowl committed Mar 15, 2016
1 parent b462a59 commit 4b23d59
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 6 deletions.
Empty file added twisted/topfiles/8193.misc
Empty file.
91 changes: 90 additions & 1 deletion twisted/web/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,14 @@ def _parseHeader(line):
from twisted.python.components import proxyForInterface
from twisted.internet import interfaces, reactor, protocol, address
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IProtocol
from twisted.protocols import policies, basic

from twisted.web.iweb import IRequest, IAccessLogFormatter
from twisted.web.http_headers import Headers

H2_ENABLED = False

from twisted.web._responses import (
SWITCHING,

Expand Down Expand Up @@ -1978,6 +1981,92 @@ def proxiedLogFormatter(timestamp, request):



class _GenericHTTPChannelProtocol(proxyForInterface(IProtocol, "_channel")):
"""
A proxy object that wraps one of the HTTP protocol objects, and switches
between them depending on TLS negotiated protocol.
@ivar _negotiatedProtocol: The protocol negotiated with ALPN or NPN, if
any.
@type _negotiatedProtocol: Either a bytestring containing the ALPN token
for the negotiated protocol, or C{None} if no protocol has yet been
negotiated.
@ivar _channel: The object capable of behaving like a L{HTTPChannel} that
is backing this object. By default this is a L{HTTPChannel}, but if a
HTTP protocol upgrade takes place this may be a different channel
object. Must implement L{IProtocol}.
@type _channel: L{HTTPChannel}
@ivar _requestFactory: A callable to use to build L{IRequest} objects.
@type _requestFactory: L{IRequest}
@ivar _site: A reference to the creating L{twisted.web.server.Site} object.
@type _site: L{twisted.web.server.Site}
"""
_negotiatedProtocol = None
_requestFactory = Request
_site = None


@property
def requestFactory(self):
return self._channel.requestFactory


@requestFactory.setter
def requestFactory(self, value):
self._requestFactory = value
self._channel.requestFactory = value


@property
def site(self):
return self._channel.site


@site.setter
def site(self, value):
self._site = value
self._channel.site = value


def dataReceived(self, data):
"""
A override of L{IProtocol.dataReceived} that checks what protocol we're
using.
"""
if self._negotiatedProtocol is None:
try:
negotiatedProtocol = self._channel.transport.negotiatedProtocol
except AttributeError:
# Plaintext HTTP, always HTTP/1.1
negotiatedProtocol = b'http/1.1'

if negotiatedProtocol is None:
negotiatedProtocol = b'http/1.1'

if negotiatedProtocol == b'h2':
assert H2_ENABLED, "Cannot negotiate HTTP/2 without support."
else:
# Only HTTP/2 and HTTP/1.1 are supported right now.
assert negotiatedProtocol == b'http/1.1', \
"Unsupported protocol negotiated"

self._negotiatedProtocol = negotiatedProtocol

return self._channel.dataReceived(data)



def _genericHTTPChannelProtocolFactory(self):
"""
Returns an appropriately initialized _GenericHTTPChannelProtocol.
"""
return _GenericHTTPChannelProtocol(HTTPChannel())



class HTTPFactory(protocol.ServerFactory):
"""
Factory for HTTP server.
Expand All @@ -2001,7 +2090,7 @@ class HTTPFactory(protocol.ServerFactory):
timestamps.
"""

protocol = HTTPChannel
protocol = _genericHTTPChannelProtocolFactory

logPath = None

Expand Down
88 changes: 88 additions & 0 deletions twisted/web/test/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,94 @@ def assertResponseEquals(self, response, expectedResponse):
self.assertEqual(response, expectedResponse)



class ProtocolNegotiationTests(unittest.TestCase):
requests = (
b"GET / HTTP/1.1\r\n"
b"Accept: text/html\r\n"
b"Connection: close\r\n"
b"\r\n"
b"GET / HTTP/1.0\r\n"
b"\r\n")


def _negotiatedProtocolForTransportInstance(self, t):
"""
Run a request using the specific instance of a transport. Returns the
negotiated protocol string.
"""
a = http._genericHTTPChannelProtocolFactory(b'')
a.requestFactory = DummyHTTPHandler
a.makeConnection(t)
# one byte at a time, to stress it.
for byte in iterbytes(self.requests):
a.dataReceived(byte)
a.connectionLost(IOError("all done"))
return a._negotiatedProtocol


def test_protocolUnspecified(self):
"""
If the transport has no support for protocol negotiation (no
negotiatedProtocol attribute), HTTP/1.1 is assumed.
"""
b = StringTransport()
negotiatedProtocol = self._negotiatedProtocolForTransportInstance(b)
self.assertEqual(negotiatedProtocol, b'http/1.1')


def test_protocolNone(self):
"""
If the transport has no support for protocol negotiation (returns None
for negotiatedProtocol), HTTP/1.1 is assumed.
"""
b = StringTransport()
b.negotiatedProtocol = None
negotiatedProtocol = self._negotiatedProtocolForTransportInstance(b)
self.assertEqual(negotiatedProtocol, b'http/1.1')


def test_http11(self):
"""
If the transport reports that HTTP/1.1 is negotiated, that's what's
negotiated.
"""
b = StringTransport()
b.negotiatedProtocol = b'http/1.1'
negotiatedProtocol = self._negotiatedProtocolForTransportInstance(b)
self.assertEqual(negotiatedProtocol, b'http/1.1')


def test_http2(self):
"""
If the transport reports that HTTP/2 is negotiated, that's what's
negotiated. Currently HTTP/2 is unsupported, so this raises an
AssertionError.
"""
b = StringTransport()
b.negotiatedProtocol = b'h2'
self.assertRaises(
AssertionError,
self._negotiatedProtocolForTransportInstance,
b,
)


def test_unknownProtocol(self):
"""
If the transport reports that a protocol other than HTTP/1.1 or HTTP/2
is negotiated, an error occurs.
"""
b = StringTransport()
b.negotiatedProtocol = b'smtp'
self.assertRaises(
AssertionError,
self._negotiatedProtocolForTransportInstance,
b,
)



class HTTPLoopbackTests(unittest.TestCase):

expectedHeaders = {b'request': b'/foo/bar',
Expand Down
10 changes: 5 additions & 5 deletions twisted/web/test/test_web.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def _modifiedTest(self, modifiedSince=None, etag=None):
else:
validator = b"If-Not-Match: " + etag
for line in [b"GET / HTTP/1.1", validator, b""]:
self.channel.lineReceived(line)
self.channel.dataReceived(line + b'\r\n')
result = self.transport.getvalue()
self.assertEqual(httpCode(result), http.OK)
self.assertEqual(httpBody(result), b"correct")
Expand All @@ -346,7 +346,7 @@ def test_unmodified(self):
"""
for line in [b"GET / HTTP/1.1",
b"If-Modified-Since: " + http.datetimeToString(100), b""]:
self.channel.lineReceived(line)
self.channel.dataReceived(line + b'\r\n')
result = self.transport.getvalue()
self.assertEqual(httpCode(result), http.NOT_MODIFIED)
self.assertEqual(httpBody(result), b"")
Expand Down Expand Up @@ -414,7 +414,7 @@ def test_etagMatched(self):
with an empty response body.
"""
for line in [b"GET / HTTP/1.1", b"If-None-Match: MatchingTag", b""]:
self.channel.lineReceived(line)
self.channel.dataReceived(line + b'\r\n')
result = self.transport.getvalue()
self.assertEqual(httpHeader(result, b"ETag"), b"MatchingTag")
self.assertEqual(httpCode(result), http.NOT_MODIFIED)
Expand All @@ -432,7 +432,7 @@ def test_unmodifiedWithContentType(self):
"""
for line in [b"GET /with-content-type HTTP/1.1",
b"If-None-Match: MatchingTag", b""]:
self.channel.lineReceived(line)
self.channel.dataReceived(line + b'\r\n')
result = self.transport.getvalue()
self.assertEqual(httpCode(result), http.NOT_MODIFIED)
self.assertEqual(httpBody(result), b"")
Expand Down Expand Up @@ -1311,7 +1311,7 @@ def test_userAgentQuote(self):
"""
self.site._logDateTime = "[%02d/%3s/%4d:%02d:%02d:%02d +0000]" % (
25, 'Oct', 2004, 12, 31, 59)
self.request.requestHeaders.addRawHeader(b'user-agent',
self.request.requestHeaders.addRawHeader(b'user-agent',
b'Malicious Web" Evil')
self.assertLogs(
b'"1.2.3.4" - - [25/Oct/2004:12:31:59 +0000] '
Expand Down

0 comments on commit 4b23d59

Please sign in to comment.