Skip to content

Commit

Permalink
Remove some internal functions and consolidate logic
Browse files Browse the repository at this point in the history
in getRawHeaders for converting from str to bytes.
  • Loading branch information
rodrigc committed Sep 19, 2020
1 parent affb7b3 commit 50ccd1c
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 56 deletions.
72 changes: 18 additions & 54 deletions src/twisted/web/http_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
"""

from typing import (
Any,
AnyStr,
Dict,
Iterator,
Expand All @@ -16,8 +15,8 @@
Optional,
Sequence,
Tuple,
Union,
)
from collections.abc import Sequence as _Sequence

from twisted.python.compat import comparable, cmp

Expand Down Expand Up @@ -119,47 +118,6 @@ def _encodeName(self, name: AnyStr) -> bytes:
return name.lower().encode("iso-8859-1")
return name.lower()

def _encodeValue(self, value: AnyStr) -> bytes:
"""
Encode a single header value to a UTF-8 encoded bytestring if required.
@param value: A single HTTP header value.
@return: C{value}, encoded if required
"""
if isinstance(value, str):
return value.encode("utf8")
return value

def _encodeValues(self, values: List[AnyStr]) -> List[bytes]:
"""
Encode a L{list} of header values to a L{list} of UTF-8 encoded
bytestrings if required.
@param values: A list of HTTP header values.
@return: C{values}, with each item encoded if required
"""
newValues = []

for value in values:
newValues.append(self._encodeValue(value))
return newValues

def _decodeValues(self, values: List[bytes]) -> List[str]:
"""
Decode a L{list} of header values into a L{list} of Unicode strings.
@param values: A list of HTTP header values.
@return: C{values}, with each item decoded
"""
newValues = []

for value in values:
newValues.append(value.decode("utf8"))
return newValues

def copy(self):
"""
Return a copy of itself with the same headers set.
Expand Down Expand Up @@ -188,7 +146,7 @@ def removeHeader(self, name: AnyStr) -> None:
"""
self._rawHeaders.pop(self._encodeName(name), None)

def setRawHeaders(self, name: AnyStr, values: List[AnyStr]) -> None:
def setRawHeaders(self, name: AnyStr, values: Sequence[AnyStr]) -> None:
"""
Sets the raw representation of the given header.
Expand All @@ -203,9 +161,9 @@ def setRawHeaders(self, name: AnyStr, values: List[AnyStr]) -> None:
@return: L{None}
"""
if not isinstance(values, list):
if not isinstance(values, _Sequence):
raise TypeError(
"Header entry %r should be list but found "
"Header entry %r should be sequence but found "
"instance of %r instead" % (name, type(values))
)

Expand All @@ -226,9 +184,13 @@ def setRawHeaders(self, name: AnyStr, values: List[AnyStr]) -> None:
)

_name = _sanitizeLinearWhitespace(self._encodeName(name))
encodedValues = [
_sanitizeLinearWhitespace(v) for v in self._encodeValues(values)
]
encodedValues = [] # type: List[bytes]
for v in values:
if isinstance(v, str):
_v = v.encode("utf8")
else:
_v = v
encodedValues.append(_sanitizeLinearWhitespace(_v))

self._rawHeaders[_name] = encodedValues

Expand Down Expand Up @@ -261,8 +223,8 @@ def addRawHeader(self, name: AnyStr, value: AnyStr) -> None:
self.setRawHeaders(name, values)

def getRawHeaders(
self, name: AnyStr, default: Optional[Any] = None
) -> Optional[Union[List[AnyStr], Any]]:
self, name: AnyStr, default: List[AnyStr] = None
) -> Optional[List[AnyStr]]:
"""
Returns a list of headers matching the given name as the raw string
given.
Expand All @@ -276,10 +238,12 @@ def getRawHeaders(
values. Otherwise, C{default}.
"""
encodedName = self._encodeName(name)
values = self._rawHeaders.get(encodedName, default)
values = self._rawHeaders.get(encodedName, None)
if values is None:
return default

if isinstance(name, str) and values is not default and values is not None:
return self._decodeValues(values)
if isinstance(name, str):
return [v if isinstance(v, str) else v.decode("utf8") for v in values]
return values

def getAllRawHeaders(self) -> Iterator[Tuple[bytes, List[bytes]]]:
Expand Down
4 changes: 2 additions & 2 deletions src/twisted/web/test/test_http_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_setRawHeaders(self):
self.assertTrue(h.hasHeader(b"Test"))
self.assertEqual(h.getRawHeaders(b"test"), rawValue)

def test_rawHeadersTypeCheckingValuesList(self):
def test_rawHeadersTypeCheckingValuesIterable(self):
"""
L{Headers.setRawHeaders} requires values to be of type list.
"""
Expand Down Expand Up @@ -435,7 +435,7 @@ def test_rawHeadersValueEncoding(self):

def test_rawHeadersTypeChecking(self):
"""
L{Headers.setRawHeaders} requires values to be of type list.
L{Headers.setRawHeaders} requires values to be of type sequence
"""
h = Headers()
self.assertRaises(TypeError, h.setRawHeaders, "key", {"Foo": "bar"})
Expand Down

0 comments on commit 50ccd1c

Please sign in to comment.