Skip to content

Commit

Permalink
Fixed Ray compute kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
jieguangzhou committed May 2, 2024
1 parent ec67009 commit e7227bc
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
5 changes: 4 additions & 1 deletion superduperdb/backends/ray/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ def _dependable_remote_job(function, *args, **kwargs):
ray.wait(dependencies)
return function(*args, **kwargs)

remote_function = ray.remote(_dependable_remote_job, **compute_kwargs)
if compute_kwargs:
remote_function = ray.remote(**compute_kwargs)(_dependable_remote_job)
else:
remote_function = ray.remote(_dependable_remote_job)
future = remote_function.remote(function, *args, **kwargs)
task_id = str(future.task_id().hex())
self._futures_collection[task_id] = future
Expand Down
7 changes: 6 additions & 1 deletion superduperdb/components/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ class Trainer(Component):
prefetch_size: int = 1000
prefetch_factor: int = 100
in_memory: bool = True
compute_kwargs: t.Dict = dc.field(default_factory=dict)

@abstractmethod
def fit(
Expand Down Expand Up @@ -232,11 +233,15 @@ def fit_in_db_job(
db: Datalayer,
dependencies: t.Sequence[Job] = (),
):
if self.trainer:
compute_kwargs = self.trainer.compute_kwargs or {}
else:
compute_kwargs = {}
job = ComponentJob(
component_identifier=self.identifier,
method_name='fit_in_db',
type_id='model',
kwargs={},
compute_kwargs=compute_kwargs,
)
job(db, dependencies)
return job
Expand Down

0 comments on commit e7227bc

Please sign in to comment.