CUDA and JAX libraries are not compatible

My CUDA version is 11.6. The version of JAX is 0.4.16, and the version of jaxlib is 0.4.16+cuda11.cudnn86.

When I run a simple Python code, an error message reads

W external/xla/xla/service/gpu/buffer_comparator.cc:1054] INTERNAL: ptxas exited with non-zero error code 65280, output: ptxas /tmp/tempfile-meiji-993e158e-113566-608cf99230264, line 10; fatal   : Unsupported .version 7.8; current version is '7.6'

The second error message reads

jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to load PTX text as a module: CUDA_ERROR_UNSUPPORTED_PTX_VERSION: the provided PTX was compiled with an unsupported toolchain.

The details are as follows:

WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1698537556.277868  113566 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
random key: [0 0]
2023-10-28 19:59:41.023080: W external/xla/xla/service/gpu/buffer_comparator.cc:1054] INTERNAL: ptxas exited with non-zero error code 65280, output: ptxas /tmp/tempfile-meiji-993e158e-113566-608cf99230264, line 10; fatal   : Unsupported .version 7.8; current version is '7.6'
ptxas fatal   : Ptx assembly aborted due to errors

Relying on driver to perform ptx compilation.
Setting XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda  or modifying $PATH can be used to set the location of ptxas
This message will only be logged once.
2023-10-28 19:59:41.077542: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:857] failed to load PTX text as a module: CUDA_ERROR_UNSUPPORTED_PTX_VERSION: the provided PTX was compiled with an unsupported toolchain.
2023-10-28 19:59:41.077572: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:862] error log buffer (98 bytes): ptxas application ptx input, line 10; fatal   : Unsupported .version 7.8; current version is '7.6
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/shuting/nucleotide_transformer/nucleotide_transformer_test.py", line 76, in <module>
    outs = forward_fn.apply(parameters, random_key, tokens)
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/haiku/_src/transform.py", line 183, in apply_fn
    out, state = f.apply(params, None, *args, **kwargs)
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/haiku/_src/transform.py", line 456, in apply_fn
    out = f(*args, **kwargs)
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/nucleotide_transformer/model.py", line 363, in nucleotide_transformer_fn
    outs = encoder(
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/haiku/_src/module.py", line 458, in wrapped
    out = f(*args, **kwargs)
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/haiku/_src/module.py", line 299, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/nucleotide_transformer/model.py", line 310, in __call__
    x, outs = self.apply_attention_blocks(
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/nucleotide_transformer/model.py", line 226, in apply_attention_blocks
    output = layer(
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/haiku/_src/module.py", line 458, in wrapped
    out = f(*args, **kwargs)
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/haiku/_src/module.py", line 299, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/nucleotide_transformer/layers.py", line 281, in __call__
    output = self.self_attention(
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/nucleotide_transformer/layers.py", line 240, in self_attention
    return self.sa_layer(x, x, x, attention_mask=attention_mask)
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/haiku/_src/module.py", line 458, in wrapped
    out = f(*args, **kwargs)
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/haiku/_src/module.py", line 299, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/nucleotide_transformer/layers.py", line 149, in __call__
    attention_weights = self.attention_weights(query, key, attention_mask)
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/nucleotide_transformer/layers.py", line 83, in attention_weights
    query_heads = self._linear_projection_he_init(query, self.key_size, "query")
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/nucleotide_transformer/layers.py", line 173, in _linear_projection_he_init
    y = hk.Linear(
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/haiku/_src/module.py", line 458, in wrapped
    out = f(*args, **kwargs)
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/haiku/_src/module.py", line 299, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/haiku/_src/basic.py", line 181, in __call__
    out = jnp.dot(inputs, w, precision=precision)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to load PTX text as a module: CUDA_ERROR_UNSUPPORTED_PTX_VERSION: the provided PTX was compiled with an unsupported toolchain.
I0000 00:00:1698537581.636191  113566 tfrt_cpu_pjrt_client.cc:352] TfrtCpuClient destroyed.

Does anyone can help me fix that issue? Thank you very much!

You can find JAX GPU installation instructions here: https://jax.readthedocs.io/en/latest/installation.html#nvidia-gpu

In particular, JAX v0.4.16 requires CUDA version 11.8 or newer. You might try installing CUDA via pip, as mentioned under the first sub-heading at that link. If for some reason you must use CUDA 11.6, I’d recommend trying an older JAX version – probably one before 0.4.8 or so would be compatible.

Leave a Comment