PoTaTo721 commited on
Commit
537a375
1 Parent(s): 8dfc341

Fix cache max_seq_len

Browse files
Files changed (2) hide show
  1. app.py +2 -2
  2. tools/llama/generate.py +4 -2
app.py CHANGED
@@ -414,7 +414,7 @@ def build_app():
414
  label="Maximum tokens per batch, 0 means no limit",
415
  minimum=0,
416
  maximum=2048,
417
- value=1024, # 0 means no limit
418
  step=8,
419
  )
420
 
@@ -640,7 +640,7 @@ if __name__ == "__main__":
640
  reference_audio=None,
641
  reference_text="",
642
  max_new_tokens=0,
643
- chunk_length=100,
644
  top_p=0.7,
645
  repetition_penalty=1.2,
646
  temperature=0.7,
 
414
  label="Maximum tokens per batch, 0 means no limit",
415
  minimum=0,
416
  maximum=2048,
417
+ value=0, # 0 means no limit
418
  step=8,
419
  )
420
 
 
640
  reference_audio=None,
641
  reference_text="",
642
  max_new_tokens=0,
643
+ chunk_length=200,
644
  top_p=0.7,
645
  repetition_penalty=1.2,
646
  temperature=0.7,
tools/llama/generate.py CHANGED
@@ -250,9 +250,11 @@ def generate(
250
  device, dtype = prompt.device, prompt.dtype
251
  with torch.device(device):
252
  model.setup_caches(
253
- max_batch_size=1, max_seq_len=T_new, dtype=next(model.parameters()).dtype
 
 
254
  )
255
-
256
  codebook_dim = 1 + model.config.num_codebooks
257
  # create an empty tensor of the expected final shape and fill in the current tokens
258
  empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
 
250
  device, dtype = prompt.device, prompt.dtype
251
  with torch.device(device):
252
  model.setup_caches(
253
+ max_batch_size=1,
254
+ max_seq_len=model.config.max_seq_len,
255
+ dtype=next(model.parameters()).dtype,
256
  )
257
+
258
  codebook_dim = 1 + model.config.num_codebooks
259
  # create an empty tensor of the expected final shape and fill in the current tokens
260
  empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)