# import logging from hf_olmo import OLMoForCausalLM, OLMoTokenizerFast # Enable logging logging.basicConfig( format="%(asctime)s - %(name)s - %(lineno)s - %(funcName)s - %(levelname)s - %(message)s", level=logging.INFO ) # set higher logging level for httpx to avoid all GET and POST requests being logged logging.getLogger("httpx").setLevel(logging.WARNING) logger = logging.getLogger(__name__) MODEL = "allenai/OLMo-7B-Instruct" olmo = OLMoForCausalLM.from_pretrained(MODEL) tokenizer = OLMoTokenizerFast.from_pretrained(MODEL) chat = [ {"role": "user", "content": "What is language modeling?"}, ] prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) inputs = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt") # optional verifying cuda # inputs = {k: v.to('cuda') for k,v in inputs.items()} # olmo = olmo.to('cuda') response = olmo.generate(input_ids=inputs.to(olmo.device), max_new_tokens=100, do_sample=True, top_k=50, top_p=0.95) print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])