Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve function type annotations #319

Merged
merged 2 commits into from
Oct 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 48 additions & 45 deletions src/prettytable/prettytable.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@
import io
import re
from html.parser import HTMLParser
from typing import Any
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from sqlite3 import Cursor

from typing_extensions import Self

# hrule styles
FRAME = 0
Expand All @@ -58,7 +63,7 @@
_re = re.compile(r"\033\[[0-9;]*m|\033\(B")


def _get_size(text):
def _get_size(text: str) -> tuple[int, int]:
lines = text.split("\n")
height = len(lines)
width = max(_str_block_width(line) for line in lines)
Expand Down Expand Up @@ -97,6 +102,7 @@ class PrettyTable:
_attributes: dict[str, str]
_escape_header: bool
_escape_data: bool
_hrule: str

def __init__(self, field_names=None, **kwargs) -> None:
"""Return a new PrettyTable instance
Expand Down Expand Up @@ -310,7 +316,7 @@ def __init__(self, field_names=None, **kwargs) -> None:
self._xhtml = kwargs["xhtml"] or False
self._attributes = kwargs["attributes"] or {}

def _justify(self, text, width, align):
def _justify(self, text: str, width: int, align) -> str:
excess = width - _str_block_width(text)
if align == "l":
return text + excess * " "
Expand Down Expand Up @@ -344,7 +350,7 @@ def __getattr__(self, name):
else:
raise AttributeError(name)

def __getitem__(self, index):
def __getitem__(self, index: int | slice) -> PrettyTable:
new = PrettyTable()
new.field_names = self.field_names
for attr in self._options:
Expand All @@ -366,7 +372,7 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return self.get_string()

def _repr_html_(self):
def _repr_html_(self) -> str:
"""
Returns get_html_string value by default
as the repr call in Jupyter notebook environment
Expand Down Expand Up @@ -1581,7 +1587,7 @@ def clear(self) -> None:
# MISC PUBLIC METHODS #
##############################

def copy(self):
def copy(self) -> Self:
import copy

return copy.deepcopy(self)
Expand Down Expand Up @@ -1623,7 +1629,7 @@ def _format_value(self, field, value):
formatter = self._custom_format.get(field, (lambda f, v: str(v)))
return formatter(field, value)

def _compute_table_width(self, options):
def _compute_table_width(self, options) -> int:
if options["vrules"] == FRAME:
table_width = 2
if options["vrules"] == ALL:
Expand Down Expand Up @@ -1712,7 +1718,7 @@ def _compute_widths(self, rows, options) -> None:
widths[-1] += min_width - sum(widths)
self._widths = widths

def _get_padding_widths(self, options):
def _get_padding_widths(self, options) -> tuple[int, int]:
if options["left_padding_width"] is not None:
lpad = options["left_padding_width"]
else:
Expand Down Expand Up @@ -1753,7 +1759,7 @@ def _get_rows(self, options):

return rows

