Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinCortacero committed Jan 8, 2025
1 parent 8e6422c commit 73adab5
Show file tree
Hide file tree
Showing 8 changed files with 251 additions and 48 deletions.
4 changes: 2 additions & 2 deletions examples/components/create_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from kartezio.types import TypeArray


@register(Endpoint, "my_endpoint")
@register(Endpoint)
class MyExampleEndpoint(Endpoint):
"""
A custom endpoint used as the final output node in a CGP graph.
Expand Down Expand Up @@ -75,7 +75,7 @@ def main():

# Instantiate the endpoint from the component registry and apply it
my_endpoint_2 = Components.instantiate(
"Endpoint", "my_endpoint", n_classes=3
"Endpoint", "MyExampleEndpoint", n_classes=3
)
output_2 = my_endpoint_2.call(inputs) # Expected Output: 0
print(f"Output from my_endpoint_2: {output_2}")
Expand Down
4 changes: 2 additions & 2 deletions examples/components/create_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from kartezio.types import TypeArray


@register(Primitive, "my_primitive")
@register(Primitive)
class MyExamplePrimitive(Primitive):
"""
A custom primitive operation that replaces specific pixel values.
Expand Down Expand Up @@ -44,7 +44,7 @@ def main():
output = my_primitive.call(inputs) # Apply the primitive operation
print(output) # Output: [1, 0, 1, 0]

my_primitive_2 = Components.instantiate("Primitive", "my_primitive")
my_primitive_2 = Components.instantiate("Primitive", "MyExamplePrimitive")
output_2 = my_primitive_2.call(inputs)
print(output_2) # Output: [1, 0, 1, 0]

Expand Down
11 changes: 5 additions & 6 deletions examples/training/advanced_trainer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from kartezio.callback import CallbackVerbose
from kartezio.core.endpoints import EndpointThreshold
from kartezio.core.endpoints import EndpointThreshold, ThresholdWatershed
from kartezio.core.fitness import IoU
from kartezio.evolution.base import KartezioTrainer
from kartezio.mutation.behavioral import AccumulateBehavior
from kartezio.mutation.decay import LinearDecay
from kartezio.mutation.decay import LinearDecay, DegreeDecay
from kartezio.mutation.edges import MutationEdgesNormal
from kartezio.mutation.effect import MutationNormal
from kartezio.primitives.array import create_array_lib
Expand All @@ -23,9 +23,8 @@ def main():
create_array_lib(use_scalars=True),
library_scalar,
] # Create a library of array operations
endpoint = EndpointThreshold(
128, mode="tozero"
) # Define the endpoint for the model
endpoint = EndpointThreshold(128) # Define the endpoint for the model
endpoint = ThresholdWatershed(True, 128, 192)
fitness = IoU() # Define the fitness metric

# Build the model with specified components
Expand All @@ -38,7 +37,7 @@ def main():
)

model.set_mutation_rates(node_rate=0.5, out_rate=0.2)
model.set_decay(LinearDecay(0.5, 0.01))
model.set_decay(DegreeDecay(4, 0.5, 0.01))
model.set_behavior(AccumulateBehavior())
model.set_mutation_effect(MutationNormal(0.5, 0.005))
model.set_mutation_edges(MutationEdgesNormal(10))
Expand Down
4 changes: 2 additions & 2 deletions src/kartezio/core/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,8 @@ def call(self, x):
return [
double_threshold_watershed(
image=x[0],
threshold1=self.threshold1,
threshold2=self.threshold2,
threshold1=self.threshold,
threshold2=self.threshold_2,
watershed_line=self.watershed_line,
)
]
Expand Down
1 change: 0 additions & 1 deletion src/kartezio/core/fitness.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Dict

import numpy as np
from scipy.optimize import linear_sum_assignment

from kartezio.core.components import Fitness, register
from kartezio.thirdparty.cellpose import cellpose_ap
Expand Down
261 changes: 236 additions & 25 deletions src/kartezio/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,98 @@
from dataclasses import dataclass, field
from typing import List, Tuple
from abc import abstractmethod
from ast import List
from dataclasses import dataclass
from typing import Dict, Tuple

from kartezio.core.components import Components
from kartezio.utils.directory import Directory
from kartezio.vision.common import draw_overlay
from kartezio.utils.json_handler import json_read, json_write



import numpy as np

from kartezio.core.components import KartezioComponent
from kartezio.utils.json_handler import json_read, json_write

CSV_DATASET = "dataset.csv"
JSON_META = "meta.json"

class DataReader(KartezioComponent):
def __init__(self, directory, scale=1.0):
super().__init__()
self.scale = scale
self.directory = directory

def read(self, filename, shape=None):
if str(filename) == "nan":
filepath = ""
else:
filepath = str(self.directory / filename)
return self._read(filepath, shape)
class Dataset:
class SubSet:
def __init__(self, dataframe):
self.x = []
self.y = []
self.v = []
self.dataframe = dataframe

@abstractmethod
def _read(self, filepath, shape=None):
pass
def add_item(self, x, y):
self.x.append(x)
self.y.append(y)

@classmethod
def __from_dict__(cls, dict_infos: Dict) -> "DataReader":
pass
def add_visual(self, visual):
self.v.append(visual)

@property
def xy(self):
return self.x, self.y

@property
def xyv(self):
return self.x, self.y, self.v

def __init__(
self, train_set, test_set, name, label_name, inputs, indices=None
):
self.train_set = train_set
self.test_set = test_set
self.name = name
self.label_name = label_name
self.inputs = inputs
self.indices = indices

@property
def train_x(self):
return self.train_set.x

