diff --git a/src/twisted/web/http_headers.py b/src/twisted/web/http_headers.py index ac8d2973b17..0a84a04a495 100644 --- a/src/twisted/web/http_headers.py +++ b/src/twisted/web/http_headers.py @@ -6,34 +6,42 @@ An API for storing HTTP header names and values. """ +from typing import ( + AnyStr, + Dict, + Iterator, + List, + Mapping, + Optional, + Sequence, + TypeVar, + Tuple, + Union, +) +from collections.abc import Sequence as _Sequence from twisted.python.compat import comparable, cmp -def _dashCapitalize(name): +def _dashCapitalize(name: bytes) -> bytes: """ Return a byte string which is capitalized using '-' as a word separator. @param name: The name of the header to capitalize. - @type name: L{bytes} @return: The given header capitalized using '-' as a word separator. - @rtype: L{bytes} """ return b"-".join([word.capitalize() for word in name.split(b"-")]) -def _sanitizeLinearWhitespace(headerComponent): +def _sanitizeLinearWhitespace(headerComponent: bytes) -> bytes: r""" Replace linear whitespace (C{\n}, C{\r\n}, C{\r}) in a header key - or value with a single space. If C{headerComponent} is not - L{bytes}, it is passed through unchanged. + or value with a single space. @param headerComponent: The header key or value to sanitize. - @type headerComponent: L{bytes} @return: The sanitized header key or value. - @rtype: L{bytes} """ return b" ".join(headerComponent.splitlines()) @@ -43,8 +51,7 @@ class Headers: """ Stores HTTP headers in a key and multiple value format. - Most methods accept L{bytes} and L{unicode}, with an internal L{bytes} - representation. When passed L{unicode}, header names (e.g. 'Content-Type') + When passed L{str}, header names (e.g. 'Content-Type') are encoded using ISO-8859-1 and header values (e.g. 'text/html;charset=utf-8') are encoded using UTF-8. Some methods that return values will return them in the same type as the name given. @@ -71,8 +78,11 @@ class Headers: b"x-xss-protection": b"X-XSS-Protection", } - def __init__(self, rawHeaders=None): - self._rawHeaders = {} + def __init__( + self, + rawHeaders: Optional[Mapping[AnyStr, Sequence[AnyStr]]] = None, + ): + self._rawHeaders = {} # type: Dict[bytes, List[bytes]] if rawHeaders is not None: for name, values in rawHeaders.items(): self.setRawHeaders(name, values) @@ -97,68 +107,19 @@ def __cmp__(self, other): ) return NotImplemented - def _encodeName(self, name): + def _encodeName(self, name: AnyStr) -> bytes: """ Encode the name of a header (eg 'Content-Type') to an ISO-8859-1 encoded bytestring if required. @param name: A HTTP header name - @type name: L{unicode} or L{bytes} @return: C{name}, encoded if required, lowercased - @rtype: L{bytes} """ if isinstance(name, str): return name.lower().encode("iso-8859-1") return name.lower() - def _encodeValue(self, value): - """ - Encode a single header value to a UTF-8 encoded bytestring if required. - - @param value: A single HTTP header value. - @type value: L{bytes} or L{unicode} - - @return: C{value}, encoded if required - @rtype: L{bytes} - """ - if isinstance(value, str): - return value.encode("utf8") - return value - - def _encodeValues(self, values): - """ - Encode a L{list} of header values to a L{list} of UTF-8 encoded - bytestrings if required. - - @param values: A list of HTTP header values. - @type values: L{list} of L{bytes} or L{unicode} (mixed types allowed) - - @return: C{values}, with each item encoded if required - @rtype: L{list} of L{bytes} - """ - newValues = [] - - for value in values: - newValues.append(self._encodeValue(value)) - return newValues - - def _decodeValues(self, values): - """ - Decode a L{list} of header values into a L{list} of Unicode strings. - - @param values: A list of HTTP header values. - @type values: L{list} of UTF-8 encoded L{bytes} - - @return: C{values}, with each item decoded - @rtype: L{list} of L{unicode} - """ - newValues = [] - - for value in values: - newValues.append(value.decode("utf8")) - return newValues - def copy(self): """ Return a copy of itself with the same headers set. @@ -167,65 +128,93 @@ def copy(self): """ return self.__class__(self._rawHeaders) - def hasHeader(self, name): + def hasHeader(self, name: AnyStr) -> bool: """ Check for the existence of a given header. - @type name: L{bytes} or L{unicode} @param name: The name of the HTTP header to check for. - @rtype: L{bool} @return: C{True} if the header exists, otherwise C{False}. """ return self._encodeName(name) in self._rawHeaders - def removeHeader(self, name): + def removeHeader(self, name: AnyStr) -> None: """ Remove the named header from this header object. - @type name: L{bytes} or L{unicode} @param name: The name of the HTTP header to remove. @return: L{None} """ self._rawHeaders.pop(self._encodeName(name), None) - def setRawHeaders(self, name, values): + def setRawHeaders(self, name: AnyStr, values: Sequence[AnyStr]) -> None: """ Sets the raw representation of the given header. - @type name: L{bytes} or L{unicode} @param name: The name of the HTTP header to set the values for. - @type values: L{list} of L{bytes} or L{unicode} strings @param values: A list of strings each one being a header value of the given name. + @raise TypeError: Raised if C{values} is not a L{list} of L{bytes} + or L{str} strings, or if C{name} is not a L{bytes} or + L{str} string. + @return: L{None} """ - if not isinstance(values, list): + if not isinstance(values, _Sequence): raise TypeError( - "Header entry %r should be list but found " + "Header entry %r should be sequence but found " "instance of %r instead" % (name, type(values)) ) - name = _sanitizeLinearWhitespace(self._encodeName(name)) - encodedValues = [ - _sanitizeLinearWhitespace(v) for v in self._encodeValues(values) - ] - - self._rawHeaders[name] = self._encodeValues(encodedValues) + if not isinstance(name, (bytes, str)): + raise TypeError( + "Header name is an instance of %r, " "not bytes or str" % (type(name),) + ) - def addRawHeader(self, name, value): + for count, value in enumerate(values): + if not isinstance(value, (bytes, str)): + raise TypeError( + "Header value at position %s is an instance of %r, not " + "bytes or str" + % ( + count, + type(value), + ) + ) + + _name = _sanitizeLinearWhitespace(self._encodeName(name)) + encodedValues = [] # type: List[bytes] + for v in values: + if isinstance(v, str): + _v = v.encode("utf8") + else: + _v = v + encodedValues.append(_sanitizeLinearWhitespace(_v)) + + self._rawHeaders[_name] = encodedValues + + def addRawHeader(self, name: AnyStr, value: AnyStr) -> None: """ Add a new raw value for the given header. - @type name: L{bytes} or L{unicode} @param name: The name of the header for which to set the value. - @type value: L{bytes} or L{unicode} @param value: The value to set for the named header. """ + if not isinstance(name, (bytes, str)): + raise TypeError( + "Header name is an instance of %r, " "not bytes or str" % (type(name),) + ) + + if not isinstance(value, (bytes, str)): + raise TypeError( + "Header value is an instance of %r, not " + "bytes or str" % (type(value),) + ) + values = self.getRawHeaders(name) if values is not None: @@ -235,30 +224,33 @@ def addRawHeader(self, name, value): self.setRawHeaders(name, values) - def getRawHeaders(self, name, default=None): + _T = TypeVar("_T") + + def getRawHeaders( + self, name: AnyStr, default: Optional[_T] = None + ) -> Union[List[AnyStr], Optional[_T]]: """ Returns a list of headers matching the given name as the raw string given. - @type name: L{bytes} or L{unicode} @param name: The name of the HTTP header to get the values of. @param default: The value to return if no header with the given C{name} exists. - @rtype: L{list} of strings, same type as C{name} (except when - C{default} is returned). @return: If the named header is present, a L{list} of its values. Otherwise, C{default}. """ encodedName = self._encodeName(name) - values = self._rawHeaders.get(encodedName, default) + values = self._rawHeaders.get(encodedName, []) + if not values: + return default - if isinstance(name, str) and values is not default: - return self._decodeValues(values) + if isinstance(name, str): + return [v if isinstance(v, str) else v.decode("utf8") for v in values] return values - def getAllRawHeaders(self): + def getAllRawHeaders(self) -> Iterator[Tuple[bytes, List[bytes]]]: """ Return an iterator of key, value pairs of all headers contained in this object, as L{bytes}. The keys are capitalized in canonical @@ -267,15 +259,13 @@ def getAllRawHeaders(self): for k, v in self._rawHeaders.items(): yield self._canonicalNameCaps(k), v - def _canonicalNameCaps(self, name): + def _canonicalNameCaps(self, name: bytes) -> bytes: """ Return the canonical name for the given header. - @type name: L{bytes} @param name: The all-lowercase header name to capitalize in its canonical form. - @rtype: L{bytes} @return: The canonical name of the header. """ return self._caseMappings.get(name, _dashCapitalize(name)) diff --git a/src/twisted/web/newsfragments/9841.misc b/src/twisted/web/newsfragments/9841.misc new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/twisted/web/test/test_http_headers.py b/src/twisted/web/test/test_http_headers.py index 377d6243285..46d2a6d9e38 100644 --- a/src/twisted/web/test/test_http_headers.py +++ b/src/twisted/web/test/test_http_headers.py @@ -82,13 +82,38 @@ def test_setRawHeaders(self): self.assertTrue(h.hasHeader(b"Test")) self.assertEqual(h.getRawHeaders(b"test"), rawValue) - def test_rawHeadersTypeChecking(self): + def test_rawHeadersTypeCheckingValuesIterable(self): """ L{Headers.setRawHeaders} requires values to be of type list. """ h = Headers() self.assertRaises(TypeError, h.setRawHeaders, b"key", {b"Foo": b"bar"}) + def test_rawHeadersTypeCheckingName(self): + """ + L{Headers.setRawHeaders} requires C{name} to be a L{bytes} or + L{str} string. + """ + h = Headers() + e = self.assertRaises(TypeError, h.setRawHeaders, None, [b"foo"]) + self.assertEqual( + e.args[0], + "Header name is an instance of , " "not bytes or str", + ) + + def test_rawHeadersTypeCheckingValuesAreString(self): + """ + L{Headers.setRawHeaders} requires values to a L{list} of L{bytes} or + L{str} strings. + """ + h = Headers() + e = self.assertRaises(TypeError, h.setRawHeaders, b"key", [b"bar", None]) + self.assertEqual( + e.args[0], + "Header value at position 1 is an instance of , " + "not bytes or str", + ) + def test_addRawHeader(self): """ L{Headers.addRawHeader} adds a new value for a given header. @@ -99,6 +124,30 @@ def test_addRawHeader(self): h.addRawHeader(b"test", b"panda") self.assertEqual(h.getRawHeaders(b"test"), [b"lemur", b"panda"]) + def test_addRawHeaderTypeCheckName(self): + """ + L{Headers.addRawHeader} requires C{name} to be a L{bytes} or L{str} + string. + """ + h = Headers() + e = self.assertRaises(TypeError, h.addRawHeader, None, b"foo") + self.assertEqual( + e.args[0], + "Header name is an instance of , " "not bytes or str", + ) + + def test_addRawHeaderTypeCheckValue(self): + """ + L{Headers.addRawHeader} requires value to be a L{bytes} or L{str} + string. + """ + h = Headers() + e = self.assertRaises(TypeError, h.addRawHeader, b"key", None) + self.assertEqual( + e.args[0], + "Header value is an instance of , " "not bytes or str", + ) + def test_getRawHeadersNoDefault(self): """ L{Headers.getRawHeaders} returns L{None} if the header is not found and @@ -304,7 +353,7 @@ def test_copy(self): class UnicodeHeadersTests(TestCase): """ - Tests for L{Headers}, using L{unicode} arguments for methods. + Tests for L{Headers}, using L{str} arguments for methods. """ def test_sanitizeLinearWhitespace(self): @@ -318,8 +367,8 @@ def test_initializer(self): """ The header values passed to L{Headers.__init__} can be retrieved via L{Headers.getRawHeaders}. If a L{bytes} argument is given, it returns - L{bytes} values, and if a L{unicode} argument is given, it returns - L{unicode} values. Both are the same header value, just encoded or + L{bytes} values, and if a L{str} argument is given, it returns + L{str} values. Both are the same header value, just encoded or decoded. """ h = Headers({"Foo": ["bar"]}) @@ -344,7 +393,7 @@ def test_setRawHeaders(self): def test_nameNotEncodable(self): """ - Passing L{unicode} to any function that takes a header name will encode + Passing L{str} to any function that takes a header name will encode said header name as ISO-8859-1, and if it cannot be encoded, it will raise a L{UnicodeDecodeError}. """ @@ -359,7 +408,7 @@ def test_nameNotEncodable(self): def test_nameEncoding(self): """ - Passing L{unicode} to any function that takes a header name will encode + Passing L{str} to any function that takes a header name will encode said header name as ISO-8859-1. """ h = Headers() @@ -376,7 +425,7 @@ def test_nameEncoding(self): def test_rawHeadersValueEncoding(self): """ - Passing L{unicode} to L{Headers.setRawHeaders} will encode the name as + Passing L{str} to L{Headers.setRawHeaders} will encode the name as ISO-8859-1 and values as UTF-8. """ h = Headers() @@ -386,7 +435,7 @@ def test_rawHeadersValueEncoding(self): def test_rawHeadersTypeChecking(self): """ - L{Headers.setRawHeaders} requires values to be of type list. + L{Headers.setRawHeaders} requires values to be of type sequence """ h = Headers() self.assertRaises(TypeError, h.setRawHeaders, "key", {"Foo": "bar"})