32 merged pull requests to Keras, Google's deep learning framework, plus additional merged work in Keras Kinetic. Contributions span new ops and layer implementations, cuDNN-backed RNN paths for the JAX and PyTorch backends, ONNX export improvements, performance work, and a long tail of correctness and shape-handling fixes.
Highlights
- • Implemented GRU and LSTM for the PyTorch and JAX backends, including a cuDNN LSTM path and a fused JAX Bidirectional LSTM cuDNN call
- • Added new public ops: keras.ops.matrix_rank, keras.ops.pinv, keras.ops.numpy.geomspace, keras.ops.nn.fold
- • Added a CTC beam search decoder for the torch backend
- • ONNX export improvements: dict/list inputs for torch export, jax2onnx for direct JAX-to-ONNX export
- • Performance work in tree.flatten, tree.map_structure, channels_last torch convs, sparse_categorical_crossentropy memory use
- • Fixed long-standing correctness bugs in Attention, GroupQueryAttention, ReversibleEmbedding, Softmax with masks, model.summary, and add_loss under JAX JIT
- Repositorykeras-team/keras
- PRs32 merged in keras-team/keras (plus more in keras-team/kinetic)
- StackPython, TensorFlow, PyTorch, JAX
- AreasRNN layers, ops, ONNX export, performance, correctness
RNN backends and cuDNN
A significant strand of the work brought cuDNN-backed RNN paths to the JAX backend and rewired the PyTorch path:
- #22791 Fuse JAX Bidirectional LSTM into a single cuDNN call
- #22470 Rewrite torch LSTM to use functional API with CPU fallback
- #22401 Add optimized GRU for JAX backend
- #22399 Add cuDNN LSTM for JAX backend
- #22115 Implement GRU for PyTorch backend
New ops and features
- #22799 Add CTC beam search decoder for torch backend
- #22772 Add keras.ops.matrix_rank
- #22763 Add keras.ops.pinv
- #22283 Add fold to keras.ops.nn
- #22257 Add geomspace to keras.ops.numpy
ONNX export
- #22737 Support dict/list inputs in torch ONNX export
- #22443 Use jax2onnx for direct JAX-to-ONNX export
Performance
- #22770 Optimize common cases for tree.flatten and tree.map_structure
- #22674 Skip tree.flatten in any_symbolic_tensors for non-nested args
- #22324 Use channels_last memory format for torch conv ops
- #22169 Reduce memory usage in sparse_categorical_crossentropy
Correctness and bug fixes
- #22808 Fix torch lstsq with rcond, re-enable test
- #22738 Fix stale return_attention_scores flag in Attention.compute_output_spec
- #22661 Pass Python strings to TextVectorization standardize on non-TF backends
- #22642 Fix symbolic output shape in GroupQueryAttention
- #22505 Fix KeyError in model.summary() when a layer returns a dict
- #22395 Fix TracerBoolConversionError when using add_loss with JAX JIT
- #22379 Fix Softmax OOM when using mask on torch backend
- #21961 Fix ReversibleEmbedding mask error when using reverse=True
Shape handling and validation
- #22676 Validate DepthwiseConv output shape in build()
- #22662 Validate SeparableConv output shape in build()
- #22280 Use canonicalize_axis in Linspace and Logspace compute_output_spec
- #22041 Torch dynamic shapes
Refactors
- #22013 Refactor ExtractPatches to handle both 2D and 3D
- #21980 Unify extract_patches to support both 2D and 3D patches