@property
def train_y(self):
return self.train_set.y

@property
def train_v(self):
return self.train_set.v

@property
def test_x(self):
return self.test_set.x

@property
def test_y(self):
return self.test_set.y

@property
def test_v(self):
return self.test_set.v

@property
def train_xy(self):
return self.train_set.xy

@property
def test_xy(self):
return self.test_set.xy

@property
def train_xyv(self):
return self.train_set.xyv

@property
def test_xyv(self):
return self.test_set.xyv

@property
def split(self):
return self.train_x, self.train_y, self.test_x, self.test_y

def __to_dict__(self) -> Dict:
return {"scale": self.scale, "directory": str(self.directory)}


class DatasetMeta:
Expand All @@ -46,7 +107,7 @@ def write(
label_name,
scale=1.0,
mode="dataframe",
meta_filename="META.json",
meta_filename=JSON_META,
):
json_data = {
"name": name,
Expand All @@ -63,6 +124,137 @@ def read(filepath, meta_filename):
return json_read(filepath / meta_filename)


@dataclass
class DatasetReader(Directory):
counting: bool = False
preview: bool = False
preview_dir: Directory = field(init=False)

def __post_init__(self, path):
super().__post_init__(path)
if self.preview:
self.preview_dir = self.next("__preview__")

def _read_meta(self, meta_filename):
from kartezio.readers import ImageRGBReader, RoiPolygonReader
meta = DatasetMeta.read(self._path, meta_filename=meta_filename)
self.name = meta["name"]
self.scale = meta["scale"]
self.mode = meta["mode"]
self.label_name = meta["label_name"]
input_reader_name = (
f"{meta['input']['type']}_{meta['input']['format']}"
)
label_reader_name = (
f"{meta['label']['type']}_{meta['label']['format']}"
)
self.input_reader = ImageRGBReader(self) # Components.instantiate("DataReader", input_reader_name, directory=self, scale=self.scale)
self.label_reader = RoiPolygonReader(self) # Components.instantiate("DataReader", label_reader_name, directory=self, scale=self.scale)

def read_dataset(
self,
dataset_filename=CSV_DATASET,
meta_filename=JSON_META,
indices=None,
):
self._read_meta(meta_filename)
if self.mode == "dataframe":
return self._read_from_dataframe(dataset_filename, indices)
raise AttributeError(f"{self.mode} is not handled yet")

def _read_from_dataframe(self, dataset_filename, indices):
dataframe = self.read(dataset_filename)
dataframe_training = dataframe[dataframe["set"] == "training"]
training = self._read_dataset(dataframe_training, indices)
dataframe_testing = dataframe[dataframe["set"] == "testing"]
testing = self._read_dataset(dataframe_testing)
input_sizes = []
[input_sizes.append(len(xi)) for xi in training.x]
[input_sizes.append(len(xi)) for xi in testing.x]
input_sizes = np.array(input_sizes)
inputs = int(input_sizes[0])
if not np.all((input_sizes == inputs)):
"""
raise ValueError(
f"Inconsistent size of inputs for this dataset: sizes: {input_sizes}"
)
"""
print(
f"Inconsistent size of inputs for this dataset: sizes: {input_sizes}"
)

if self.preview:
color = [98, 36, 97]
for i in range(len(training.x)):
visual = training.v[i]
label = training.y[i][0]
preview = draw_overlay(
visual,
label.astype(np.uint8),
color=color,
alpha=0.5,
thickness=3,
)
self.preview_dir.write(f"train_{i}.png", preview)
for i in range(len(testing.x)):
visual = testing.v[i]
label = testing.y[i][0]
preview = draw_overlay(
visual,
label.astype(np.uint8),
color=color,
alpha=0.5,
thickness=3,
)
self.preview_dir.write(f"test_{i}.png", preview)
return Dataset(
training, testing, self.name, self.label_name, inputs, indices
)

def _read_auto(self, dataset):
pass

def _read_dataset(self, dataframe, indices=None):
dataset = Dataset.SubSet(dataframe)
dataframe.reset_index(inplace=True)
if indices:
dataframe = dataframe.loc[indices]
for row in dataframe.itertuples():
x = self.input_reader.read(row.input, shape=None)
y = self.label_reader.read(row.label, shape=x.shape)
if self.counting:
y = [y.datalist[0], y.count]
else:
y = y.datalist
dataset.n_inputs = x.size
dataset.add_item(x.datalist, y)
visual_from_table = False
if "visual" in dataframe.columns:
if str(row.visual) != "nan":
dataset.add_visual(self.read(row.visual))
visual_from_table = True
if not visual_from_table:
dataset.add_visual(x.visual)
return dataset


class DataReader:
def __init__(self, directory, scale=1.0):
self.scale = scale
self.directory = directory

def read(self, filename, shape=None):
if str(filename) == "nan":
filepath = ""
else:
filepath = str(self.directory / filename)
return self._read(filepath, shape)

@abstractmethod
def _read(self, filepath, shape=None):
pass


@dataclass
class DataItem:
datalist: List
Expand All @@ -73,3 +265,22 @@ class DataItem:
@property
def size(self):
return len(self.datalist)


def read_dataset(
dataset_path,
filename=CSV_DATASET,
meta_filename=JSON_META,
indices=None,
counting=False,
preview=False,
reader=None,
):
dataset_reader = DatasetReader(
dataset_path, counting=counting, preview=preview
)
if reader is not None:
dataset_reader.add_reader(reader)
return dataset_reader.read_dataset(
dataset_filename=filename, meta_filename=meta_filename, indices=indices
)
Loading

0 comments on commit 73adab5

Please sign in to comment.