Skip to content

Commit

Permalink
Support async put for vineyard client.
Browse files Browse the repository at this point in the history
Signed-off-by: Ye Cao <caoye.cao@alibaba-inc.com>
  • Loading branch information
dashanji committed Nov 18, 2024
1 parent 0f78867 commit cdffb5b
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 2 deletions.
42 changes: 40 additions & 2 deletions python/vineyard/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def __init__(
session: int = None,
username: str = None,
password: str = None,
max_workers: int = 8,
config: str = None,
):
"""Connects to the vineyard IPC socket and RPC socket.
Expand Down Expand Up @@ -211,6 +212,8 @@ def __init__(
is enabled.
password: Optional, the required password of vineyardd when authentication
is enabled.
max_workers: Optional, the maximum number of threads that can be used to
asynchronously put objects to vineyard. Default is 8.
config: Optional, can either be a path to a YAML configuration file or
a path to a directory containing the default config file
`vineyard-config.yaml`. Also, the environment variable
Expand Down Expand Up @@ -290,6 +293,9 @@ def __init__(
except VineyardException:
continue

self._max_workers = max_workers
self._put_thread_pool = None

self._spread = False
self._compression = True
if self._ipc_client is None and self._rpc_client is None:
Expand Down Expand Up @@ -347,6 +353,13 @@ def rpc_client(self) -> RPCClient:
assert self._rpc_client is not None, "RPC client is not available."
return self._rpc_client

@property
def put_thread_pool(self) -> ThreadPoolExecutor:
"""Lazy initialization of the thread pool for asynchronous put."""
if self._put_thread_pool is None:
self._put_thread_pool = ThreadPoolExecutor(max_workers=self._max_workers)
return self._put_thread_pool

def has_ipc_client(self):
return self._ipc_client is not None

Expand Down Expand Up @@ -820,8 +833,7 @@ def get(
):
return get(self, object_id, name, resolver, fetch, **kwargs)

@_apply_docstring(put)
def put(
def _put_internal(
self,
value: Any,
builder: Optional[BuilderContext] = None,
Expand Down Expand Up @@ -858,6 +870,32 @@ def put(
self.compression = previous_compression_state
return put(self, value, builder, persist, name, **kwargs)

@_apply_docstring(put)
def put(
self,
value: Any,
builder: Optional[BuilderContext] = None,
persist: bool = False,
name: Optional[str] = None,
as_async: bool = False,
**kwargs,
):
if as_async:
def _default_callback(future):
try:
result = future.result()
print(f"Successfully put object {result}", flush=True)
except Exception as e:
print(f"Failed to put object: {e}", flush=True)

thread_pool = self.put_thread_pool
result = thread_pool.submit(
self._put_internal, value, builder, persist, name, **kwargs
)
result.add_done_callback(_default_callback)
return result
return self._put_internal(value, builder, persist, name, **kwargs)

@contextlib.contextmanager
def with_compression(self, enabled: bool = True):
"""Disable compression for the following put operations."""
Expand Down
39 changes: 39 additions & 0 deletions python/vineyard/core/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
import itertools
import multiprocessing
import random
import time
import traceback
from concurrent.futures import ThreadPoolExecutor
from threading import Thread

import numpy as np

Expand Down Expand Up @@ -317,3 +319,40 @@ def test_memory_trim(vineyard_client):

# there might be some fragmentation overhead
assert parse_shared_memory_usage() <= original_memory_usage + 2 * data_kbytes


def test_async_put_and_get(vineyard_client):
data = np.ones((100, 100, 16))
object_nums = 100

def producer(vineyard_client):
start_time = time.time()
client = vineyard_client.fork()
for i in range(object_nums):
client.put(data, name="test" + str(i), as_async=True, persist=True)
client.put(data)
end_time = time.time()
print("Producer time: ", end_time - start_time)

def consumer(vineyard_client):
start_time = time.time()
client = vineyard_client.fork()
for i in range(object_nums):
object_id = client.get_name(name="test" + str(i), wait=True)
client.get(object_id)
end_time = time.time()
print("Consumer time: ", end_time - start_time)

producer_thread = Thread(target=producer, args=(vineyard_client,))
consumer_thread = Thread(target=consumer, args=(vineyard_client,))

start_time = time.time()

producer_thread.start()
consumer_thread.start()

producer_thread.join()
consumer_thread.join()

end_time = time.time()
print("Total time: ", end_time - start_time)

0 comments on commit cdffb5b

Please sign in to comment.