Skip to content

_validate_params fails on zero-sized arraysΒ #1309

Open
@hrbigelow

Description

Hi,

@niketkumar @cpgaffney1,

cc @dionhaefner

The following attempts to serialize a zero-sized array, but it fails validation in _validate_params.

I believe the problem is that _validate_params expects to find for every 'foo/.zarray' entry, a matching data entry foo/0. However, this code produces tensorstore entries: 'a/0', 'a/.zarray', 'z/.zarray', but not z/0 since there is no data in the z tensor.

I'm actually not sure if tensorstore saves an entry z/0 or not, or what the intended behavior should be.

Any insight would be greatly appreciated!

import jax.numpy as jnp
import jax.tree_util as jtu
import tempfile
import orbax.checkpoint as ocp

target = {
    'a': jnp.array([1, 2, 3], jnp.int32),
    'z': jnp.zeros((0,)),
}

orbax_checkpointer = ocp.Checkpointer(
  ocp.PyTreeCheckpointHandler()
)

with tempfile.TemporaryDirectory() as ckpt_path:
  overwrite = True
  save_args = jtu.tree_map(lambda _: ocp.SaveArgs(), target)
  orbax_checkpointer.save(ckpt_path, target, save_args=save_args, force=overwrite)
(jax_env) henry@henry-gs65:orbax$ python flax4309.py 
Traceback (most recent call last):
  File "/home/henry/ai/projects/orbax/flax4309.py", line 18, in <module>
    orbax_checkpointer.save(ckpt_path, target, save_args=save_args, force=overwrite)
  File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/checkpointer.py", line 216, in save
    self._handler.finalize(tmpdir.get())
  File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py", line 1004, in finalize
    self._handler_impl.finalize(directory)
  File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py", line 806, in finalize
    asyncio_utils.run_sync(
  File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/_src/asyncio_utils.py", line 50, in run_sync
    return asyncio.run(coro)
           ^^^^^^^^^^^^^^^^^
  File "/home/henry/miniconda3/lib/python3.11/asyncio/runners.py", line 190, in run
    return runner.run(main)
           ^^^^^^^^^^^^^^^^
  File "/home/henry/miniconda3/lib/python3.11/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/henry/miniconda3/lib/python3.11/asyncio/base_events.py", line 653, in run_until_complete
    return future.result()
           ^^^^^^^^^^^^^^^
  File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/type_handlers.py", line 704, in merge_ocdbt_per_process_files
    await _validate_params(directory, ts_context, use_zarr3=use_zarr3)
  File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/type_handlers.py", line 625, in _validate_params
    raise ValueError(
ValueError: Save failed: 1/2 params are missing in checkpoint:
z.
Tensorstore KvStore: KvStore({
  'base': {
    'driver': 'file',
    'path': '/tmp/tmpbxi1zpec.orbax-checkpoint-tmp-0/',
  },
  'cache_pool': 'cache_pool#ocdbt',
  'config': {
    'compression': {'id': 'zstd'},
    'max_decoded_node_bytes': 100000000,
    'max_inline_value_bytes': 1024,
    'uuid': '3ef941407cca4f778414e9e92b15dedb',
    'version_tree_arity_log2': 4,
  },
  'context': {
    'cache_pool#ocdbt': {'total_bytes_limit': 100000000},
    'data_copy_concurrency': {},
    'file_io_concurrency': {'limit': 128},
    'file_io_sync': True,
    'ocdbt_coordinator': {},
  },
  'driver': 'ocdbt',
  'experimental_read_coalescing_interval': '1ms',
  'experimental_read_coalescing_merged_bytes': 500000000000,
  'experimental_read_coalescing_threshold_bytes': 1000000,
}).

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions