|
import json |
|
import os |
|
from typing import Dict, List, Any |
|
from llama_cpp import Llama |
|
import gemma_tools as gem |
|
|
|
MAX_TOKENS=8192 |
|
|
|
class EndpointHandler(): |
|
def __init__(self, data): |
|
self.model = Llama.from_pretrained("lmstudio-ai/gemma-2b-it-GGUF", filename="gemma-2b-it-q4_k_m.gguf", n_ctx=8192) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
args = gem.get_args_or_none(data) |
|
fmat = "<startofturn>system\n{system_prompt} <endofturn>\n<startofturn>user\n{prompt} <endofturn>\n<startofturn>model" |
|
print(args, fmat) |
|
if not args[0]: |
|
return { |
|
"status": args["status"], |
|
"message": args["description"] |
|
} |
|
try: |
|
fmat = fmat.format(system_prompt = args["system_prompt"], prompt = args["inputs"]) |
|
except Exception as e: |
|
return json.dumps({ |
|
"status": "error", |
|
"reason": "invalid format" |
|
}) |
|
max_length = data.pop("max_length", 512) |
|
try: |
|
max_length = int(max_length) |
|
except Exception as e: |
|
return json.dumps({ |
|
"status": "error", |
|
"reason": "max_length was passed as something that was absolutely not a plain old int" |
|
}) |
|
|
|
res = self.model(fmat, temperature=args["temperature"], top_p=args["top_p"], top_k=args["top_k"], max_tokens=max_length) |
|
|
|
return res |