jbetker commited on
Commit
f37375b
1 Parent(s): 73e9929

updates for new autoregressive

Browse files
api_new_autoregressive.py CHANGED
@@ -135,7 +135,7 @@ class TextToSpeech:
135
  download_models()
136
 
137
  self.autoregressive = AutoregressiveCodegen(1024, 16).cpu().eval()
138
- self.autoregressive.load_state_dict(torch.load('X:\\dlas\\experiments\\train_autoregressive_codegen\\models\\11000_codegen_ema.pth'))
139
 
140
  self.clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12,
141
  text_seq_len=350, text_heads=8,
 
135
  download_models()
136
 
137
  self.autoregressive = AutoregressiveCodegen(1024, 16).cpu().eval()
138
+ self.autoregressive.load_state_dict(torch.load('X:\\dlas\\experiments\\train_autoregressive_codegen\\models\\17000_codegen_ema.pth'))
139
 
140
  self.clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12,
141
  text_seq_len=350, text_heads=8,
models/new_autoregressive.py CHANGED
@@ -85,7 +85,13 @@ class InferenceModel(GPT2PreTrainedModel):
85
  assert labels is None # Training not supported by this inference model.
86
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
87
 
88
- hidden_states = self.transformer.decoder(input_ids, full_context=self.context, return_embeddings=True)
 
 
 
 
 
 
89
  logits = self.transformer.decoder.to_logits(hidden_states)
90
 
91
  if not return_dict:
@@ -94,7 +100,7 @@ class InferenceModel(GPT2PreTrainedModel):
94
  return CausalLMOutputWithCrossAttentions(
95
  loss=None,
96
  logits=logits,
97
- past_key_values=None,
98
  hidden_states=hidden_states,
99
  attentions=None,
100
  cross_attentions=None,
@@ -258,7 +264,7 @@ class AutoregressiveCodegen(nn.Module):
258
  inference_model.store_context(full_context)
259
 
260
  gen = inference_model.generate(bos_token_id=self.START_TOKEN, pad_token_id=self.STOP_TOKEN, eos_token_id=self.STOP_TOKEN,
261
- max_length=max_tokens, output_attentions=False, return_dict_in_generate=True,
262
  **hf_generate_kwargs)
263
  return gen.sequences
264
 
 
85
  assert labels is None # Training not supported by this inference model.
86
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
87
 
88
+ out = self.transformer.decoder(input_ids, full_context=self.context, return_embeddings=True, past_key_values=past_key_values,
89
+ use_cache=use_cache, expected_seq_len=100)
90
+ if use_cache:
91
+ hidden_states, present_key_values = out
92
+ else:
93
+ hidden_states = out
94
+ present_key_values = None
95
  logits = self.transformer.decoder.to_logits(hidden_states)
96
 
97
  if not return_dict:
 
100
  return CausalLMOutputWithCrossAttentions(
101
  loss=None,
102
  logits=logits,
103
+ past_key_values=present_key_values,
104
  hidden_states=hidden_states,
105
  attentions=None,
106
  cross_attentions=None,
 
264
  inference_model.store_context(full_context)
265
 
266
  gen = inference_model.generate(bos_token_id=self.START_TOKEN, pad_token_id=self.STOP_TOKEN, eos_token_id=self.STOP_TOKEN,
267
+ max_length=max_tokens, output_attentions=False, return_dict_in_generate=True, use_cache=False,
268
  **hf_generate_kwargs)
269
  return gen.sequences
270
 
models/xtransformers.py CHANGED
@@ -24,7 +24,8 @@ Intermediates = namedtuple('Intermediates', [
24
 
25
  LayerIntermediates = namedtuple('Intermediates', [
26
  'hiddens',
27
- 'attn_intermediates'
 
28
  ])
29
 
30
 
@@ -589,7 +590,8 @@ class Attention(nn.Module):
589
  sinusoidal_emb=None,
590
  rotary_pos_emb=None,
591
  prev_attn=None,
592
- mem=None
 
593
  ):
594
  b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, self.head_scale, self.scale, x.device, exists(
595
  context)
@@ -620,6 +622,13 @@ class Attention(nn.Module):
620
  k = rearrange(k, 'b n d -> b () n d')
621
  v = rearrange(v, 'b n (h d) -> b h n d', h=h)
622
 
 
 
 
 
 
 
 
623
  if exists(rotary_pos_emb) and not has_context:
624
  l = rotary_pos_emb.shape[-1]
625
  (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v))
@@ -723,7 +732,7 @@ class Attention(nn.Module):
723
  post_softmax_attn=post_softmax_attn
724
  )
725
 
726
- return self.to_out(out), intermediates
727
 
728
 
729
  class AttentionLayers(nn.Module):
@@ -770,6 +779,7 @@ class AttentionLayers(nn.Module):
770
  self.dim = dim
771
  self.depth = depth
772
  self.layers = nn.ModuleList([])
 
773
 
774
  rel_pos_bias = 'rel_pos_bias' in attn_kwargs
775
  self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
@@ -911,6 +921,8 @@ class AttentionLayers(nn.Module):
911
  mems=None,
912
  return_hiddens=False,
913
  norm_scale_shift_inp=None,
 
 
914
  ):
