jhj0517 commited on
Commit
91b9b83
1 Parent(s): 736206b

better read

Browse files
Files changed (1) hide show
  1. modules/model_Inference.py +72 -66
modules/model_Inference.py CHANGED
@@ -1,48 +1,52 @@
1
  import whisper
2
- from modules.subtitle_manager import get_srt,get_vtt,write_srt,write_vtt,safe_filename
3
- from modules.youtube_manager import get_ytdata,get_ytaudio
4
  import gradio as gr
5
  import os
6
  from datetime import datetime
7
 
8
- DEFAULT_MODEL_SIZE="large-v2"
9
 
10
- class WhisperInference():
 
11
  def __init__(self):
12
  print("\nInitializing Model..\n")
13
  self.current_model_size = DEFAULT_MODEL_SIZE
14
- self.model = whisper.load_model(name=DEFAULT_MODEL_SIZE,download_root="models")
15
  self.available_models = whisper.available_models()
16
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
17
 
18
- def transcribe_file(self,fileobjs
19
- ,model_size,lang,subformat,istranslate,
20
  progress=gr.Progress()):
21
-
22
  def progress_callback(progress_value):
23
- progress(progress_value,desc="Transcribing..")
24
-
25
  if model_size != self.current_model_size:
26
- progress(0,desc="Initializing Model..")
27
  self.current_model_size = model_size
28
- self.model = whisper.load_model(name=model_size,download_root="models")
29
 
30
- if lang == "Automatic Detection" :
31
- lang = None
32
 
33
- progress(0,desc="Loading Audio..")
34
 
35
  files_info = {}
36
- for fileobj in fileobjs:
 
37
  audio = whisper.load_audio(fileobj.name)
38
 
39
- translatable_model = ["large","large-v1","large-v2"]
40
  if istranslate and self.current_model_size in translatable_model:
41
- result = self.model.transcribe(audio=audio,language=lang,verbose=False,task="translate",progress_callback=progress_callback)
42
- else :
43
- result = self.model.transcribe(audio=audio,language=lang,verbose=False,progress_callback=progress_callback)
 
 
44
 
45
- progress(1,desc="Completed!")
46
 
47
  file_name, file_ext = os.path.splitext(os.path.basename(fileobj.orig_name))
48
  file_name = file_name[:-9]
@@ -52,47 +56,49 @@ class WhisperInference():
52
 
53
  if subformat == "SRT":
54
  subtitle = get_srt(result["segments"])
55
- write_srt(subtitle,f"{output_path}.srt")
56
  elif subformat == "WebVTT":
57
  subtitle = get_vtt(result["segments"])
58
- write_vtt(subtitle,f"{output_path}.vtt")
59
 
60
  files_info[file_name] = subtitle
61
 
62
  total_result = ''
63
- for file_name,subtitle in files_info.items():
64
- total_result+='------------------------------------\n'
65
- total_result+=f'{file_name}\n\n'
66
- total_result+=f'{subtitle}'
67
 
68
  return f"Done! Subtitle is in the outputs folder.\n\n{total_result}"
69
-
70
- def transcribe_youtube(self,youtubelink
71
- ,model_size,lang,subformat,istranslate,
72
- progress=gr.Progress()):
73
-
74
  def progress_callback(progress_value):
75
- progress(progress_value,desc="Transcribing..")
76
 
77
  if model_size != self.current_model_size:
78
- progress(0,desc="Initializing Model..")
79
  self.current_model_size = model_size
80
- self.model = whisper.load_model(name=model_size,download_root="models")
81
 
82
- if lang == "Automatic Detection" :
83
- lang = None
84
 
85
- progress(0,desc="Loading Audio from Youtube..")
86
  yt = get_ytdata(youtubelink)
87
  audio = whisper.load_audio(get_ytaudio(yt))
88
 
89
- translatable_model = ["large","large-v1","large-v2"]
90
  if istranslate and self.current_model_size in translatable_model:
91
- result = self.model.transcribe(audio=audio,language=lang,verbose=False,task="translate",progress_callback=progress_callback)
92
- else :
93
- result = self.model.transcribe(audio=audio,language=lang,verbose=False,progress_callback=progress_callback)
 
 
94
 
95
- progress(1,desc="Completed!")
96
 
97
  file_name = safe_filename(yt.title)
