Kohaku-Blueleaf commited on
Commit
4a7ac82
1 Parent(s): edff9b1

auto device

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -13,6 +13,8 @@ from kgen.metainfo import SPECIAL, TARGET
13
 
14
 
15
  MODEL_PATH = "KBlueLeaf/DanTagGen"
 
 
16
 
17
 
18
  @torch.no_grad()
@@ -31,7 +33,6 @@ def get_result(
31
  escape_bracket: bool = False,
32
  temperature: float = 1.35,
33
  ):
34
- text_model.eval().half().cuda()
35
  start = time_ns()
36
  print("=" * 50, "\n")
37
  # Use LLM to predict possible summary
@@ -114,7 +115,7 @@ masterpiece, newest, absurdres, {rating}"""
114
  if __name__ == "__main__":
115
  tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
116
  text_model = LlamaForCausalLM.from_pretrained(MODEL_PATH)
117
- text_model = text_model.eval()
118
 
119
  @spaces.GPU
120
  def wrapper(
 
13
 
14
 
15
  MODEL_PATH = "KBlueLeaf/DanTagGen"
16
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
17
+ print(f"Using device: {DEVICE}")
18
 
19
 
20
  @torch.no_grad()
 
33
  escape_bracket: bool = False,
34
  temperature: float = 1.35,
35
  ):
 
36
  start = time_ns()
37
  print("=" * 50, "\n")
38
  # Use LLM to predict possible summary
 
115
  if __name__ == "__main__":
116
  tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
117
  text_model = LlamaForCausalLM.from_pretrained(MODEL_PATH)
118
+ text_model = text_model.eval().half().to(DEVICE)
119
 
120
  @spaces.GPU
121
  def wrapper(