Petr Tsvetkov commited on
Commit
a52ecf0
1 Parent(s): 8b55f41

Apply fix to the HF dataset saver

Browse files
Files changed (2) hide show
  1. app.py +4 -3
  2. hf_dataset_saver_builder.py +71 -0
app.py CHANGED
@@ -7,15 +7,16 @@ from difflib import ndiff
7
  import gradio as gr
8
 
9
  from data_loader import load_data
 
10
 
11
- HF_TOKEN = os.environ.get('HF_TOKEN')
12
- HF_DATASET = os.environ.get('HF_DATASET')
13
 
14
  data = load_data()
15
 
16
  n_samples = len(data)
17
 
18
- saver = gr.HuggingFaceDatasetSaver(HF_TOKEN, HF_DATASET, private=True, separate_dirs=True)
19
 
20
 
21
  def convert_diff_to_unified(diff):
 
7
  import gradio as gr
8
 
9
  from data_loader import load_data
10
+ from hf_dataset_saver_builder import get_dataset_saver
11
 
12
+ HF_TOKEN = os.environ.get('HF_REWRITING_TOKEN')
13
+ HF_DATASET = os.environ.get('HF_REWRITING_DATASET')
14
 
15
  data = load_data()
16
 
17
  n_samples = len(data)
18
 
19
+ saver = get_dataset_saver(HF_TOKEN, HF_DATASET, private=True, separate_dirs=True)
20
 
21
 
22
  def convert_diff_to_unified(diff):
hf_dataset_saver_builder.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from pathlib import Path
3
+ from typing import Any
4
+
5
+ import gradio as gr
6
+
7
+
8
+ def _deserialize_components_fix(
9
+ self,
10
+ data_dir: Path,
11
+ flag_data: list[Any],
12
+ flag_option: str = "",
13
+ username: str = "",
14
+ ) -> tuple[dict[Any, Any], list[Any]]:
15
+ """Deserialize components and return the corresponding row for the flagged sample.
16
+
17
+ Images/audio are saved to disk as individual files.
18
+ """
19
+ # Components that can have a preview on dataset repos
20
+ file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"}
21
+
22
+ # Generate the row corresponding to the flagged sample
23
+ features = OrderedDict()
24
+ row = []
25
+ for component, sample in zip(self.components, flag_data):
26
+ # Get deserialized object (will save sample to disk if applicable -file, audio, image,...-)
27
+ label = component.label or ""
28
+ save_dir = data_dir / gr.flagging.client_utils.strip_invalid_filename_characters(label)
29
+ save_dir.mkdir(exist_ok=True, parents=True)
30
+ deserialized = component.flag(sample, save_dir)
31
+
32
+ # Add deserialized object to row
33
+ features[label] = {"dtype": "string", "_type": "Value"}
34
+ try:
35
+ assert Path(deserialized).exists()
36
+ row.append(str(Path(deserialized).relative_to(self.dataset_dir)))
37
+ except (AssertionError, TypeError, ValueError, OSError):
38
+ deserialized = "" if deserialized is None else str(deserialized)
39
+ row.append(deserialized)
40
+
41
+ # If component is eligible for a preview, add the URL of the file
42
+ # Be mindful that images and audio can be None
43
+ if isinstance(component, tuple(file_preview_types)): # type: ignore
44
+ for _component, _type in file_preview_types.items():
45
+ if isinstance(component, _component):
46
+ features[label + " file"] = {"_type": _type}
47
+ break
48
+ if deserialized:
49
+ path_in_repo = str( # returned filepath is absolute, we want it relative to compute URL
50
+ Path(deserialized).relative_to(self.dataset_dir)
51
+ ).replace("\\", "/")
52
+ row.append(
53
+ gr.flagging.huggingface_hub.hf_hub_url(
54
+ repo_id=self.dataset_id,
55
+ filename=path_in_repo,
56
+ repo_type="dataset",
57
+ )
58
+ )
59
+ else:
60
+ row.append("")
61
+ features["flag"] = {"dtype": "string", "_type": "Value"}
62
+ features["username"] = {"dtype": "string", "_type": "Value"}
63
+ row.append(flag_option)
64
+ row.append(username)
65
+ return features, row
66
+
67
+
68
+ def get_dataset_saver(*args, **kwargs):
69
+ saver = gr.HuggingFaceDatasetSaver(*args, **kwargs)
70
+ saver._deserialize_components = _deserialize_components_fix
71
+ return saver