98
  timestamp = datetime.now().strftime("%m%d%H%M%S")
@@ -100,48 +106,48 @@ class WhisperInference():
100
 
101
  if subformat == "SRT":
102
  subtitle = get_srt(result["segments"])
103
- write_srt(subtitle,f"{output_path}.srt")
104
  elif subformat == "WebVTT":
105
  subtitle = get_vtt(result["segments"])
106
- write_vtt(subtitle,f"{output_path}.vtt")
107
 
108
  return f"Done! Subtitle file is in the outputs folder.\n\n{subtitle}"
109
-
110
- def transcribe_mic(self,micaudio
111
- ,model_size,lang,subformat,istranslate,
112
- progress=gr.Progress()):
113
 
114
  def progress_callback(progress_value):
115
- progress(progress_value,desc="Transcribing..")
116
-
117
  if model_size != self.current_model_size:
118
- progress(0,desc="Initializing Model..")
119
  self.current_model_size = model_size
120
- self.model = whisper.load_model(name=model_size,download_root="models")
121
 
122
- if lang == "Automatic Detection" :
123
- lang = None
124
 
125
- progress(0,desc="Loading Audio..")
126
 
127
- translatable_model = ["large","large-v1","large-v2"]
128
  if istranslate and self.current_model_size in translatable_model:
129
- result = self.model.transcribe(audio=micaudio,language=lang,verbose=False,task="translate",progress_callback=progress_callback)
130
- else :
131
- result = self.model.transcribe(audio=micaudio,language=lang,verbose=False,progress_callback=progress_callback)
 
 
132
 
133
- progress(1,desc="Completed!")
134
 
135
  timestamp = datetime.now().strftime("%m%d%H%M%S")
136
  output_path = f"outputs/Mic-{timestamp}"
137
 
138
  if subformat == "SRT":
139
  subtitle = get_srt(result["segments"])
140
- write_srt(subtitle,f"{output_path}.srt")
141
  elif subformat == "WebVTT":
142
  subtitle = get_vtt(result["segments"])
143
- write_vtt(subtitle,f"{output_path}.vtt")
144
-
145
  return f"Done! Subtitle file is in the outputs folder.\n\n{subtitle}"
146
-
147
-
 
1
  import whisper
2
+ from modules.subtitle_manager import get_srt, get_vtt, write_srt, write_vtt, safe_filename
3
+ from modules.youtube_manager import get_ytdata, get_ytaudio
4
  import gradio as gr
5
  import os
6
  from datetime import datetime
7
 
8
+ DEFAULT_MODEL_SIZE = "large-v2"
9
 
10
+
11
+ class WhisperInference:
12
  def __init__(self):
13
  print("\nInitializing Model..\n")
14
  self.current_model_size = DEFAULT_MODEL_SIZE
15
+ self.model = whisper.load_model(name=DEFAULT_MODEL_SIZE, download_root="models")
16
  self.available_models = whisper.available_models()
17
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
18
 
19
+ def transcribe_file(self, fileobjs
20
+ , model_size, lang, subformat, istranslate,
21
  progress=gr.Progress()):
22
+
23
  def progress_callback(progress_value):
24
+ progress(progress_value, desc="Transcribing..")
25
+
26
  if model_size != self.current_model_size:
27
+ progress(0, desc="Initializing Model..")
28
  self.current_model_size = model_size
29
+ self.model = whisper.load_model(name=model_size, download_root="models")
30
 
31
+ if lang == "Automatic Detection":
32
+ lang = None
33
 
34
+ progress(0, desc="Loading Audio..")
35
 
36
  files_info = {}
37
+ for fileobj in fileobjs:
38
+
39
  audio = whisper.load_audio(fileobj.name)
40
 
41
+ translatable_model = ["large", "large-v1", "large-v2"]
42
  if istranslate and self.current_model_size in translatable_model:
43
+ result = self.model.transcribe(audio=audio, language=lang, verbose=False, task="translate",
44
+ progress_callback=progress_callback)
45
+ else:
46
+ result = self.model.transcribe(audio=audio, language=lang, verbose=False,
47
+ progress_callback=progress_callback)
48
 
49
+ progress(1, desc="Completed!")
50
 
51
  file_name, file_ext = os.path.splitext(os.path.basename(fileobj.orig_name))
