Skip to content

Commit

Permalink
typechecking python (#104)
Browse files Browse the repository at this point in the history
* infrastructure

* typecheck the code
  • Loading branch information
adriandavila authored Dec 28, 2023
1 parent 502eb2d commit 9294c29
Show file tree
Hide file tree
Showing 10 changed files with 46 additions and 68 deletions.
9 changes: 5 additions & 4 deletions extras/ortoa/benchmark/infrastucture/experiment_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ def model_post_init(self, __context: Any) -> None:
assert self.experiment_path.is_file()
return super().model_post_init(__context)


@classmethod
def construct(cls, experiment: Path) -> List[Self]:
def construct_experiments(cls, experiment: Path) -> List[Self]:
"""
Construct an list of ExperimentPath instances
"""
Expand All @@ -33,12 +34,12 @@ def construct(cls, experiment: Path) -> List[Self]:

@classmethod
def from_path(cls, experiment: Path) -> List[Self]:
return [ExperimentPath(experiment_path=experiment)]
return [cls(experiment_path=experiment)]

@classmethod
def from_dir(cls, experiment_dir: Path) -> List[Self]:
return [
ExperimentPath(experiment_path=e) for e in experiment_dir.glob("**/*.yaml")
cls(experiment_path=e) for e in experiment_dir.glob("**/*.yaml")
]


Expand All @@ -48,6 +49,6 @@ def collect_experiments(experiments: Iterable[Path]) -> List[ExperimentPath]:
"""
return list(
itertools.chain.from_iterable(
[ExperimentPath.construct(experiment) for experiment in experiments]
[ExperimentPath.construct_experiments(experiment) for experiment in experiments]
)
)
7 changes: 4 additions & 3 deletions extras/ortoa/benchmark/infrastucture/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ class LogFiles:
class ClientFlags(BaseModel):
initdb: bool = True
nthreads: int = 1
seed: Path = Field(required=True)
operations: Path = Field(required=True)
output: Path = Field(required=True)
seed: Path = Field()
operations: Path = Field()
output: Path = Field()

@property
def initdb_flags(self) -> str:
Expand Down Expand Up @@ -210,6 +210,7 @@ def make_jobs(

for flag in experiment.client_flags:
if flag.name == "nthreads":
assert isinstance(flag.value, int)
e_client_flags.nthreads = flag.value
elif flag.name == "client_logging_enabled":
pass
Expand Down
1 change: 0 additions & 1 deletion extras/ortoa/benchmark/infrastucture/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def model_post_init(self, __context: Any) -> None:

def run(self) -> List[Result[JobT]]:
"""Leaving this for when I'm ready to implement multiprocessing for the benchmarking"""
assert self.max_processes >= 1
raise NotImplementedError

def run_sequential(self) -> List[Result[JobT]]:
Expand Down
2 changes: 1 addition & 1 deletion extras/ortoa/benchmark/infrastucture/stats_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def save_to(self, dir: Path) -> None:
self._save_graphs(dir=dir)

@classmethod
def _parse_result(self, job: ClientJob, results_file: Path) -> pd.DataFrame:
def _parse_result(cls, job: ClientJob, results_file: Path) -> pd.DataFrame:
"""Parse the results from C++ and add them to the dataframe"""

with results_file.open("r") as f:
Expand Down
6 changes: 4 additions & 2 deletions extras/ortoa/benchmark/interface/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def from_generation_config(
cls, data: DataGenerationConfigBase, output_dir: Path
) -> Self:
seed, operations = data.generate_files(output_dir)
return SeedData(seed=seed, operations=operations)
return cls(seed=seed, operations=operations)


class Config(BaseModel, Generic[FlagT]):
Expand Down Expand Up @@ -131,7 +131,9 @@ def atomicize_experiments(experiments: List[Experiment]) -> List[AtomicExperimen
atomic_experiments: List[AtomicExperiment] = []
for experiment in experiments:
assert isinstance(experiment.client_config.data, SeedData)

assert experiment.client_config.data.seed is not None
assert experiment.client_config.data.operations is not None

all_client_flags = [
flag.get_atomic_flags() for flag in experiment.client_config.flags
]
Expand Down
8 changes: 4 additions & 4 deletions extras/ortoa/benchmark/interface/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,18 @@ def __str__(self):
return f"--nthreads {self.value}"

def get_atomic_flags(self) -> List[Self]:
atomic_selfs: List[Self] = []
atomic_flags: List[Self] = []
if isinstance(self.value, int):
atomic_selfs.append(NClientThreads(name=self.name, value=self.value))
atomic_flags.append(self.__class__(name=self.name, value=self.value))
elif isinstance(self.value, (IntegerIncrementRange, IntegerMultiplyRange)):
for val in self.value.generate_values():
atomic_selfs.append(NClientThreads(name=self.name, value=val))
atomic_flags.append(self.__class__(name=self.name, value=val))
else:
raise TypeError(
"NClientThreads::get_atomic_flags() did not recognize type of self.value"
)

return atomic_selfs
return atomic_flags


class ClientLoggingEnabled(ClientFlag):
Expand Down
53 changes: 2 additions & 51 deletions extras/ortoa/benchmark/interface/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,20 @@

from pydantic import BaseModel, Field

T = TypeVar("T", bound=Union[int, str, bool])
T = TypeVar("T", bound=Union[int, str, bool, float])


##########################
# Abstractions
##########################


class FloatType(BaseModel):
type: Literal["float"] = Field(default="float", frozen=True)


class IntType(BaseModel):
type: Literal["int"] = Field(default="int", frozen=True)


class Parameter(BaseModel, ABC):
@abstractmethod
def generate_values(self) -> List[str]:
def generate_values(self) -> List:
raise NotImplementedError


Expand All @@ -33,18 +28,10 @@ class RangeParameter(Parameter, Generic[NumberT]):
maximum: NumberT


class StaticParameter(Parameter, Generic[T]):
value: T

def generate_values(self):
return [str(self.value)]


##########################
# Parameter Types
##########################


class IntegerIncrementRange(RangeParameter[int], IntType):
step: int

Expand All @@ -71,39 +58,3 @@ def generate_values(self) -> List[int]:
i *= self.multiplier

return res


class FloatIncrementRange(RangeParameter[int], FloatType):
step: float

def generate_values(self) -> List[float]:
res: List[float] = []

i = self.minimum
while i <= self.maximum:
res.append(i)
i += self.step

return res


class FloatMultiplyRange(RangeParameter[float], FloatType):
multiplier: float

def generate_values(self) -> List[str]:
res: List[float] = []

i = self.minimum
while i <= self.maximum:
res.append(i)
i *= self.multiplier

return res


class IntegerParameter(StaticParameter[int], IntType):
pass


class FloatParameter(StaticParameter[float], FloatType):
pass
8 changes: 7 additions & 1 deletion extras/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,15 @@ dev = [
"icecream==2.1.3",
"pytest==7.4.3",
"isort==5.13.1",
"black==23.12.0"
"black==23.12.0",
"pyright==1.1.343"
]

[tool.isort]
profile = "black"

[tool.pyright]
pythonVersion = "3.8"
include = [
"ortoa/"
]
1 change: 0 additions & 1 deletion extras/test/benchmark/interface/test_experiment_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from ortoa.benchmark.interface.parameter import (
IntegerIncrementRange,
IntegerMultiplyRange,
IntegerParameter,
)


Expand Down
19 changes: 19 additions & 0 deletions scripts/ortoa-lib.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ export REPO_ROOT=$(cd ${SCRIPT_DIR} && git rev-parse --show-superproject-working
export ORTOA_SHARED="${REPO_ROOT}"
export BUILD_DIR="${ORTOA_SHARED}/build"
export INSTALL_DIR="${ORTOA_SHARED}/install"
export SDK_DIR="${ORTOA_SHARED}/extras"

export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/lib:${REPO_ROOT}/install/lib"

Expand Down Expand Up @@ -257,6 +258,24 @@ Syntax: ortoa-sort-python [-h]
isort extras/
}

ortoa-typecheck-python() {
local HELP="""\
Typechecks the extras/ directory
Syntax: ortoa-typecheck-python [-h]
------------------------------
-h Print this help message
"""
OPTIND=1
while getopts ":h" option; do
case "${option}" in
h) echo "${HELP}"; return 0 ;;
esac
done

pyright -p "${SDK_DIR}" --warnings
}

############################################
# Data Generation
############################################
Expand Down

0 comments on commit 9294c29

Please sign in to comment.