The JAX project, a high-performance numerical computing library for automatic differentiation and compilation, is experiencing significant user-reported performance issues on various hardware backends, while actively expanding its feature set and improving documentation.
Recent issues and pull requests (PRs) highlight ongoing challenges with numerical stability and performance, particularly on GPUs and TPUs. Key issues include #23637, which addresses tracer errors in loops, and #23634, concerning sharding regressions. These reflect a broader theme of backend-specific discrepancies and the need for enhanced error handling.
George Necula (gnecula)
Google ML Automation
Kanglan Tang (kanglant)
Sergei Lebedev (superbobry)
Parker Schuh (pschuh)
Yash Katariya (yashk2810)
Peter Hawkins (hawkinsp)
Matthew Johnson (mattjj)
Justin Fu (justinjfu)
Dougal Maclaurin (dougalm)
Timespan | Opened | Closed | Comments | Labeled | Milestones |
---|---|---|---|---|---|
7 Days | 31 | 11 | 100 | 0 | 1 |
30 Days | 96 | 47 | 305 | 0 | 1 |
90 Days | 261 | 162 | 828 | 1 | 1 |
All Time | 5514 | 4177 | - | - | - |
Like all software activity quantification, these numbers are imperfect but sometimes useful. Comments, Labels, and Milestones refer to those issues opened in the timespan in question.
Developer | Avatar | Branches | PRs | Commits | Files | Changes |
---|---|---|---|---|---|---|
Yash Katariya | 5 | 0/0/0 | 19 | 33 | 3284 | |
Google ML Automation | 11 | 0/0/0 | 38 | 46 | 2085 | |
Jake Vanderplas | 4 | 10/9/0 | 12 | 18 | 1854 | |
Sergei Lebedev | 8 | 5/3/1 | 31 | 43 | 1791 | |
Dan Foreman-Mackey | 2 | 4/2/0 | 7 | 24 | 1714 | |
Pawel Paruzel | 4 | 0/0/0 | 4 | 11 | 1472 | |
Jevin Jiang | 2 | 0/0/0 | 5 | 15 | 1204 | |
rajasekharporeddy | 4 | 11/9/0 | 14 | 3 | 1085 | |
Vadym Matsishevskyi | 1 | 0/0/0 | 1 | 3 | 731 | |
Justin Fu | 1 | 1/1/0 | 6 | 30 | 542 | |
Peter Hawkins | 6 | 9/8/1 | 32 | 42 | 539 | |
None (dependabot[bot]) | 4 | 12/0/8 | 4 | 9 | 381 | |
Jérome Eertmans | 1 | 0/0/0 | 1 | 4 | 331 | |
Jake VanderPlas | 2 | 0/0/0 | 2 | 5 | 321 | |
George Necula | 2 | 2/2/0 | 5 | 11 | 233 | |
Parker Schuh | 2 | 1/1/0 | 4 | 10 | 227 | |
Ayaka | 1 | 2/1/0 | 3 | 6 | 215 | |
Kaixi Hou | 2 | 1/1/0 | 2 | 2 | 201 | |
Vladimir Belitskiy | 2 | 0/0/0 | 2 | 6 | 164 | |
Shanbin Ke | 1 | 0/0/0 | 1 | 2 | 159 | |
Chris Jones | 1 | 0/0/0 | 2 | 1 | 151 | |
Bart Chrzaszcz | 2 | 0/0/0 | 2 | 8 | 124 | |
Yury Kirpichev | 1 | 0/0/0 | 1 | 2 | 90 | |
Keshav Balasubramanian | 1 | 0/0/0 | 1 | 4 | 85 | |
Selam Waktola | 1 | 3/2/0 | 2 | 1 | 83 | |
Georg Stefan Schmid (gspschmid) | 1 | 1/1/0 | 1 | 2 | 62 | |
Christos Perivolaropoulos | 1 | 0/0/0 | 1 | 4 | 61 | |
Keith Rush | 2 | 0/0/0 | 2 | 2 | 60 | |
Pearu Peterson | 1 | 1/1/0 | 1 | 2 | 59 | |
Adam Paszke | 2 | 0/0/0 | 2 | 5 | 57 | |
Damiano Amatruda (damianoamatruda) | 1 | 1/1/0 | 1 | 2 | 43 | |
Matthew Johnson | 2 | 4/4/0 | 4 | 3 | 40 | |
Dougal Maclaurin | 1 | 0/0/0 | 8 | 5 | 38 | |
Sharad Vikram | 3 | 0/0/0 | 3 | 2 | 38 | |
Roy Frostig | 2 | 1/1/0 | 2 | 1 | 37 | |
Abhinav Gunjal | 1 | 0/0/0 | 1 | 1 | 21 | |
Jaroslav Sevcik | 1 | 2/1/0 | 1 | 1 | 20 | |
Sebastian Bodenstein | 1 | 0/0/0 | 1 | 2 | 17 | |
Kanglan Tang | 1 | 0/0/0 | 1 | 1 | 15 | |
Adam Banaś | 1 | 0/0/0 | 1 | 1 | 13 | |
Enrique Piqueras | 1 | 0/0/0 | 1 | 1 | 12 | |
Carlos Martin | 1 | 1/1/0 | 2 | 3 | 9 | |
Frederik Wilde | 1 | 1/1/0 | 2 | 2 | 4 | |
Michael Deistler | 1 | 1/1/0 | 1 | 1 | 4 | |
Luke Yang | 1 | 1/1/0 | 1 | 1 | 2 | |
Fabian Pedregosa | 1 | 0/0/0 | 1 | 1 | 2 | |
David Mis | 1 | 0/0/0 | 1 | 1 | 2 | |
Zheng Zeng (Aiemu) | 0 | 1/0/0 | 0 | 0 | 0 | |
Robert Dyro (rdyro) | 0 | 1/0/0 | 0 | 0 | 0 | |
Roman Knyazhitskiy (knyazer) | 0 | 1/0/0 | 0 | 0 | 0 | |
Mathew Odden (mrodden) | 0 | 1/0/0 | 0 | 0 | 0 | |
Yunlong Liu (yliu120) | 0 | 1/0/0 | 0 | 0 | 0 | |
None (pkgoogle) | 0 | 2/2/0 | 0 | 0 | 0 | |
Ilia Sergachev (sergachev) | 0 | 2/1/0 | 0 | 0 | 0 | |
Alexander Pivovarov (apivovarov) | 0 | 1/0/0 | 0 | 0 | 0 | |
jax authors | 0 | 0/0/0 | 0 | 0 | 0 | |
Shaikh Yaser (shaikhyaser) | 0 | 1/0/1 | 0 | 0 | 0 | |
Joshua G Albert (Joshuaalbert) | 0 | 1/0/0 | 0 | 0 | 0 | |
Abhinav Goel (abhinavgoel95) | 0 | 1/0/0 | 0 | 0 | 0 | |
Chase Riley Roberts (chaserileyroberts) | 0 | 1/0/0 | 0 | 0 | 0 | |
None (copybara-service[bot]) | 0 | 111/79/5 | 0 | 0 | 0 |
PRs: created by that dev and opened/merged/closed-unmerged during the period
The JAX project on GitHub currently has 1,337 open issues, indicating a high level of ongoing activity and user engagement. Recent issues highlight various bugs, performance discrepancies, and feature requests, particularly concerning the handling of complex numbers, performance on different hardware backends (especially GPUs and TPUs), and the integration of new features like Pallas for advanced kernel programming.
Notable themes include:
jax.scipy.special
methods or jax.lax
operations that behave inconsistently across platforms.Here are some of the most recently created and updated issues:
Issue #23637: When a tracer error happens in for_loop, should point to the user's body function
Issue #23634: Corner-case sharding regression when replacing concrete mesh with abstract mesh
Issue #23626: jax.lax.linalg.lu returns LU factorisation for singular matrix
Issue #23625: When calculating the loss, the input data does not contain NaN, but the output contains NaN
Issue #23624: pure_callback is broken with multiple vmap
Issue #23616: Orthogonal Initializer raises gpusolverDnCreate(&handle) failed
Issue #23600: Make jax.debug.print work with non-jax types
Issue #23599: Make jax.distributed timeouts configurable via jax.config
Issue #23594: Allow tuple inputs to scatter_dimension in jax.lax.psum_scatter
Issue #23590: Wrong results on CPU since 0.4.32
These issues reflect a mix of bugs related to core functionalities and enhancements aimed at improving usability and performance across different platforms.
This analysis highlights the importance of addressing both the reported bugs and enhancing existing functionalities to maintain user trust and satisfaction in the JAX ecosystem.
The analysis of the current pull requests (PRs) for the JAX project reveals a total of 363 open PRs, with a significant focus on enhancing functionality, improving documentation, and fixing bugs. The recent activity indicates a strong emphasis on GPU and TPU optimizations, as well as ongoing efforts to refine the API and improve user experience.
PR #23640: Pallas pipeline API tweaks for more advanced pipelining patterns.
PR #23636: Add Python 3.130rc2 support to the build.
PR #23635: Failing test for issue #23634.
PR #23633: Add memory space annotation to ShapedArray.
PR #23632: Generalize global jit cpp cache keys.
PR #23627: Add a "broadcasting vmap" helper to custom_batching.
PR #23623: Improve documentation for jax.numpy
: power
and pow
.
PR #23620: Test for io_callback in custom partitioning.
PR #23619: Relax usage of io_callback in automatic differentiation (AD).
PR #23617: Generalize global jit cpp cache keys (Take 2).
The recent pull requests in the JAX repository reflect several key themes and areas of focus:
A significant number of PRs are dedicated to optimizing performance, particularly concerning GPU and TPU capabilities. For instance, PRs like #23632 and #23617 aim to enhance caching mechanisms in JIT compilation, which is crucial for improving execution speed on accelerators. The introduction of features like "broadcasting vmap" (PR #23627) further indicates an effort to streamline vectorized operations, which can significantly impact performance when processing large datasets or complex models.
Quality assurance remains a priority, as evidenced by PRs that add failing tests (e.g., PR #23635) or address known issues (e.g., PR #23620). This proactive approach ensures that regressions are caught early and that the stability of the library is maintained as new features are introduced. The discussions surrounding these PRs also highlight community engagement in identifying and resolving issues collaboratively.
Documentation is another area receiving attention, with multiple PRs aimed at clarifying existing functionalities (e.g., PRs #23623 and #23596). Clear documentation is vital for user adoption and effective utilization of the library's features, especially given JAX's complexity and its use cases in advanced numerical computing and machine learning.
Several pull requests focus on refining the API to enhance usability and clarity (e.g., PR #23619). The ongoing discussions about naming conventions and whether to expose certain functionalities indicate an iterative process aimed at balancing feature richness with usability concerns.
The active discussions among contributors regarding new features and their implications demonstrate a healthy community dynamic within the JAX project. Contributors are not only focused on coding but also engaging in meaningful dialogues about best practices, potential pitfalls, and future directions for the library.
Overall, the current landscape of pull requests in JAX showcases a vibrant development environment focused on performance enhancements, robust quality assurance practices, improved documentation, and thoughtful API design. As the project continues to evolve, maintaining this balance will be crucial for its success and adoption within the broader machine learning community.
George Necula (gnecula)
Google ML Automation
Kanglan Tang (kanglant)
Sergei Lebedev (superbobry)
Parker Schuh (pschuh)
Yash Katariya (yashk2810)
Peter Hawkins (hawkinsp)
Matthew Johnson (mattjj)
Justin Fu (justinjfu)
Dougal Maclaurin (dougalm)
The JAX development team is exhibiting robust activity with a clear focus on enhancing functionality while maintaining code quality. The collaborative environment fosters effective problem-solving and innovation, particularly in the evolving landscape of numerical computing and machine learning applications.