anas-awadalla commited on
Commit
28a1d28
1 Parent(s): 46b0f0f

removed lm weights from checkpoint

Browse files
Files changed (2) hide show
  1. checkpoint.pt +2 -2
  2. clean_checkpoint.py +12 -0
checkpoint.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:def3dbd2e3cbf471019bc5d8ee854fc644b3eca62dd3e0fc81e76dd6e0363c06
3
- size 15077874322
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a7cc063ff3f02187dba34c45a7f58e60c291b7b37144a965edccd0c877c8f5a
3
+ size 4872875366
clean_checkpoint.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # Load the checkpoint
4
+ checkpoint = torch.load('checkpoint_cleaned.pt', map_location=torch.device('cpu'))
5
+ print(checkpoint.keys())
6
+ # remove keys of fform lang_encoder.gpt_neox.layers.x.decoder_layer
7
+ for key in list(checkpoint.keys()):
8
+ if 'decoder_layer' in key:
9
+ del checkpoint[key]
10
+
11
+ # save the checkpoint
12
+ torch.save(checkpoint, 'checkpoint_cleaned.pt')