Skip to content

Commit

Permalink
imap4.IMAP4Server: test two STARTTLS commands.
Browse files Browse the repository at this point in the history
  • Loading branch information
markrwilliams committed Aug 18, 2017
1 parent 96b968e commit 2233a76
Showing 1 changed file with 67 additions and 27 deletions.
94 changes: 67 additions & 27 deletions src/twisted/mail/test/test_imap.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
IClientAuthentication,
ICloseableMailboxIMAP)
from twisted.mail.imap4 import MessageSet
from twisted.protocols import loopback
from twisted.protocols import loopback, basic
from twisted.python import failure
from twisted.python import util, log
from twisted.python.compat import (intToBytes, range, nativeString,
Expand Down Expand Up @@ -1960,6 +1960,25 @@ def loopback(self):
return loopback.loopbackAsync(self.server, self.client)


def assertClientFailureMessage(self, failure, expected):
"""
Assert that the provided failure is an L{IMAP4Exception} with
the given message.
@param failure: A failure whose value L{IMAP4Exception}
@type failure: L{failure.Failure}
@param expected: The expected failure message.
@type expected: L{bytes}
"""
failure.trap(imap4.IMAP4Exception)
message = str(failure.value)
if _PY3:
expected = repr(expected)

self.assertEqual(message, expected)



class IMAP4ServerTests(IMAP4HelperMixin, unittest.TestCase):
def testCapability(self):
Expand Down Expand Up @@ -3373,25 +3392,6 @@ def setUp(self):
self.account = realm.theAccount


def assertClientFailureMessage(self, failure, expected):
"""
Assert that the provided failure is an L{IMAP4Exception} with
the given message.
@param failure: A failure whose value L{IMAP4Exception}
@type failure: L{failure.Failure}
@param expected: The expected failure message.
@type expected: L{bytes}
"""
failure.trap(imap4.IMAP4Exception)
message = str(failure.value)
if _PY3:
expected = repr(expected)

self.assertEqual(message, expected)


def test_customChallengers(self):
"""
L{imap4.IMAP4Server} accepts a L{dict} mapping challenge type
Expand Down Expand Up @@ -6916,27 +6916,67 @@ def testLoginLogin(self):
return d


def test_startTLS(self):
def startTLSAndAssertSession(self):
"""
L{IMAP4Client.startTLS} triggers TLS negotiation and returns a
L{Deferred} which fires after the client's transport is using
encryption.
Begin a C{STARTTLS} sequence and assert that it results in a
TLS session.
@return: A L{Deferred} that fires when the underlying
connection between the client and server has been terminated.
"""
success = []
self.connected.addCallback(lambda _: self.client.startTLS())
self.connected.addCallback(strip(self.client.startTLS))
def checkSecure(ignored):
self.assertTrue(
interfaces.ISSLTransport.providedBy(self.client.transport))
self.connected.addCallback(checkSecure)
self.connected.addCallback(self._cbStopClient)
self.connected.addCallback(success.append)
self.connected.addErrback(self._ebGeneral)

d = self.loopback()
d.addCallback(lambda x : self.assertTrue(success))
return defer.gatherResults([d, self.connected])


def test_startTLS(self):
"""
L{IMAP4Client.startTLS} triggers TLS negotiation and returns a
L{Deferred} which fires after the client's transport is using
encryption.
"""
disconnected = self.startTLSAndAssertSession()
self.connected.addCallback(self._cbStopClient)
self.connected.addErrback(self._ebGeneral)
return disconnected


def test_doubleSTARTTLS(self):
"""
A server that receives a second C{STARTTLS} sends a C{NO}
response.
"""

class DoubleSTARTTLSClient(SimpleClient):

def startTLS(self):
if not self.startedTLS:
return SimpleClient.startTLS(self)

return self.sendCommand(imap4.Command(b"STARTTLS"))

self.client = DoubleSTARTTLSClient(self.connected,
contextFactory=self.clientCTX)

disconnected = self.startTLSAndAssertSession()

self.connected.addCallback(strip(self.client.startTLS))
self.connected.addErrback(self.assertClientFailureMessage, b"TLS already negotiated")

self.connected.addCallback(self._cbStopClient)
self.connected.addErrback(self._ebGeneral)

return disconnected


def testFailedStartTLS(self):
failures = []
def breakServerTLS(ign):
Expand Down

0 comments on commit 2233a76

Please sign in to comment.