Skip to content

Commit

Permalink
Merge pull request twisted#1389 from twisted/9841-rodrigc-headers-typ…
Browse files Browse the repository at this point in the history
…echeck

Author: hawkowl, rodrigc
Reviewer: glyph,wsanchez,twm
Fixes: ticket:9841 

Clean up type checking/hinting for twisted.web.http_headers
  • Loading branch information
rodrigc authored Sep 22, 2020
2 parents 22f949f + e6f417b commit f0e7b34
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 100 deletions.
174 changes: 82 additions & 92 deletions src/twisted/web/http_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,42 @@
An API for storing HTTP header names and values.
"""

from typing import (
AnyStr,
Dict,
Iterator,
List,
Mapping,
Optional,
Sequence,
TypeVar,
Tuple,
Union,
)
from collections.abc import Sequence as _Sequence

from twisted.python.compat import comparable, cmp


def _dashCapitalize(name):
def _dashCapitalize(name: bytes) -> bytes:
"""
Return a byte string which is capitalized using '-' as a word separator.
@param name: The name of the header to capitalize.
@type name: L{bytes}
@return: The given header capitalized using '-' as a word separator.
@rtype: L{bytes}
"""
return b"-".join([word.capitalize() for word in name.split(b"-")])


def _sanitizeLinearWhitespace(headerComponent):
def _sanitizeLinearWhitespace(headerComponent: bytes) -> bytes:
r"""
Replace linear whitespace (C{\n}, C{\r\n}, C{\r}) in a header key
or value with a single space. If C{headerComponent} is not
L{bytes}, it is passed through unchanged.
or value with a single space.
@param headerComponent: The header key or value to sanitize.
@type headerComponent: L{bytes}
@return: The sanitized header key or value.
@rtype: L{bytes}
"""
return b" ".join(headerComponent.splitlines())

Expand All @@ -43,8 +51,7 @@ class Headers:
"""
Stores HTTP headers in a key and multiple value format.
Most methods accept L{bytes} and L{unicode}, with an internal L{bytes}
representation. When passed L{unicode}, header names (e.g. 'Content-Type')
When passed L{str}, header names (e.g. 'Content-Type')
are encoded using ISO-8859-1 and header values (e.g.
'text/html;charset=utf-8') are encoded using UTF-8. Some methods that return
values will return them in the same type as the name given.
Expand All @@ -71,8 +78,11 @@ class Headers:
b"x-xss-protection": b"X-XSS-Protection",
}

def __init__(self, rawHeaders=None):
self._rawHeaders = {}
def __init__(
self,
rawHeaders: Optional[Mapping[AnyStr, Sequence[AnyStr]]] = None,
):
self._rawHeaders = {} # type: Dict[bytes, List[bytes]]
if rawHeaders is not None:
for name, values in rawHeaders.items():
self.setRawHeaders(name, values)
Expand All @@ -97,68 +107,19 @@ def __cmp__(self, other):
)
return NotImplemented

def _encodeName(self, name):
def _encodeName(self, name: AnyStr) -> bytes:
"""
Encode the name of a header (eg 'Content-Type') to an ISO-8859-1 encoded
bytestring if required.
@param name: A HTTP header name
@type name: L{unicode} or L{bytes}
@return: C{name}, encoded if required, lowercased
@rtype: L{bytes}
"""
if isinstance(name, str):
return name.lower().encode("iso-8859-1")
return name.lower()

def _encodeValue(self, value):
"""
Encode a single header value to a UTF-8 encoded bytestring if required.
@param value: A single HTTP header value.
@type value: L{bytes} or L{unicode}
@return: C{value}, encoded if required
@rtype: L{bytes}
"""
if isinstance(value, str):
return value.encode("utf8")
return value

def _encodeValues(self, values):
"""
Encode a L{list} of header values to a L{list} of UTF-8 encoded
bytestrings if required.
@param values: A list of HTTP header values.
@type values: L{list} of L{bytes} or L{unicode} (mixed types allowed)
@return: C{values}, with each item encoded if required
@rtype: L{list} of L{bytes}
"""
newValues = []

for value in values:
newValues.append(self._encodeValue(value))
return newValues

def _decodeValues(self, values):
"""
Decode a L{list} of header values into a L{list} of Unicode strings.
@param values: A list of HTTP header values.
@type values: L{list} of UTF-8 encoded L{bytes}
@return: C{values}, with each item decoded
@rtype: L{list} of L{unicode}
"""
newValues = []

for value in values:
newValues.append(value.decode("utf8"))
return newValues

def copy(self):
"""
Return a copy of itself with the same headers set.
Expand All @@ -167,65 +128,93 @@ def copy(self):
"""
return self.__class__(self._rawHeaders)

