Spaces:
Running
on
Zero
Running
on
Zero
"""Utility functions for training and inference.""" | |
import functools | |
from pathlib import Path | |
import pickle | |
import warnings | |
from io import BytesIO | |
import torch | |
import torch.utils._device | |
from lightning.fabric.strategies import DeepSpeedStrategy, FSDPStrategy | |
from torch.distributed.fsdp import FullStateDictConfig | |
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | |
from torch.distributed.fsdp import StateDictType | |
def save_model_checkpoint(fabric, model, file_path): | |
"""Handles boilerplate logic for retrieving and saving the state_dict. | |
This will be upstreamed to Fabric soon. | |
""" | |
file_path = Path(file_path) | |
if isinstance(fabric.strategy, DeepSpeedStrategy): | |
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict | |
fabric.save(file_path, {"model": model}) | |
fabric.barrier() | |
if fabric.global_rank == 0: | |
# Create a consolidated checkpoint with the same name next to the deepspeed checkpoint | |
convert_zero_checkpoint_to_fp32_state_dict(file_path, file_path.with_suffix(".pth")) | |
return | |
if isinstance(fabric.strategy, FSDPStrategy): | |
save_policy = FullStateDictConfig(offload_to_cpu=(fabric.world_size > 1), rank0_only=True) | |
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy): | |
state_dict = model._forward_module.state_dict() | |
else: | |
state_dict = model.state_dict() | |
if fabric.global_rank == 0: | |
torch.save(state_dict, file_path) | |
fabric.barrier() | |
class EmptyInitOnDevice(torch.overrides.TorchFunctionMode): | |
def __init__(self, device=None, dtype=None, quantization_mode=None): | |
""" | |
Create tensors with given device and dtype and don't run initialization | |
(but instead use "empty tensors", i.e. uninitialized memory). | |
device: `torch.device` to work with | |
dtype: `torch.dtype` to work with | |
quantization_mode: optional string, quantization mode to work with, default `None`. | |
Available modes: `llm.int8` bitsnbytes LLM.int8 quantization (only on GPU) | |
`qptq.int4`, `gptq.int8`: GPTQ pre-quantized models | |
Example:: | |
with EmptyInitOnDevice("cuda", dtype=torch.bfloat16): | |
model = LLaMA.from_name('7B') | |
model.load_state_dict(torch.load('llama-lit/7B/lit-llama.pth'))""" | |
self.quantization_mode = quantization_mode | |
self.quantized_linear_cls = None | |
if self.quantization_mode == 'llm.int8': | |
if device.type != "cuda": | |
raise ValueError("Quantization is only supported on the GPU.") | |
from .quantization import Linear8bitLt | |
self.quantized_linear_cls = Linear8bitLt | |
elif self.quantization_mode == 'gptq.int4': | |
from .quantization import ColBlockQuantizedLinear | |
self.quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=4, tile_cols=-1) | |
elif self.quantization_mode == 'gptq.int8': | |
from .quantization import ColBlockQuantizedLinear | |
self.quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=8, tile_cols=-1) | |
elif self.quantization_mode is not None: | |
raise RuntimeError(f"unknown quantization mode {self.quantization_mode}") | |
self.device = device | |
self.dtype = dtype | |
def __enter__(self): | |
if self.quantized_linear_cls != None: | |
self.torch_linear_cls = torch.nn.Linear | |
torch.nn.Linear = self.quantized_linear_cls | |
return super().__enter__() | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
if self.quantized_linear_cls != None: | |
torch.nn.Linear = self.torch_linear_cls | |
return super().__exit__(exc_type, exc_val, exc_tb) | |
def __torch_function__(self, func, types, args=(), kwargs=None): | |
kwargs = kwargs or {} | |
if getattr(func, "__module__", None) == "torch.nn.init": | |
if "tensor" in kwargs: | |
return kwargs["tensor"] | |
else: | |
return args[0] | |
if ( | |
self.device is not None | |
and func in torch.utils._device._device_constructors() | |
and kwargs.get("device") is None | |
): | |
kwargs["device"] = self.device | |
if ( | |
self.dtype is not None | |
and func in torch.utils._device._device_constructors() | |
and kwargs.get("dtype") is None | |
): | |
kwargs["dtype"] = self.dtype | |
return func(*args, **kwargs) | |
# this is taken from torchhacks https://github.com/lernapparat/torchhacks | |
class NotYetLoadedTensor: | |
def __init__(self, metatensor, archiveinfo, storageinfo, rebuild_args): | |
self.metatensor = metatensor | |
self.archiveinfo = archiveinfo | |
self.storageinfo = storageinfo | |
self.rebuild_args = rebuild_args | |
def rebuild( | |
cls, | |
storage, | |
storage_offset, | |
size, | |
stride, | |
requires_grad, | |
backward_hooks, | |
metadata=None, | |
archiveinfo=None, | |
): | |
rebuild_args = ( | |
storage_offset, | |
size, | |
stride, | |
requires_grad, | |
backward_hooks, | |
metadata, | |
) | |
metatensor = torch._utils._rebuild_tensor_v2( | |
storage, | |
storage_offset, | |
size, | |
stride, | |
requires_grad, | |
backward_hooks, | |
metadata, | |
) | |
storageinfo = storage.archiveinfo | |
return NotYetLoadedTensor(metatensor, archiveinfo, storageinfo, rebuild_args) | |
def _load_tensor(self): | |
name, storage_cls, fn, device, size = self.storageinfo | |
dtype = self.metatensor.dtype | |
uts = ( | |
self.archiveinfo.zipfile.get_storage_from_record( | |
f"data/{fn}", | |
size * torch._utils._element_size(dtype), | |
torch.UntypedStorage, | |
) | |
._typed_storage() | |
._untyped_storage | |
) | |
with warnings.catch_warnings(): | |
warnings.simplefilter("ignore") | |
storage = torch.storage.TypedStorage( | |
wrap_storage=uts, dtype=self.metatensor.dtype, _internal=True | |
) | |
tensor = torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args) | |
return tensor | |
def __torch_function__(cls, func, types, args=(), kwargs=None): | |
if kwargs is None: | |
kwargs = {} | |
loaded_args = [ | |
(a._load_tensor() if isinstance(a, NotYetLoadedTensor) else a) for a in args | |
] | |
res = func(*loaded_args, **kwargs) | |
# gc.collect would be costly here, maybe do it optionally | |
return res | |
def __getattr__(self, name): | |
# properties | |
## TODO: device, is_...?? | |
## TODO: mH, mT, H, T, data, imag, real | |
## name ??? | |
if name in { | |
"dtype", | |
"grad", | |
"grad_fn", | |
"layout", | |
"names", | |
"ndim", | |
"output_nr", | |
"requires_grad", | |
"retains_grad", | |
"shape", | |
"volatile", | |
}: | |
return getattr(self.metatensor, name) | |
if name in {"size"}: | |
return getattr(self.metatensor, name) | |
# materializing with contiguous is needed for quantization | |
if name in {"contiguous"}: | |
return getattr(self._load_tensor(), name) | |
raise AttributeError(f"{type(self)} does not have {name}") | |
def __repr__(self): | |
return f"NotYetLoadedTensor({repr(self.metatensor)})" | |
class LazyLoadingUnpickler(pickle.Unpickler): | |
def __init__(self, file, zipfile): | |
super().__init__(file) | |
self.zipfile = zipfile | |
def find_class(self, module, name): | |
if module == "torch._utils" and name == "_rebuild_tensor_v2": | |
res = super().find_class(module, name) | |
return functools.partial(NotYetLoadedTensor.rebuild, archiveinfo=self) | |
return super().find_class(module, name) | |
def persistent_load(self, pid): | |
name, cls, fn, device, size = pid | |
with warnings.catch_warnings(): | |
warnings.simplefilter("ignore") | |
s = torch.storage.TypedStorage(dtype=cls().dtype, device="meta") | |
s.archiveinfo = pid | |
return s | |
def lazy_load(fn): | |
zf = torch._C.PyTorchFileReader(str(fn)) | |
with BytesIO(zf.get_record("data.pkl")) as pkl: | |
mup = LazyLoadingUnpickler(pkl, zf) | |
sd = mup.load() | |
return sd | |