meg-huggingface commited on
Commit
d3c28ec
1 Parent(s): e3f7160

Loading per-widget. Various changes to streamlit interactions for efficiency.

Browse files
app.py CHANGED
@@ -101,20 +101,18 @@ def load_or_prepare(ds_args, show_embeddings, use_cache=False):
101
  if use_cache:
102
  logs.warning("Using cache")
103
  dstats = dataset_statistics.DatasetStatisticsCacheClass(CACHE_DIR, **ds_args, use_cache=use_cache)
104
- logs.warning("Loading Dataset")
105
  dstats.load_or_prepare_dataset()
106
- logs.warning("Extracting Labels")
107
  dstats.load_or_prepare_labels()
108
- logs.warning("Computing Text Lengths")
109
  dstats.load_or_prepare_text_lengths()
110
- logs.warning("Computing Duplicates")
111
  dstats.load_or_prepare_text_duplicates()
112
- logs.warning("Extracting Vocabulary")
113
  dstats.load_or_prepare_vocab()
114
- logs.warning("Calculating General Statistics...")
115
  dstats.load_or_prepare_general_stats()
116
- logs.warning("Completed Calculation.")
117
- logs.warning("Calculating Fine-Grained Statistics...")
118
  if show_embeddings:
119
  logs.warning("Loading Embeddings")
120
  dstats.load_or_prepare_embeddings()
@@ -135,6 +133,7 @@ def load_or_prepare_widgets(ds_args, show_embeddings, use_cache=False):
135
  Returns:
136
 
137
  """
 
138
  if not isdir(CACHE_DIR):
139
  logs.warning("Creating cache")
140
  # We need to preprocess everything.
@@ -143,6 +142,8 @@ def load_or_prepare_widgets(ds_args, show_embeddings, use_cache=False):
143
  if use_cache:
144
  logs.warning("Using cache")
145
  dstats = dataset_statistics.DatasetStatisticsCacheClass(CACHE_DIR, **ds_args, use_cache=use_cache)
 
 
146
  # Header widget
147
  dstats.load_or_prepare_dset_peek()
148
  # General stats widget
@@ -157,23 +158,21 @@ def load_or_prepare_widgets(ds_args, show_embeddings, use_cache=False):
157
  dstats.load_or_prepare_text_duplicates()
158
  dstats.load_or_prepare_npmi()
159
  dstats.load_or_prepare_zipf()
160
- # Don't recalculate; we're live
161
- dstats.set_deployment(True)
162
 
163
- def show_column(dstats, ds_name_to_dict, show_embeddings, column_id, use_cache=True):
164
  """
165
  Function for displaying the elements in the right column of the streamlit app.
166
  Args:
167
  ds_name_to_dict (dict): the dataset name and options in dictionary form
168
  show_embeddings (Bool): whether embeddings should we loaded and displayed for this dataset
169
  column_id (str): what column of the dataset the analysis is done on
170
- use_cache (Bool): whether the cache is used by default or not
171
  Returns:
172
  The function displays the information using the functions defined in the st_utils class.
173
  """
174
  # Note that at this point we assume we can use cache; default value is True.
175
  # start showing stuff
176
- title_str = f"### Showing{column_id}: {dstats.dset_name} - {dstats.dset_config} - {'-'.join(dstats.text_field)}"
177
  st.markdown(title_str)
178
  logs.info("showing header")
179
  st_utils.expander_header(dstats, ds_name_to_dict, column_id)
@@ -230,7 +229,7 @@ def main():
230
  else:
231
  logs.warning("Using Single Dataset Mode")
232
  dataset_args = st_utils.sidebar_selection(ds_name_to_dict, "")
233
- dstats = load_or_prepare(dataset_args, show_embeddings, use_cache=use_cache)
234
  show_column(dstats, ds_name_to_dict, show_embeddings, "")
235
 
236
 
 
101
  if use_cache:
102
  logs.warning("Using cache")
103
  dstats = dataset_statistics.DatasetStatisticsCacheClass(CACHE_DIR, **ds_args, use_cache=use_cache)
104
+ logs.warning("Loading dataset")
105
  dstats.load_or_prepare_dataset()
106
+ logs.warning("Loading labels")
107
  dstats.load_or_prepare_labels()
108
+ logs.warning("Loading text lengths")
109
  dstats.load_or_prepare_text_lengths()
110
+ logs.warning("Loading duplicates")
111
  dstats.load_or_prepare_text_duplicates()
112
+ logs.warning("Loading vocabulary")
113
  dstats.load_or_prepare_vocab()
114
+ logs.warning("Loading general statistics...")
115
  dstats.load_or_prepare_general_stats()
 
 
116
  if show_embeddings:
117
  logs.warning("Loading Embeddings")
118
  dstats.load_or_prepare_embeddings()
 
133
  Returns:
134
 
135
  """
136
+
137
  if not isdir(CACHE_DIR):
138
  logs.warning("Creating cache")
139
  # We need to preprocess everything.
 
142
  if use_cache:
143
  logs.warning("Using cache")
