custom_pt dict error only visible in sentence transformers and not in normal transformers why?

#51
by sleeping4cat - opened

I tried loading JINA v3 model on my 3090 node. While I used transformers library to load the model and generate embeddings it was fine despite the fact that I have a 3.8.10 Python version. But, when I try loading the model using sentence transformers library, it causes an error. Solution to the error is trivial. (I have to modify a line on custom_st.py file to get it compatiable with my Python version)

But, I am puzzled why this error happens in sentence transformers only?

Jina AI org

Hi @sleeping4cat , it's because custom_st.py is only used by sentence-transformers and not by transformers. Feel free to open a PR changing that one line in custom_st.py, would be appreciated!

@jupyterjazz thanks! I don't think it will be a good idea to change the code since hardly anyone uses Python3.8 anymore. I installed Python3.9 and I received a pytorch error. not sure how can I fix it cuz it feels like I need to load the model using pytorch and then fix it. I didn't investigate too much. But, I am posting the traceback below:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[3], line 9
      5 if __name__ == '__main__':
      6     
      7     # model = SentenceTransformer('all-MiniLM-L6-v2')
      8     model = SentenceTransformer("jinaai/jina-embeddings-v3", trust_remote_code=True)
----> 9     pool = model.start_multi_process_pool()
     10     embed = model.encode_multi_process(sentences, pool=pool, batch_size=32, show_progress_bar=True)
     11     print('Embeddings computed. Shape:', embed.shape)

File /mnt/raid/gpu/lib/python3.9/site-packages/sentence_transformers/SentenceTransformer.py:851, in SentenceTransformer.start_multi_process_pool(self, target_devices)
    845 for device_id in target_devices:
    846     p = ctx.Process(
    847         target=SentenceTransformer._encode_multi_process_worker,
    848         args=(device_id, self, input_queue, output_queue),
    849         daemon=True,
    850     )
--> 851     p.start()
    852     processes.append(p)
    854 return {"input": input_queue, "output": output_queue, "processes": processes}

File /usr/lib/python3.9/multiprocessing/process.py:121, in BaseProcess.start(self)
    118 assert not _current_process._config.get('daemon'), \
    119        'daemonic processes are not allowed to have children'
    120 _cleanup()
--> 121 self._popen = self._Popen(self)
    122 self._sentinel = self._popen.sentinel
    123 # Avoid a refcycle if the target function holds an indirect
    124 # reference to the process object (see bpo-30775)

File /usr/lib/python3.9/multiprocessing/context.py:284, in SpawnProcess._Popen(process_obj)
    281 @staticmethod
    282 def _Popen(process_obj):
    283     from .popen_spawn_posix import Popen
--> 284     return Popen(process_obj)

File /usr/lib/python3.9/multiprocessing/popen_spawn_posix.py:32, in Popen.__init__(self, process_obj)
     30 def __init__(self, process_obj):
     31     self._fds = []
---> 32     super().__init__(process_obj)

File /usr/lib/python3.9/multiprocessing/popen_fork.py:19, in Popen.__init__(self, process_obj)
     17 self.returncode = None
     18 self.finalizer = None
---> 19 self._launch(process_obj)

File /usr/lib/python3.9/multiprocessing/popen_spawn_posix.py:47, in Popen._launch(self, process_obj)
     45 try:
     46     reduction.dump(prep_data, fp)
---> 47     reduction.dump(process_obj, fp)
     48 finally:
     49     set_spawning_popen(None)

File /usr/lib/python3.9/multiprocessing/reduction.py:60, in dump(obj, file, protocol)
     58 def dump(obj, file, protocol=None):
     59     '''Replacement for pickle.dump() using ForkingPickler.'''
---> 60     ForkingPickler(file, protocol).dump(obj)

File /mnt/raid/gpu/lib/python3.9/site-packages/torch/nn/utils/parametrize.py:340, in _inject_new_class.<locals>.getstate(self)
    339 def getstate(self):
