My CUDA version is 11.6. The version of JAX is 0.4.16, and the version of jaxlib is 0.4.16+cuda11.cudnn86.
When I run a simple Python code, an error message reads
W external/xla/xla/service/gpu/buffer_comparator.cc:1054] INTERNAL: ptxas exited with non-zero error code 65280, output: ptxas /tmp/tempfile-meiji-993e158e-113566-608cf99230264, line 10; fatal : Unsupported .version 7.8; current version is '7.6'
The second error message reads
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to load PTX text as a module: CUDA_ERROR_UNSUPPORTED_PTX_VERSION: the provided PTX was compiled with an unsupported toolchain.
The details are as follows:
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1698537556.277868 113566 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
random key: [0 0]
2023-10-28 19:59:41.023080: W external/xla/xla/service/gpu/buffer_comparator.cc:1054] INTERNAL: ptxas exited with non-zero error code 65280, output: ptxas /tmp/tempfile-meiji-993e158e-113566-608cf99230264, line 10; fatal : Unsupported .version 7.8; current version is '7.6'
ptxas fatal : Ptx assembly aborted due to errors
Relying on driver to perform ptx compilation.
Setting XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda or modifying $PATH can be used to set the location of ptxas
This message will only be logged once.
2023-10-28 19:59:41.077542: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:857] failed to load PTX text as a module: CUDA_ERROR_UNSUPPORTED_PTX_VERSION: the provided PTX was compiled with an unsupported toolchain.
2023-10-28 19:59:41.077572: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:862] error log buffer (98 bytes): ptxas application ptx input, line 10; fatal : Unsupported .version 7.8; current version is '7.6
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/shuting/nucleotide_transformer/nucleotide_transformer_test.py", line 76, in <module>
outs = forward_fn.apply(parameters, random_key, tokens)
File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/haiku/_src/transform.py", line 183, in apply_fn
out, state = f.apply(params, None, *args, **kwargs)
File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/haiku/_src/transform.py", line 456, in apply_fn
out = f(*args, **kwargs)
File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/nucleotide_transformer/model.py", line 363, in nucleotide_transformer_fn
outs = encoder(
File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/haiku/_src/module.py", line 458, in wrapped
out = f(*args, **kwargs)
File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/haiku/_src/module.py", line 299, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/nucleotide_transformer/model.py", line 310, in __call__
x, outs = self.apply_attention_blocks(
File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/nucleotide_transformer/model.py", line 226, in apply_attention_blocks
output = layer(
File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/haiku/_src/module.py", line 458, in wrapped
out = f(*args, **kwargs)
File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/haiku/_src/module.py", line 299, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/nucleotide_transformer/layers.py", line 281, in __call__
output = self.self_attention(
File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/nucleotide_transformer/layers.py", line 240, in self_attention
return self.sa_layer(x, x, x, attention_mask=attention_mask)
File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/haiku/_src/module.py", line 458, in wrapped
out = f(*args, **kwargs)
File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/haiku/_src/module.py", line 299, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/nucleotide_transformer/layers.py", line 149, in __call__
attention_weights = self.attention_weights(query, key, attention_mask)
File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/nucleotide_transformer/layers.py", line 83, in attention_weights
query_heads = self._linear_projection_he_init(query, self.key_size, "query")
File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/nucleotide_transformer/layers.py", line 173, in _linear_projection_he_init
y = hk.Linear(
File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/haiku/_src/module.py", line 458, in wrapped
out = f(*args, **kwargs)
File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/haiku/_src/module.py", line 299, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/haiku/_src/basic.py", line 181, in __call__
out = jnp.dot(inputs, w, precision=precision)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to load PTX text as a module: CUDA_ERROR_UNSUPPORTED_PTX_VERSION: the provided PTX was compiled with an unsupported toolchain.
I0000 00:00:1698537581.636191 113566 tfrt_cpu_pjrt_client.cc:352] TfrtCpuClient destroyed.
Does anyone can help me fix that issue? Thank you very much!
You can find JAX GPU installation instructions here: https://jax.readthedocs.io/en/latest/installation.html#nvidia-gpu
In particular, JAX v0.4.16 requires CUDA version 11.8 or newer. You might try installing CUDA via pip
, as mentioned under the first sub-heading at that link. If for some reason you must use CUDA 11.6, I’d recommend trying an older JAX version – probably one before 0.4.8 or so would be compatible.