915
 
916
  assert not (self.cross_attend ^ (exists(context) or exists(
@@ -929,9 +941,17 @@ class AttentionLayers(nn.Module):
929
 
930
  rotary_pos_emb = None
931
  if exists(self.rotary_pos_emb):
932
- max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems)))
 
 
 
 
 
 
 
933
  rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
934
 
 
935
  cross_attn_count = 0
936
  for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
937
  if layer_type == 'a':
@@ -944,18 +964,28 @@ class AttentionLayers(nn.Module):
944
  if exists(pre_branch_norm):
945
  x = pre_branch_norm(x, **norm_args)
946
 
 
 
 
 
 
 
 
947
  if layer_type == 'a':
948
- out, inter = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
949
- prev_attn, layer_mem)
950
  elif layer_type == 'c':
951
  if exists(full_context):
952
- out, inter = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None,
953
- None, prev_attn)
954
  else:
955
- out, inter = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn)
956
  elif layer_type == 'f':
957
  out = checkpoint(block, x)
958
 
 
 
 
959
  if exists(post_branch_norm):
960
  out = post_branch_norm(out, **norm_args)
961
 
@@ -981,7 +1011,8 @@ class AttentionLayers(nn.Module):
981
  if return_hiddens:
982
  intermediates = LayerIntermediates(
983
  hiddens=hiddens,
984
- attn_intermediates=intermediates
 
985
  )
986
 
987
  return x, intermediates
@@ -1115,6 +1146,7 @@ class TransformerWrapper(nn.Module):
1115
  return_hiddens=False,
1116
  return_attn=False,
1117
  mems=None,
 
1118
  **kwargs
1119
  ):
1120
  b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
@@ -1147,11 +1179,14 @@ class TransformerWrapper(nn.Module):
1147
  hiddens = intermediates.hiddens
1148
  return out, hiddens
1149
 
 
1150
  if return_attn:
1151
  attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
1152
- return out, attn_maps
 
 
1153
 
1154
- return out
1155
 
1156
 
1157
  class ContinuousTransformerWrapper(nn.Module):
@@ -1191,6 +1226,7 @@ class ContinuousTransformerWrapper(nn.Module):
1191
  mask=None,
1192
  return_attn=False,
1193
  mems=None,
 
1194
  **kwargs
1195
  ):
1196
  b, n, _, device = *x.shape, x.device
@@ -1204,11 +1240,14 @@ class ContinuousTransformerWrapper(nn.Module):
1204
 
1205
  out = self.project_out(x) if not return_embeddings else x
1206
 
 
1207
  if return_attn:
1208
  attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
1209
- return out, attn_maps
 
 
1210
 
1211
- return out
1212
 
1213
 
1214
  class XTransformer(nn.Module):
 
24
 
25
  LayerIntermediates = namedtuple('Intermediates', [
26
  'hiddens',
27
+ 'attn_intermediates',
28
+ 'past_key_values',
29
  ])
30
 
31
 
 
590
  sinusoidal_emb=None,
591
  rotary_pos_emb=None,
592
  prev_attn=None,
593
+ mem=None,
594
+ layer_past=None,
595
  ):
596
  b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, self.head_scale, self.scale, x.device, exists(
597
  context)
 
622
  k = rearrange(k, 'b n d -> b () n d')
623
  v = rearrange(v, 'b n (h d) -> b h n d', h=h)
624
 
625
+ if layer_past is not None:
626
+ past_key, past_value = layer_past
627
+ k = torch.cat([past_key, k], dim=-2)
628
+ v = torch.cat([past_value, v], dim=-2)
629
+ k_cache = k
630
+ v_cache = v
631
+
632
  if exists(rotary_pos_emb) and not has_context:
