Skip to content

Commit

Permalink
Merge branch '8188' of https://github.com/lukasa/twisted into h2-alpn…
Browse files Browse the repository at this point in the history
…-8188-2
  • Loading branch information
hawkowl committed Jun 14, 2016
2 parents 31c126f + 40a12a1 commit 8e94ffa
Show file tree
Hide file tree
Showing 7 changed files with 339 additions and 103 deletions.
101 changes: 59 additions & 42 deletions twisted/internet/_sslverify.py
Original file line number Diff line number Diff line change
Expand Up @@ -1668,52 +1668,11 @@ def _verifyCallback(conn, cert, errno, depth, preverify_ok):
if self._acceptableProtocols:
# Try to set NPN and ALPN. _acceptableProtocols cannot be set by
# the constructor unless at least one mechanism is supported.
self._setUpNextProtocolMechanisms(ctx)
_setAcceptableProtocols(ctx, self._acceptableProtocols)

return ctx


def _setUpNextProtocolMechanisms(self, ctx):
"""
Called to set up the C{ctx} for doing NPN and/or ALPN negotiation.
@param ctx: The context which is set up.
@type ctx: L{OpenSSL.SSL.Context}
"""
supported = protocolNegotiationMechanisms()

if supported & ProtocolNegotiationSupport.NPN:
def npnAdvertiseCallback(conn):
return self._acceptableProtocols

ctx.set_npn_advertise_callback(npnAdvertiseCallback)
ctx.set_npn_select_callback(self._protoSelectCallback)

if supported & ProtocolNegotiationSupport.ALPN:
ctx.set_alpn_select_callback(self._protoSelectCallback)
ctx.set_alpn_protos(self._acceptableProtocols)


def _protoSelectCallback(self, conn, protocols):
"""
NPN client-side and ALPN server-side callback used to select
the next protocol. Prefers protocols found earlier in
C{_acceptableProtocols}.
@param conn: The context which is set up.
@type conn: L{OpenSSL.SSL.Connection}
@param conn: Protocols advertised by the other side.
@type conn: C{list} of C{bytes}
"""
overlap = set(protocols) & set(self._acceptableProtocols)

for p in self._acceptableProtocols:
if p in overlap:
return p
else:
return b''

OpenSSLCertificateOptions.__getstate__ = deprecated(
Version("Twisted", 15, 0, 0),
"a real persistence system")(OpenSSLCertificateOptions.__getstate__)
Expand Down Expand Up @@ -1942,3 +1901,61 @@ def fromFile(cls, filePath):
<twisted.internet.ssl.DiffieHellmanParameters>}
"""
return cls(filePath)


def _setAcceptableProtocols(context, acceptableProtocols):
"""
Called to set up the L{OpenSSL.SSL.Context} for doing NPN and/or ALPN
negotiation.
@param context: The context which is set up.
@type context: L{OpenSSL.SSL.Context}
@param acceptableProtocols: The protocols this peer is willing to speak
after the TLS negotation has completed, advertised over both ALPN and
NPN. If this argument is specified, and no overlap can be found with
the other peer, the connection will fail to be established. If the
remote peer does not offer NPN or ALPN, the connection will be
established, but no protocol wil be negotiated. Protocols earlier in
the list are preferred over those later in the list.
@type acceptableProtocols: C{list} of C{bytes}
"""
def protoSelectCallback(conn, protocols):
"""
NPN client-side and ALPN server-side callback used to select
the next protocol. Prefers protocols found earlier in
C{_acceptableProtocols}.
@param conn: The context which is set up.
@type conn: L{OpenSSL.SSL.Connection}
@param conn: Protocols advertised by the other side.
@type conn: C{list} of C{bytes}
"""
overlap = set(protocols) & set(acceptableProtocols)

for p in acceptableProtocols:
if p in overlap:
return p
else:
return b''

# If we don't actually have protocols to negotiate, don't set anything up.
# Depending on OpenSSL version, failing some of the selection callbacks can
# cause the handshake to fail, which is presumably not what was intended
# here.
if not acceptableProtocols:
return

supported = protocolNegotiationMechanisms()

if supported & ProtocolNegotiationSupport.NPN:
def npnAdvertiseCallback(conn):
return acceptableProtocols

context.set_npn_advertise_callback(npnAdvertiseCallback)
context.set_npn_select_callback(protoSelectCallback)

if supported & ProtocolNegotiationSupport.ALPN:
context.set_alpn_select_callback(protoSelectCallback)
context.set_alpn_protos(acceptableProtocols)
22 changes: 22 additions & 0 deletions twisted/internet/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -2207,6 +2207,28 @@ def clientConnectionForTLS(tlsProtocol):



class IProtocolNegotiationFactory(Interface):
"""
A provider of L{IProtocolNegotiationFactory} can provide information about
the various protocols that the factory can create implementations of. This
can be used, for example, to provide protocol names for Next Protocol
Negotation and Application Layer Protocol Negotiation.
@see: L{twisted.internet.ssl}
"""

def acceptableProtocols():
"""
Returns a list of protocols that can be spoken by the connection
factory in the form of ALPN tokens, as laid out in the IANA registry
for ALPN tokens.
@return: a list of ALPN tokens in order of preference.
@rtype: L{list} of L{bytes}
"""



class ITLSTransport(ITCPTransport):
"""
A TCP transport that supports switching to TLS midstream.
Expand Down
219 changes: 218 additions & 1 deletion twisted/protocols/test/test_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from __future__ import division, absolute_import

