Kohaku-Blueleaf commited on
Commit
0d5be9b
1 Parent(s): 89f225d

flash attn to prevent oom

Browse files
Files changed (3) hide show
  1. app.py +7 -1
  2. kgen/generate.py +3 -1
  3. requirements.txt +2 -1
app.py CHANGED
@@ -113,7 +113,13 @@ masterpiece, newest, absurdres, {rating}"""
113
  if __name__ == "__main__":
114
  models = {
115
  model_path: [
116
- LlamaForCausalLM.from_pretrained(model_path).eval().half().to(DEVICE),
 
 
 
 
 
 
117
  LlamaTokenizer.from_pretrained(model_path),
118
  ]
119
  for model_path in MODEL_PATHS
 
113
  if __name__ == "__main__":
114
  models = {
115
  model_path: [
116
+ LlamaForCausalLM.from_pretrained(
117
+ model_path, attn_implementation="flash_attention_2"
118
+ )
119
+ .requires_grad_(False)
120
+ .eval()
121
+ .half()
122
+ .to(DEVICE),
123
  LlamaTokenizer.from_pretrained(model_path),
124
  ]
125
  for model_path in MODEL_PATHS
kgen/generate.py CHANGED
@@ -83,7 +83,9 @@ def tag_gen(
83
  repetition_penalty=None,
84
  max_new_tokens=max_new_tokens,
85
  stream_output=False,
86
- autocast_gen=nullcontext,
 
 
87
  prompt_lookup_num_tokens=10,
88
  pad_token_id=tokenizer.eos_token_id,
89
  eos_token_id=tokenizer.eos_token_id,
 
83
  repetition_penalty=None,
84
  max_new_tokens=max_new_tokens,
85
  stream_output=False,
86
+ autocast_gen=lambda: (
87
+ torch.autocast("cuda") if torch.cuda.is_available() else nullcontext()
88
+ ),
89
  prompt_lookup_num_tokens=10,
90
  pad_token_id=tokenizer.eos_token_id,
91
  eos_token_id=tokenizer.eos_token_id,
requirements.txt CHANGED
@@ -4,4 +4,5 @@ llama-cpp-python
4
  gradio
5
  requests
6
  sentencepiece
7
- spaces
 
 
4
  gradio
5
  requests
6
  sentencepiece
7
+ spaces
8
+ flash-attn