Millakisan's picture
Update app.py
9c911bc verified
#
import logging
from hf_olmo import OLMoForCausalLM, OLMoTokenizerFast
# Enable logging
logging.basicConfig(
format="%(asctime)s - %(name)s - %(lineno)s - %(funcName)s - %(levelname)s - %(message)s",
level=logging.INFO
)
# set higher logging level for httpx to avoid all GET and POST requests being logged
logging.getLogger("httpx").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
MODEL = "allenai/OLMo-7B-Instruct"
olmo = OLMoForCausalLM.from_pretrained(MODEL)
tokenizer = OLMoTokenizerFast.from_pretrained(MODEL)
chat = [
{"role": "user",
"content": "What is language modeling?"},
]
prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
inputs = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
# optional verifying cuda
# inputs = {k: v.to('cuda') for k,v in inputs.items()}
# olmo = olmo.to('cuda')
response = olmo.generate(input_ids=inputs.to(olmo.device), max_new_tokens=100, do_sample=True, top_k=50, top_p=0.95)
print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])