144
  dstats = dataset_statistics.DatasetStatisticsCacheClass(CACHE_DIR, **ds_args, use_cache=use_cache)
145
+ # Don't recalculate; we're live
146
+ dstats.set_deployment(True)
147
  # Header widget
148
  dstats.load_or_prepare_dset_peek()
149
  # General stats widget
 
158
  dstats.load_or_prepare_text_duplicates()
159
  dstats.load_or_prepare_npmi()
160
  dstats.load_or_prepare_zipf()
161
+ return dstats
 
162
 
163
+ def show_column(dstats, ds_name_to_dict, show_embeddings, column_id):
164
  """
165
  Function for displaying the elements in the right column of the streamlit app.
166
  Args:
167
  ds_name_to_dict (dict): the dataset name and options in dictionary form
168
  show_embeddings (Bool): whether embeddings should we loaded and displayed for this dataset
169
  column_id (str): what column of the dataset the analysis is done on
 
170
  Returns:
171
  The function displays the information using the functions defined in the st_utils class.
172
  """
173
  # Note that at this point we assume we can use cache; default value is True.
174
  # start showing stuff
175
+ title_str = f"### Showing{column_id}: {dstats.dset_name} - {dstats.dset_config} - {dstats.split_name} - {'-'.join(dstats.text_field)}"
176
  st.markdown(title_str)
177
  logs.info("showing header")
178
  st_utils.expander_header(dstats, ds_name_to_dict, column_id)
 
229
  else:
230
  logs.warning("Using Single Dataset Mode")
231
  dataset_args = st_utils.sidebar_selection(ds_name_to_dict, "")
232
+ dstats = load_or_prepare_widgets(dataset_args, show_embeddings, use_cache=use_cache)
233
  show_column(dstats, ds_name_to_dict, show_embeddings, "")
234
 
235
 
data_measurements/dataset_statistics.py CHANGED
@@ -178,6 +178,7 @@ class DatasetStatisticsCacheClass:
178
  self.dset_config = dset_config
179
  # name of the split to analyze
180
  self.split_name = split_name
 
181
  # which text fields are we analysing?
182
  self.text_field = text_field
183
  # which label fields are we analysing?
@@ -207,6 +208,7 @@ class DatasetStatisticsCacheClass:
207
  self.vocab_counts_df = None
208
  # Vocabulary filtered to remove stopwords
209
  self.vocab_counts_filtered_df = None
 
210
  ## General statistics and duplicates
211
  self.total_words = 0
212
  self.total_open_words = 0
@@ -340,12 +342,13 @@ class DatasetStatisticsCacheClass:
340
  logs.info('Loading cached general stats')
341
  self.load_general_stats()
342
  else:
343
- logs.info('Preparing general stats')
344
- self.prepare_general_stats()
345
- if save:
346
- write_df(self.sorted_top_vocab_df, self.sorted_top_vocab_df_fid)
347
- write_df(self.dup_counts_df, self.dup_counts_df_fid)
348
- write_json(self.general_stats_dict, self.general_stats_json_fid)
 
349
 
350
 
351
  def load_or_prepare_text_lengths(self, save=True):
@@ -362,17 +365,19 @@ class DatasetStatisticsCacheClass:
362
  if (self.use_cache and exists(self.fig_tok_length_fid)):
363
  self.fig_tok_length = read_plotly(self.fig_tok_length_fid)
364
  else:
365
- self.prepare_fig_text_lengths()
366
- if save:
367
- write_plotly(self.fig_tok_length, self.fig_tok_length_fid)
 
368
 
369
  # Text length dataframe
370
  if self.use_cache and exists(self.length_df_fid):
371
  self.length_df = feather.read_feather(self.length_df_fid)
372
  else:
373
- self.prepare_length_df()
374
- if save:
375
- write_df(self.length_df, self.length_df_fid)
 
376
 
377
  # Text length stats.
378
  if self.use_cache and exists(self.length_stats_json_fid):
@@ -382,9 +387,10 @@ class DatasetStatisticsCacheClass:
382
  self.std_length = self.length_stats_dict["std length"]
383
  self.num_uniq_lengths = self.length_stats_dict["num lengths"]
384
  else:
385
- self.prepare_text_length_stats()
386
- if save:
387
- write_json(self.length_stats_dict, self.length_stats_json_fid)
 
388
 
389
  def prepare_length_df(self):
390
  if not self.live:
@@ -481,15 +487,17 @@ class DatasetStatisticsCacheClass:
481
  with open(self.dup_counts_df_fid, "rb") as f:
482
  self.dup_counts_df = feather.read_feather(f)
483
  elif self.dup_counts_df is None:
484
- self.prepare_text_duplicates()
485
- if save:
486
- write_df(self.dup_counts_df, self.dup_counts_df_fid)
 
487
  else:
488
- # This happens when self.dup_counts_df is already defined;
489
- # This happens when general_statistics were calculated first,
490
- # since general statistics requires the number of duplicates
491
- if save:
492
- write_df(self.dup_counts_df, self.dup_counts_df_fid)
 
493
 
494
  def load_general_stats(self):
495
  self.general_stats_dict = json.load(open(self.general_stats_json_fid, encoding="utf-8"))
@@ -815,6 +823,8 @@ class nPMIStatisticsCacheClass:
815
  write_subgroup_npmi_data(subgroup, subgroup_dict, subgroup_files)
816
  with open(joint_npmi_fid, "w+") as f:
817
  joint_npmi_df.to_csv(f)
 
 
818
  logs.info("The joint npmi df is")
819
  logs.info(joint_npmi_df)
820
  return joint_npmi_df
 
178
  self.dset_config = dset_config
179
  # name of the split to analyze
180
  self.split_name = split_name
181
+ # TODO: Chould this be "feature" ?
182
  # which text fields are we analysing?
183
  self.text_field = text_field
184
  # which label fields are we analysing?
 
208
  self.vocab_counts_df = None
209
  # Vocabulary filtered to remove stopwords
210
  self.vocab_counts_filtered_df = None
211
+ self.sorted_top_vocab_df = None
212
  ## General statistics and duplicates
213
  self.total_words = 0
214
  self.total_open_words = 0
 
342
  logs.info('Loading cached general stats')
343
  self.load_general_stats()
344
  else:
345
+ if not self.live:
346
+ logs.info('Preparing general stats')
347
+ self.prepare_general_stats()
348
+ if save:
349
+ write_df(self.sorted_top_vocab_df, self.sorted_top_vocab_df_fid)
350
+ write_df(self.dup_counts_df, self.dup_counts_df_fid)
351
+ write_json(self.general_stats_dict, self.general_stats_json_fid)
352
 
353
 
354
  def load_or_prepare_text_lengths(self, save=True):
 
365
  if (self.use_cache and exists(self.fig_tok_length_fid)):
366
  self.fig_tok_length = read_plotly(self.fig_tok_length_fid)
367
  else:
368
+ if not self.live:
369
+ self.prepare_fig_text_lengths()
370
+ if save:
371
+ write_plotly(self.fig_tok_length, self.fig_tok_length_fid)
372
 
373
  # Text length dataframe
374
  if self.use_cache and exists(self.length_df_fid):
375
  self.length_df = feather.read_feather(self.length_df_fid)
376
  else:
377
+ if not self.live:
378
+ self.prepare_length_df()
379
+ if save:
380
+ write_df(self.length_df, self.length_df_fid)
381
 
382
  # Text length stats.
383
  if self.use_cache and exists(self.length_stats_json_fid):
 
387
  self.std_length = self.length_stats_dict["std length"]
388
  self.num_uniq_lengths = self.length_stats_dict["num lengths"]
389
  else:
390
+ if not self.live:
391
+ self.prepare_text_length_stats()
392
+ if save:
393
+ write_json(self.length_stats_dict, self.length_stats_json_fid)
394
 
395
  def prepare_length_df(self):
396
  if not self.live:
 
487
  with open(self.dup_counts_df_fid, "rb") as f:
488
  self.dup_counts_df = feather.read_feather(f)
489
  elif self.dup_counts_df is None:
490
+ if not self.live:
491
+ self.prepare_text_duplicates()
492
+ if save:
493
+ write_df(self.dup_counts_df, self.dup_counts_df_fid)
494
  else:
495
+ if not self.live:
496
+ # This happens when self.dup_counts_df is already defined;
497
+ # This happens when general_statistics were calculated first,
498
+ # since general statistics requires the number of duplicates
499
+ if save:
500
+ write_df(self.dup_counts_df, self.dup_counts_df_fid)
501
 
502
  def load_general_stats(self):
503
  self.general_stats_dict = json.load(open(self.general_stats_json_fid, encoding="utf-8"))
 
823
  write_subgroup_npmi_data(subgroup, subgroup_dict, subgroup_files)
824
  with open(joint_npmi_fid, "w+") as f:
825
  joint_npmi_df.to_csv(f)
826
+ else:
827
+ joint_npmi_df = pd.DataFrame()
828
  logs.info("The joint npmi df is")
829
  logs.info(joint_npmi_df)
830
  return joint_npmi_df
data_measurements/streamlit_utils.py CHANGED
@@ -126,13 +126,18 @@ def expander_general_stats(dstats, column_id):
126
  str(dstats.text_nan_count)
127
  )
128
  )
129
- st.markdown(
130
- "There are {0} duplicate items in the dataset. "
131
- "For more information about the duplicates, "
132
- "click the 'Duplicates' tab below.".format(
133
- str(dstats.dedup_total)
 
 
134
  )
135
- )
 
 
 
136
 
137
 
138
  ### Show the label distribution from the datasets
 
126
  str(dstats.text_nan_count)
127
  )
128
  )
129
+ if dstats.dedup_total > 0:
130
+ st.markdown(
131
+ "There are {0} duplicate items in the dataset. "
132
+ "For more information about the duplicates, "
133
+ "click the 'Duplicates' tab below.".format(
134
+ str(dstats.dedup_total)
135
+ )
136
  )
137
+ else:
138
+ st.markdown(
139
+ "There are 0 duplicate items in the dataset. ")
140
+
141
 
142
 
143
  ### Show the label distribution from the datasets