schroneko commited on
Commit
f0f8a6c
1 Parent(s): c6ba5c0
Files changed (2) hide show
  1. app.py +50 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import gradio as gr
5
+ import spaces
6
+
7
+ huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
8
+ if not huggingface_token:
9
+ raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")
10
+
11
+ model_id = "meta-llama/Llama-Guard-3-8B"
12
+ device = "cuda"
13
+ dtype = torch.bfloat16
14
+
15
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=huggingface_token)
16
+
17
+
18
+ @spaces.GPU
19
+ def moderate(chat):
20
+ model = AutoModelForCausalLM.from_pretrained(
21
+ model_id, torch_dtype=dtype, device_map=device, token=huggingface_token
22
+ )
23
+ input_ids = tokenizer.apply_chat_template(chat, return_tensors="pt").to(device)
24
+ output = model.generate(input_ids=input_ids, max_new_tokens=100, pad_token_id=0)
25
+ prompt_len = input_ids.shape[-1]
26
+ return tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
27
+
28
+
29
+ def gradio_moderate(user_input, assistant_response):
30
+ chat = [
31
+ {"role": "user", "content": user_input},
32
+ ]
33
+ if assistant_response:
34
+ chat.append({"role": "assistant", "content": assistant_response})
35
+ return moderate(chat)
36
+
37
+
38
+ iface = gr.Interface(
39
+ fn=gradio_moderate,
40
+ inputs=[
41
+ gr.Textbox(lines=3, label="User Input"),
42
+ gr.Textbox(lines=3, label="Assistant Response (optional)"),
43
+ ],
44
+ outputs=gr.Textbox(label="Moderation Result"),
45
+ title="Llama Guard Moderation",
46
+ description="Enter a user input and an optional assistant response to check for content moderation.",
47
+ )
48
+
49
+ if __name__ == "__main__":
50
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ numpy
2
+ torch
3
+ spaces
4
+ bitsandbytes
5
+ accelerate
6
+ transformers