Model checkpoint loading throws an error

#5
by alvations - opened

With unbabel-comet==2.2.2, when loading the checkpoint with:

import os
from huggingface_hub import snapshot_download

from comet.models.multitask.xcomet_metric import XCOMETMetric

#model_path = snapshot_download(repo_id="Unbabel/XCOMET-XL", cache_dir=os.path.abspath(os.path.dirname('.')))


model_checkpoint_path = f"{model_path}/checkpoints/model.ckpt"
xcometxl = XCOMETMetric.load_from_checkpoint(model_checkpoint_path)

It's giving an error when loading the state_dict:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-9-103976254713> in <cell line: 7>()
      5 
      6 model_checkpoint_path = f"{model_path}/checkpoints/model.ckpt"
----> 7 xcometxl = XCOMETMetric.load_from_checkpoint(model_checkpoint_path)

4 frames
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict, assign)
   2187 
   2188         if len(error_msgs) > 0:
-> 2189             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   2190                                self.__class__.__name__, "\n\t".join(error_msgs)))
   2191         return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for XCOMETMetric:
    Unexpected key(s) in state_dict: "encoder.model.embeddings.position_ids". 

Is there something else that needs to be initialized for the model checkpoint to load properly?

Sign up or log in to comment