Skip to content

Commit

Permalink
Improve function type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
cdce8p committed Oct 13, 2024
1 parent 3309569 commit 2af9987
Showing 1 changed file with 47 additions and 44 deletions.
91 changes: 47 additions & 44 deletions src/prettytable/prettytable.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@
import io
import re
from html.parser import HTMLParser
from typing import Any
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from sqlite3 import Cursor

from typing_extensions import Self

# hrule styles
FRAME = 0
Expand All @@ -58,7 +63,7 @@
_re = re.compile(r"\033\[[0-9;]*m|\033\(B")


def _get_size(text):
def _get_size(text: str) -> tuple[int, int]:
lines = text.split("\n")
height = len(lines)
width = max(_str_block_width(line) for line in lines)
Expand Down Expand Up @@ -97,6 +102,7 @@ class PrettyTable:
_attributes: dict[str, str]
_escape_header: bool
_escape_data: bool
_hrule: str

def __init__(self, field_names=None, **kwargs) -> None:
"""Return a new PrettyTable instance
Expand Down Expand Up @@ -310,7 +316,7 @@ def __init__(self, field_names=None, **kwargs) -> None:
self._xhtml = kwargs["xhtml"] or False
self._attributes = kwargs["attributes"] or {}

def _justify(self, text, width, align):
def _justify(self, text: str, width: int, align) -> str:
excess = width - _str_block_width(text)
if align == "l":
return text + excess * " "
Expand Down Expand Up @@ -344,7 +350,7 @@ def __getattr__(self, name):
else:
raise AttributeError(name)

def __getitem__(self, index):
def __getitem__(self, index: int | slice) -> PrettyTable:
new = PrettyTable()
new.field_names = self.field_names
for attr in self._options:
Expand All @@ -366,7 +372,7 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return self.get_string()

def _repr_html_(self):
def _repr_html_(self) -> str:
"""
Returns get_html_string value by default
as the repr call in Jupyter notebook environment
Expand Down Expand Up @@ -1581,7 +1587,7 @@ def clear(self) -> None:
# MISC PUBLIC METHODS #
##############################

def copy(self):
def copy(self) -> Self:
import copy

return copy.deepcopy(self)
Expand Down Expand Up @@ -1623,7 +1629,7 @@ def _format_value(self, field, value):
formatter = self._custom_format.get(field, (lambda f, v: str(v)))
return formatter(field, value)

def _compute_table_width(self, options):
def _compute_table_width(self, options) -> int:
if options["vrules"] == FRAME:
table_width = 2
if options["vrules"] == ALL:
Expand Down Expand Up @@ -1712,7 +1718,7 @@ def _compute_widths(self, rows, options) -> None:
widths[-1] += min_width - sum(widths)
self._widths = widths

def _get_padding_widths(self, options):
def _get_padding_widths(self, options) -> tuple[int, int]:
if options["left_padding_width"] is not None:
lpad = options["left_padding_width"]
else:
Expand Down Expand Up @@ -1753,7 +1759,7 @@ def _get_rows(self, options):

return rows

