Unable to install jax==0.4.5 from git repository

I have python 3.9 in my device.
When I run:

pip install jax and jaxlib

This is installing the version jax-0.4.21 jaxlib-0.4.21 However I need the version 0.4.5 for both jax and jaxlib.

When I run:

pip  install jax==0.4.5 jaxlib==0.4.5

This is throwing this error:

    ERROR: Could not find a version that satisfies the requirement jaxlib==0.4.5 (from versions: 0.4.18, 0.4.19, 0.4.20, 0.4.21)
    ERROR: No matching distribution found for jaxlib==0.4.5

I don’t kwno why happens it. Lately I decide to install jax by using the git url, of this way:

    pip install "git+https://github.com/google/[email protected]"

This is the versión: https://github.com/google/jax/releases/tag/jax-v0.4.5

however I got this error:

Collecting git+https://github.com/google/[email protected]
  Cloning https://github.com/google/jax.git (to revision v0.4.5) to /tmp/pip-req-build-_8dc4uf5
  Running command git clone --filter=blob:none --quiet https://github.com/google/jax.git /tmp/pip-req-build-_8dc4uf5
  WARNING: Did not find branch or tag 'v0.4.5', assuming revision or ref.
  Running command git checkout -q v0.4.5
  error: pathspec 'v0.4.5' did not match any file(s) known to git
  error: subprocess-exited-with-error
  
  × git checkout -q v0.4.5 did not run successfully.
  │ exit code: 1
  ╰─> See above for output.
  
  note: This error originates from a subprocess, and is likely not a problem with pip.
error: subprocess-exited-with-error

× git checkout -q v0.4.5 did not run successfully.
│ exit code: 1
╰─> See above for output.

note: This error originates from a subprocess, and is likely not a problem with pip.

And when I change the version of this way (jax-v0.4.5):

pip install "git+https://github.com/google/[email protected]"

This throwed this message:

Collecting git+https://github.com/google/[email protected]
  Cloning https://github.com/google/jax.git (to revision jax-v0.4.5) to /tmp/pip-req-build-828jle88
  Running command git clone --filter=blob:none --quiet https://github.com/google/jax.git /tmp/pip-req-build-828jle88
  Running command git checkout -q cafaa50b25515a554568db06667b0c9b6abaff27
  Resolved https://github.com/google/jax.git to commit cafaa50b25515a554568db06667b0c9b6abaff27
  Preparing metadata (setup.py) ... done
Requirement already satisfied: numpy>=1.20 in ./.myenv/lib/python3.9/site-packages (from jax==0.4.5) (1.26.2)
Requirement already satisfied: opt_einsum in ./.myenv/lib/python3.9/site-packages (from jax==0.4.5) (3.3.0)
Requirement already satisfied: scipy>=1.5 in ./.myenv/lib/python3.9/site-packages (from jax==0.4.5) (1.11.4)
Building wheels for collected packages: jax
  Building wheel for jax (setup.py) ... done
  Created wheel for jax: filename=jax-0.4.5-py3-none-any.whl size=1424415 sha256=aff60a0ed1dea1d004cb5748ed322b9db90cd0064e470567f102a2f83d1873d1
  Stored in directory: /tmp/pip-ephem-wheel-cache-e8920quj/wheels/47/59/88/cceb9c59d0d692b940160f055bae0c60cd1295e4edc393ff48
Successfully built jax
Installing collected packages: jax
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
chex 0.1.85 requires jaxlib>=0.1.37, which is not installed.
optax 0.1.7 requires jaxlib>=0.1.37, which is not installed.
orbax-checkpoint 0.4.6 requires jaxlib, which is not installed.
chex 0.1.85 requires jax>=0.4.16, but you have jax 0.4.5 which is incompatible.
flax 0.7.5 requires jax>=0.4.19, but you have jax 0.4.5 which is incompatible.
orbax-checkpoint 0.4.6 requires jax>=0.4.9, but you have jax 0.4.5 which is incompatible.

This seems that first we need to install jaxlib, but I don’t know how install this through github url. I just know that this is the github repo for jax but I don’t know where is the github repo for jaxlib.

I will appreciate any idea to fix this problem guys, thanks so much.

There was never a version 0.4.5 release of jaxlib, as you can see from the jaxlib PyPI history or from the JAX CHANGELOG.

When JAX v0.4.5 was released, the newest jaxlib was v0.4.4, so you can install them this way:

pip install jax==0.4.5 jaxlib==0.4.4

or, if you want the appropriate CPU jaxlib to be determinted automatically, you can use

pip install jax[cpu]==0.4.5

Note that at some point in the future, this jaxlib version may not be available via PyPI (refer to Installing older jaxlib wheels) at which point you’ll be able to install them this way:

pip install jax[cpu]==0.4.5 -f https://storage.googleapis.com/jax-releases/jax_releases.html

Leave a Comment