52
  file_name = file_name[:-9]
 
56
 
57
  if subformat == "SRT":
58
  subtitle = get_srt(result["segments"])
59
+ write_srt(subtitle, f"{output_path}.srt")
60
  elif subformat == "WebVTT":
61
  subtitle = get_vtt(result["segments"])
62
+ write_vtt(subtitle, f"{output_path}.vtt")
63
 
64
  files_info[file_name] = subtitle
65
 
66
  total_result = ''
67
+ for file_name, subtitle in files_info.items():
68
+ total_result += '------------------------------------\n'
69
+ total_result += f'{file_name}\n\n'
70
+ total_result += f'{subtitle}'
71
 
72
  return f"Done! Subtitle is in the outputs folder.\n\n{total_result}"
73
+
74
+ def transcribe_youtube(self, youtubelink
75
+ , model_size, lang, subformat, istranslate,
76
+ progress=gr.Progress()):
77
+
78
  def progress_callback(progress_value):
79
+ progress(progress_value, desc="Transcribing..")
80
 
81
  if model_size != self.current_model_size:
82
+ progress(0, desc="Initializing Model..")
83
  self.current_model_size = model_size
84
+ self.model = whisper.load_model(name=model_size, download_root="models")
85
 
86
+ if lang == "Automatic Detection":
87
+ lang = None
88
 
89
+ progress(0, desc="Loading Audio from Youtube..")
90
  yt = get_ytdata(youtubelink)
91
  audio = whisper.load_audio(get_ytaudio(yt))
92
 
93
+ translatable_model = ["large", "large-v1", "large-v2"]
94
  if istranslate and self.current_model_size in translatable_model:
95
+ result = self.model.transcribe(audio=audio, language=lang, verbose=False, task="translate",
96
+ progress_callback=progress_callback)
97
+ else:
98
+ result = self.model.transcribe(audio=audio, language=lang, verbose=False,
99
+ progress_callback=progress_callback)
100
 
101
+ progress(1, desc="Completed!")
102
 
103
  file_name = safe_filename(yt.title)
104
  timestamp = datetime.now().strftime("%m%d%H%M%S")
 
106
 
107
  if subformat == "SRT":
108
  subtitle = get_srt(result["segments"])
109
+ write_srt(subtitle, f"{output_path}.srt")
110
  elif subformat == "WebVTT":
111
  subtitle = get_vtt(result["segments"])
112
+ write_vtt(subtitle, f"{output_path}.vtt")
113
 
114
  return f"Done! Subtitle file is in the outputs folder.\n\n{subtitle}"
115
+
116
+ def transcribe_mic(self, micaudio
117
+ , model_size, lang, subformat, istranslate,
118
+ progress=gr.Progress()):
119
 
120
  def progress_callback(progress_value):
121
+ progress(progress_value, desc="Transcribing..")
122
+
123
  if model_size != self.current_model_size:
124
+ progress(0, desc="Initializing Model..")
125
  self.current_model_size = model_size
126
+ self.model = whisper.load_model(name=model_size, download_root="models")
127
 
128
+ if lang == "Automatic Detection":
129
+ lang = None
130
 
131
+ progress(0, desc="Loading Audio..")
132
 
133
+ translatable_model = ["large", "large-v1", "large-v2"]
134
  if istranslate and self.current_model_size in translatable_model:
135
+ result = self.model.transcribe(audio=micaudio, language=lang, verbose=False, task="translate",
136
+ progress_callback=progress_callback)
137
+ else:
138
+ result = self.model.transcribe(audio=micaudio, language=lang, verbose=False,
139
+ progress_callback=progress_callback)
140
 
141
+ progress(1, desc="Completed!")
142
 
143
  timestamp = datetime.now().strftime("%m%d%H%M%S")
144
  output_path = f"outputs/Mic-{timestamp}"
145
 
146
  if subformat == "SRT":
147
  subtitle = get_srt(result["segments"])
148
+ write_srt(subtitle, f"{output_path}.srt")
149
  elif subformat == "WebVTT":
150
  subtitle = get_vtt(result["segments"])
151
+ write_vtt(subtitle, f"{output_path}.vtt")
152
+
153
  return f"Done! Subtitle file is in the outputs folder.\n\n{subtitle}"