Skip to content

Commit

Permalink
Merge pull request twisted#1549 from twisted/10125-template-recursion…
Browse files Browse the repository at this point in the history
…-error

Author: glyph

Reviewer: wsanchez,pwestlak-xilinx

Fixes: ticket:10125

Fix and simplify trampoline for twisted.web.template flattening to avoid RecursionErrors with synchronous deferreds.
  • Loading branch information
glyph authored Apr 3, 2021
2 parents 022659c + 9a7b064 commit 12d53db
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 162 deletions.
1 change: 1 addition & 0 deletions src/twisted/newsfragments/10125.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
twisted.web.template.flatten and flattenString will no longer raise RecursionError if a large number of synchronous Deferreds are included in a document.
99 changes: 36 additions & 63 deletions src/twisted/web/_flatten.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# -*- test-case-name: twisted.web.test.test_flatten -*-
# -*- test-case-name: twisted.web.test.test_flatten,twisted.web.test.test_template -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.

Expand Down Expand Up @@ -157,6 +157,25 @@ def _getSlotValue(name, slotData, default=None):
raise UnfilledSlot(name)


def _fork(d):
"""
Create a new L{Deferred} based on C{d} that will fire and fail with C{d}'s
result or error, but will not modify C{d}'s callback type.
"""
d2 = Deferred(d.cancel)

def callback(result):
d2.callback(result)
return result

def errback(failure):
d2.errback(failure)
return failure

d.addCallbacks(callback, errback)
return d2


def _flattenElement(request, root, write, slotData, renderFactory, dataEscaper):
"""
Make C{root} slightly more flat by yielding all its immediate contents as
Expand Down Expand Up @@ -204,6 +223,9 @@ def keepGoing(
request, newRoot, write, slotData, renderFactory, dataEscaper
)

def keepGoingAsync(result):
return result.addCallback(keepGoing)

if isinstance(root, (bytes, str)):
write(dataEscaper(root))
elif isinstance(root, slot):
Expand Down Expand Up @@ -270,18 +292,17 @@ def keepGoing(
escaped = "&#%d;" % (root.ordinal,)
write(escaped.encode("ascii"))
elif isinstance(root, Deferred):
yield root.addCallback(lambda result: (result, keepGoing(result)))
yield keepGoingAsync(_fork(root))
elif iscoroutine(root):
d = ensureDeferred(root)
yield d.addCallback(lambda result: (result, keepGoing(result)))
yield keepGoingAsync(Deferred.fromCoroutine(root))
elif IRenderable.providedBy(root):
result = root.render(request)
yield keepGoing(result, renderFactory=root)
else:
raise UnsupportedType(root)


def _flattenTree(request, root, write):
async def _flattenTree(request, root, write):
"""
Make C{root} into an iterable of L{bytes} and L{Deferred} by doing a depth
first traversal of the tree.
Expand All @@ -297,16 +318,15 @@ def _flattenTree(request, root, write):
@param write: A callable which will be invoked with each L{bytes} produced
by flattening C{root}.
@return: An iterator which yields objects of type L{bytes} and L{Deferred}.
A L{Deferred} is only yielded when one is encountered in the process of
flattening C{root}. The returned iterator must not be iterated again
until the L{Deferred} is called back.
@return: A C{Deferred}-returning coroutine that resolves to C{None}.
"""
stack = [_flattenElement(request, root, write, [], None, escapeForContent)]
while stack:
try:
frame = stack[-1].gi_frame
element = next(stack[-1])
if isinstance(element, Deferred):
element = await element
except StopIteration:
stack.pop()
except Exception as e:
Expand All @@ -317,51 +337,7 @@ def _flattenTree(request, root, write):
roots.append(frame.f_locals["root"])
raise FlattenerError(e, roots, extract_tb(exc_info()[2]))
else:
if isinstance(element, Deferred):

def cbx(originalAndToFlatten):
original, toFlatten = originalAndToFlatten
stack.append(toFlatten)
return original

yield element.addCallback(cbx)
else:
stack.append(element)


def _writeFlattenedData(state, write, result):
"""
Take strings from an iterator and pass them to a writer function.
@param state: An iterator of L{str} and L{Deferred}. L{str} instances will
be passed to C{write}. L{Deferred} instances will be waited on before
resuming iteration of C{state}.
@param write: A callable which will be invoked with each L{str}
produced by iterating C{state}.
@param result: A L{Deferred} which will be called back when C{state} has
been completely flattened into C{write} or which will be errbacked if
an exception in a generator passed to C{state} or an errback from a
L{Deferred} from state occurs.
@return: L{None}
"""
while True:
try:
element = next(state)
except StopIteration:
result.callback(None)
except BaseException:
result.errback()
else:

def cby(original):
_writeFlattenedData(state, write, result)
return original

element.addCallbacks(cby, result.errback)
break
stack.append(element)


def flatten(request, root, write):
Expand All @@ -377,20 +353,17 @@ def flatten(request, root, write):
@param root: An object to be made flatter. This may be of type L{unicode},
L{bytes}, L{slot}, L{Tag <twisted.web.template.Tag>}, L{tuple},
L{list}, L{types.GeneratorType}, L{Deferred}, or something that provides
L{IRenderable}.
L{list}, L{types.GeneratorType}, L{Deferred}, or something that
provides L{IRenderable}.
@param write: A callable which will be invoked with each L{bytes} produced
by flattening C{root}.
@return: A L{Deferred} which will be called back when C{root} has been
completely flattened into C{write} or which will be errbacked if an
unexpected exception occurs.
@return: A L{Deferred} which will be called back with C{None} when C{root}
has been completely flattened into C{write} or which will be errbacked
if an unexpected exception occurs.
"""
result = Deferred()
state = _flattenTree(request, root, write)
_writeFlattenedData(state, write, result)
return result
return ensureDeferred(_flattenTree(request, root, write))


def flattenString(request, root):
Expand Down
22 changes: 5 additions & 17 deletions src/twisted/web/test/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@

from twisted.internet.defer import succeed
from twisted.web import server
from twisted.trial.unittest import TestCase
from twisted.python.failure import Failure
from twisted.trial.unittest import SynchronousTestCase

from twisted.web._flatten import flattenString
from twisted.web.error import FlattenerError
Expand All @@ -30,7 +29,7 @@ def _render(resource, request):
raise ValueError(f"Unexpected return value: {result!r}")


class FlattenTestCase(TestCase):
class FlattenTestCase(SynchronousTestCase):
"""
A test case that assists with testing L{twisted.web._flatten}.
"""
Expand All @@ -56,25 +55,14 @@ def assertFlattensImmediately(self, root, target):
L{target}.
@rtype: L{bytes}
"""
results = []
it = self.assertFlattensTo(root, target)
it.addBoth(results.append)
# Do our best to clean it up if something goes wrong.
self.addCleanup(it.cancel)
if not results:
self.fail("Rendering did not complete immediately.")
result = results[0]
if isinstance(result, Failure):
result.raiseException()
return results[0]
return self.successResultOf(self.assertFlattensTo(root, target))

def assertFlatteningRaises(self, root, exn):
"""
Assert flattening a root element raises a particular exception.
"""
d = self.assertFailure(self.assertFlattensTo(root, b""), FlattenerError)
d.addCallback(lambda exc: self.assertIsInstance(exc._exception, exn))
return d
failure = self.failureResultOf(self.assertFlattensTo(root, b""), FlattenerError)
self.assertIsInstance(failure.value._exception, exn)


def assertIsFilesystemTemporary(case, fileObj):
Expand Down
Loading

0 comments on commit 12d53db

Please sign in to comment.