Spaces:
Runtime error
Runtime error
File size: 5,229 Bytes
850b0e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
from typing import List, Optional
import numpy as np
import torch
from tqdm import tqdm
from transformers import (
AutoModelWithLMHead,
AutoTokenizer,
GPT2Model,
GPT2Tokenizer,
LogitsProcessorList,
PreTrainedModel,
PreTrainedTokenizer,
TemperatureLogitsWarper,
TopKLogitsWarper,
)
from mario_gpt.prompter import Prompter
PRETRAINED_MODEL_PATH = "shyamsn97/Mario-GPT2-700-context-length"
class MarioLM:
def __init__(
self,
lm: Optional[PreTrainedModel] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
context_len: int = 700,
prompter: Optional[Prompter] = None,
):
self.context_len = context_len
self.lm = lm
if lm is None:
self.lm = self.load_pretrained_lm()
self.tokenizer = tokenizer
if tokenizer is None:
self.tokenizer = self.load_pretrained_tokenizer()
self.prompter = prompter
if prompter is None:
self.prompter = Prompter(self.tokenizer)
@property
def device(self):
return self.lm.device
def to(self, device: torch.device):
self.lm = self.lm.to(device)
return self
def load_pretrained_lm(self) -> GPT2Model:
print(f"Using {PRETRAINED_MODEL_PATH} model")
return AutoModelWithLMHead.from_pretrained(PRETRAINED_MODEL_PATH)
def load_pretrained_tokenizer(self) -> GPT2Tokenizer:
print(f"Using {PRETRAINED_MODEL_PATH} tokenizer")
return AutoTokenizer.from_pretrained(PRETRAINED_MODEL_PATH)
def sample_step(
self,
seed: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temperature: float = 2.0,
):
lm = self.lm
logits_processor = LogitsProcessorList()
logits_warper = LogitsProcessorList(
[
TopKLogitsWarper(16), # number of characters
TemperatureLogitsWarper(temperature),
]
)
with torch.no_grad():
attention_mask = torch.ones_like(seed).to(seed.device)
input_ids = seed
out = lm(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
token_type_ids=None,
)
logits = out.logits.detach()
if len(logits.shape) == 2:
logits = logits.view(1, 1, -1)
next_token_logits = logits[:, -1, :]
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
return next_tokens, encoder_hidden_states
def sample(
self,
seed: Optional[torch.Tensor] = None,
prompts: Optional[List[str]] = None,
num_steps: int = 1,
temperature: float = 2.0,
encoder_hidden_states: torch.Tensor = None,
use_tqdm: bool = False,
):
context_len = self.context_len - 28
self.lm.eval()
with torch.no_grad():
if seed is None:
seed = self.tokenizer("X", return_tensors="pt").input_ids.view(1, 1)
out = seed.to(self.device)
if encoder_hidden_states is None:
if prompts is not None:
encoder_hidden_states = torch.stack(
[self.prompter.output_hidden(prompt) for prompt in prompts]
)
else:
encoder_hidden_states = torch.stack(
[
self.prompter(sample_prompt=True)[1]
for _ in range(seed.shape[0])
]
)
encoder_hidden_states = encoder_hidden_states.to(
self.device
) # b x 1 x hidden_dim
encoder_hidden_states = encoder_hidden_states.view(seed.shape[0], 1, -1)
if not use_tqdm:
bar = np.arange(num_steps)
else:
bar = tqdm(np.arange(num_steps))
with torch.no_grad():
for i in bar:
inp = out * 1
if len(out.shape) > 0 and out.shape[-1] > context_len:
diff = inp.shape[-1] % 14 # height of mario level
ctx = context_len + diff
inp = inp[:, -ctx:] * 1
next_tokens, encoder_hidden_states = self.sample_step(
inp,
encoder_hidden_states=encoder_hidden_states,
temperature=temperature,
)
out = torch.cat([out, next_tokens.unsqueeze(-1)], dim=-1)
if use_tqdm:
bar.set_description(
f"shape: {inp.shape}, {out.shape} first: {inp[0][0]}, last: {out[0][-1]}"
)
if use_tqdm:
bar.close()
self.lm.train()
return out
|