from zope.interface.verify import verifyObject
from zope.interface import Interface, directlyProvides
from zope.interface import Interface, directlyProvides, implementer

from twisted.python.compat import intToBytes, iterbytes
try:
Expand All @@ -32,6 +32,7 @@
from twisted.python import log
from twisted.internet.interfaces import ISystemHandle, ISSLTransport
from twisted.internet.interfaces import IPushProducer
from twisted.internet.interfaces import IProtocolNegotiationFactory
from twisted.internet.error import ConnectionDone, ConnectionLost
from twisted.internet.defer import Deferred, gatherResults
from twisted.internet.protocol import Protocol, ClientFactory, ServerFactory
Expand Down Expand Up @@ -1522,3 +1523,219 @@ def test_interface(self):
nsProducer = NonStreamingProducer(consumer)
streamingProducer = _PullToPush(nsProducer, consumer)
self.assertTrue(verifyObject(IPushProducer, streamingProducer))



@implementer(IProtocolNegotiationFactory)
class ClientNegotiationFactory(ClientFactory):
"""
A L{ClientFactory} that has a set of acceptable protocols for NPN/ALPN
negotiation.
"""
def __init__(self, acceptableProtocols):
"""
Create a L{ClientNegotiationFactory}.
@param acceptableProtocols: The protocols the client will accept
speaking after the TLS handshake is complete.
@type acceptableProtocols: L{list} of L{bytes}
"""
self._acceptableProtocols = acceptableProtocols


def acceptableProtocols(self):
"""
Returns a list of protocols that can be spoken by the connection
factory in the form of ALPN tokens, as laid out in the IANA registry
for ALPN tokens.
@return: a list of ALPN tokens in order of preference.
@rtype: L{list} of L{bytes}
"""
return self._acceptableProtocols



@implementer(IProtocolNegotiationFactory)
class ServerNegotiationFactory(ServerFactory):
"""
A L{ServerFactory} that has a set of acceptable protocols for NPN/ALPN
negotiation.
"""
def __init__(self, acceptableProtocols):
"""
Create a L{ServerNegotiationFactory}.
@param acceptableProtocols: The protocols the server will accept
speaking after the TLS handshake is complete.
@type acceptableProtocols: L{list} of L{bytes}
"""
self._acceptableProtocols = acceptableProtocols


def acceptableProtocols(self):
"""
Returns a list of protocols that can be spoken by the connection
factory in the form of ALPN tokens, as laid out in the IANA registry
for ALPN tokens.
@return: a list of ALPN tokens in order of preference.
@rtype: L{list} of L{bytes}
"""
return self._acceptableProtocols



class IProtocolNegotiationFactoryTests(TestCase):
"""
Tests for L{IProtocolNegotiationFactory} inside L{TLSMemoryBIOFactory}.
These tests expressly don't include the case where both server and client
advertise protocols but don't have any overlap. This is because the
behaviour here is platform-dependent and changes from version to version.
Prior to version 1.1.0 of OpenSSL, failing the ALPN negotiation does not
fail the handshake. At least in 1.0.2h, failing NPN *does* fail the
handshake, at least with the callback implemented by PyOpenSSL.
This is sufficiently painful to test that we simply don't. It's not
necessary to validate that our offering logic works anyway: all we need to
see is that it works in the successful case and that it degrades properly.
"""
def handshakeProtocols(self, clientProtocols, serverProtocols):
"""
Start handshake between TLS client and server.
@param clientProtocols: The protocols the client will accept speaking
after the TLS handshake is complete.
@type clientProtocols: L{list} of L{bytes}
@param serverProtocols: The protocols the server will accept speaking
after the TLS handshake is complete.
@type serverProtocols: L{list} of L{bytes}
@return: A L{tuple} of four different items: the client L{Protocol},
the server L{Protocol}, a L{Deferred} that fires when the client
first receives bytes (and so the TLS connection is complete), and a
L{Deferred} that fires when the server first receives bytes.
@rtype: A L{tuple} of (L{Protocol}, L{Protocol}, L{Deferred},
L{Deferred})
"""
bytes = b'some bytes'

