Spaces:
Runtime error
Runtime error
Joshua Lochner
commited on
Commit
•
ad7fc61
1
Parent(s):
a6de017
Use classifier category if transformer generates unknown category
Browse files- src/predict.py +6 -4
src/predict.py
CHANGED
@@ -106,7 +106,7 @@ class ClassifierArguments:
|
|
106 |
default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'})
|
107 |
|
108 |
|
109 |
-
def
|
110 |
"""Use classifier to filter predictions"""
|
111 |
if not predictions:
|
112 |
return predictions
|
@@ -134,8 +134,10 @@ def add_predictions(predictions, classifier_args): # classifier, vectorizer,
|
|
134 |
if classifier_category is None and classifier_probability > classifier_args.min_probability:
|
135 |
continue # Ignore
|
136 |
|
137 |
-
if
|
138 |
-
|
|
|
|
|
139 |
prediction['category'] = classifier_category
|
140 |
|
141 |
prediction['probability'] = predicted_probabilities[prediction['category']]
|
@@ -173,7 +175,7 @@ def predict(video_id, model, tokenizer, segmentation_args, words=None, classifie
|
|
173 |
|
174 |
# TODO add back
|
175 |
if classifier_args is not None:
|
176 |
-
predictions =
|
177 |
|
178 |
return predictions
|
179 |
|
|
|
106 |
default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'})
|
107 |
|
108 |
|
109 |
+
def filter_and_add_probabilities(predictions, classifier_args): # classifier, vectorizer,
|
110 |
"""Use classifier to filter predictions"""
|
111 |
if not predictions:
|
112 |
return predictions
|
|
|
134 |
if classifier_category is None and classifier_probability > classifier_args.min_probability:
|
135 |
continue # Ignore
|
136 |
|
137 |
+
if (prediction['category'] not in predicted_probabilities) \
|
138 |
+
or (classifier_category is not None and classifier_probability > 0.5): # TODO make param
|
139 |
+
# Unknown category or we are confident enough to overrule,
|
140 |
+
# so change category to what was predicted by classifier
|
141 |
prediction['category'] = classifier_category
|
142 |
|
143 |
prediction['probability'] = predicted_probabilities[prediction['category']]
|
|
|
175 |
|
176 |
# TODO add back
|
177 |
if classifier_args is not None:
|
178 |
+
predictions = filter_and_add_probabilities(predictions, classifier_args)
|
179 |
|
180 |
return predictions
|
181 |
|