--> 340     raise RuntimeError(
    341         "Serialization of parametrized modules is only "
    342         "supported through state_dict(). See:\n"
    343         "https://pytorch.org/tutorials/beginner/saving_loading_models.html"
    344         "#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training"
    345     )

RuntimeError: Serialization of parametrized modules is only supported through state_dict(). See:
https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training

I received this error when I loaded the model using SentenceTransformers

@jupyterjazz I have identified the issue. SentenceTransformers is loading the model while using custom_st.py file. Since this is not included in the main file/model, in multi-processing it is causing an serialization error. I tried loading the model without multi_process function and it worked fine. But, I want the multi_process function to work and able to load the model because I am going to run a big process on a supercluster where I'll use (4x A100s) x 16. And I can use all the performance I can get. I tried colbertv2 and it works with multi-process on GPUs.

If you guys can help resolve this for Jina-v3 embedding model, it'll mean a lot.

Jina AI org

@michael-guenther @Saba can you take a look?

Jina AI org

The issue seems to be that jina-embeddings-v3 uses custom parametrized modules (from torch.nn.utils.parametrize) to implement LoRA layers. I suspect these layers are causing the serialization problem because start_multi_process_pool seems to be working fine with colbertv2 which is based on the same architecture but without LoRA adapters.

@tomaarsen , is there anything we can try here? One possible solution could be passing the state_dict instead of the model object to each process and reinitializing the model inside the process. I think the error message suggests this as well.

Hmm, I'm not familiar with this issues - I've also never worked with parametrized modules before. The error message is indeed relatively clear, a potential solution would be to create a custom function for multiprocessing and loading a fresh model instance in each of the processes.

  • Tom Aarsen
Jina AI org

Thanks @tomaarsen

@sleeping4cat which task do you plan to use? We could also try merging the specific LoRA adapter into the main weights, this way we shouldn't have any custom parametrization modules

@jupyterjazz I want to rewrite articles and have the embeddings for training models in science domain. You can take the Stanford Storm project and my applications are in the same domain. In my case, I don't think I need Lora adapters since my GPUs have enough memory for loading the entire model and even the old 3090s I'm using for testing the code aren't have any issue loading the entire model in its VRAM.

But, I want to highlight, we want to release this code as part of our embeddings pipeline from LAION's GitHub org so others may want to use LoRA adapters or use 4bit quantise. If its too much I can just mention they should use normal encode function to generate the embeddings. another plan I have is: If I get the time sometime in the next few weeks, I could write a function that works similar to setence transformers multi_process function and use that custom function as combination with transformers. Since transformers don't use custom_st.py file.

@jupyterjazz also yes! LoRA adapters are behind the issue. I went through the code last night and except for LoRA couldn't think of other layers that might be responsible.

Jina AI org

I validated that custom parametrizations are indeed causing the issue because the same v3 model works fine when I remove LoRA weights. So a possible solution could be to merge the specific LoRA adapter with the main weights or to use the model without any adapters.

However, if you plan to release the code as well, I don't think we should modify the model. I think it's better to adopt the start_multi_process_pool. I made some quick changes and it seems to be working fine, feel free to use this code if it helps:

import logging
from typing import Any, Dict, List, Literal
from queue import Empty, Queue
from sentence_transformers import SentenceTransformer

import torch
import torch.multiprocessing as mp
from transformers import is_torch_npu_available

logger = logging.getLogger(__name__)


