MrOvkill's picture
Update handler.py
a2eb90c verified
import json
import os
from typing import Dict, List, Any
from llama_cpp import Llama
import gemma_tools as gem
MAX_TOKENS=8192
class EndpointHandler():
def __init__(self, data):
self.model = Llama.from_pretrained("lmstudio-ai/gemma-2b-it-GGUF", filename="gemma-2b-it-q4_k_m.gguf", n_ctx=8192)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
args = gem.get_args_or_none(data)
fmat = "<startofturn>system\n{system_prompt} <endofturn>\n<startofturn>user\n{prompt} <endofturn>\n<startofturn>model"
print(args, fmat)
if not args[0]:
return {
"status": args["status"],
"message": args["description"]
}
try:
fmat = fmat.format(system_prompt = args["system_prompt"], prompt = args["inputs"])
except Exception as e:
return json.dumps({
"status": "error",
"reason": "invalid format"
})
max_length = data.pop("max_length", 512)
try:
max_length = int(max_length)
except Exception as e:
return json.dumps({
"status": "error",
"reason": "max_length was passed as something that was absolutely not a plain old int"
})
res = self.model(fmat, temperature=args["temperature"], top_p=args["top_p"], top_k=args["top_k"], max_tokens=max_length)
return res