You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The data format we are currently using to represent attributes and edges is suboptimal regarding JAX transformation. Per JAX standard, PyTrees are only accessible at compile time and cannot be indexed using Tracer, that are accessible at run time. For this reason, update functions need to set the node_idx and edges variables as static arguments, which cache a new function for each node separately. This makes us lose the advantages provided by the modularity of the implementation, and large models will definitely benefit from having a uniquely cached update function.
The solution I see would be:
[x ] Use a dictionary of arrays to store the edges using a connectivity matrix representation.
[x ] Use a dictionary of arrays for each node parameter.
Update: The current status is that it is (very) difficult to write readable update functions that can pass messages with a dynamically valued number of nodes without using something like Dynamic shapes. It is under development in JAX but not yet available. Until such a feature is available it seems unreasonable to try to move the code to this implementation. We have a working example for the two-level binary HGF and the total execution time is longer than the default implementation, so it is unclear if we would really benefit from this, besides compilation time.
The text was updated successfully, but these errors were encountered:
LegrandNico
changed the title
Use a dictionary instead of tuples to store the edges of a network
Use dictionaries of array to store the edges and attributes of a network
Oct 23, 2023
LegrandNico
changed the title
Use dictionaries of array to store the edges and attributes of a network
Use dictionaries of arrays to store the edges and attributes of a network
Oct 23, 2023
LegrandNico
changed the title
Use dictionaries of arrays to store the edges and attributes of a network
Vectorize the edges and attributes of the networks
Oct 24, 2023
The data format we are currently using to represent
attributes
andedges
is suboptimal regarding JAX transformation. Per JAX standard, PyTrees are only accessible at compile time and cannot be indexed using Tracer, that are accessible at run time. For this reason, update functions need to set thenode_idx
andedges
variables as static arguments, which cache a new function for each node separately. This makes us lose the advantages provided by the modularity of the implementation, and large models will definitely benefit from having a uniquely cached update function.The solution I see would be:
edges
using a connectivity matrix representation.Update: The current status is that it is (very) difficult to write readable update functions that can pass messages with a dynamically valued number of nodes without using something like Dynamic shapes. It is under development in JAX but not yet available. Until such a feature is available it seems unreasonable to try to move the code to this implementation. We have a working example for the two-level binary HGF and the total execution time is longer than the default implementation, so it is unclear if we would really benefit from this, besides compilation time.
The text was updated successfully, but these errors were encountered: