jhj0517 commited on
Commit
1f79d6e
1 Parent(s): 1c64e54

fix compute type error

Browse files
app.py CHANGED
@@ -18,6 +18,7 @@ class App:
18
  print("Use Faster Whisper implementation")
19
  else:
20
  print("Use Open AI Whisper implementation")
 
21
  self.nllb_inf = NLLBInference()
22
 
23
  @staticmethod
 
18
  print("Use Faster Whisper implementation")
19
  else:
20
  print("Use Open AI Whisper implementation")
21
+ print(f"Device \"{self.whisper_inf.device}\" is detected")
22
  self.nllb_inf = NLLBInference()
23
 
24
  @staticmethod
modules/faster_whisper_inference.py CHANGED
@@ -26,6 +26,7 @@ class FasterWhisperInference(BaseInterface):
26
  self.translatable_models = ["large", "large-v1", "large-v2"]
27
  self.default_beam_size = 1
28
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
29
 
30
  def transcribe_file(self,
31
  fileobjs: list,
@@ -365,7 +366,7 @@ class FasterWhisperInference(BaseInterface):
365
  device=self.device,
366
  model_size_or_path=model_size,
367
  download_root=os.path.join("models", "Whisper", "faster-whisper"),
368
- compute_type="float16"
369
  )
370
 
371
  @staticmethod
 
26
  self.translatable_models = ["large", "large-v1", "large-v2"]
27
  self.default_beam_size = 1
28
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
29
+ self.compute_type = "float16" if self.device == "cuda" else "int8"
30
 
31
  def transcribe_file(self,
32
  fileobjs: list,
 
366
  device=self.device,
367
  model_size_or_path=model_size,
368
  download_root=os.path.join("models", "Whisper", "faster-whisper"),
369
+ compute_type=self.compute_type
370
  )
371
 
372
  @staticmethod
modules/whisper_Inference.py CHANGED
@@ -21,6 +21,7 @@ class WhisperInference(BaseInterface):
21
  self.model = None
22
  self.available_models = whisper.available_models()
23
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
 
24
  self.default_beam_size = 1
25
 
26
  def transcribe_file(self,
 
21
  self.model = None
22
  self.available_models = whisper.available_models()
23
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
24
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
25
  self.default_beam_size = 1
26
 
27
  def transcribe_file(self,