File size: 2,786 Bytes
463c2c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
from tokenizers import Tokenizer
from train import GPT, GPTConfig  # Assuming your model definition is in train.py

import torch.nn.functional as F

def nucleus_sampling(logits, p=0.9):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
    sorted_indices_to_remove = cumulative_probs > p
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0
    logits[sorted_indices[sorted_indices_to_remove]] = -float('Inf')
    probabilities = F.softmax(logits, dim=-1)
    next_token_id = torch.multinomial(probabilities, num_samples=1).item()
    return next_token_id

def load_model_and_tokenizer():
    # Load the model configuration and tokenizer
    config = GPTConfig()
    model = GPT(config)
    model.load_state_dict(torch.load('best_model.pt', map_location=torch.device('cpu')))
    model.eval()  # Set model to evaluation mode
    tokenizer = Tokenizer.from_file("az_tokenizer.json")  # Load tokenizer
    return model, tokenizer

def apply_repetition_penalty(logits, input_ids, penalty=1.2):
    # Penalize the logits for tokens that have already been generated
    for token_id in set(input_ids):
        logits[0, token_id] /= penalty
    return logits

def generate_text(model, tokenizer, prompt, max_new_tokens=50, temperature=0.001, p=0.95, repetition_penalty=1.5, device='cpu'):
    model = model.to(device)
    input_ids = tokenizer.encode(prompt).ids
    input_tensor = torch.tensor([input_ids], dtype=torch.long).to(device)
    
    for _ in range(max_new_tokens):
        with torch.no_grad():
            output_logits, _ = model(input_tensor)
        
        # Apply temperature scaling
        logits = output_logits[:, -1, :] / temperature
        
        # Apply repetition penalty
        logits = apply_repetition_penalty(logits.clone(), input_ids, penalty=repetition_penalty)
        
        # Use nucleus sampling
        next_token_id = nucleus_sampling(logits[0], p=p)
        
        input_ids.append(next_token_id)
        input_tensor = torch.tensor([input_ids], dtype=torch.long).to(device)
        
        if next_token_id == tokenizer.token_to_id('[END]'):  # Replace with actual end token if applicable
            break
    
    generated_text = tokenizer.decode(input_ids)
    return generated_text.replace(' i ', ' ')  # Example: minor post-processing to clean up spaces


def main():
    model, tokenizer = load_model_and_tokenizer()
    prompt = "Azərbaycanın tarixi"  # Your input prompt
    generated_text = generate_text(model, tokenizer, prompt, p=0.9)  # Adjust p as needed
    print("Generated Text:", generated_text)

if __name__ == '__main__':
    main()