TinyChat / app.py
theotherdylan's picture
order
3022c67
raw
history blame contribute delete
No virus
2.13 kB
import os
import threading
import time
import subprocess
#from transformers import pipeline
import ollama
import gradio
OLLAMA = os.path.expanduser("~/ollama")
if not os.path.exists(OLLAMA):
subprocess.run("curl -L https://ollama.com/download/ollama-linux-amd64 -o ~/ollama", shell=True)
os.chmod(OLLAMA, 0o755)
history = []
def ollama_service_thread():
subprocess.run("~/ollama serve", shell=True)
OLLAMA_SERVICE_THREAD = threading.Thread(target=ollama_service_thread)
OLLAMA_SERVICE_THREAD.start()
print("Giving ollama serve a moment")
time.sleep(10)
subprocess.run("~/ollama pull tinydolphin:latest", shell=True)
def get_history_messages():
messages = []
for user, assist in history:
messages.append({"role": "user", "content": user})
messages.append({"role": "assistant", "content": assist})
return messages
def predict(prompt):
response = ollama.chat(
model="tinydolphin",
messages=[
*get_history_messages(),
{"role": "user", "content": prompt}
],
stream=True
)
history.append((prompt, ""))
message = ""
for chunk in response:
message += chunk["message"]["content"]
history[-1] = (prompt, message)
yield "", history
def predict_t(prompt):
print("Predict:", prompt)
print("Loading model")
pipe = pipeline("conversational", model="cognitivecomputations/TinyDolphin-2.8-1.1b")
print("Running pipeline")
response = pipe(
[
*get_history_messages(),
{"role": "user", "content": prompt}
],
)
history.append((prompt, response.messages[-1]["content"]))
print("Predict done")
return "", history
with gradio.Blocks(fill_height=True) as demo:
chat = gradio.Chatbot(scale=1)
with gradio.Row(variant="compact"):
prompt = gradio.Textbox(show_label=False, scale=6, autofocus=True)
button = gradio.Button(scale=1)
for handler in [button.click, prompt.submit]:
handler(predict, inputs=[prompt], outputs=[prompt, chat])
if __name__ == '__main__':
demo.launch()