Spaces:
Sleeping
Sleeping
import gradio as gr | |
import requests | |
import json | |
import os | |
API_URL = os.getenv("API_URL") | |
HF_API_TOKEN = os.getenv("HF_API_TOKEN") | |
# Global variable to store the generated text | |
generated_text = "" | |
def query(payload): | |
response = requests.request("POST", API_URL, json=payload, headers={"Authorization": f"Bearer {HF_API_TOKEN}"}) | |
return json.loads(response.content.decode("utf-8")) | |
def generate_and_append_text(max_length): | |
global generated_text | |
# Ensure that the input is not empty | |
input_text = generated_text if generated_text.strip() else " " | |
parameters = { | |
"max_new_tokens": max_length, | |
"top_p": 0.9, | |
"do_sample": True, | |
"seed": 42, | |
"early_stopping": False, | |
"length_penalty": 0.0, | |
"eos_token_id": None, | |
} | |
payload = {"inputs": input_text, "parameters": parameters, "options": {"use_cache": False}} | |
data = query(payload) | |
if "error" in data: | |
return f"<span style='color:red'>ERROR: {data['error']} </span>" | |
new_text = data[0]["generated_text"].replace(input_text, "").strip() | |
generated_text += " " + new_text if generated_text.strip() else new_text | |
return generated_text | |
if __name__ == "__main__": | |
demo = gr.Blocks() | |
with demo: | |
with gr.Row(): | |
generate_button = gr.Button("Generate Text") | |
tokens = gr.Slider(1, 64, value=32, step=1, label="Tokens to generate") | |
text_out = gr.Textbox(label="Generated Text") | |
generate_button.click( | |
generate_and_append_text, | |
inputs=tokens, | |
outputs=text_out | |
) | |
demo.launch() | |