jhj0517 commited on
Commit
9390f92
2 Parent(s): f282b7c 67cc6b1

Merge pull request #15 from damho1104/mitigate-cuda-out-of-memory

Browse files
modules/base_interface.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from typing import List
4
+
5
+
6
+ class BaseInterface:
7
+ def __init__(self):
8
+ pass
9
+
10
+ @staticmethod
11
+ def release_cuda_memory():
12
+ torch.cuda.empty_cache()
13
+ torch.cuda.reset_max_memory_allocated()
14
+
15
+ @staticmethod
16
+ def remove_input_files(file_paths: List[str]):
17
+ for file_path in file_paths:
18
+ if not os.path.exists(file_path):
19
+ continue
20
+ os.remove(file_path)
modules/nllb_inference.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
2
  import gradio as gr
3
  import torch
@@ -10,8 +11,9 @@ DEFAULT_MODEL_SIZE = "facebook/nllb-200-1.3B"
10
  NLLB_MODELS = ["facebook/nllb-200-3.3B", "facebook/nllb-200-1.3B", "facebook/nllb-200-distilled-600M"]
11
 
12
 
13
- class NLLBInference:
14
  def __init__(self):
 
15
  self.default_model_size = DEFAULT_MODEL_SIZE
16
  self.current_model_size = None
17
  self.model = None
@@ -29,69 +31,74 @@ class NLLBInference:
29
  def translate_file(self, fileobjs
30
  , model_size, src_lang, tgt_lang,
31
  progress=gr.Progress()):
 
 
 
 
 
 
 
 
 
32
 
33
- if model_size != self.current_model_size or self.model is None:
34
- print("\nInitializing NLLB Model..\n")
35
- progress(0, desc="Initializing NLLB Model..")
36
- self.current_model_size = model_size
37
- self.model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path=model_size,
38
- cache_dir="models/NLLB")
39
- self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size,
40
- cache_dir=f"models/NLLB/tokenizers")
41
 
42
- src_lang = NLLB_AVAILABLE_LANGS[src_lang]
43
- tgt_lang = NLLB_AVAILABLE_LANGS[tgt_lang]
 
 
 
 
44
 
45
- self.pipeline = pipeline("translation",
46
- model=self.model,
47
- tokenizer=self.tokenizer,
48
- src_lang=src_lang,
49
- tgt_lang=tgt_lang,
50
- device=self.device)
 
 
 
 
 
 
51
 
52
- files_info = {}
53
- for fileobj in fileobjs:
54
- file_path = fileobj.name
55
- file_name, file_ext = os.path.splitext(os.path.basename(fileobj.orig_name))
56
- if file_ext == ".srt":
57
- parsed_dicts = parse_srt(file_path=file_path)
58
- total_progress = len(parsed_dicts)
59
- for index, dic in enumerate(parsed_dicts):
60
- progress(index / total_progress, desc="Translating..")
61
- translated_text = self.translate_text(dic["sentence"])
62
- dic["sentence"] = translated_text
63
- subtitle = get_serialized_srt(parsed_dicts)
64
 
65
- timestamp = datetime.now().strftime("%m%d%H%M%S")
66
- file_name = file_name[:-9]
67
- output_path = f"outputs/translations/{file_name}-{timestamp}"
68
 
69
- write_file(subtitle, f"{output_path}.srt")
 
 
 
 
 
 
 
70
 
71
- elif file_ext == ".vtt":
72
- parsed_dicts = parse_vtt(file_path=file_path)
73
- total_progress = len(parsed_dicts)
74
- for index, dic in enumerate(parsed_dicts):
75
- progress(index / total_progress, desc="Translating..")
76
- translated_text = self.translate_text(dic["sentence"])
77
- dic["sentence"] = translated_text
78
- subtitle = get_serialized_vtt(parsed_dicts)
79
 
80
- timestamp = datetime.now().strftime("%m%d%H%M%S")
81
- file_name = file_name[:-9]
82
- output_path = f"outputs/translations/{file_name}-{timestamp}"
83
 
84
- write_file(subtitle, f"{output_path}.vtt")
85
 
86
- files_info[file_name] = subtitle
 
 
 
 
87
 
88
- total_result = ''
89
- for file_name, subtitle in files_info.items():
90
- total_result += '------------------------------------\n'
91
- total_result += f'{file_name}\n\n'
92
- total_result += f'{subtitle}'
93
-
94
- return f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}"
95
 
96
 
97
  NLLB_AVAILABLE_LANGS = {
 
1
+ from .base_interface import BaseInterface
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
3
  import gradio as gr
4
  import torch
 
11
  NLLB_MODELS = ["facebook/nllb-200-3.3B", "facebook/nllb-200-1.3B", "facebook/nllb-200-distilled-600M"]
12
 
13
 
14
+ class NLLBInference(BaseInterface):
15
  def __init__(self):
16
+ super().__init__()
17
  self.default_model_size = DEFAULT_MODEL_SIZE
18
  self.current_model_size = None
19
  self.model = None
 
31
  def translate_file(self, fileobjs
32
  , model_size, src_lang, tgt_lang,
33
  progress=gr.Progress()):
34
+ try:
35
+ if model_size != self.current_model_size or self.model is None:
36
+ print("\nInitializing NLLB Model..\n")
37
+ progress(0, desc="Initializing NLLB Model..")
38
+ self.current_model_size = model_size
39
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path=model_size,
40
+ cache_dir="models/NLLB")
41
+ self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size,
42
+ cache_dir=f"models/NLLB/tokenizers")
43
 
44
+ src_lang = NLLB_AVAILABLE_LANGS[src_lang]
45
+ tgt_lang = NLLB_AVAILABLE_LANGS[tgt_lang]
 
 
 
 
 
 
46
 
47
+ self.pipeline = pipeline("translation",
48
+ model=self.model,
49
+ tokenizer=self.tokenizer,
50
+ src_lang=src_lang,
51
+ tgt_lang=tgt_lang,
52
+ device=self.device)
53
 
54
+ files_info = {}
55
+ for fileobj in fileobjs:
56
+ file_path = fileobj.name
57
+ file_name, file_ext = os.path.splitext(os.path.basename(fileobj.orig_name))
58
+ if file_ext == ".srt":
59
+ parsed_dicts = parse_srt(file_path=file_path)
60
+ total_progress = len(parsed_dicts)
61
+ for index, dic in enumerate(parsed_dicts):
62
+ progress(index / total_progress, desc="Translating..")
63
+ translated_text = self.translate_text(dic["sentence"])
64
+ dic["sentence"] = translated_text
65
+ subtitle = get_serialized_srt(parsed_dicts)
66
 
67
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
68
+ file_name = file_name[:-9]
69
+ output_path = f"outputs/translations/{file_name}-{timestamp}"
 
 
 
 
 
 
 
 
 
70
 
71
+ write_file(subtitle, f"{output_path}.srt")
 
 
72
 
73
+ elif file_ext == ".vtt":
74
+ parsed_dicts = parse_vtt(file_path=file_path)
75
+ total_progress = len(parsed_dicts)
76
+ for index, dic in enumerate(parsed_dicts):
77
+ progress(index / total_progress, desc="Translating..")
78
+ translated_text = self.translate_text(dic["sentence"])
79
+ dic["sentence"] = translated_text
80
+ subtitle = get_serialized_vtt(parsed_dicts)
81
 
82
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
83
+ file_name = file_name[:-9]
84
+ output_path = f"outputs/translations/{file_name}-{timestamp}"
 
 
 
 
 
85
 
86
+ write_file(subtitle, f"{output_path}.vtt")
 
 
87
 
88
+ files_info[file_name] = subtitle
89
 
90
+ total_result = ''
91
+ for file_name, subtitle in files_info.items():
92
+ total_result += '------------------------------------\n'
93
+ total_result += f'{file_name}\n\n'
94
+ total_result += f'{subtitle}'
95
 