def hasHeader(self, name):
def hasHeader(self, name: AnyStr) -> bool:
"""
Check for the existence of a given header.
@type name: L{bytes} or L{unicode}
@param name: The name of the HTTP header to check for.
@rtype: L{bool}
@return: C{True} if the header exists, otherwise C{False}.
"""
return self._encodeName(name) in self._rawHeaders

def removeHeader(self, name):
def removeHeader(self, name: AnyStr) -> None:
"""
Remove the named header from this header object.
@type name: L{bytes} or L{unicode}
@param name: The name of the HTTP header to remove.
@return: L{None}
"""
self._rawHeaders.pop(self._encodeName(name), None)

def setRawHeaders(self, name, values):
def setRawHeaders(self, name: AnyStr, values: Sequence[AnyStr]) -> None:
"""
Sets the raw representation of the given header.
@type name: L{bytes} or L{unicode}
@param name: The name of the HTTP header to set the values for.
@type values: L{list} of L{bytes} or L{unicode} strings
@param values: A list of strings each one being a header value of
the given name.
@raise TypeError: Raised if C{values} is not a L{list} of L{bytes}
or L{str} strings, or if C{name} is not a L{bytes} or
L{str} string.
@return: L{None}
"""
if not isinstance(values, list):
if not isinstance(values, _Sequence):
raise TypeError(
"Header entry %r should be list but found "
"Header entry %r should be sequence but found "
"instance of %r instead" % (name, type(values))
)

name = _sanitizeLinearWhitespace(self._encodeName(name))
encodedValues = [
_sanitizeLinearWhitespace(v) for v in self._encodeValues(values)
]

self._rawHeaders[name] = self._encodeValues(encodedValues)
if not isinstance(name, (bytes, str)):
raise TypeError(
"Header name is an instance of %r, " "not bytes or str" % (type(name),)
)

def addRawHeader(self, name, value):
for count, value in enumerate(values):
if not isinstance(value, (bytes, str)):
raise TypeError(
"Header value at position %s is an instance of %r, not "
"bytes or str"
% (
count,
type(value),
)
)

_name = _sanitizeLinearWhitespace(self._encodeName(name))
encodedValues = [] # type: List[bytes]
for v in values:
if isinstance(v, str):
_v = v.encode("utf8")
else:
_v = v
encodedValues.append(_sanitizeLinearWhitespace(_v))

self._rawHeaders[_name] = encodedValues

def addRawHeader(self, name: AnyStr, value: AnyStr) -> None:
"""
Add a new raw value for the given header.
@type name: L{bytes} or L{unicode}
@param name: The name of the header for which to set the value.
@type value: L{bytes} or L{unicode}
@param value: The value to set for the named header.
"""
if not isinstance(name, (bytes, str)):
raise TypeError(
"Header name is an instance of %r, " "not bytes or str" % (type(name),)
)

if not isinstance(value, (bytes, str)):
raise TypeError(
"Header value is an instance of %r, not "
"bytes or str" % (type(value),)
)

values = self.getRawHeaders(name)

if values is not None:
Expand All @@ -235,30 +224,33 @@ def addRawHeader(self, name, value):

self.setRawHeaders(name, values)

def getRawHeaders(self, name, default=None):
_T = TypeVar("_T")

def getRawHeaders(
self, name: AnyStr, default: Optional[_T] = None
) -> Union[List[AnyStr], Optional[_T]]:
"""
Returns a list of headers matching the given name as the raw string
given.
@type name: L{bytes} or L{unicode}
@param name: The name of the HTTP header to get the values of.
@param default: The value to return if no header with the given C{name}
exists.
@rtype: L{list} of strings, same type as C{name} (except when
C{default} is returned).
@return: If the named header is present, a L{list} of its
values. Otherwise, C{default}.
"""
encodedName = self._encodeName(name)
values = self._rawHeaders.get(encodedName, default)
values = self._rawHeaders.get(encodedName, [])
if not values:
return default

if isinstance(name, str) and values is not default:
return self._decodeValues(values)
if isinstance(name, str):
return [v if isinstance(v, str) else v.decode("utf8") for v in values]
return values

def getAllRawHeaders(self):
def getAllRawHeaders(self) -> Iterator[Tuple[bytes, List[bytes]]]:
"""
Return an iterator of key, value pairs of all headers contained in this
object, as L{bytes}. The keys are capitalized in canonical
Expand All @@ -267,15 +259,13 @@ def getAllRawHeaders(self):
for k, v in self._rawHeaders.items():
yield self._canonicalNameCaps(k), v

def _canonicalNameCaps(self, name):
def _canonicalNameCaps(self, name: bytes) -> bytes:
"""
Return the canonical name for the given header.
@type name: L{bytes}
@param name: The all-lowercase header name to capitalize in its
canonical form.
@rtype: L{bytes}
@return: The canonical name of the header.
"""
return self._caseMappings.get(name, _dashCapitalize(name))
Expand Down
Empty file.
Loading

0 comments on commit f0e7b34

Please sign in to comment.