def start_multi_process_pool(
    model_name, target_devices: list[str] = None
) -> dict[Literal["input", "output", "processes"], Any]:
    if target_devices is None:
        if torch.cuda.is_available():
            target_devices = [f"cuda:{i}" for i in range(torch.cuda.device_count())]
        elif is_torch_npu_available():
            target_devices = [f"npu:{i}" for i in range(torch.npu.device_count())]
        else:
            logger.info("CUDA/NPU is not available. Starting 4 CPU workers")
            target_devices = ["cpu"] * 4

    logger.info("Start multi-process pool on devices: {}".format(", ".join(map(str, target_devices))))

    ctx = mp.get_context("spawn")
    input_queue = ctx.Queue()
    output_queue = ctx.Queue()
    processes = []

    for device_id in target_devices:
        p = ctx.Process(
            target=encode_multi_process_worker,
            args=(device_id, model_name, input_queue, output_queue),
            daemon=True,
        )
        p.start()
        processes.append(p)

    return {"input": input_queue, "output": output_queue, "processes": processes}


def encode_multi_process_worker(
    target_device: str, model_name: str, input_queue: Queue, results_queue: Queue
) -> None:
    """
    Internal working process to encode sentences in multi-process setup
    """
    model = SentenceTransformer(model_name, trust_remote_code=True)
    while True:
        try:
            chunk_id, batch_size, sentences, prompt_name, prompt, precision, normalize_embeddings = (
                input_queue.get()
            )
            embeddings = model.encode(
                sentences,
                prompt_name=prompt_name,
                prompt=prompt,
                device=target_device,
                show_progress_bar=False,
                precision=precision,
                convert_to_numpy=True,
                batch_size=batch_size,
                normalize_embeddings=normalize_embeddings,
            )

            results_queue.put([chunk_id, embeddings])
        except queue.Empty:
            break


def main():
    sentences = ['some text' for _ in range(100)]
    model = SentenceTransformer('jinaai/jina-embeddings-v3', trust_remote_code=True)
    pool = start_multi_process_pool(model_name='jinaai/jina-embeddings-v3')
    embed = model.encode_multi_process(sentences, pool=pool, batch_size=4, show_progress_bar=True)
    SentenceTransformer.stop_multi_process_pool(pool)

if __name__ == "__main__":
    main()

Regarding the LoRA layers, it’s not really about memory as they make up only a small part of the final weights, the goal is to generate better task-specific embeddings using the main weights combined with a specific LoRA adapter which can be one of these: retrieval.query, retrieval.passage, separation, classification, text-matching. You can read more about this in the README. For your project it seems to me that either retrieval or text-matching might be suitable.

@jupyterjazz can you do a PR on sentencetransformers library and ask @tomaarsen to merge it? I think it will be an on-going issue in the future since more models will have LoRA weights adapters and will cause such issues.

Jina AI org

I think the issue is not LoRA in general, but our implementation of LoRA which is different, so it shouldn't affect other models.

@jupyterjazz that case I will test the code and let you know : ) if I need some other modifications and help.

@jupyterjazz I ran the code on my system and it is quite buggy and prone to errors. Especially despite having the entire model saved locally, I had to keep trust_remote code enabled otherwise it will say custom_pet not found. It is a headache. And I could theoretically use transformers library to load the model and do embeddings and avoid this. Then I have to write custom logics and functions to use all my gpus which is a headache and not worth investing time for me.

That's why, I am using sentence transformers in the first place. at this point I am rather having embeddings with higher dimensions than taking the pain to write custom functions.

also the code you posted I tried to run it and it is creating this error:

Traceback (most recent call last):
  File "<string>", line 1, in <module>
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/usr/lib/python3.9/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/usr/lib/python3.9/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'encode_multi_process_worker' on <module '__main__' (built-in)>
  File "/usr/lib/python3.9/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/usr/lib/python3.9/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'encode_multi_process_worker' on <module '__main__' (built-in)>

I think the fix will be having the function in a seperate file. Still, it won't solve my issue since in my german supercomputer I need to have the model locally saved on my system to use it. Cuz it does not allow access to the internet. If you could merge the code so that it doesn't require custom_pt I could still use it on my project otherwise I have to figure out a different model. : (

Jina AI org

@sleeping4cat I sent you a friend request on discord, let's continue our discussion there as it will be more convenient

Sign up or log in to comment