96
+ return f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}"
97
+ except Exception as e:
98
+ return f"Error: {str(e)}"
99
+ finally:
100
+ self.release_cuda_memory()
101
+ self.remove_input_files([fileobj.name for fileobj in fileobjs])
 
102
 
103
 
104
  NLLB_AVAILABLE_LANGS = {
modules/whisper_Inference.py CHANGED
@@ -1,4 +1,5 @@
1
  import whisper
 
2
  from modules.subtitle_manager import get_srt, get_vtt, write_file, safe_filename
3
  from modules.youtube_manager import get_ytdata, get_ytaudio
4
  import gradio as gr
@@ -8,8 +9,9 @@ from datetime import datetime
8
  DEFAULT_MODEL_SIZE = "large-v2"
9
 
10
 
11
- class WhisperInference:
12
  def __init__(self):
 
13
  self.current_model_size = None
14
  self.model = None
15
  self.available_models = whisper.available_models()
@@ -71,11 +73,10 @@ class WhisperInference:
71
 
72
  return f"Done! Subtitle is in the outputs folder.\n\n{total_result}"
73
  except Exception as e:
74
- return str(e)
75
  finally:
76
- for fileobj in fileobjs:
77
- if os.path.exists(fileobj.name):
78
- os.remove(fileobj.name)
79
 
80
  def transcribe_youtube(self, youtubelink,
81
  model_size, lang, subformat, istranslate,
@@ -120,12 +121,12 @@ class WhisperInference:
120
 
121
  return f"Done! Subtitle file is in the outputs folder.\n\n{subtitle}"
122
  except Exception as e:
123
- return str(e)
124
  finally:
125
  yt = get_ytdata(youtubelink)
126
  file_path = get_ytaudio(yt)
127
- if os.path.exists(file_path):
128
- os.remove(file_path)
129
 
130
  def transcribe_mic(self, micaudio,
131
  model_size, lang, subformat, istranslate,
@@ -167,7 +168,7 @@ class WhisperInference:
167
 
168
  return f"Done! Subtitle file is in the outputs folder.\n\n{subtitle}"
169
  except Exception as e:
170
- print(str(e))
171
  finally:
172
- if os.path.exists(micaudio):
173
- os.remove(micaudio)
 
1
  import whisper
2
+ from .base_interface import BaseInterface
3
  from modules.subtitle_manager import get_srt, get_vtt, write_file, safe_filename
4
  from modules.youtube_manager import get_ytdata, get_ytaudio
5
  import gradio as gr
 
9
  DEFAULT_MODEL_SIZE = "large-v2"
10
 
11
 
12
+ class WhisperInference(BaseInterface):
13
  def __init__(self):
14
+ super().__init__()
15
  self.current_model_size = None
16
  self.model = None
17
  self.available_models = whisper.available_models()
 
73
 
74
  return f"Done! Subtitle is in the outputs folder.\n\n{total_result}"
75
  except Exception as e:
76
+ return f"Error: {str(e)}"
77
  finally:
78
+ self.release_cuda_memory()
79
+ self.remove_input_files([fileobj.name for fileobj in fileobjs])
 
80
 
81
  def transcribe_youtube(self, youtubelink,
82
  model_size, lang, subformat, istranslate,
 
121
 
122
  return f"Done! Subtitle file is in the outputs folder.\n\n{subtitle}"
123
  except Exception as e:
124
+ return f"Error: {str(e)}"
125
  finally:
126
  yt = get_ytdata(youtubelink)
127
  file_path = get_ytaudio(yt)
128
+ self.release_cuda_memory()
129
+ self.remove_input_files([file_path])
130
 
131
  def transcribe_mic(self, micaudio,
132
  model_size, lang, subformat, istranslate,
 
168
 
169
  return f"Done! Subtitle file is in the outputs folder.\n\n{subtitle}"
170
  except Exception as e:
171
+ return f"Error: {str(e)}"
172
  finally:
173
+ self.release_cuda_memory()
174
+ self.remove_input_files([micaudio])