Etash Guha commited on
Commit
d4db51c
1 Parent(s): 7936375

changed key

Browse files
Files changed (1) hide show
  1. generators/model.py +20 -13
generators/model.py CHANGED
@@ -125,13 +125,7 @@ class Samba():
125
  resps = []
126
 
127
  for i in range(num_comps):
128
- url = "kjddazcq2e2wzvzv.snova.ai"
129
- api_key = "bGlnaHRuaW5nOlUyM3pMcFlHY3dmVzRzUGFy"
130
- url = f'https://{url}/api/v1/chat/completion'
131
- headers = {
132
- 'Authorization': f'Basic {api_key}',
133
- 'Content-Type': 'application/json',
134
- }
135
  payload = {
136
  "inputs": [dataclasses.asdict(message) for message in messages],
137
  "params": {
@@ -142,22 +136,26 @@ class Samba():
142
  },
143
  "model": "llama3-8b"
144
  }
145
-
146
-
147
-
148
- post_response = requests.post(url, headers=headers, data=payload)
 
 
 
149
  response_text = ""
150
  for line in post_response.iter_lines():
151
  if line.startswith(b"data: "):
152
  data_str = line.decode('utf-8')[6:]
153
  try:
154
  line_json = json.loads(data_str)
155
- content = line_json['0'].get("stream_token", "")
156
  if content:
157
  response_text += content
158
  except json.JSONDecodeError as e:
159
  pass
160
- breakpoint()
 
161
  resps.append(response_text)
162
  if num_comps == 1:
163
  return resps[0]
@@ -339,3 +337,12 @@ If a question does not make any sense, or is not factually coherent, explain why
339
  def extract_output(self, output: str) -> str:
340
  out = output.split("[/INST]")[-1].split("</s>")[0].strip()
341
  return out
 
 
 
 
 
 
 
 
 
 
125
  resps = []
126
 
127
  for i in range(num_comps):
128
+
 
 
 
 
 
 
129
  payload = {
130
  "inputs": [dataclasses.asdict(message) for message in messages],
131
  "params": {
 
136
  },
137
  "model": "llama3-8b"
138
  }
139
+ url = "kjddazcq2e2wzvzv.snova.ai"
140
+ key = "bGlnaHRuaW5nOlUyM3pMcFlHY3dmVzRzUGFy"
141
+ headers = {
142
+ "Authorization": f"Basic {key}",
143
+ "Content-Type": "application/json"
144
+ }
145
+ post_response = requests.post(f'https://{url}/api/v1/chat/completion', json=payload, headers=headers, stream=True)
146
  response_text = ""
147
  for line in post_response.iter_lines():
148
  if line.startswith(b"data: "):
149
  data_str = line.decode('utf-8')[6:]
150
  try:
151
  line_json = json.loads(data_str)
152
+ content = line_json.get('completion')
153
  if content:
154
  response_text += content
155
  except json.JSONDecodeError as e:
156
  pass
157
+ except:
158
+ pass
159
  resps.append(response_text)
160
  if num_comps == 1:
161
  return resps[0]
 
337
  def extract_output(self, output: str) -> str:
338
  out = output.split("[/INST]")[-1].split("</s>")[0].strip()
339
  return out
340
+
341
+ if __name__ == "__main__":
342
+ model = Samba()
343
+ messages = [Message(
344
+ role="user", # TODO: check this
345
+ content="say something",
346
+ )]
347
+ out= model.generate_chat(messages)
348
+ breakpoint()