Skip to content

Commit

Permalink
Fix up type hints for getRawHeaders and getAllRawHeaders
Browse files Browse the repository at this point in the history
  • Loading branch information
wsanchez committed Mar 18, 2021
1 parent bceb432 commit 88f32e2
Showing 1 changed file with 19 additions and 11 deletions.
30 changes: 19 additions & 11 deletions src/twisted/web/http_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,17 @@
TypeVar,
Tuple,
Union,
cast,
overload,
)
from collections.abc import Sequence as _Sequence

from twisted.python.compat import comparable, cmp


_T = TypeVar("_T")


def _dashCapitalize(name: bytes) -> bytes:
"""
Return a byte string which is capitalized using '-' as a word separator.
Expand Down Expand Up @@ -215,30 +220,33 @@ def addRawHeader(self, name: AnyStr, value: AnyStr) -> None:
"bytes or str" % (type(value),)
)

values = self.getRawHeaders(name)

if values is not None:
values.append(value)
else:
values = [value]
# We secretly know getRawHeaders is really returning a list
values = cast(list, self.getRawHeaders(name, default=[]))
values.append(value)

self.setRawHeaders(name, values)

_T = TypeVar("_T")
@overload
def getRawHeaders(self, name: AnyStr) -> Optional[Sequence[AnyStr]]:
...

@overload
def getRawHeaders(self, name: AnyStr, default: _T) -> Union[Sequence[AnyStr], _T]:
...

def getRawHeaders(
self, name: AnyStr, default: Optional[_T] = None
) -> Union[List[AnyStr], Optional[_T]]:
) -> Union[Sequence[AnyStr], Optional[_T]]:
"""
Returns a list of headers matching the given name as the raw string
Returns a sequence of headers matching the given name as the raw string
given.
@param name: The name of the HTTP header to get the values of.
@param default: The value to return if no header with the given C{name}
exists.
@return: If the named header is present, a L{list} of its
@return: If the named header is present, a sequence of its
values. Otherwise, C{default}.
"""
encodedName = self._encodeName(name)
Expand All @@ -250,7 +258,7 @@ def getRawHeaders(
return [v.decode("utf8") for v in values]
return values

def getAllRawHeaders(self) -> Iterator[Tuple[bytes, List[bytes]]]:
def getAllRawHeaders(self) -> Iterator[Tuple[bytes, Sequence[bytes]]]:
"""
Return an iterator of key, value pairs of all headers contained in this
object, as L{bytes}. The keys are capitalized in canonical
Expand Down

0 comments on commit 88f32e2

Please sign in to comment.