def _get_dividers(self, options):
def _get_dividers(self, options) -> list[bool]:
"""Return only those dividers that should be printed, based on slicing.
Arguments:
Expand Down Expand Up @@ -1834,7 +1840,7 @@ def get_string(self, **kwargs) -> str:

options = self._get_options(kwargs)

lines = []
lines: list[str] = []

# Don't think too hard about an empty table
# Is this the desired behaviour? Maybe we should still print the header?
Expand Down Expand Up @@ -1894,7 +1900,7 @@ def get_string(self, **kwargs) -> str:

return "\n".join(lines)

def _stringify_hrule(self, options, where: str = ""):
def _stringify_hrule(self, options, where: str = "") -> str:
if not options["border"] and not options["preserve_internal_border"]:
return ""
lpad, rpad = self._get_padding_widths(options)
Expand Down Expand Up @@ -1933,7 +1939,7 @@ def _stringify_hrule(self, options, where: str = ""):

return "".join(bits)

def _stringify_title(self, title, options):
def _stringify_title(self, title: str, options) -> str:
lines = []
lpad, rpad = self._get_padding_widths(options)
if options["border"]:
Expand All @@ -1943,7 +1949,7 @@ def _stringify_title(self, title, options):
options["vrules"] = ALL
elif options["vrules"] == FRAME:
lines.append(self._stringify_hrule(options, "top_"))
bits = []
bits: list[str] = []
endpoint = (
options["vertical_char"]
if options["vrules"] in (ALL, FRAME) and options["border"]
Expand All @@ -1959,8 +1965,8 @@ def _stringify_title(self, title, options):
lines.append("".join(bits))
return "\n".join(lines)

def _stringify_header(self, options):
bits = []
def _stringify_header(self, options) -> str:
bits: list[str] = []
lpad, rpad = self._get_padding_widths(options)
if options["border"]:
if options["hrules"] in (ALL, FRAME):
Expand Down Expand Up @@ -2027,7 +2033,7 @@ def _stringify_header(self, options):
bits.append(self._hrule)
return "".join(bits)

def _stringify_row(self, row, options, hrule):
def _stringify_row(self, row, options, hrule: str) -> str:
import textwrap

for index, field, value, width in zip(
Expand All @@ -2052,7 +2058,7 @@ def _stringify_row(self, row, options, hrule):
if h > row_height:
row_height = h

bits = []
bits: list[list[str]] = []
lpad, rpad = self._get_padding_widths(options)
for y in range(0, row_height):
bits.append([])
Expand All @@ -2078,8 +2084,7 @@ def _stringify_row(self, row, options, hrule):
else:
lines = lines + [""] * d_height

y = 0
for line in lines:
for y, line in enumerate(lines):
if options["fields"] and field not in options["fields"]:
continue

Expand All @@ -2093,7 +2098,6 @@ def _stringify_row(self, row, options, hrule):
bits[y].append(self.vertical_char)
else:
bits[y].append(" ")
y += 1

# If only preserve_internal_border is true, then we just appended
# a vertical character at the end when we wanted a space
Expand All @@ -2112,13 +2116,11 @@ def _stringify_row(self, row, options, hrule):
bits[row_height - 1].append("\n")
bits[row_height - 1].append(hrule)

for y in range(0, row_height):
bits[y] = "".join(bits[y])

return "\n".join(bits)
bits_str = ["".join(bits_y) for bits_y in bits]
return "\n".join(bits_str)

def paginate(self, page_length: int = 58, line_break: str = "\f", **kwargs):
pages = []
def paginate(self, page_length: int = 58, line_break: str = "\f", **kwargs) -> str:
pages: list[str] = []
kwargs["start"] = kwargs.get("start", 0)
true_end = kwargs.get("end", self.rowcount)
while True:
Expand Down Expand Up @@ -2259,10 +2261,10 @@ def get_html_string(self, **kwargs) -> str:

return string

def _get_simple_html_string(self, options):
def _get_simple_html_string(self, options) -> str:
from html import escape

lines = []
lines: list[str] = []
if options["xhtml"]:
linebreak = "<br/>"
else:
Expand Down Expand Up @@ -2318,10 +2320,10 @@ def _get_simple_html_string(self, options):

return "\n".join(lines)

def _get_formatted_html_string(self, options):
def _get_formatted_html_string(self, options) -> str:
from html import escape

lines = []
lines: list[str] = []
lpad, rpad = self._get_padding_widths(options)
if options["xhtml"]:
linebreak = "<br/>"
Expand Down Expand Up @@ -2378,8 +2380,8 @@ def _get_formatted_html_string(self, options):
lines.append(" <tbody>")
rows = self._get_rows(options)
formatted_rows = self._format_rows(rows)
aligns = []
valigns = []
aligns: list[str] = []
valigns: list[str] = []
for field in self._field_names:
aligns.append(
{"l": "left", "r": "right", "c": "center"}[self._align[field]]
Expand Down Expand Up @@ -2449,8 +2451,8 @@ def get_latex_string(self, **kwargs) -> str:
string = self._get_simple_latex_string(options)
return string

def _get_simple_latex_string(self, options):
lines = []
def _get_simple_latex_string(self, options) -> str:
lines: list[str] = []

wanted_fields = []
if options["fields"]:
Expand Down Expand Up @@ -2482,8 +2484,8 @@ def _get_simple_latex_string(self, options):

return "\r\n".join(lines)

def _get_formatted_latex_string(self, options):
lines = []
def _get_formatted_latex_string(self, options) -> str:
lines: list[str] = []

wanted_fields = []
if options["fields"]:
Expand Down Expand Up @@ -2542,7 +2544,7 @@ def _get_formatted_latex_string(self, options):
##############################


def _str_block_width(val):
def _str_block_width(val: str) -> int:
import wcwidth # type: ignore[import-not-found]

return wcwidth.wcswidth(_re.sub("", val))
Expand All @@ -2553,7 +2555,7 @@ def _str_block_width(val):
##############################


def from_csv(fp, field_names: Any | None = None, **kwargs):
def from_csv(fp, field_names: Any | None = None, **kwargs) -> PrettyTable:
import csv

fmtparams = {}
Expand Down Expand Up @@ -2588,16 +2590,17 @@ def from_csv(fp, field_names: Any | None = None, **kwargs):
return table


def from_db_cursor(cursor, **kwargs):
def from_db_cursor(cursor: Cursor, **kwargs) -> PrettyTable | None:
if cursor.description:
table = PrettyTable(**kwargs)
table.field_names = [col[0] for col in cursor.description]
for row in cursor.fetchall():
table.add_row(row)
return table
return None


def from_json(json_string, **kwargs):
def from_json(json_string: str | bytes, **kwargs) -> PrettyTable:
import json

table = PrettyTable(**kwargs)
Expand All @@ -2613,7 +2616,7 @@ class TableHandler(HTMLParser):
def __init__(self, **kwargs) -> None:
HTMLParser.__init__(self)
self.kwargs = kwargs
self.tables: list[list] = []
self.tables: list[PrettyTable] = []
self.last_row: list[str] = []
self.rows: list[Any] = []
self.max_row_width = 0
Expand Down Expand Up @@ -2654,7 +2657,7 @@ def handle_endtag(self, tag) -> None:
def handle_data(self, data) -> None:
self.last_content += data

def generate_table(self, rows):
def generate_table(self, rows) -> PrettyTable:
"""
Generates from a list of rows a PrettyTable object.
"""
Expand Down Expand Up @@ -2682,7 +2685,7 @@ def make_fields_unique(self, fields) -> None:
fields[j] += "'"


def from_html(html_code, **kwargs):
def from_html(html_code: str, **kwargs) -> list[PrettyTable]:
"""
Generates a list of PrettyTables from a string of HTML code. Each <table> in
the HTML becomes one PrettyTable object.
Expand All @@ -2693,7 +2696,7 @@ def from_html(html_code, **kwargs):
return parser.tables


def from_html_one(html_code, **kwargs):
def from_html_one(html_code: str, **kwargs) -> PrettyTable:
"""
Generates a PrettyTables from a string of HTML code which contains only a
single <table>
Expand Down

0 comments on commit 2af9987

Please sign in to comment.