class NotifyingSender(Protocol):
def __init__(self, notifier):
self.notifier = notifier

def connectionMade(self):
self.transport.writeSequence(list(iterbytes(bytes)))

def dataReceived(self, bytes):
if self.notifier is not None:
self.notifier.callback(self)
self.notifier = None


clientDataReceived = Deferred()
clientFactory = ClientNegotiationFactory(clientProtocols)
clientFactory.protocol = lambda: NotifyingSender(
clientDataReceived
)

clientContextFactory, _ = (
HandshakeCallbackContextFactory.factoryAndDeferred())
wrapperFactory = TLSMemoryBIOFactory(
clientContextFactory, True, clientFactory)
sslClientProtocol = wrapperFactory.buildProtocol(None)

serverDataReceived = Deferred()
serverFactory = ServerNegotiationFactory(serverProtocols)
serverFactory.protocol = lambda: NotifyingSender(
serverDataReceived
)

serverContextFactory = ServerTLSContext()
wrapperFactory = TLSMemoryBIOFactory(
serverContextFactory, False, serverFactory)
sslServerProtocol = wrapperFactory.buildProtocol(None)

loopbackAsync(
sslServerProtocol, sslClientProtocol
)
return (sslClientProtocol, sslServerProtocol, clientDataReceived,
serverDataReceived)


def test_negotiationWithNoProtocols(self):
"""
When factories support L{IProtocolNegotiationFactory} but don't
advertise support for any protocols, no protocols are negotiated.
"""
client, server, clientDataReceived, serverDataReceived = (
self.handshakeProtocols([], [])
)

def checkNegotiatedProtocol(ignored):
self.assertEqual(client.negotiatedProtocol, None)
self.assertEqual(server.negotiatedProtocol, None)

clientDataReceived.addCallback(lambda ignored: serverDataReceived)
serverDataReceived.addCallback(checkNegotiatedProtocol)

return clientDataReceived


def test_negotiationWithProtocolOverlap(self):
"""
When factories support L{IProtocolNegotiationFactory} and support
overlapping protocols, the first protocol is negotiated.
"""
client, server, clientDataReceived, serverDataReceived = (
self.handshakeProtocols([b'h2', b'http/1.1'], [b'h2', b'http/1.1'])
)

def checkNegotiatedProtocol(ignored):
self.assertEqual(client.negotiatedProtocol, b'h2')
self.assertEqual(server.negotiatedProtocol, b'h2')

clientDataReceived.addCallback(lambda ignored: serverDataReceived)
serverDataReceived.addCallback(checkNegotiatedProtocol)

return clientDataReceived


def test_negotiationClientOnly(self):
"""
When factories support L{IProtocolNegotiationFactory} and only the
client advertises, nothing is negotiated.
"""
client, server, clientDataReceived, serverDataReceived = (
self.handshakeProtocols([b'h2', b'http/1.1'], [])
)

def checkNegotiatedProtocol(ignored):
self.assertEqual(client.negotiatedProtocol, None)
self.assertEqual(server.negotiatedProtocol, None)

clientDataReceived.addCallback(lambda ignored: serverDataReceived)
serverDataReceived.addCallback(checkNegotiatedProtocol)

return clientDataReceived


def test_negotiationServerOnly(self):
"""
When factories support L{IProtocolNegotiationFactory} and only the
server advertises, nothing is negotiated.
"""
client, server, clientDataReceived, serverDataReceived = (
self.handshakeProtocols([], [b'h2', b'http/1.1'])
)

def checkNegotiatedProtocol(ignored):
self.assertEqual(client.negotiatedProtocol, None)
self.assertEqual(server.negotiatedProtocol, None)

clientDataReceived.addCallback(lambda ignored: serverDataReceived)
serverDataReceived.addCallback(checkNegotiatedProtocol)

return clientDataReceived
Loading

0 comments on commit 8e94ffa

Please sign in to comment.