FrankJJHu commited on
Commit
71169d5
1 Parent(s): 88e47d9

compatible with transformers>=4.42.0

Browse files
Files changed (1) hide show
  1. modeling_kangaroo.py +9 -1
modeling_kangaroo.py CHANGED
@@ -1346,7 +1346,15 @@ class KangarooForCausalLM(LlamaPreTrainedModel):
1346
  position_ids = position_ids[:, -input_ids.shape[1] :]
1347
 
1348
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1349
- if inputs_embeds is not None and past_key_values is None:
 
 
 
 
 
 
 
 
1350
  model_inputs = {"inputs_embeds": inputs_embeds}
1351
  else:
1352
  # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
 
1346
  position_ids = position_ids[:, -input_ids.shape[1] :]
1347
 
1348
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1349
+ set_inputs_embeds = False
1350
+ if inputs_embeds is not None:
1351
+ if isinstance(past_key_values, Cache):
1352
+ if past_key_values.get_seq_length() == 0:
1353
+ set_inputs_embeds = True
1354
+ else:
1355
+ if past_key_values is None:
1356
+ set_inputs_embeds = True
1357
+ if set_inputs_embeds:
1358
  model_inputs = {"inputs_embeds": inputs_embeds}
1359
  else:
1360
  # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise