iofu728 commited on
Commit
7b75ee1
1 Parent(s): ad9d4f6

Feature(MInference): changing to GPU

Browse files
Files changed (2) hide show
  1. app.py +12 -8
  2. requirements.txt +3 -1
app.py CHANGED
@@ -1,10 +1,10 @@
1
- import subprocess
2
- # Install flash attention, skipping CUDA build if necessary
3
- subprocess.run(
4
- "pip install flash-attn --no-build-isolation",
5
- env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
6
- shell=True,
7
- )
8
 
9
  import gradio as gr
10
  import os
@@ -59,6 +59,10 @@ model_name = "gradientai/Llama-3-8B-Instruct-262k"
59
  tokenizer = AutoTokenizer.from_pretrained(model_name)
60
  model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") # to("cuda:0")
61
 
 
 
 
 
62
  terminators = [
63
  tokenizer.eos_token_id,
64
  tokenizer.convert_tokens_to_ids("<|eot_id|>")
@@ -80,7 +84,7 @@ def chat_llama3_8b(message: str,
80
  Returns:
81
  str: The generated response.
82
  """
83
- global model
84
  conversation = []
85
  for user, assistant in history:
86
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
 
1
+ # import subprocess
2
+ # # Install flash attention, skipping CUDA build if necessary
3
+ # subprocess.run(
4
+ # "pip install flash-attn --no-build-isolation",
5
+ # env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
6
+ # shell=True,
7
+ # )
8
 
9
  import gradio as gr
10
  import os
 
59
  tokenizer = AutoTokenizer.from_pretrained(model_name)
60
  model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") # to("cuda:0")
61
 
62
+ from minference import MInference
63
+ minference_patch = MInference("minference", model_name)
64
+ model = minference_patch(model)
65
+
66
  terminators = [
67
  tokenizer.eos_token_id,
68
  tokenizer.convert_tokens_to_ids("<|eot_id|>")
 
84
  Returns:
85
  str: The generated response.
86
  """
87
+ # global model
88
  conversation = []
89
  for user, assistant in history:
90
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
requirements.txt CHANGED
@@ -1,3 +1,5 @@
1
  triton==2.1.0
2
  accelerate
3
- transformers
 
 
 
1
  triton==2.1.0
2
  accelerate
3
+ transformers
4
+ flash_attn
5
+ pycuda==2023.1