Benjamin Bossan commited on
Commit
dd68837
1 Parent(s): 2923fea

Add HF repo creation feature

Browse files
Files changed (6) hide show
  1. app.py +29 -3
  2. create.py +129 -0
  3. edit.py +43 -14
  4. start.py +14 -10
  5. tasks.py +56 -10
  6. utils.py +18 -1
app.py CHANGED
@@ -4,14 +4,40 @@ This ties together the different parts of the app.
4
 
5
  """
6
 
 
 
 
 
7
  import streamlit as st
8
 
9
- from start import start_input_form
10
  from edit import edit_input_form
 
 
 
 
 
 
 
 
 
11
 
12
  st.header("Skops model card creator")
13
 
14
- if not st.session_state.get("model_card"):
 
 
 
 
 
 
 
 
 
15
  start_input_form()
16
- else:
17
  edit_input_form()
 
 
 
 
 
4
 
5
  """
6
 
7
+ from pathlib import Path
8
+ from tempfile import mkdtemp
9
+ from typing import Literal
10
+
11
  import streamlit as st
12
 
13
+ from create import create_repo_input_form
14
  from edit import edit_input_form
15
+ from start import start_input_form
16
+
17
+
18
+ # Create a hf_path, which is where the repo will be created locally. When the
19
+ # session is created, copy the dummy cat.png file there and make it the cwd
20
+ if "hf_path" not in st.session_state:
21
+ hf_path = Path(mkdtemp(prefix="skops-"))
22
+ st.session_state.hf_path = hf_path
23
+
24
 
25
  st.header("Skops model card creator")
26
 
27
+
28
+ class Screen:
29
+ state: Literal["start", "edit", "create_repo"] = "start"
30
+
31
+
32
+ if "screen" not in st.session_state:
33
+ st.session_state.screen: Screen = Screen()
34
+
35
+
36
+ if st.session_state.screen.state == "start":
37
  start_input_form()
38
+ elif st.session_state.screen.state == "edit":
39
  edit_input_form()
