Skip to content

Conversation

@copybara-service
Copy link

[pmap] Avoid degraded performance under the new jax.pmap.

This change prepares for the new jax.pmap by implementing the recommended mechanism for accessing the first shard in a sharded array. A common pattern used with jax.pmap is to shard an array that is semantically replicated and grabbing the first shard is meant to "unreplicate". However, JAX does not know that a sharded array is actually replicated, so we must now explicitly grab the first shard.

The change is under the jax_pmap_shmap_merge configuration flag. If True, the new jax.pmap implementation based on jax.jit(jax.shard_map) is used and requires the new explicit shard access. If False, the old jax.pmap implementation is used and there is a special case in how x[0] works.

Please see details here: https://docs.jax.dev/en/latest/migrate_pmap.html#int-array-indexing-into-sharded-arrays

This change prepares for the new `jax.pmap` by implementing the recommended mechanism for accessing the first shard in a sharded array. A common pattern used with `jax.pmap` is to shard an array that is semantically replicated and grabbing the first shard is meant to "unreplicate". However, JAX does not know that a sharded array is actually replicated, so we must now explicitly grab the first shard.

The change is under the `jax_pmap_shmap_merge` configuration flag. If `True`, the new `jax.pmap` implementation based on `jax.jit(jax.shard_map)` is used and requires the new explicit shard access. If `False`, the old `jax.pmap` implementation is used and there is a special case in how `x[0]` works.

Please see details here: https://docs.jax.dev/en/latest/migrate_pmap.html#int-array-indexing-into-sharded-arrays

PiperOrigin-RevId: 846723995
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant