‹ Reports
The Dispatch

OSS Report: google/jax


JAX Development Faces Performance Challenges Amidst Active Feature Expansion

The JAX project, a high-performance numerical computing library for automatic differentiation and compilation, is experiencing significant user-reported performance issues on various hardware backends, while actively expanding its feature set and improving documentation.

Recent Activity

Recent issues and pull requests (PRs) highlight ongoing challenges with numerical stability and performance, particularly on GPUs and TPUs. Key issues include #23637, which addresses tracer errors in loops, and #23634, concerning sharding regressions. These reflect a broader theme of backend-specific discrepancies and the need for enhanced error handling.

Development Team and Recent Activity

  1. George Necula (gnecula)

    • Cleaned up forward-compatibility conditionals in Pallas lowering.
    • Improved documentation for forward compatibility.
  2. Google ML Automation

    • Fixed flaky tests and updated XLA dependencies.
    • Merged PRs related to documentation and bug fixes.
  3. Kanglan Tang (kanglant)

    • Fixed layout test failures on the GPU backend.
  4. Sergei Lebedev (superbobry)

    • Added new flags for GPU parameters; improved docstrings.
  5. Parker Schuh (pschuh)

    • Fixed bugs related to device assignment checks.
  6. Yash Katariya (yashk2810)

    • Focused on bug fixes, deprecated old code, improved error handling.
  7. Peter Hawkins (hawkinsp)

    • Disabled failing tests on specific hardware; fixed bugs.
  8. Matthew Johnson (mattjj)

    • Improved documentation; addressed shard map implementation bugs.
  9. Justin Fu (justinjfu)

    • Implemented Pallas Mosaic GPU features; debugging tools.
  10. Dougal Maclaurin (dougalm)

    • Added leak checkers; adjusted custom derivative tests.

Of Note

Quantified Reports

Quantify Issues



Recent GitHub Issues Activity

Timespan Opened Closed Comments Labeled Milestones
7 Days 31 11 100 0 1
30 Days 96 47 305 0 1
90 Days 261 162 828 1 1
All Time 5514 4177 - - -

Like all software activity quantification, these numbers are imperfect but sometimes useful. Comments, Labels, and Milestones refer to those issues opened in the timespan in question.

Quantify commits



Quantified Commit Activity Over 30 Days

Developer Avatar Branches PRs Commits Files Changes
Yash Katariya 5 0/0/0 19 33 3284
Google ML Automation 11 0/0/0 38 46 2085
Jake Vanderplas 4 10/9/0 12 18 1854
Sergei Lebedev 8 5/3/1 31 43 1791
Dan Foreman-Mackey 2 4/2/0 7 24 1714
Pawel Paruzel 4 0/0/0 4 11 1472
Jevin Jiang 2 0/0/0 5 15 1204
rajasekharporeddy 4 11/9/0 14 3 1085
Vadym Matsishevskyi 1 0/0/0 1 3 731
Justin Fu 1 1/1/0 6 30 542
Peter Hawkins 6 9/8/1 32 42 539
None (dependabot[bot]) 4 12/0/8 4 9 381
Jérome Eertmans 1 0/0/0 1 4 331
Jake VanderPlas 2 0/0/0 2 5 321
George Necula 2 2/2/0 5 11 233
Parker Schuh 2 1/1/0 4 10 227
Ayaka 1 2/1/0 3 6 215
Kaixi Hou 2 1/1/0 2 2 201
Vladimir Belitskiy 2 0/0/0 2 6 164
Shanbin Ke 1 0/0/0 1 2 159
Chris Jones 1 0/0/0 2 1 151
Bart Chrzaszcz 2 0/0/0 2 8 124
Yury Kirpichev 1 0/0/0 1 2 90
Keshav Balasubramanian 1 0/0/0 1 4 85
Selam Waktola 1 3/2/0 2 1 83
Georg Stefan Schmid (gspschmid) 1 1/1/0 1 2 62
Christos Perivolaropoulos 1 0/0/0 1 4 61
Keith Rush 2 0/0/0 2 2 60
Pearu Peterson 1 1/1/0 1 2 59
Adam Paszke 2 0/0/0 2 5 57
Damiano Amatruda (damianoamatruda) 1 1/1/0 1 2 43
Matthew Johnson 2 4/4/0 4 3 40
Dougal Maclaurin 1 0/0/0 8 5 38
Sharad Vikram 3 0/0/0 3 2 38
Roy Frostig 2 1/1/0 2 1 37
Abhinav Gunjal 1 0/0/0 1 1 21
Jaroslav Sevcik 1 2/1/0 1 1 20
Sebastian Bodenstein 1 0/0/0 1 2 17
Kanglan Tang 1 0/0/0 1 1 15
Adam Banaś 1 0/0/0 1 1 13
Enrique Piqueras 1 0/0/0 1 1 12
Carlos Martin 1 1/1/0 2 3 9
Frederik Wilde 1 1/1/0 2 2 4
Michael Deistler 1 1/1/0 1 1 4
Luke Yang 1 1/1/0 1 1 2
Fabian Pedregosa 1 0/0/0 1 1 2
David Mis 1 0/0/0 1 1 2
Zheng Zeng (Aiemu) 0 1/0/0 0 0 0
Robert Dyro (rdyro) 0 1/0/0 0 0 0
Roman Knyazhitskiy (knyazer) 0 1/0/0 0 0 0
Mathew Odden (mrodden) 0 1/0/0 0 0 0
Yunlong Liu (yliu120) 0 1/0/0 0 0 0
None (pkgoogle) 0 2/2/0 0 0 0
Ilia Sergachev (sergachev) 0 2/1/0 0 0 0
Alexander Pivovarov (apivovarov) 0 1/0/0 0 0 0
jax authors 0 0/0/0 0 0 0
Shaikh Yaser (shaikhyaser) 0 1/0/1 0 0 0
Joshua G Albert (Joshuaalbert) 0 1/0/0 0 0 0
Abhinav Goel (abhinavgoel95) 0 1/0/0 0 0 0
Chase Riley Roberts (chaserileyroberts) 0 1/0/0 0 0 0
None (copybara-service[bot]) 0 111/79/5 0 0 0

PRs: created by that dev and opened/merged/closed-unmerged during the period

Detailed Reports

Report On: Fetch issues



Recent Activity Analysis

The JAX project on GitHub currently has 1,337 open issues, indicating a high level of ongoing activity and user engagement. Recent issues highlight various bugs, performance discrepancies, and feature requests, particularly concerning the handling of complex numbers, performance on different hardware backends (especially GPUs and TPUs), and the integration of new features like Pallas for advanced kernel programming.

Notable themes include:

  • Performance Issues: Many users report unexpected slowdowns or discrepancies in results between different backends (CPU vs. GPU).
  • Feature Requests: There is a demand for enhanced functionality, such as support for additional data types and improved error handling.
  • Bugs in Specific Functions: Several issues are related to specific functions like jax.scipy.special methods or jax.lax operations that behave inconsistently across platforms.

Issue Details

Here are some of the most recently created and updated issues:

  1. Issue #23637: When a tracer error happens in for_loop, should point to the user's body function

    • Priority: Bug
    • Status: Open
    • Created: 1 day ago
    • Updated: N/A
  2. Issue #23634: Corner-case sharding regression when replacing concrete mesh with abstract mesh

    • Priority: Bug
    • Status: Open
    • Created: 1 day ago
    • Updated: N/A
  3. Issue #23626: jax.lax.linalg.lu returns LU factorisation for singular matrix

    • Priority: Bug
    • Status: Open
    • Created: 2 days ago
    • Updated: N/A
  4. Issue #23625: When calculating the loss, the input data does not contain NaN, but the output contains NaN

    • Priority: Bug
    • Status: Open
    • Created: 2 days ago
    • Updated: N/A
  5. Issue #23624: pure_callback is broken with multiple vmap

    • Priority: Bug
    • Status: Open
    • Created: 2 days ago
    • Updated: N/A
  6. Issue #23616: Orthogonal Initializer raises gpusolverDnCreate(&handle) failed

    • Priority: Bug
    • Status: Open
    • Created: 3 days ago
    • Updated: N/A
  7. Issue #23600: Make jax.debug.print work with non-jax types

    • Priority: Enhancement
    • Status: Open
    • Created: 3 days ago
    • Updated: N/A
  8. Issue #23599: Make jax.distributed timeouts configurable via jax.config

    • Priority: Enhancement
    • Status: Open
    • Created: 3 days ago
    • Updated: N/A
  9. Issue #23594: Allow tuple inputs to scatter_dimension in jax.lax.psum_scatter

    • Priority: Enhancement
    • Status: Open
    • Created: 4 days ago
    • Updated: N/A
  10. Issue #23590: Wrong results on CPU since 0.4.32

    • Priority: Bug
    • Status: Open
    • Created: 4 days ago
    • Updated: N/A

These issues reflect a mix of bugs related to core functionalities and enhancements aimed at improving usability and performance across different platforms.

Important Observations

  • The presence of multiple issues related to numerical stability and performance on GPUs suggests a need for further optimization and testing across different hardware configurations.
  • A significant number of recent issues focus on specific functions within JAX's library that may not perform as expected under certain conditions or inputs.
  • The community's engagement in reporting these issues indicates an active user base that relies heavily on JAX for their computational needs, particularly in machine learning contexts.

This analysis highlights the importance of addressing both the reported bugs and enhancing existing functionalities to maintain user trust and satisfaction in the JAX ecosystem.

Report On: Fetch pull requests



Overview

The analysis of the current pull requests (PRs) for the JAX project reveals a total of 363 open PRs, with a significant focus on enhancing functionality, improving documentation, and fixing bugs. The recent activity indicates a strong emphasis on GPU and TPU optimizations, as well as ongoing efforts to refine the API and improve user experience.

Summary of Pull Requests

  1. PR #23640: Pallas pipeline API tweaks for more advanced pipelining patterns.

    • Significance: Introduces enhancements to the Pallas pipeline API, allowing for more complex pipelining patterns.
    • Notable: Created just today, indicating active development.
  2. PR #23636: Add Python 3.130rc2 support to the build.

    • Significance: Prepares JAX for compatibility with Python 3.13 features.
    • Notable: Dependent on another PR in XLA, highlighting interdependencies within the ecosystem.
  3. PR #23635: Failing test for issue #23634.

    • Significance: Adds a test case to capture a known issue, which is crucial for regression testing.
    • Notable: Indicates proactive measures in quality assurance.
  4. PR #23633: Add memory space annotation to ShapedArray.

    • Significance: Enhances data structure annotations, potentially improving memory management.
    • Notable: Reflects ongoing improvements in data handling.
  5. PR #23632: Generalize global jit cpp cache keys.

    • Significance: Aims to improve cache hit rates in JIT compilation, enhancing performance.
    • Notable: Reverts a previous change, suggesting iterative refinement of the caching mechanism.
  6. PR #23627: Add a "broadcasting vmap" helper to custom_batching.

    • Significance: Introduces new functionality to simplify vectorized operations.
    • Notable: Active discussions among contributors about naming and exposing new APIs.
  7. PR #23623: Improve documentation for jax.numpy: power and pow.

    • Significance: Enhances user understanding of functions within JAX's numpy module.
    • Notable: Part of broader documentation improvement efforts.
  8. PR #23620: Test for io_callback in custom partitioning.

    • Significance: Addresses a known bug with io_callback usage in partitioning scenarios.
    • Notable: Highlights ongoing debugging efforts within the project.
  9. PR #23619: Relax usage of io_callback in automatic differentiation (AD).

    • Significance: Expands the usability of io_callback under certain conditions.
    • Notable: Suggests flexibility improvements in callback handling.
  10. PR #23617: Generalize global jit cpp cache keys (Take 2).

    • Significance: Aims to optimize caching further by allowing additional keys.
    • Notable: Indicates ongoing performance tuning efforts.

Analysis of Pull Requests

The recent pull requests in the JAX repository reflect several key themes and areas of focus:

Performance Optimization

A significant number of PRs are dedicated to optimizing performance, particularly concerning GPU and TPU capabilities. For instance, PRs like #23632 and #23617 aim to enhance caching mechanisms in JIT compilation, which is crucial for improving execution speed on accelerators. The introduction of features like "broadcasting vmap" (PR #23627) further indicates an effort to streamline vectorized operations, which can significantly impact performance when processing large datasets or complex models.

Bug Fixes and Quality Assurance

