Open
Description
opened on Nov 8, 2024
Hi,
cc @dionhaefner
The following attempts to serialize a zero-sized array, but it fails validation in _validate_params
.
I believe the problem is that _validate_params
expects to find for every 'foo/.zarray' entry, a matching data entry foo/0
. However, this code produces tensorstore entries: 'a/0', 'a/.zarray', 'z/.zarray', but not z/0 since there is no data in the z tensor.
I'm actually not sure if tensorstore saves an entry z/0
or not, or what the intended behavior should be.
Any insight would be greatly appreciated!
import jax.numpy as jnp
import jax.tree_util as jtu
import tempfile
import orbax.checkpoint as ocp
target = {
'a': jnp.array([1, 2, 3], jnp.int32),
'z': jnp.zeros((0,)),
}
orbax_checkpointer = ocp.Checkpointer(
ocp.PyTreeCheckpointHandler()
)
with tempfile.TemporaryDirectory() as ckpt_path:
overwrite = True
save_args = jtu.tree_map(lambda _: ocp.SaveArgs(), target)
orbax_checkpointer.save(ckpt_path, target, save_args=save_args, force=overwrite)
(jax_env) henry@henry-gs65:orbax$ python flax4309.py
Traceback (most recent call last):
File "/home/henry/ai/projects/orbax/flax4309.py", line 18, in <module>
orbax_checkpointer.save(ckpt_path, target, save_args=save_args, force=overwrite)
File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/checkpointer.py", line 216, in save
self._handler.finalize(tmpdir.get())
File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py", line 1004, in finalize
self._handler_impl.finalize(directory)
File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py", line 806, in finalize
asyncio_utils.run_sync(
File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/_src/asyncio_utils.py", line 50, in run_sync
return asyncio.run(coro)
^^^^^^^^^^^^^^^^^
File "/home/henry/miniconda3/lib/python3.11/asyncio/runners.py", line 190, in run
return runner.run(main)
^^^^^^^^^^^^^^^^
File "/home/henry/miniconda3/lib/python3.11/asyncio/runners.py", line 118, in run
return self._loop.run_until_complete(task)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/henry/miniconda3/lib/python3.11/asyncio/base_events.py", line 653, in run_until_complete
return future.result()
^^^^^^^^^^^^^^^
File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/type_handlers.py", line 704, in merge_ocdbt_per_process_files
await _validate_params(directory, ts_context, use_zarr3=use_zarr3)
File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/type_handlers.py", line 625, in _validate_params
raise ValueError(
ValueError: Save failed: 1/2 params are missing in checkpoint:
z.
Tensorstore KvStore: KvStore({
'base': {
'driver': 'file',
'path': '/tmp/tmpbxi1zpec.orbax-checkpoint-tmp-0/',
},
'cache_pool': 'cache_pool#ocdbt',
'config': {
'compression': {'id': 'zstd'},
'max_decoded_node_bytes': 100000000,
'max_inline_value_bytes': 1024,
'uuid': '3ef941407cca4f778414e9e92b15dedb',
'version_tree_arity_log2': 4,
},
'context': {
'cache_pool#ocdbt': {'total_bytes_limit': 100000000},
'data_copy_concurrency': {},
'file_io_concurrency': {'limit': 128},
'file_io_sync': True,
'ocdbt_coordinator': {},
},
'driver': 'ocdbt',
'experimental_read_coalescing_interval': '1ms',
'experimental_read_coalescing_merged_bytes': 500000000000,
'experimental_read_coalescing_threshold_bytes': 1000000,
}).
Metadata
Assignees
Labels
No labels
Activity