40
+ elif st.session_state.screen.state == "create_repo":
41
+ create_repo_input_form()
42
+ else:
43
+ st.write("Something went wrong, please open an issue")
create.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import streamlit as st
5
+ from skops import hub_utils
6
+
7
+ from utils import get_rendered_model_card
8
+
9
+
10
+ def _add_back_button():
11
+ def fn():
12
+ st.session_state.screen.state = "edit"
13
+
14
+ st.button("Back", help="continue editing the model card", on_click=fn)
15
+
16
+
17
+ def _add_delete_button():
18
+ def fn():
19
+ st.session_state.screen.state = "start"
20
+ if "model_card" in st.session_state:
21
+ del st.session_state["model_card"]
22
+ if "task_state" in st.session_state:
23
+ st.session_state.task_state.reset()
24
+ if "create_repo_name" in st.session_state:
25
+ del st.session_state["create_repo_name"]
26
+ if "hf_token" in st.session_state:
27
+ del st.session_state["hf_token"]
28
+
29
+ st.button("Delete", on_click=fn, help="Start over from scratch (lose all progress)")
30
+
31
+
32
+ def _save_model_card(path: Path) -> None:
33
+ model_card = st.session_state.get("model_card")
34
+ if model_card:
35
+ # do not use model_card.save, see doc of get_rendered_model_card
36
+ rendered = get_rendered_model_card(
37
+ model_card, hf_path=str(st.session_state.hf_path)
38
+ )
39
+ with open(path / "README.md", "w") as f:
40
+ f.write(rendered)
41
+
42
+
43
+ def _display_repo_overview(path: Path) -> None:
44
+ text = "Files included in the repository:\n"
45
+ for file in os.listdir(path):
46
+ size = os.path.getsize(path / file)
47
+ text += f"- `{file} ({size:,} bytes)`\n"
48
+ st.markdown(text)
49
+
50
+
51
+ def _display_private_box():
52
+ tip = (
53
+ "Private repositories can only seen by you or members of the same "
54
+ "organization, see https://huggingface.co/docs/hub/repositories-settings"
55
+ )
56
+ st.checkbox("Make repo private", value=True, help=tip, key="create_repo_private")
57
+
58
+
59
+ def _repo_id_field():
60
+ st.text_input("Name of the repository (e.g. 'User/MyRepo')", key="create_repo_name")
61
+
62
+
63
+ def _hf_token_field():
64
+ tip = "The Hugging Face token can be found at https://hf.co/settings/token"
65
+ st.text_input(
66
+ "Enter your Hugging Face token ('hf_***')", key="hf_token", help=tip
67
+ )
68
+
69
+
70
+ def _create_hf_repo(path, repo_name, hf_token, private):
71
+ try:
72
+ hub_utils.push(
73
+ repo_id=repo_name,
74
+ source=path,
75
+ token=hf_token,
76
+ private=private,
77
+ create_remote=True,
78
+ )
79
+ except Exception as exc:
80
+ st.error(
81
+ "Oops, something went wrong, please create an issue. "
82
+ f"The error message is:\n\n{exc}"
83
+ )
84
+ return
85
+
86
+ st.success(f"Successfully created the repo 'https://huggingface.co/{repo_name}'")
87
+
88
+
89
+ def _add_create_repo_button():
90
+ private = bool(st.session_state.get("create_repo_private"))
91
+ repo_name = st.session_state.get("create_repo_name")
92
+ hf_token = st.session_state.get("hf_token")
93
+ disabled = (not repo_name) or (not hf_token)
94
+
95
+ button_text = "Create a new repository"
96
+ tip = "Creating a repo requires a name and a token"
97
+ path = st.session_state.get("hf_path")
98
+ st.button(
99
+ button_text,
100
+ help=tip,
101
+ disabled=disabled,
102
+ on_click=_create_hf_repo,
103
+ args=(path, repo_name, hf_token, private),
104
+ )
105
+
106
+ if not repo_name:
107
+ st.info("Repository name is required")
108
+ if not hf_token:
109
+ st.info("Token is required")
110
+
111
+
112
+ def create_repo_input_form():
113
+ if not st.session_state.screen.state == "create_repo":
114
+ return
115
+
116
+ col_0, col_1, *_ = st.columns([2, 2, 2, 2])
117
+ with col_0:
118
+ _add_back_button()
119
+ with col_1:
120
+ _add_delete_button()
121
+
122
+ hf_path = st.session_state.hf_path
123
+ _save_model_card(hf_path)
124
+ _display_repo_overview(hf_path)
125
+ _display_private_box()
126
+ st.markdown("---")
127
+ _repo_id_field()
128
+ _hf_token_field()
129
+ _add_create_repo_button()
edit.py CHANGED
@@ -35,7 +35,11 @@ from huggingface_hub import hf_hub_download
35
  from skops import card
36
  from skops.card._model_card import PlotSection, split_subsection_names
37
 
