Loading...
Works

Keras (Google) — Open Source 2026

32 merged pull requests to Keras, Google's deep learning framework, plus additional merged work in Keras Kinetic. Contributions span new ops and layer implementations, cuDNN-backed RNN paths for the JAX and PyTorch backends, ONNX export improvements, performance work, and a long tail of correctness and shape-handling fixes.

Highlights

  • • Implemented GRU and LSTM for the PyTorch and JAX backends, including a cuDNN LSTM path and a fused JAX Bidirectional LSTM cuDNN call
  • • Added new public ops: keras.ops.matrix_rank, keras.ops.pinv, keras.ops.numpy.geomspace, keras.ops.nn.fold
  • • Added a CTC beam search decoder for the torch backend
  • • ONNX export improvements: dict/list inputs for torch export, jax2onnx for direct JAX-to-ONNX export
  • • Performance work in tree.flatten, tree.map_structure, channels_last torch convs, sparse_categorical_crossentropy memory use
  • • Fixed long-standing correctness bugs in Attention, GroupQueryAttention, ReversibleEmbedding, Softmax with masks, model.summary, and add_loss under JAX JIT

RNN backends and cuDNN

A significant strand of the work brought cuDNN-backed RNN paths to the JAX backend and rewired the PyTorch path:

  • #22791 Fuse JAX Bidirectional LSTM into a single cuDNN call
  • #22470 Rewrite torch LSTM to use functional API with CPU fallback
  • #22401 Add optimized GRU for JAX backend
  • #22399 Add cuDNN LSTM for JAX backend
  • #22115 Implement GRU for PyTorch backend

New ops and features

  • #22799 Add CTC beam search decoder for torch backend
  • #22772 Add keras.ops.matrix_rank
  • #22763 Add keras.ops.pinv
  • #22283 Add fold to keras.ops.nn
  • #22257 Add geomspace to keras.ops.numpy

ONNX export

  • #22737 Support dict/list inputs in torch ONNX export
  • #22443 Use jax2onnx for direct JAX-to-ONNX export

Performance

  • #22770 Optimize common cases for tree.flatten and tree.map_structure
  • #22674 Skip tree.flatten in any_symbolic_tensors for non-nested args
  • #22324 Use channels_last memory format for torch conv ops
  • #22169 Reduce memory usage in sparse_categorical_crossentropy

Correctness and bug fixes

  • #22808 Fix torch lstsq with rcond, re-enable test
  • #22738 Fix stale return_attention_scores flag in Attention.compute_output_spec
  • #22661 Pass Python strings to TextVectorization standardize on non-TF backends
  • #22642 Fix symbolic output shape in GroupQueryAttention
  • #22505 Fix KeyError in model.summary() when a layer returns a dict
  • #22395 Fix TracerBoolConversionError when using add_loss with JAX JIT
  • #22379 Fix Softmax OOM when using mask on torch backend
  • #21961 Fix ReversibleEmbedding mask error when using reverse=True

Shape handling and validation

  • #22676 Validate DepthwiseConv output shape in build()
  • #22662 Validate SeparableConv output shape in build()
  • #22280 Use canonicalize_axis in Linspace and Logspace compute_output_spec
  • #22041 Torch dynamic shapes

Refactors

  • #22013 Refactor ExtractPatches to handle both 2D and 3D
  • #21980 Unify extract_patches to support both 2D and 3D patches

Test coverage

  • #22066 Increase test coverage for TextVectorization layer
  • #22022 Increase test coverage for IntegerLookup layer

© 2026 Marcos Ashton Iglesias. All Rights Reserved.

Built with Next.js & Chakra UI