p1atdev commited on
Commit
489cd9a
1 Parent(s): 188c9e8

Upload playground.ipynb

Browse files
Files changed (1) hide show
  1. playground.ipynb +167 -0
playground.ipynb ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 65,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from PIL import Image\n",
10
+ "\n",
11
+ "import torch\n",
12
+ "from transformers import (\n",
13
+ " AutoModelForImageClassification,\n",
14
+ " AutoImageProcessor,\n",
15
+ ")\n",
16
+ "import numpy as np"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": 2,
22
+ "metadata": {},
23
+ "outputs": [],
24
+ "source": [
25
+ "MODEL_NAME = \"p1atdev/siglip-tagger-test-2\""
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": 44,
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "model = AutoModelForImageClassification.from_pretrained(\n",
35
+ " MODEL_NAME, torch_dtype=torch.bfloat16, trust_remote_code=True\n",
36
+ ")\n",
37
+ "model.eval()\n",
38
+ "processor = AutoImageProcessor.from_pretrained(MODEL_NAME)"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": 45,
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "image = Image.open(\"sample.jpg\")\n",
48
+ "inputs = processor(image, return_tensors=\"pt\").to(model.device, model.dtype)"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": 70,
54
+ "metadata": {},
55
+ "outputs": [],
56
+ "source": [
57
+ "logits = model(**inputs).logits.detach().cpu().float()[0]\n",
58
+ "logits = np.clip(logits, 0.0, 1.0)"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": 80,
64
+ "metadata": {},
65
+ "outputs": [],
66
+ "source": [
67
+ "results = {\n",
68
+ " model.config.id2label[i]: logit for i, logit in enumerate(logits) if logit > 0\n",
69
+ "}\n",
70
+ "results = sorted(results.items(), key=lambda x: x[1], reverse=True)"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": 81,
76
+ "metadata": {},
77
+ "outputs": [
78
+ {
79
+ "name": "stdout",
80
+ "output_type": "stream",
81
+ "text": [
82
+ "1girl: 100.00%\n",
83
+ "outdoors: 100.00%\n",
84
+ "sky: 100.00%\n",
85
+ "solo: 100.00%\n",
86
+ "school uniform: 96.88%\n",
87
+ "skirt: 92.97%\n",
88
+ "day: 89.06%\n",
89
+ "cloud: 85.94%\n",
90
+ "scenery: 79.69%\n",
91
+ "pleated skirt: 72.27%\n",
92
+ "black hair: 66.80%\n",
93
+ "standing: 65.62%\n",
94
+ "sailor collar: 59.38%\n",
95
+ "sitting: 57.81%\n",
96
+ "long sleeves: 53.52%\n",
97
+ "serafuku: 53.12%\n",
98
+ "holding: 52.34%\n",
99
+ "tree: 47.46%\n",
100
+ "dress: 46.48%\n",
101
+ "shoes: 43.55%\n",
102
+ "building: 42.77%\n",
103
+ "neckerchief: 40.82%\n",
104
+ "short hair: 38.09%\n",
105
+ "water: 38.09%\n",
106
+ "cloudy sky: 37.30%\n",
107
+ "looking at viewer: 32.23%\n",
108
+ "long hair: 32.03%\n",
109
+ "brown eyes: 31.45%\n",
110
+ "plant: 31.05%\n",
111
+ "bag: 29.30%\n",
112
+ "railing: 29.10%\n",
113
+ "sunlight: 28.12%\n",
114
+ "from side: 27.73%\n",
115
+ "window: 27.54%\n",
116
+ "brown hair: 26.37%\n",
117
+ "white shirt: 25.78%\n",
118
+ "shirt: 25.39%\n",
119
+ "blue sky: 23.93%\n",
120
+ "hairclip: 23.44%\n",
121
+ "blunt bangs: 21.58%\n",
122
+ "picture frame: 19.34%\n",
123
+ "hand up: 18.26%\n",
124
+ "black skirt: 17.87%\n",
125
+ "smile: 17.87%\n",
126
+ "from behind: 13.57%\n",
127
+ "cowboy shot: 10.99%\n",
128
+ "indoors: 10.74%\n",
129
+ "curtains: 10.25%\n",
130
+ "facing away: 9.23%\n",
131
+ "white socks: 6.08%\n",
132
+ "bottle: 6.01%\n",
133
+ "mountain: 5.66%\n",
134
+ "blue skirt: 5.13%\n",
135
+ "drinking straw: 3.37%\n",
136
+ "kneehighs: 1.71%\n"
137
+ ]
138
+ }
139
+ ],
140
+ "source": [
141
+ "for tag, score in results:\n",
142
+ " print(f\"{tag}: {score*100:.2f}%\")"
143
+ ]
144
+ }
145
+ ],
146
+ "metadata": {
147
+ "kernelspec": {
148
+ "display_name": "py310",
149
+ "language": "python",
150
+ "name": "python3"
151
+ },
152
+ "language_info": {
153
+ "codemirror_mode": {
154
+ "name": "ipython",
155
+ "version": 3
156
+ },
157
+ "file_extension": ".py",
158
+ "mimetype": "text/x-python",
159
+ "name": "python",
160
+ "nbconvert_exporter": "python",
161
+ "pygments_lexer": "ipython3",
162
+ "version": "3.10.13"
163
+ }
164
+ },
165
+ "nbformat": 4,
166
+ "nbformat_minor": 2
167
+ }