def _get_dividers(self, options):
def _get_dividers(self, options) -> list[bool]:
"""Return only those dividers that should be printed, based on slicing.

Arguments:
Expand Down Expand Up @@ -1834,7 +1840,7 @@ def get_string(self, **kwargs) -> str:

options = self._get_options(kwargs)

lines = []
lines: list[str] = []

# Don't think too hard about an empty table
# Is this the desired behaviour? Maybe we should still print the header?
Expand Down Expand Up @@ -1894,7 +1900,7 @@ def get_string(self, **kwargs) -> str:

return "\n".join(lines)

def _stringify_hrule(self, options, where: str = ""):
def _stringify_hrule(self, options, where: str = "") -> str:
if not options["border"] and not options["preserve_internal_border"]:
return ""
lpad, rpad = self._get_padding_widths(options)
Expand Down Expand Up @@ -1933,7 +1939,7 @@ def _stringify_hrule(self, options, where: str = ""):

return "".join(bits)

def _stringify_title(self, title, options):
def _stringify_title(self, title: str, options) -> str:
lines = []
lpad, rpad = self._get_padding_widths(options)
if options["border"]:
Expand All @@ -1943,7 +1949,7 @@ def _stringify_title(self, title, options):
options["vrules"] = ALL
elif options["vrules"] == FRAME:
lines.append(self._stringify_hrule(options, "top_"))
bits = []
bits: list[str] = []
endpoint = (
options["vertical_char"]
if options["vrules"] in (ALL, FRAME) and options["border"]
Expand All @@ -1959,8 +1965,8 @@ def _stringify_title(self, title, options):
lines.append("".join(bits))
return "\n".join(lines)

def _stringify_header(self, options):
bits = []
def _stringify_header(self, options) -> str:
bits: list[str] = []
lpad, rpad = self._get_padding_widths(options)
if options["border"]:
if options["hrules"] in (ALL, FRAME):
Expand Down Expand Up @@ -2027,7 +2033,7 @@ def _stringify_header(self, options):
bits.append(self._hrule)
return "".join(bits)

def _stringify_row(self, row, options, hrule):
def _stringify_row(self, row, options, hrule: str) -> str:
import textwrap

for index, field, value, width in zip(
Expand All @@ -2052,7 +2058,7 @@ def _stringify_row(self, row, options, hrule):
if h > row_height:
row_height = h

bits = []
bits: list[list[str]] = []
lpad, rpad = self._get_padding_widths(options)
for y in range(0, row_height):
bits.append([])
Expand All @@ -2078,8 +2084,7 @@ def _stringify_row(self, row, options, hrule):
else:
lines = lines + [""] * d_height

y = 0
for line in lines:
for y, line in enumerate(lines):
if options["fields"] and field not in options["fields"]:
continue

Expand All @@ -2093,7 +2098,6 @@ def _stringify_row(self, row, options, hrule):
bits[y].append(self.vertical_char)
else:
bits[y].append(" ")
y += 1

# If only preserve_internal_border is true, then we just appended
# a vertical character at the end when we wanted a space
Expand All @@ -2112,13 +2116,11 @@ def _stringify_row(self, row, options, hrule):
bits[row_height - 1].append("\n")
bits[row_height - 1].append(hrule)

for y in range(0, row_height):
bits[y] = "".join(bits[y])

return "\n".join(bits)
bits_str = ["".join(bits_y) for bits_y in bits]
return "\n".join(bits_str)

def paginate(self, page_length: int = 58, line_break: str = "\f", **kwargs):
pages = []
def paginate(self, page_length: int = 58, line_break: str = "\f", **kwargs) -> str:
pages: list[str] = []
kwargs["start"] = kwargs.get("start", 0)
true_end = kwargs.get("end", self.rowcount)
while True:
Expand Down Expand Up @@ -2259,10 +2261,10 @@ def get_html_string(self, **kwargs) -> str:

return string

def _get_simple_html_string(self, options):
def _get_simple_html_string(self, options) -> str:
from html import escape

lines = []
lines: list[str] = []
if options["xhtml"]:
linebreak = "<br/>"
else:
Expand Down Expand Up @@ -2318,10 +2320,10 @@ def _get_simple_html_string(self, options):

return "\n".join(lines)

def _get_formatted_html_string(self, options):
def _get_formatted_html_string(self, options) -> str:
from html import escape

lines = []
lines: list[str] = []
lpad, rpad = self._get_padding_widths(options)
if options["xhtml"]:
linebreak = "<br/>"
Expand Down Expand Up @@ -2378,8 +2380,8 @@ def _get_formatted_html_string(self, options):
lines.append(" <tbody>")
rows = self._get_rows(options)
formatted_rows = self._format_rows(rows)
aligns = []
valigns = []
aligns: list[str] = []
valigns: list[str] = []
for field in self._field_names:
aligns.append(
{"l": "left", "r": "right", "c": "center"}[self._align[field]]
Expand Down Expand Up @@ -2449,8 +2451,8 @@ def get_latex_string(self, **kwargs) -> str:
string = self._get_simple_latex_string(options)
return string

def _get_simple_latex_string(self, options):
lines = []
def _get_simple_latex_string(self, options) -> str:
lines: list[str] = []

wanted_fields = []
if options["fields"]:
Expand Down Expand Up @@ -2482,8 +2484,8 @@ def _get_simple_latex_string(self, options):

return "\r\n".join(lines)

def _get_formatted_latex_string(self, options):
lines = []
def _get_formatted_latex_string(self, options) -> str:
lines: list[str] = []

wanted_fields = []
if options["fields"]:
Expand Down Expand Up @@ -2542,7 +2544,7 @@ def _get_formatted_latex_string(self, options):
##############################


def _str_block_width(val):
def _str_block_width(val: str) -> int:
import wcwidth # type: ignore[import-not-found]

return wcwidth.wcswidth(_re.sub("", val))
Expand All @@ -2553,7 +2555,7 @@ def _str_block_width(val):
##############################


def from_csv(fp, field_names: Any | None = None, **kwargs):
def from_csv(fp, field_names: Any | None = None, **kwargs) -> PrettyTable:
import csv

fmtparams = {}
Expand Down Expand Up @@ -2588,16 +2590,17 @@ def from_csv(fp, field_names: Any | None = None, **kwargs):
return table


def from_db_cursor(cursor, **kwargs):
def from_db_cursor(cursor: Cursor, **kwargs) -> PrettyTable | None:
if cursor.description:
table = PrettyTable(**kwargs)
table.field_names = [col[0] for col in cursor.description]
for row in cursor.fetchall():
table.add_row(row)
return table
return None


def from_json(json_string, **kwargs):
def from_json(json_string: str | bytes, **kwargs) -> PrettyTable:
import json

table = PrettyTable(**kwargs)
Expand All @@ -2613,7 +2616,7 @@ class TableHandler(HTMLParser):
def __init__(self, **kwargs) -> None:
HTMLParser.__init__(self)
self.kwargs = kwargs
self.tables: list[list] = []
self.tables: list[PrettyTable] = []
self.last_row: list[str] = []
self.rows: list[Any] = []
self.max_row_width = 0
Expand Down Expand Up @@ -2654,7 +2657,7 @@ def handle_endtag(self, tag) -> None:
def handle_data(self, data) -> None:
self.last_content += data

def generate_table(self, rows):
def generate_table(self, rows) -> PrettyTable:
"""
Generates from a list of rows a PrettyTable object.
"""
Expand Down Expand Up @@ -2682,7 +2685,7 @@ def make_fields_unique(self, fields) -> None:
fields[j] += "'"


def from_html(html_code, **kwargs):
def from_html(html_code: str, **kwargs) -> list[PrettyTable]:
"""
Generates a list of PrettyTables from a string of HTML code. Each <table> in
the HTML becomes one PrettyTable object.
Expand All @@ -2693,9 +2696,9 @@ def from_html(html_code, **kwargs):
return parser.tables


def from_html_one(html_code, **kwargs):
def from_html_one(html_code: str, **kwargs) -> PrettyTable:
"""
Generates a PrettyTables from a string of HTML code which contains only a
Generates a PrettyTable from a string of HTML code which contains only a
single <table>
"""

Expand Down
Loading