Skip to content

Commit

Permalink
Merge policies-py3-6097.
Browse files Browse the repository at this point in the history
Author: itamar
Review: exarkun
Fixes: twisted#6097

Port twisted.protocols.policies to Python 3.


git-svn-id: svn://svn.twistedmatrix.com/svn/Twisted/trunk@36096 bbbe8e31-12d6-0310-92fd-ac37d47ddeeb
  • Loading branch information
itamarst committed Oct 11, 2012
1 parent cec8f44 commit c243ea8
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 31 deletions.
2 changes: 2 additions & 0 deletions admin/_twistedpython3.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
"twisted.names.test",
"twisted.protocols",
"twisted.protocols.basic",
"twisted.protocols.policies",
"twisted.protocols.test",
"twisted.python",
"twisted.python.compat",
Expand Down Expand Up @@ -136,6 +137,7 @@
"twisted.test.test_log",
"twisted.test.test_monkey",
"twisted.test.test_paths",
"twisted.test.test_policies",
"twisted.test.test_randbytes",
"twisted.test.test_setup",
"twisted.test.test_task",
Expand Down
4 changes: 3 additions & 1 deletion twisted/protocols/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
@seealso: See also L{twisted.protocols.htb} for rate limiting.
"""

from __future__ import division, absolute_import

# system imports
import sys, operator

Expand Down Expand Up @@ -235,7 +237,7 @@ class ThrottlingFactory(WrappingFactory):

protocol = ThrottlingProtocol

def __init__(self, wrappedFactory, maxConnectionCount=sys.maxint,
def __init__(self, wrappedFactory, maxConnectionCount=sys.maxsize,
readLimit=None, writeLimit=None):
WrappingFactory.__init__(self, wrappedFactory)
self.connectionCount = 0
Expand Down
181 changes: 151 additions & 30 deletions twisted/test/test_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
"""
Test code for policies.
"""
from __future__ import division, absolute_import

from zope.interface import Interface, implements, implementedBy

from StringIO import StringIO
from zope.interface import Interface, implementer, implementedBy

from twisted.python.compat import NativeStringIO, _PY3
from twisted.trial import unittest
from twisted.test.proto_helpers import StringTransport
from twisted.test.proto_helpers import StringTransportWithDisconnection
Expand All @@ -21,7 +21,7 @@
class SimpleProtocol(protocol.Protocol):

connected = disconnected = 0
buffer = ""
buffer = b""

def __init__(self):
self.dConnected = defer.Deferred()
Expand Down Expand Up @@ -144,8 +144,9 @@ def test_transportInterfaces(self):
class IStubTransport(Interface):
pass

@implementer(IStubTransport)
class StubTransport:
implements(IStubTransport)
pass

# Looking up what ProtocolWrapper implements also mutates the class.
# It adds __implemented__ and __providedBy__ attributes to it. These
Expand Down Expand Up @@ -216,6 +217,123 @@ class NoProtocol(object):
protocol.logPrefix())


def _getWrapper(self):
"""
Return L{policies.ProtocolWrapper} that has been connected to a
L{StringTransport}.
"""
wrapper = policies.ProtocolWrapper(policies.WrappingFactory(Server()),
protocol.Protocol())
transport = StringTransport()
wrapper.makeConnection(transport)
return wrapper


def test_getHost(self):
"""
L{policies.ProtocolWrapper.getHost} calls C{getHost} on the underlying
transport.
"""
wrapper = self._getWrapper()
self.assertEqual(wrapper.getHost(), wrapper.transport.getHost())


def test_getPeer(self):
"""
L{policies.ProtocolWrapper.getPeer} calls C{getPeer} on the underlying
transport.
"""
wrapper = self._getWrapper()
self.assertEqual(wrapper.getPeer(), wrapper.transport.getPeer())


def test_registerProducer(self):
"""
L{policies.ProtocolWrapper.registerProducer} calls C{registerProducer}
on the underlying transport.
"""
wrapper = self._getWrapper()
producer = object()
wrapper.registerProducer(producer, True)
self.assertIdentical(wrapper.transport.producer, producer)
self.assertTrue(wrapper.transport.streaming)


def test_unregisterProducer(self):
"""
L{policies.ProtocolWrapper.unregisterProducer} calls
C{unregisterProducer} on the underlying transport.
"""
wrapper = self._getWrapper()
producer = object()
wrapper.registerProducer(producer, True)
wrapper.unregisterProducer()
self.assertIdentical(wrapper.transport.producer, None)
self.assertIdentical(wrapper.transport.streaming, None)


def test_stopConsuming(self):
"""
L{policies.ProtocolWrapper.stopConsuming} calls C{stopConsuming} on
the underlying transport.
"""
wrapper = self._getWrapper()
result = []
wrapper.transport.stopConsuming = lambda: result.append(True)
wrapper.stopConsuming()
self.assertEqual(result, [True])


def test_startedConnecting(self):
"""
L{policies.WrappingFactory.startedConnecting} calls
C{startedConnecting} on the underlying factory.
"""
result = []
class Factory(object):
def startedConnecting(self, connector):
result.append(connector)

wrapper = policies.WrappingFactory(Factory())
connector = object()
wrapper.startedConnecting(connector)
self.assertEqual(result, [connector])


def test_clientConnectionLost(self):
"""
L{policies.WrappingFactory.clientConnectionLost} calls
C{clientConnectionLost} on the underlying factory.
"""
result = []
class Factory(object):
def clientConnectionLost(self, connector, reason):
result.append((connector, reason))

wrapper = policies.WrappingFactory(Factory())
connector = object()
reason = object()
wrapper.clientConnectionLost(connector, reason)
self.assertEqual(result, [(connector, reason)])


def test_clientConnectionFailed(self):
"""
L{policies.WrappingFactory.clientConnectionFailed} calls
C{clientConnectionFailed} on the underlying factory.
"""
result = []
class Factory(object):
def clientConnectionFailed(self, connector, reason):
result.append((connector, reason))

wrapper = policies.WrappingFactory(Factory())
connector = object()
reason = object()
wrapper.clientConnectionFailed(connector, reason)
self.assertEqual(result, [(connector, reason)])



class WrappingFactory(policies.WrappingFactory):
protocol = lambda s, f, p: p
Expand Down Expand Up @@ -254,8 +372,8 @@ def _connect123(results):
return c3.dDisconnected

def _check123(results):
self.assertEqual([c.connected for c in c1, c2, c3], [1, 1, 1])
self.assertEqual([c.disconnected for c in c1, c2, c3], [0, 0, 1])
self.assertEqual([c.connected for c in (c1, c2, c3)], [1, 1, 1])
self.assertEqual([c.disconnected for c in (c1, c2, c3)], [0, 0, 1])
self.assertEqual(len(tServer.protocols.keys()), 2)
return results

Expand Down Expand Up @@ -289,6 +407,9 @@ def _cleanup(results):
wrapTServer.deferred.addCallback(_cleanup)
return wrapTServer.deferred

if _PY3:
test_limit.skip = "Re-enable in #6002"


def test_writeLimit(self):
"""
Expand All @@ -303,9 +424,9 @@ def test_writeLimit(self):
port.makeConnection(tr)
port.producer = port.wrappedProtocol

