Commit
•
24688b3
1
Parent(s):
290cdf1
Update README.md
Browse files
README.md
CHANGED
@@ -1,5 +1,3 @@
|
|
1 |
-
---
|
2 |
-
|
3 |
---
|
4 |
language:
|
5 |
- multilingual
|
@@ -14,18 +12,10 @@ metrics:
|
|
14 |
datasets:
|
15 |
- mnli
|
16 |
- xnli
|
17 |
-
- anli
|
18 |
-
license: mit
|
19 |
pipeline_tag: zero-shot-classification
|
20 |
widget:
|
21 |
-
- text: "
|
22 |
-
candidate_labels: "
|
23 |
-
- text: "La película empezaba bien pero terminó siendo un desastre."
|
24 |
-
candidate_labels: "positivo, negativo, neutral"
|
25 |
-
- text: "La película empezó siendo un desastre pero en general fue bien."
|
26 |
-
candidate_labels: "positivo, negativo, neutral"
|
27 |
-
- text: "¿A quién vas a votar en 2020?"
|
28 |
-
candidate_labels: "Europa, elecciones, política, ciencia, deportes"
|
29 |
---
|
30 |
# Multilingual mDeBERTa-v3-base-mnli-xnli
|
31 |
## Model description
|
@@ -41,8 +31,8 @@ import torch
|
|
41 |
model_name = "MoritzLaurer/mDeBERTa-v3-base-xnli-mnli"
|
42 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
43 |
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
44 |
-
premise = "
|
45 |
-
hypothesis = "
|
46 |
input = tokenizer(premise, hypothesis, truncation=True, return_tensors="pt")
|
47 |
output = model(input["input_ids"].to(device)) # device = "cuda:0" or "cpu"
|
48 |
prediction = torch.softmax(output["logits"][0], -1).tolist()
|
@@ -70,7 +60,7 @@ training_args = TrainingArguments(
|
|
70 |
The model was evaluated using the matched test set and achieves 0.90 accuracy.
|
71 |
|
72 |
average | ar | bg | de | el | en | es | fr | hi | ru | sw | th | tr | ur | vu | zh
|
73 |
-
|
74 |
0.808 | 0.802 | 0.829 | 0.825 | 0.826 | 0.883 | 0.845 | 0.834 | 0.771 | 0.813 | 0.748 | 0.793 | 0.807 | 0.740 | 0.795 | 0.8116
|
75 |
|
76 |
{'ar': 0.8017964071856287, 'bg': 0.8287425149700599, 'de': 0.8253493013972056, 'el': 0.8267465069860279, 'en': 0.8830339321357286, 'es': 0.8449101796407186, 'fr': 0.8343313373253493, 'hi': 0.7712574850299401, 'ru': 0.8127744510978044, 'sw': 0.7483033932135729, 'th': 0.792814371257485, 'tr': 0.8065868263473054, 'ur': 0.7403193612774451, 'vi': 0.7954091816367266, 'zh': 0.8115768463073852}
|
|
|
|
|
|
|
1 |
---
|
2 |
language:
|
3 |
- multilingual
|
|
|
12 |
datasets:
|
13 |
- mnli
|
14 |
- xnli
|
|
|
|
|
15 |
pipeline_tag: zero-shot-classification
|
16 |
widget:
|
17 |
+
- text: "Angela Merkel ist eine Politikerin in Deutschland und Vorsitzende der CDU"
|
18 |
+
candidate_labels: "politics, economy, entertainment, environment"
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
---
|
20 |
# Multilingual mDeBERTa-v3-base-mnli-xnli
|
21 |
## Model description
|
|
|
31 |
model_name = "MoritzLaurer/mDeBERTa-v3-base-xnli-mnli"
|
32 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
33 |
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
34 |
+
premise = "Angela Merkel ist eine Politikerin in Deutschland und Vorsitzende der CDU"
|
35 |
+
hypothesis = "Emmanuel Macron is the President of France"
|
36 |
input = tokenizer(premise, hypothesis, truncation=True, return_tensors="pt")
|
37 |
output = model(input["input_ids"].to(device)) # device = "cuda:0" or "cpu"
|
38 |
prediction = torch.softmax(output["logits"][0], -1).tolist()
|
|
|
60 |
The model was evaluated using the matched test set and achieves 0.90 accuracy.
|
61 |
|
62 |
average | ar | bg | de | el | en | es | fr | hi | ru | sw | th | tr | ur | vu | zh
|
63 |
+
---------|----------|---------|----------|----------|----------|----------|----------|----------|----------|----------|----------|----------|----------|----------|----------
|
64 |
0.808 | 0.802 | 0.829 | 0.825 | 0.826 | 0.883 | 0.845 | 0.834 | 0.771 | 0.813 | 0.748 | 0.793 | 0.807 | 0.740 | 0.795 | 0.8116
|
65 |
|
66 |
{'ar': 0.8017964071856287, 'bg': 0.8287425149700599, 'de': 0.8253493013972056, 'el': 0.8267465069860279, 'en': 0.8830339321357286, 'es': 0.8449101796407186, 'fr': 0.8343313373253493, 'hi': 0.7712574850299401, 'ru': 0.8127744510978044, 'sw': 0.7483033932135729, 'th': 0.792814371257485, 'tr': 0.8065868263473054, 'ur': 0.7403193612774451, 'vi': 0.7954091816367266, 'zh': 0.8115768463073852}
|