d-Matrix commited on
Commit
dfd01dd
1 Parent(s): 61e212c

Update modeling_gptj.py

Browse files
Files changed (1) hide show
  1. modeling_gptj.py +16 -2
modeling_gptj.py CHANGED
@@ -40,6 +40,21 @@ from transformers.utils import (
40
  )
41
  from .configuration_gptj import GPTJConfig
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  logger = logging.get_logger(__name__)
45
 
@@ -53,7 +68,6 @@ GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST = [
53
  # See all GPT-J models at https://huggingface.co/models?filter=gptj
54
  ]
55
 
56
-
57
  def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
58
  inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))
59
  sinusoid_inp = torch.einsum(
@@ -365,7 +379,7 @@ class GPTJBlock(nn.Module):
365
  return outputs # hidden_states, present, (attentions)
366
 
367
 
368
- class GPTJPreTrainedModel(PreTrainedModel):
369
  """
370
  An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
371
  models.
 
40
  )
41
  from .configuration_gptj import GPTJConfig
42
 
43
+ from mltools.dmx import DmxModel
44
+
45
+
46
+ class DmxPreTrainedModel(PreTrainedModel):
47
+ @classmethod
48
+ def from_pretrained(cls, *args, **kwargs):
49
+ _model = super().from_pretrained(*args, **kwargs)
50
+ _model = DmxModel.from_torch(
51
+ _model,
52
+ hf=True,
53
+ input_names=["input_ids"], # TODO: no hard-coding!!!
54
+ concrete_args=None,
55
+ )
56
+ return _model
57
+
58
 
59
  logger = logging.get_logger(__name__)
60
 
 
68
  # See all GPT-J models at https://huggingface.co/models?filter=gptj
69
  ]
70
 
 
71
  def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
72
  inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))
73
  sinusoid_inp = torch.einsum(
 
379
  return outputs # hidden_states, present, (attentions)
380
 
381
 
382
+ class GPTJPreTrainedModel(DmxPreTrainedModel):
383
  """
384
  An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
385
  models.