gpt-wiki-az / generate.py
IsmatS's picture
Upload folder using huggingface_hub
463c2c1 verified
import torch
from tokenizers import Tokenizer
from train import GPT, GPTConfig # Assuming your model definition is in train.py
import torch.nn.functional as F
def nucleus_sampling(logits, p=0.9):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
logits[sorted_indices[sorted_indices_to_remove]] = -float('Inf')
probabilities = F.softmax(logits, dim=-1)
next_token_id = torch.multinomial(probabilities, num_samples=1).item()
return next_token_id
def load_model_and_tokenizer():
# Load the model configuration and tokenizer
config = GPTConfig()
model = GPT(config)
model.load_state_dict(torch.load('best_model.pt', map_location=torch.device('cpu')))
model.eval() # Set model to evaluation mode
tokenizer = Tokenizer.from_file("az_tokenizer.json") # Load tokenizer
return model, tokenizer
def apply_repetition_penalty(logits, input_ids, penalty=1.2):
# Penalize the logits for tokens that have already been generated
for token_id in set(input_ids):
logits[0, token_id] /= penalty
return logits
def generate_text(model, tokenizer, prompt, max_new_tokens=50, temperature=0.001, p=0.95, repetition_penalty=1.5, device='cpu'):
model = model.to(device)
input_ids = tokenizer.encode(prompt).ids
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(device)
for _ in range(max_new_tokens):
with torch.no_grad():
output_logits, _ = model(input_tensor)
# Apply temperature scaling
logits = output_logits[:, -1, :] / temperature
# Apply repetition penalty
logits = apply_repetition_penalty(logits.clone(), input_ids, penalty=repetition_penalty)
# Use nucleus sampling
next_token_id = nucleus_sampling(logits[0], p=p)
input_ids.append(next_token_id)
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(device)
if next_token_id == tokenizer.token_to_id('[END]'): # Replace with actual end token if applicable
break
generated_text = tokenizer.decode(input_ids)
return generated_text.replace(' i ', ' ') # Example: minor post-processing to clean up spaces
def main():
model, tokenizer = load_model_and_tokenizer()
prompt = "Azərbaycanın tarixi" # Your input prompt
generated_text = generate_text(model, tokenizer, prompt, p=0.9) # Adjust p as needed
print("Generated Text:", generated_text)
if __name__ == '__main__':
main()