Quality assurance remains a priority, as evidenced by PRs that add failing tests (e.g., PR #23635) or address known issues (e.g., PR #23620). This proactive approach ensures that regressions are caught early and that the stability of the library is maintained as new features are introduced. The discussions surrounding these PRs also highlight community engagement in identifying and resolving issues collaboratively.

Documentation Improvements

Documentation is another area receiving attention, with multiple PRs aimed at clarifying existing functionalities (e.g., PRs #23623 and #23596). Clear documentation is vital for user adoption and effective utilization of the library's features, especially given JAX's complexity and its use cases in advanced numerical computing and machine learning.

API Refinements

Several pull requests focus on refining the API to enhance usability and clarity (e.g., PR #23619). The ongoing discussions about naming conventions and whether to expose certain functionalities indicate an iterative process aimed at balancing feature richness with usability concerns.

Community Engagement

The active discussions among contributors regarding new features and their implications demonstrate a healthy community dynamic within the JAX project. Contributors are not only focused on coding but also engaging in meaningful dialogues about best practices, potential pitfalls, and future directions for the library.

Conclusion

Overall, the current landscape of pull requests in JAX showcases a vibrant development environment focused on performance enhancements, robust quality assurance practices, improved documentation, and thoughtful API design. As the project continues to evolve, maintaining this balance will be crucial for its success and adoption within the broader machine learning community.

Report On: Fetch commits



Repo Commits Analysis

Development Team and Recent Activity

Team Members and Their Recent Activities

  1. George Necula (gnecula)

    • Recent Activity: Cleaned up forward-compatibility conditionals in Pallas lowering. Improved documentation for forward compatibility.
    • Collaborations: Worked on documentation alongside Google ML Automation.
  2. Google ML Automation

    • Recent Activity: Numerous updates including fixing flaky tests, updating XLA dependencies, and merging various pull requests related to documentation and bug fixes.
    • Collaborations: Collaborated with multiple developers on various pull requests.
  3. Kanglan Tang (kanglant)

    • Recent Activity: Fixed layout test failures on the GPU backend.
    • Collaborations: Primarily independent work.
  4. Sergei Lebedev (superbobry)

    • Recent Activity: Made multiple changes including adding new flags for GPU parameters, cleaning up unused arguments, and improving docstrings.
    • Collaborations: Collaborated with other team members on Pallas-related features.
  5. Parker Schuh (pschuh)

    • Recent Activity: Fixed bugs related to device assignment checks and improved error messages.
    • Collaborations: Worked independently with occasional collaborations.
  6. Yash Katariya (yashk2810)

    • Recent Activity: Multiple commits focusing on bug fixes, deprecating old code, and improving error handling.
    • Collaborations: Worked independently but also contributed to collaborative efforts.
  7. Peter Hawkins (hawkinsp)

    • Recent Activity: Engaged in numerous activities including disabling tests that fail on specific hardware, fixing bugs, and updating configurations.
    • Collaborations: Collaborated with various team members on testing and configuration issues.
  8. Matthew Johnson (mattjj)

    • Recent Activity: Focused on improving documentation and addressing bugs in the shard map implementation.
    • Collaborations: Worked closely with other developers on documentation improvements.
  9. Justin Fu (justinjfu)

    • Recent Activity: Implemented features related to Pallas Mosaic GPU, including debugging tools.
    • Collaborations: Primarily focused on Pallas features.
  10. Dougal Maclaurin (dougalm)

    • Recent Activity: Added leak checkers and made adjustments to tests related to custom derivatives.
    • Collaborations: Worked independently but integrated changes into the main branch.

Patterns and Themes

  • The development team is actively engaged in both feature development and bug fixing, with a strong emphasis on improving documentation and ensuring compatibility across various hardware platforms.
  • There is a notable focus on Pallas-related features, indicating ongoing enhancements in this area of the JAX library.
  • Collaboration is prevalent among team members, especially in merging pull requests and addressing common issues such as flaky tests or compatibility problems.
  • The recent activities reflect a balance between adding new features and maintaining existing code quality through extensive testing and documentation efforts.

Conclusions

The JAX development team is exhibiting robust activity with a clear focus on enhancing functionality while maintaining code quality. The collaborative environment fosters effective problem-solving and innovation, particularly in the evolving landscape of numerical computing and machine learning applications.