Spaces:
Runtime error
Runtime error
# | |
import logging | |
from hf_olmo import OLMoForCausalLM, OLMoTokenizerFast | |
# Enable logging | |
logging.basicConfig( | |
format="%(asctime)s - %(name)s - %(lineno)s - %(funcName)s - %(levelname)s - %(message)s", | |
level=logging.INFO | |
) | |
# set higher logging level for httpx to avoid all GET and POST requests being logged | |
logging.getLogger("httpx").setLevel(logging.WARNING) | |
logger = logging.getLogger(__name__) | |
MODEL = "allenai/OLMo-7B-Instruct" | |
olmo = OLMoForCausalLM.from_pretrained(MODEL) | |
tokenizer = OLMoTokenizerFast.from_pretrained(MODEL) | |
chat = [ | |
{"role": "user", | |
"content": "What is language modeling?"}, | |
] | |
prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) | |
inputs = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt") | |
# optional verifying cuda | |
# inputs = {k: v.to('cuda') for k,v in inputs.items()} | |
# olmo = olmo.to('cuda') | |
response = olmo.generate(input_ids=inputs.to(olmo.device), max_new_tokens=100, do_sample=True, top_k=50, top_p=0.95) | |
print(tokenizer.batch_decode(response, skip_special_tokens=True)[0]) |