yuntian-deng commited on
Commit
3241c9e
1 Parent(s): 14684f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -3
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  import json
3
  import re
@@ -13,7 +14,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
13
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
14
  model.eval()
15
  if torch.cuda.is_available():
16
- model.cuda()
17
 
18
  validation_results = json.load(open('validation_results.json'))
19
  scores, thresholds, precisions, recalls = validation_results['scores'], validation_results['thresholds'], validation_results['precisions'], validation_results['recalls']
@@ -72,18 +73,24 @@ Authors: {authors}
72
  Abstract: {abstract}"""
73
  return text
74
 
75
- @torch.no_grad()
 
76
  def model_inference(title, authors, abstract):
 
 
 
77
  text = fill_template(title, authors, abstract)
78
  text = f'[CLS] {text} [SEP]'
79
  print (text)
 
80
  inputs = tokenizer([text], return_tensors="pt", truncation=True, max_length=max_length)
81
  if torch.cuda.is_available():
82
- inputs = {key: value.cuda() for key, value in inputs.items()}
83
  outputs = model(**inputs)
84
  logits = outputs.logits
85
  probs = logits.softmax(dim=-1).view(-1)
86
  score = probs[1].item()
 
87
  return score
88
 
89
  def predict(title, authors, abstract):
 
1
+ import spaces
2
  import gradio as gr
3
  import json
4
  import re
 
14
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
15
  model.eval()
16
  if torch.cuda.is_available():
17
+ model.to("cuda:0")
18
 
19
  validation_results = json.load(open('validation_results.json'))
20
  scores, thresholds, precisions, recalls = validation_results['scores'], validation_results['thresholds'], validation_results['precisions'], validation_results['recalls']
 
73
  Abstract: {abstract}"""
74
  return text
75
 
76
+ @torch.no_grad
77
+ @spaces.GPU
78
  def model_inference(title, authors, abstract):
79
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
80
+ if device != model.device:
81
+ model.to(device)
82
  text = fill_template(title, authors, abstract)
83
  text = f'[CLS] {text} [SEP]'
84
  print (text)
85
+ print (device)
86
  inputs = tokenizer([text], return_tensors="pt", truncation=True, max_length=max_length)
87
  if torch.cuda.is_available():
88
+ inputs = {key: value.to(device) for key, value in inputs.items()}
89
  outputs = model(**inputs)
90
  logits = outputs.logits
91
  probs = logits.softmax(dim=-1).view(-1)
92
  score = probs[1].item()
93
+ print (score)
94
  return score
95
 
96
  def predict(title, authors, abstract):