Cannot push model to Huggingface

When I push model to hugging face, it will always show the error message:

/pyenv/versions/3.10.0/lib/python3.10/site-packages/transformers/utils/hub.py:844: FutureWarning: The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.
  warnings.warn(
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[249], line 1
----> 1 model.push_to_hub(model_id, use_auth_token=hf_auth_token)
      2 tokenizer.push_to_hub(model_id, use_auth_token=hf_auth_token)

File /pyenv/versions/3.10.0/lib/python3.10/site-packages/transformers/utils/hub.py:893, in PushToHubMixin.push_to_hub(self, repo_id, use_temp_dir, commit_message, private, token, max_shard_size, create_pr, safe_serialization, revision, **deprecated_kwargs)
    890 files_timestamps = self._get_files_timestamps(work_dir)
    892 # Save all files.
--> 893 self.save_pretrained(work_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)
    895 return self._upload_modified_files(
    896     work_dir,
    897     repo_id,
   (...)
    902     revision=revision,
    903 )

File /pyenv/versions/3.10.0/lib/python3.10/site-packages/transformers/modeling_flax_utils.py:1120, in FlaxPreTrainedModel.save_pretrained(self, save_directory, params, push_to_hub, max_shard_size, token, **kwargs)
   1118     with open(output_model_file, "wb") as f:
   1119         params = params if params is not None else self.params
-> 1120         model_bytes = to_bytes(params)
   1121         f.write(model_bytes)
   1123 else:

File /pyenv/versions/3.10.0/lib/python3.10/site-packages/flax/serialization.py:459, in to_bytes(target)
    449 """Save optimizer or other object as msgpack-serialized state-dict.
    450 
    451 Args:
   (...)
    456   Bytes of msgpack-encoded state-dict of `target` object.
    457 """
    458 state_dict = to_state_dict(target)
--> 459 return msgpack_serialize(state_dict, in_place=True)

File /pyenv/versions/3.10.0/lib/python3.10/site-packages/flax/serialization.py:407, in msgpack_serialize(pytree, in_place)
    405 if not in_place:
    406   pytree = jax.tree_util.tree_map(lambda x: x, pytree)
--> 407 pytree = _np_convert_in_place(pytree)
    408 pytree = _chunk_array_leaves_in_place(pytree)
    409 return msgpack.packb(pytree, default=_msgpack_ext_pack, strict_types=True)

File /pyenv/versions/3.10.0/lib/python3.10/site-packages/flax/serialization.py:328, in _np_convert_in_place(d)
    326       d[k] = np.array(v)
    327     elif isinstance(v, dict):
--> 328       _np_convert_in_place(v)
    329 elif isinstance(d, jax.Array):
    330   return np.array(d)

File /pyenv/versions/3.10.0/lib/python3.10/site-packages/flax/serialization.py:328, in _np_convert_in_place(d)
    326       d[k] = np.array(v)
    327     elif isinstance(v, dict):
--> 328       _np_convert_in_place(v)
    329 elif isinstance(d, jax.Array):
    330   return np.array(d)

File /pyenv/versions/3.10.0/lib/python3.10/site-packages/flax/serialization.py:328, in _np_convert_in_place(d)
    326       d[k] = np.array(v)
    327     elif isinstance(v, dict):
--> 328       _np_convert_in_place(v)
    329 elif isinstance(d, jax.Array):
    330   return np.array(d)

File /pyenv/versions/3.10.0/lib/python3.10/site-packages/flax/serialization.py:326, in _np_convert_in_place(d)
    324 for k, v in d.items():
    325   if isinstance(v, jax.Array):
--> 326     d[k] = np.array(v)
    327   elif isinstance(v, dict):
    328     _np_convert_in_place(v)

File /pyenv/versions/3.10.0/lib/python3.10/site-packages/jax/_src/array.py:377, in ArrayImpl.__array__(self, dtype, context)
    376 def __array__(self, dtype=None, context=None):
--> 377   return np.asarray(self._value, dtype=dtype)

File /pyenv/versions/3.10.0/lib/python3.10/site-packages/jax/_src/profiler.py:314, in annotate_function.<locals>.wrapper(*args, **kwargs)
    311 @wraps(func)
    312 def wrapper(*args, **kwargs):
    313   with TraceAnnotation(name, **decorator_kwargs):
--> 314     return func(*args, **kwargs)
    315   return wrapper

File /pyenv/versions/3.10.0/lib/python3.10/site-packages/jax/_src/array.py:562, in ArrayImpl._value(self)
    559 @property
    560 @functools.partial(profiler.annotate_function, name="np.asarray(jax.Array)")
    561 def _value(self) -> np.ndarray:
--> 562   self._check_if_deleted()
    564   if self._npy_value is None:
    565     if self.is_fully_replicated:

File /pyenv/versions/3.10.0/lib/python3.10/site-packages/jax/_src/array.py:530, in ArrayImpl._check_if_deleted(self)
    528 def _check_if_deleted(self):
    529   if self.is_deleted():
--> 530     raise RuntimeError(
    531         f"Array has been deleted with shape={self.aval.str_short()}.")

RuntimeError: Array has been deleted with shape=float32[1024].

The error occurred in the last step. Can anyone help? Thanks!

I expect I will push the model to HuggingFace smoothly as was written in this tutorial: https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/text_classification_flax.ipynb#scrollTo=FJHgcqvkBsIc

Leave a Comment