633
  l = rotary_pos_emb.shape[-1]
634
  (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v))
 
732
  post_softmax_attn=post_softmax_attn
733
  )
734
 
735
+ return self.to_out(out), intermediates, k_cache, v_cache
736
 
737
 
738
  class AttentionLayers(nn.Module):
 
779
  self.dim = dim
780
  self.depth = depth
781
  self.layers = nn.ModuleList([])
782
+ self.causal = causal
783
 
784
  rel_pos_bias = 'rel_pos_bias' in attn_kwargs
785
  self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
 
921
  mems=None,
922
  return_hiddens=False,
923
  norm_scale_shift_inp=None,
924
+ past_key_values=None,
925
+ expected_seq_len=None,
926
  ):
927
 
928
  assert not (self.cross_attend ^ (exists(context) or exists(
 
941
 
942
  rotary_pos_emb = None
943
  if exists(self.rotary_pos_emb):
944
+ if not self.training and self.causal:
945
+ assert expected_seq_len is not None, "To decode a transformer with rotary embeddings, you must specify an `expected_seq_len`"
946
+ elif expected_seq_len is None:
947
+ expected_seq_len = 0
948
+ seq_len = x.shape[1]
949
+ if past_key_values is not None:
950
+ seq_len += past_key_values[0][0].shape[-2]
951
+ max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)) + [expected_seq_len])
952
  rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
953
 
954
+ present_key_values = []
955
  cross_attn_count = 0
956
  for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
957
  if layer_type == 'a':
 
964
  if exists(pre_branch_norm):
965
  x = pre_branch_norm(x, **norm_args)
966
 
967
+ if layer_type == 'a' or layer_type == 'c':
968
+ if past_key_values is not None:
969
+ layer_kv = past_key_values.pop(0)
970
+ layer_past = tuple(s.to(x.device) for s in layer_kv)
971
+ else:
972
+ layer_past = None
973
+
974
  if layer_type == 'a':
975
+ out, inter, k, v = checkpoint(block, x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
976
+ prev_attn, layer_mem, layer_past)
977
  elif layer_type == 'c':
978
  if exists(full_context):
979
+ out, inter, k, v = checkpoint(block, x, full_context[cross_attn_count], mask, context_mask, None, None,
980
+ None, prev_attn, None, layer_past)
981
  else:
982
+ out, inter, k, v = checkpoint(block, x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past)
983
  elif layer_type == 'f':
984
  out = checkpoint(block, x)
985
 
986
+ if layer_type == 'a' or layer_type == 'c' and present_key_values is not None:
987
+ present_key_values.append((k.detach(), v.detach()))
988
+
989
  if exists(post_branch_norm):
990
  out = post_branch_norm(out, **norm_args)
991
 
 
1011
  if return_hiddens:
1012
  intermediates = LayerIntermediates(
1013
  hiddens=hiddens,
1014
+ attn_intermediates=intermediates,
1015
+ past_key_values=present_key_values
1016
  )
1017
 
1018
  return x, intermediates
 
1146
  return_hiddens=False,
1147
  return_attn=False,
1148
  mems=None,
1149
+ use_cache=False,
1150
  **kwargs
1151
  ):
1152
  b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
 
1179
  hiddens = intermediates.hiddens
1180
  return out, hiddens
1181
 
1182
+ res = [out]
1183
  if return_attn:
1184
  attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
1185
+ res.append(attn_maps)
1186
+ if use_cache:
1187
+ res.append(intermediates.past_key_values)
1188
 
1189
+ return res
1190
 
1191
 
1192
  class ContinuousTransformerWrapper(nn.Module):
 
1226
  mask=None,
1227
  return_attn=False,
1228
  mems=None,
1229
+ use_cache=False,
1230
  **kwargs
1231
  ):
1232
  b, n, _, device = *x.shape, x.device
 
1240
 
1241
  out = self.project_out(x) if not return_embeddings else x
1242
 
1243
+ res = [out]
1244
  if return_attn:
1245
  attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
1246
+ res.append(attn_maps)
1247
+ if use_cache:
1248
+ res.append(intermediates.past_key_values)
1249
 
1250
+ return tuple(res)
1251
 
1252
 
1253
  class XTransformer(nn.Module):