port.dataReceived("0123456789")
port.dataReceived("abcdefghij")
self.assertEqual(tr.value(), "0123456789abcdefghij")
port.dataReceived(b"0123456789")
port.dataReceived(b"abcdefghij")
self.assertEqual(tr.value(), b"0123456789abcdefghij")
self.assertEqual(tServer.writtenThisSecond, 20)
self.assertFalse(port.wrappedProtocol.paused)

Expand Down Expand Up @@ -333,9 +454,9 @@ def test_readLimit(self):
tr.protocol = port
port.makeConnection(tr)

port.dataReceived("0123456789")
port.dataReceived("abcdefghij")
self.assertEqual(tr.value(), "0123456789abcdefghij")
port.dataReceived(b"0123456789")
port.dataReceived(b"abcdefghij")
self.assertEqual(tr.value(), b"0123456789abcdefghij")
self.assertEqual(tServer.readThisSecond, 20)

tServer.clock.advance(1.05)
Expand All @@ -347,9 +468,9 @@ def test_readLimit(self):
self.assertEqual(tr.producerState, 'producing')

tr.clear()
port.dataReceived("0123456789")
port.dataReceived("abcdefghij")
self.assertEqual(tr.value(), "0123456789abcdefghij")
port.dataReceived(b"0123456789")
port.dataReceived(b"abcdefghij")
self.assertEqual(tr.value(), b"0123456789abcdefghij")
self.assertEqual(tServer.readThisSecond, 20)

