d-Matrix
commited on
Commit
•
036923e
1
Parent(s):
dfd01dd
Update modeling_gptj.py
Browse files- modeling_gptj.py +0 -16
modeling_gptj.py
CHANGED
@@ -40,22 +40,6 @@ from transformers.utils import (
|
|
40 |
)
|
41 |
from .configuration_gptj import GPTJConfig
|
42 |
|
43 |
-
from mltools.dmx import DmxModel
|
44 |
-
|
45 |
-
|
46 |
-
class DmxPreTrainedModel(PreTrainedModel):
|
47 |
-
@classmethod
|
48 |
-
def from_pretrained(cls, *args, **kwargs):
|
49 |
-
_model = super().from_pretrained(*args, **kwargs)
|
50 |
-
_model = DmxModel.from_torch(
|
51 |
-
_model,
|
52 |
-
hf=True,
|
53 |
-
input_names=["input_ids"], # TODO: no hard-coding!!!
|
54 |
-
concrete_args=None,
|
55 |
-
)
|
56 |
-
return _model
|
57 |
-
|
58 |
-
|
59 |
logger = logging.get_logger(__name__)
|
60 |
|
61 |
_CHECKPOINT_FOR_DOC = "hf-internal-testing/tiny-random-gptj"
|
|
|
40 |
)
|
41 |
from .configuration_gptj import GPTJConfig
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
logger = logging.get_logger(__name__)
|
44 |
|
45 |
_CHECKPOINT_FOR_DOC = "hf-internal-testing/tiny-random-gptj"
|