38
- from utils import iterate_key_section_content, process_card_for_rendering
 
 
 
 
39
  from tasks import (
40
  AddMetricsTask,
41
  AddSectionTask,
@@ -94,17 +98,20 @@ def _update_model_card(
94
  return
95
 
96
  if is_fig:
97
- fpath = None
98
  if new_content: # new figure uploaded
99
  fname = new_content.name.replace(" ", "_")
100
- fpath = tmp_path / fname
 
 
101
  task = UpdateFigureTask(
102
  model_card,
103
  key=key,
104
  old_name=section_name,
105
  new_name=new_title,
106
  data=new_content,
107
- path=fpath,
 
108
  )
109
  else:
110
  task = UpdateSectionTask(
@@ -128,12 +135,17 @@ def _add_section(model_card: card.Card, key: str) -> None:
128
 
129
  def _add_figure(model_card: card.Card, key: str) -> None:
130
  section_name = f"{key}/Untitled"
131
- task = AddFigureTask(model_card, title=section_name, content="cat.png")
 
 
 
132
  st.session_state.task_state.add(task)
133
 
134
 
135
- def _delete_section(model_card: card.Card, key: str) -> None:
136
- task = DeleteSectionTask(model_card, key=key)
 
 
137
  st.session_state.task_state.add(task)
138
 
139
 
@@ -197,10 +209,11 @@ def create_form_from_section(
197
 
198
  col_0, col_1, col_2 = st.columns([4, 2, 2])
199
  with col_0:
 
200
  st.button(
201
  f"delete '{arepr.repr(old_title)}'",
202
  on_click=_delete_section,
203
- args=(model_card, key),
204
  key=f"{key}.delete",
205
  help="Delete this section, including all its subsections",
206
  )
@@ -267,6 +280,7 @@ def reset_model_card() -> None:
267
 
268
 
269
  def delete_model_card() -> None:
 
270
  if "model_card" in st.session_state:
271
  del st.session_state["model_card"]
272
  if "task_state" in st.session_state:
@@ -284,19 +298,32 @@ def redo_last():
284
 
285
 
286
  def add_download_model_card_button():
287
- model_card = st.session_state.get("model_card")
288
- download_disabled = not bool(model_card)
289
- data = model_card.render()
 
290
  tip = "Download the generated model card as markdown file"
291
  st.download_button(
292
  "Save (md)",
293
  data=data,
294
- disabled=download_disabled,
295
  help=tip,
296
  file_name="README.md",
297
  )
298
 
299
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  def display_edit_buttons():
301
  # first row: undo + redo + reset
302
  col_0, col_1, col_2, *_ = st.columns([2, 2, 2, 2])
@@ -314,11 +341,13 @@ def display_edit_buttons():
314
  tip = "Undo all edits"
315
  st.button("Reset", on_click=reset_model_card, help=tip)
316
 
317
- # second row: download + delete
318
- col_0, col_1, *_ = st.columns([2, 2, 2, 2])
319
  with col_0:
320
  add_download_model_card_button()
321
  with col_1:
 
 
322
  tip = "Start over from scratch (lose all progress)"
323
  st.button("Delete", on_click=delete_model_card, help=tip)
324
 
 
35
  from skops import card
36
  from skops.card._model_card import PlotSection, split_subsection_names
37
 
38
+ from utils import (
39
+ get_rendered_model_card,
40
+ iterate_key_section_content,
41
+ process_card_for_rendering,
42
+ )
43
  from tasks import (
44
  AddMetricsTask,
45
  AddSectionTask,
 
98
  return
99
 
100
  if is_fig:
101
+ old_path, fpath = None, None
102
  if new_content: # new figure uploaded
103
  fname = new_content.name.replace(" ", "_")
104
+ fpath = st.session_state.hf_path / fname
105
+ old_path = fpath.parent / model_card.select(key).content.path
106
+
107
  task = UpdateFigureTask(
108
  model_card,
109
  key=key,
110
  old_name=section_name,
111
  new_name=new_title,
112
  data=new_content,
113
+ new_path=fpath,
114
+ old_path=old_path,
115
  )
116
  else:
117
  task = UpdateSectionTask(
 
135
 
136
  def _add_figure(model_card: card.Card, key: str) -> None:
137
  section_name = f"{key}/Untitled"
138
+ hf_path = st.session_state.hf_path
139
+ task = AddFigureTask(
140
+ model_card, path=hf_path, title=section_name, content="cat.png"
141
+ )
142
  st.session_state.task_state.add(task)
143
 
144
 
145
+ def _delete_section(
146
+ model_card: card.Card, key: str, path: Path
147
+ ) -> None:
148
+ task = DeleteSectionTask(model_card, key=key, path=path)
149
  st.session_state.task_state.add(task)
150
 
151
 
 
209
 
210
  col_0, col_1, col_2 = st.columns([4, 2, 2])
211
  with col_0:
212
+ path = st.session_state.hf_path / content.path if is_fig else None
213
  st.button(
214
  f"delete '{arepr.repr(old_title)}'",
215
  on_click=_delete_section,
216
+ args=(model_card, key, path),
217
  key=f"{key}.delete",
218
  help="Delete this section, including all its subsections",
219
  )
 
280
 
281
 
282
  def delete_model_card() -> None:
283
+ st.session_state.screen.state = "start"
284
  if "model_card" in st.session_state:
285
  del st.session_state["model_card"]
286
  if "task_state" in st.session_state:
 
298
 
299
 
300
  def add_download_model_card_button():
301
+ model_card = st.session_state.model_card
302
+ data = get_rendered_model_card(
303
+ model_card, hf_path=str(st.session_state.hf_path)
304
+ )
305
  tip = "Download the generated model card as markdown file"
306
  st.download_button(
307
  "Save (md)",
308
  data=data,
 
309
  help=tip,
310
  file_name="README.md",
311
  )
312
 
313
 
314
+ def add_create_repo_button():
315
+ def fn():
316
+ st.session_state.screen.state = "create_repo"
317
+
318
+ button_disabled = not bool(st.session_state.get("model_card"))
319
+ st.button(
320
+ "Create Repo",
321
+ help="Create a model repository on Hugging Face Hub",
322
+ on_click=fn,
323
+ disabled=button_disabled,
324
+ )
325
+
326
+
327
  def display_edit_buttons():
328
  # first row: undo + redo + reset
329
  col_0, col_1, col_2, *_ = st.columns([2, 2, 2, 2])
 
341
  tip = "Undo all edits"
342
  st.button("Reset", on_click=reset_model_card, help=tip)
343
 
344
+ # second row: download + create repo + delete
345
+ col_0, col_1, col_2, *_ = st.columns([2, 2, 2, 2])
346
  with col_0:
347
  add_download_model_card_button()
348
  with col_1:
349
+ add_create_repo_button()
350
+ with col_2:
351
  tip = "Start over from scratch (lose all progress)"
352
  st.button("Delete", on_click=delete_model_card, help=tip)
353
 
start.py CHANGED
@@ -30,7 +30,6 @@ import skops.io as sio
30
  from skops import card, hub_utils
31
 
32
 
33
- hf_path = Path(mkdtemp(prefix="skops-")) # hf repo
34
  tmp_path = Path(mkdtemp(prefix="skops-")) # temporary files
35
  description = """Create an sklearn model card
36
 
@@ -70,7 +69,8 @@ def _clear_repo(path: str) -> None:
70
  shutil.rmtree(file_path)
71
 
72
 
73
- def init_repo(path: str) -> None:
 
74
  _clear_repo(path)
75
  requirements = []
76
  task = "tabular-classification"
@@ -104,26 +104,28 @@ def init_repo(path: str) -> None:
104
 
105
 
106
  def create_skops_model_card() -> None:
107
- init_repo(hf_path)
108
- metadata = card.metadata_from_config(hf_path)
109
  model_card = card.Card(model=st.session_state.model, metadata=metadata)
110
  st.session_state.model_card = model_card
111
  st.session_state.model_card_type = "skops"
 
112
 
113
 
114
  def create_empty_model_card() -> None:
115
- init_repo(hf_path)
116
- metadata = card.metadata_from_config(hf_path)
117
  model_card = card.Card(
118
  model=st.session_state.model, metadata=metadata, template=None
119
  )
120
  model_card.add(**{"Untitled": "[More Information Needed]"})
121
  st.session_state.model_card = model_card
122
  st.session_state.model_card_type = "empty"
 
123
 
124
 
125
  def create_hf_model_card() -> None:
126
- repo_id = st.session_state.get("hf_repo_id", "").strip("'").strip('"')
127
  if not repo_id:
128
  return
129
 
@@ -139,6 +141,7 @@ def create_hf_model_card() -> None:
139
  model_card = card.parse_modelcard(path)
140
  st.session_state.model_card = model_card
141
  st.session_state.model_card_type = "loaded"
 
142
 
143
 
144
  def start_input_form():
@@ -158,7 +161,10 @@ def start_input_form():
158
  "Upload an sklearn model (strongly recommended)\n"
159
  "The model can be used to automatically populate fields in the model card."
160
  )
161
- st.file_uploader("Upload a model*", on_change=load_model, key="model_file")
 
 
 
162
  st.markdown("---")
163
 
164
  st.text(
@@ -180,7 +186,6 @@ def start_input_form():
180
  ],
181
  key="task",
182
  on_change=init_repo,
183
- args=(hf_path,),
184
  )
185
  st.markdown("---")
186
 
@@ -189,7 +194,6 @@ def start_input_form():
189
  value=f"scikit-learn=={sklearn.__version__}\n",
190
  key="requirements",
191
  on_change=init_repo,
192
- args=(hf_path,),
193
  )
194
  st.markdown("---")
195
 
 
30
  from skops import card, hub_utils
31
 
32
 
 
33
  tmp_path = Path(mkdtemp(prefix="skops-")) # temporary files
34
  description = """Create an sklearn model card
35
 
 
69
  shutil.rmtree(file_path)
70
 
71
 
72
+ def init_repo() -> None:
73
+ path = st.session_state.hf_path
74
  _clear_repo(path)
75
  requirements = []
76
  task = "tabular-classification"
 
104
 
105
 
106
  def create_skops_model_card() -> None:
107
+ init_repo()
108
+ metadata = card.metadata_from_config(st.session_state.hf_path)
109
  model_card = card.Card(model=st.session_state.model, metadata=metadata)
110
  st.session_state.model_card = model_card
111
  st.session_state.model_card_type = "skops"
112
+ st.session_state.screen.state = "edit"
113
 
114
 
115
  def create_empty_model_card() -> None:
116
+ init_repo()
117
+ metadata = card.metadata_from_config(st.session_state.hf_path)
118
  model_card = card.Card(
119
  model=st.session_state.model, metadata=metadata, template=None
120
  )
121
  model_card.add(**{"Untitled": "[More Information Needed]"})
122
  st.session_state.model_card = model_card
123
  st.session_state.model_card_type = "empty"
124
+ st.session_state.screen.state = "edit"
125
 
126
 
127
  def create_hf_model_card() -> None:
128
+ repo_id = st.session_state.get("hf_repo_id", "").strip().strip("'").strip('"')
129
  if not repo_id:
130
  return
131
 
 
141
  model_card = card.parse_modelcard(path)
142
  st.session_state.model_card = model_card
143
  st.session_state.model_card_type = "loaded"
144
+ st.session_state.screen.state = "edit"
145
 
146
 
147
  def start_input_form():
 
161
  "Upload an sklearn model (strongly recommended)\n"
162
  "The model can be used to automatically populate fields in the model card."
163
  )
164
+
165
+ if not st.session_state.get("model_file"):
166
+ st.file_uploader("Upload a model*", on_change=load_model, key="model_file")
167
+
168
  st.markdown("---")
169
 
170
  st.text(
 
186
  ],
187
  key="task",
188
  on_change=init_repo,
 
189
  )
190
  st.markdown("---")
191
 
 
194
  value=f"scikit-learn=={sklearn.__version__}\n",
195
  key="requirements",
196
  on_change=init_repo,
 
197
  )
198
  st.markdown("---")
199
 
tasks.py CHANGED
@@ -5,7 +5,9 @@ Tasks are used to implement "undo" and "redo" functionality.
5
  """
6
  from __future__ import annotations
7
 
 
8
  from pathlib import Path
 
9
  from uuid import uuid4
10
 
11
  from skops import card
@@ -80,26 +82,42 @@ class AddSectionTask(Task):
80
 
81
 
82
  class AddFigureTask(Task):
83
- """Add a new figure section"""
 
 
 
 
84
 
85
  def __init__(
86
  self,
87
  model_card: card.Card,
 
88
  title: str,
89
  content: str,
90
  ) -> None:
91
  self.model_card = model_card
92
  self.title = title
93
- self.key = title + " " + str(uuid4())[:6]
94
- self.content = content
 
 
 
 
 
 
 
 
 
95
 
96
  def do(self) -> None:
 
97
  self.model_card.add_plot(**{self.key: self.content})
98
  section = self.model_card.select(self.key)
99
  section.title = split_subsection_names(self.title)[-1]
100
  section.is_fig = True # type: ignore
101
 
102
  def undo(self) -> None:
 
103
  self.model_card.delete(self.key)
104
 
105
 
@@ -115,15 +133,23 @@ class DeleteSectionTask(Task):
115
  self,
116
  model_card: card.Card,
117
  key: str,
 
118
  ) -> None:
119
  self.model_card = model_card
120
  self.key = key
 
 
 
121
 
122
  def do(self) -> None:
123
  self.model_card.select(self.key).visible = False
 
 
124
 
125
  def undo(self) -> None:
126
  self.model_card.select(self.key).visible = True
 
 
127
 
128
 
129
  class UpdateSectionTask(Task):
@@ -159,7 +185,20 @@ class UpdateSectionTask(Task):
159
 
160
 
161
  class UpdateFigureTask(Task):
162
- """Change the title or image of a figure section"""
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
  def __init__(
165
  self,
@@ -168,14 +207,18 @@ class UpdateFigureTask(Task):
168
  old_name: str,
169
  new_name: str,
170
  data: UploadedFile | None,
171
- path: Path | None,
 
172
  ) -> None:
173
  self.model_card = model_card
174
  self.key = key
175
  self.old_name = old_name
176
  self.new_name = new_name
177
  self.old_data = self.model_card.select(self.key).content
178
- self.path = path
 
 
 
179
 
180
  if not data:
181
  self.new_data = self.old_data
@@ -192,12 +235,14 @@ class UpdateFigureTask(Task):
192
  # write figure
193
  # note: this can still be the same image if the image is a file, there
194
  # is no test to check, e.g., the hash of the image
195
- with open(self.path, "wb") as f:
 
 
196
  f.write(self.new_data.getvalue())
197
  section.content = PlotSection(
198
  alt_text=self.new_data.name,
199
- path=self.path,
200
- ).format()
201
 
202
  def undo(self) -> None:
203
  section = self.model_card.select(self.key)
@@ -206,7 +251,8 @@ class UpdateFigureTask(Task):
206
  if self.new_data == self.old_data: # image is same
207
  return
208
 
209
- self.path.unlink(missing_ok=True)
 
210
  section.content = self.old_data
211
 
212
 
 
5
  """
6
  from __future__ import annotations
7
 
8
+ import shutil
9
  from pathlib import Path
10
+ from tempfile import mkdtemp
11
  from uuid import uuid4
12
 
13
  from skops import card
 
82
 
83
 
84
  class AddFigureTask(Task):
85
+ """Add a new figure section
86
+
87
+ Figure always starts out with dummy image cat.png.
88
+
89
+ """
90
 
91
  def __init__(
92
  self,
93
  model_card: card.Card,
94
+ path: Path,
95
  title: str,
96
  content: str,
97
  ) -> None:
98
  self.model_card = model_card
99
  self.title = title
100
+
101
+ # Create a unique file name, since the same image can exist more than
102
+ # once per model card.
103
+ fname = Path(content)
104
+ stem = fname.stem
105
+ suffix = fname.suffix
106
+ uniq = str(uuid4())[:6]
107
+ new_fname = str(path / stem) + "_" + uniq + suffix
108
+
109
+ self.key = title + " " + uniq
110
+ self.content = Path(new_fname)
111
 
112
  def do(self) -> None:
113
+ shutil.copy("cat.png", self.content)
114
  self.model_card.add_plot(**{self.key: self.content})
115
  section = self.model_card.select(self.key)
116
  section.title = split_subsection_names(self.title)[-1]
117
  section.is_fig = True # type: ignore
118
 
119
  def undo(self) -> None:
120
+ self.content.unlink(missing_ok=True)
121
  self.model_card.delete(self.key)
122
 
123
 
 
133
  self,
134
  model_card: card.Card,
135
  key: str,
136
+ path: Path | None,
137
  ) -> None:
138
  self.model_card = model_card
139
  self.key = key
140
+ # when 'deleting' a file, move it to a temp file
141
+ self.path = path
142
+ self.tmp_path = Path(mkdtemp(prefix="skops-")) / str(uuid4())
143
 
144
  def do(self) -> None:
145
  self.model_card.select(self.key).visible = False
146
+ if self.path:
147
+ shutil.move(self.path, self.tmp_path)
148
 
149
  def undo(self) -> None:
150
  self.model_card.select(self.key).visible = True
151
+ if self.path:
152
+ shutil.move(self.tmp_path, self.path)
153
 
154
 
155
  class UpdateSectionTask(Task):
 
185
 
186
 
187
  class UpdateFigureTask(Task):
188
+ """Change the title or image of a figure section
189
+
190
+ Changing the title is easy, just replace it and be done.
191
+
192
+ Changing the figure is a bit more tricky. The old figure is in the hf_path
193
+ under its old name. The new figure is an UploadFile object. For the DO
194
+ operation, move the old figure to a temporary file and store the UploadFile
195
+ content to a new file (which may have a different name).
196
+
197
+ For the UNDO operation, delete the new figure (its content is still stored
198
+ in the UploadFile) and move back the old figure from its temporary file to
199
+ the original location (with its original name).
200
+
201
+ """
202
 
203
  def __init__(
204
  self,
 
207
  old_name: str,
208
  new_name: str,
209
  data: UploadedFile | None,
210
+ new_path: Path | None,
211
+ old_path: Path | None,
212
  ) -> None:
213
  self.model_card = model_card
214
  self.key = key
215
  self.old_name = old_name
216
  self.new_name = new_name
217
  self.old_data = self.model_card.select(self.key).content
218
+ self.new_path = new_path
219
+ self.old_path = old_path
220
+ # when 'deleting' the old image, move to temp path
221
+ self.tmp_path = Path(mkdtemp(prefix="skops-")) / str(uuid4())
222
 
223
  if not data:
224
  self.new_data = self.old_data
 
235
  # write figure
236
  # note: this can still be the same image if the image is a file, there
237
  # is no test to check, e.g., the hash of the image
238
+ shutil.move(self.old_path, self.tmp_path)
239
+
240
+ with open(self.new_path, "wb") as f:
241
  f.write(self.new_data.getvalue())
242
  section.content = PlotSection(
243
  alt_text=self.new_data.name,
244
+ path=self.new_path,
245
+ )
246
 
247
  def undo(self) -> None:
248
  section = self.model_card.select(self.key)
 
251
  if self.new_data == self.old_data: # image is same
252
  return
253
 
254
+ self.new_path.unlink(missing_ok=True)
255
+ shutil.move(self.tmp_path, self.old_path)
256
  section.content = self.old_data
257
 
258
 
utils.py CHANGED
@@ -3,14 +3,31 @@
3
  from __future__ import annotations
4
 
5
  import base64
 
6
  import re
7
  from dataclasses import dataclass
8
  from pathlib import Path
9
- from typing import Iterator
10
 
 
11
  from skops.card._model_card import Section
12
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def process_card_for_rendering(rendered: str) -> tuple[str, str]:
15
  idx = rendered[1:].index("\n---") + 1
16
  metadata = rendered[3:idx]
 
3
  from __future__ import annotations
4
 
5
  import base64
6
+ import os
7
  import re
8
  from dataclasses import dataclass
9
  from pathlib import Path
 
10
 
11
+ from skops import card
12
  from skops.card._model_card import Section
13
 
14
 
15
+ def get_rendered_model_card(model_card: card.Card, hf_path: str) -> str:
16
+ # This is a bit hacky:
17
+ # As a space, the model card is created in a temporary hf_path directory,
18
+ # which is where all the files are put. So e.g. if a figure is added, it is
19
+ # found at /tmp/skops-jtyqdgk3/fig.png. However, when the model card is is
20
+ # actually used, we don't want that, since there, the files will be in the
21
+ # cwd. Therefore, we remove the tmp directory everywhere we find it in the
22
+ # file.
23
+ if not hf_path.endswith(os.path.sep):
24
+ hf_path += os.path.sep
25
+
26
+ rendered = model_card.render()
27
+ rendered = rendered.replace(hf_path, "")
28
+ return rendered
29
+
30
+
31
  def process_card_for_rendering(rendered: str) -> tuple[str, str]:
32
  idx = rendered[1:].index("\n---") + 1
33
  metadata = rendered[3:idx]