tServer.clock.advance(1.05)
Expand Down Expand Up @@ -409,14 +530,14 @@ def test_sendAvoidsTimeout(self):

# Send some data (self.proto is the /real/ proto's transport, so this
# is the write that gets called)
self.proto.write('bytes bytes bytes')
self.proto.write(b'bytes bytes bytes')

# More time passes, putting us past the original timeout
self.clock.pump([0.0, 1.0, 1.0])
self.failIf(self.proto.wrappedProtocol.disconnected)

# Make sure writeSequence delays timeout as well
self.proto.writeSequence(['bytes'] * 3)
self.proto.writeSequence([b'bytes'] * 3)

# Tick tock
self.clock.pump([0.0, 1.0, 1.0])
Expand All @@ -436,7 +557,7 @@ def test_receiveAvoidsTimeout(self):
self.failIf(self.proto.wrappedProtocol.disconnected)

# Some bytes arrive, they should reset the counter
self.proto.dataReceived('bytes bytes bytes')
self.proto.dataReceived(b'bytes bytes bytes')

# We pass the original timeout
self.clock.pump([0.0, 1.0, 1.0])
Expand Down Expand Up @@ -547,7 +668,7 @@ def test_noTimeout(self):

self.clock.pump([0, 0.5, 1.0, 1.0])
self.failIf(self.proto.timedOut)
self.proto.dataReceived('hello there')
self.proto.dataReceived(b'hello there')
self.clock.pump([0, 1.0, 1.0, 0.5])
self.failIf(self.proto.timedOut)
self.clock.pump([0, 1.0])
Expand Down Expand Up @@ -663,7 +784,7 @@ def connectionMade(self):

class WriteSequenceEchoProtocol(EchoProtocol):
def dataReceived(self, bytes):
if bytes.find('vector!') != -1:
if bytes.find(b'vector!') != -1:
self.transport.writeSequence([bytes])
else:
EchoProtocol.dataReceived(self, bytes)
Expand All @@ -672,7 +793,7 @@ class TestLoggingFactory(policies.TrafficLoggingFactory):
openFile = None
def open(self, name):
assert self.openFile is None, "open() called too many times"
self.openFile = StringIO()
self.openFile = NativeStringIO()
return self.openFile


Expand All @@ -695,21 +816,21 @@ def test_thingsGetLogged(self):
p.makeConnection(t)

v = f.openFile.getvalue()
self.failUnless('*' in v, "* not found in %r" % (v,))
self.assertIn('*', v)
self.failIf(t.value())

p.dataReceived('here are some bytes')
p.dataReceived(b'here are some bytes')

v = f.openFile.getvalue()
self.assertIn("C 1: 'here are some bytes'", v)
self.assertIn("S 1: 'here are some bytes'", v)
self.assertEqual(t.value(), 'here are some bytes')
self.assertIn("C 1: %r" % (b'here are some bytes',), v)
self.assertIn("S 1: %r" % (b'here are some bytes',), v)
self.assertEqual(t.value(), b'here are some bytes')

t.clear()
p.dataReceived('prepare for vector! to the extreme')
p.dataReceived(b'prepare for vector! to the extreme')
v = f.openFile.getvalue()
self.assertIn("SV 1: ['prepare for vector! to the extreme']", v)
self.assertEqual(t.value(), 'prepare for vector! to the extreme')
self.assertIn("SV 1: %r" % ([b'prepare for vector! to the extreme'],), v)
self.assertEqual(t.value(), b'prepare for vector! to the extreme')

p.loseConnection()

Expand Down
Empty file added twisted/topfiles/6097.misc
Empty file.

0 comments on commit c243ea8

Please sign in to comment.