File size: 1,084 Bytes
9c911bc
 
3741ac2
9c911bc
3741ac2
9c911bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#
import logging

from hf_olmo import OLMoForCausalLM, OLMoTokenizerFast

# Enable logging
logging.basicConfig(
    format="%(asctime)s - %(name)s - %(lineno)s - %(funcName)s - %(levelname)s - %(message)s",
    level=logging.INFO
)
# set higher logging level for httpx to avoid all GET and POST requests being logged
logging.getLogger("httpx").setLevel(logging.WARNING)

logger = logging.getLogger(__name__)

MODEL = "allenai/OLMo-7B-Instruct"

olmo = OLMoForCausalLM.from_pretrained(MODEL)
tokenizer = OLMoTokenizerFast.from_pretrained(MODEL)
chat = [
    {"role": "user",
     "content": "What is language modeling?"},
]
prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
inputs = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
# optional verifying cuda
# inputs = {k: v.to('cuda') for k,v in inputs.items()}
# olmo = olmo.to('cuda')
response = olmo.generate(input_ids=inputs.to(olmo.device), max_new_tokens=100, do_sample=True, top_k=50, top_p=0.95)
print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])