Benjamin Bossan commited on
Commit
ba2892a
1 Parent(s): 41e8f46

Make it possible to load skops model format

Browse files
Files changed (2) hide show
  1. make-data.py +2 -0
  2. start.py +10 -3
make-data.py CHANGED
@@ -3,6 +3,7 @@
3
  import pickle
4
 
5
  import pandas as pd
 
6
  from sklearn.datasets import make_classification
7
  from sklearn.linear_model import LogisticRegression
8
  from sklearn.pipeline import Pipeline
@@ -21,6 +22,7 @@ clf.fit(X, y)
21
 
22
  with open("logreg.pkl", "wb") as f:
23
  pickle.dump(clf, f)
 
24
 
25
 
26
  df.to_csv("data.csv", index=False)
 
3
  import pickle
4
 
5
  import pandas as pd
6
+ import skops.io as sio
7
  from sklearn.datasets import make_classification
8
  from sklearn.linear_model import LogisticRegression
9
  from sklearn.pipeline import Pipeline
 
22
 
23
  with open("logreg.pkl", "wb") as f:
24
  pickle.dump(clf, f)
25
+ sio.dump(clf, "logreg.skops")
26
 
27
 
28
  df.to_csv("data.csv", index=False)
start.py CHANGED
@@ -45,7 +45,10 @@ def load_model() -> None:
45
  return
46
 
47
  bytes_data = st.session_state.model_file.getvalue()
48
- model = pickle.loads(bytes_data)
 
 
 
49
  assert isinstance(model, BaseEstimator), "model must be an sklearn model"
50
 
51
  st.session_state.model = model
@@ -167,7 +170,11 @@ def start_input_form():
167
  )
168
 
169
  if not st.session_state.get("model_file"):
170
- st.file_uploader("Upload a model*", on_change=load_model, key="model_file")
 
 
 
 
171
 
172
  st.markdown("---")
173
 
@@ -176,7 +183,7 @@ def start_input_form():
176
  "This sample data can be attached to the metadata of the model card"
177
  )
178
  st.file_uploader(
179
- "Upload X data (csv)*", type=["csv"], on_change=load_data, key="data_file"
180
  )
181
  st.markdown("---")
182
 
 
45
  return
46
 
47
  bytes_data = st.session_state.model_file.getvalue()
48
+ if st.session_state.model_file.name.endswith("skops"):
49
+ model = sio.loads(bytes_data, trusted=True)
50
+ else:
51
+ model = pickle.loads(bytes_data)
52
  assert isinstance(model, BaseEstimator), "model must be an sklearn model"
53
 
54
  st.session_state.model = model
 
170
  )
171
 
172
  if not st.session_state.get("model_file"):
173
+ st.file_uploader(
174
+ "Upload an sklearn model (pickle or skops format)",
175
+ on_change=load_model,
176
+ key="model_file",
177
+ )
178
 
179
  st.markdown("---")
180
 
 
183
  "This sample data can be attached to the metadata of the model card"
184
  )
185
  st.file_uploader(
186
+ "Upload input data (csv)", type=["csv"], on_change=load_data, key="data_file"
187
  )
188
  st.markdown("---")
189