Spaces:
Running
on
Zero
Running
on
Zero
HonestAnnie
commited on
Commit
•
2ec7158
1
Parent(s):
38d2199
trhhh
Browse files
app.py
CHANGED
@@ -1,21 +1,18 @@
|
|
1 |
import os
|
2 |
-
import requests
|
3 |
import gradio as gr
|
4 |
import chromadb
|
5 |
-
import json
|
6 |
-
import pandas as pd
|
7 |
|
8 |
from sentence_transformers import SentenceTransformer
|
9 |
|
10 |
import spaces
|
11 |
|
12 |
@spaces.GPU
|
13 |
-
def get_embeddings(
|
14 |
model = SentenceTransformer("Linq-AI-Research/Linq-Embed-Mistral", use_auth_token=os.getenv("HF_TOKEN"))
|
15 |
task = "Given a question, retrieve passages that answer the question"
|
16 |
-
prompt = f"Instruct: {task}\nQuery: {
|
17 |
-
query_embeddings = model.encode([prompt]
|
18 |
-
return query_embeddings
|
19 |
|
20 |
|
21 |
# Initialize a persistent Chroma client and retrieve collection
|
@@ -56,27 +53,28 @@ def query_chroma(embeddings, authors, num_results=10):
|
|
56 |
|
57 |
return formatted_results
|
58 |
except Exception as e:
|
59 |
-
return
|
60 |
|
61 |
|
62 |
# Main function
|
63 |
-
def perform_query(query,
|
|
|
64 |
embeddings = get_embeddings(query, task)
|
65 |
-
|
|
|
|
|
|
|
66 |
|
67 |
-
results = [(f"{res['author']}, {res['book']}, Distance: {res['distance']}", res['text'], res['id']) for res in initial_results]
|
68 |
-
|
69 |
updates = []
|
70 |
-
for
|
71 |
-
markdown_content = f"**{
|
72 |
updates.append(gr.update(visible=True, value=markdown_content))
|
73 |
updates.append(gr.update(visible=True, value="Flag", elem_id=f"flag-{len(updates)//2}"))
|
74 |
-
updates.append(gr.update(visible=False, value=
|
75 |
|
76 |
-
updates += [gr.update(visible=False)] * (3 * (max_textboxes - len(results)
|
77 |
-
|
78 |
-
return updates
|
79 |
|
|
|
80 |
|
81 |
# Initialize the CSVLogger callback for flagging
|
82 |
callback = gr.CSVLogger()
|
|
|
1 |
import os
|
|
|
2 |
import gradio as gr
|
3 |
import chromadb
|
|
|
|
|
4 |
|
5 |
from sentence_transformers import SentenceTransformer
|
6 |
|
7 |
import spaces
|
8 |
|
9 |
@spaces.GPU
|
10 |
+
def get_embeddings(query, task):
|
11 |
model = SentenceTransformer("Linq-AI-Research/Linq-Embed-Mistral", use_auth_token=os.getenv("HF_TOKEN"))
|
12 |
task = "Given a question, retrieve passages that answer the question"
|
13 |
+
prompt = f"Instruct: {task}\nQuery: {query}"
|
14 |
+
query_embeddings = model.encode([prompt])
|
15 |
+
return query_embeddings
|
16 |
|
17 |
|
18 |
# Initialize a persistent Chroma client and retrieve collection
|
|
|
53 |
|
54 |
return formatted_results
|
55 |
except Exception as e:
|
56 |
+
return {"error": str(e)}
|
57 |
|
58 |
|
59 |
# Main function
|
60 |
+
def perform_query(query, authors, num_results):
|
61 |
+
task = "Given a question, retrieve passages that answer the question"
|
62 |
embeddings = get_embeddings(query, task)
|
63 |
+
results = query_chroma(embeddings, authors, num_results)
|
64 |
+
|
65 |
+
if "error" in results:
|
66 |
+
return [gr.update(visible=True, value=f"Error: {results['error']}") for _ in range(max_textboxes * 3)]
|
67 |
|
|
|
|
|
68 |
updates = []
|
69 |
+
for res in results:
|
70 |
+
markdown_content = f"**{res['author']}, {res['book']}, Distance: {res['distance']}**\n\n{res['text']}"
|
71 |
updates.append(gr.update(visible=True, value=markdown_content))
|
72 |
updates.append(gr.update(visible=True, value="Flag", elem_id=f"flag-{len(updates)//2}"))
|
73 |
+
updates.append(gr.update(visible=False, value=res['id'])) # Hide the ID textbox
|
74 |
|
75 |
+
updates += [gr.update(visible=False)] * (3 * (max_textboxes - len(results)))
|
|
|
|
|
76 |
|
77 |
+
return updates
|
78 |
|
79 |
# Initialize the CSVLogger callback for flagging
|
80 |
callback = gr.CSVLogger()
|