File size: 5,176 Bytes
8026e91 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import math
import numpy as np
import torch as t
import torch.nn as nn
import torch.nn.functional as F
# Import FusedLayerNorm if we have apex, otherwise use regular LayerNorm
try:
from apex.normalization import FusedLayerNorm
print("Using apex FusedLayerNorm")
except ImportError:
from torch.nn import LayerNorm as FusedLayerNorm
class LayerNorm(FusedLayerNorm):
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
super().__init__(normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
self.width = np.prod(normalized_shape)
self.max_numel = 65535*self.width
def forward(self, input):
if input.numel() > self.max_numel:
return F.layer_norm(input.float(), self.normalized_shape, self.weight, self.bias, self.eps).type_as(input)
else:
return super(LayerNorm, self).forward(input.float()).type_as(input)
def gelu(x):
return 0.5 * x * (1 + t.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * t.pow(x, 3))))
def swish(x):
return x * t.sigmoid(x)
@t.jit.script
def quick_gelu(x):
return x * t.sigmoid(1.702 * x)
@t.jit.script
def quick_gelu_bwd(x, grad_output):
sig = t.sigmoid(1.702 * x)
return grad_output * sig * (1.702 * x * (1 - sig) + 1.)
class QuickGelu(t.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return quick_gelu(x)
@staticmethod
def backward(ctx, grad_output):
return quick_gelu_bwd(ctx.saved_tensors[0], grad_output)
def memory_efficient_quick_gelu(x):
return QuickGelu.apply(x)
ACT_FNS = {
'relu': t.nn.functional.relu,
'swish': swish,
'gelu': gelu,
'quick_gelu': memory_efficient_quick_gelu #quick_gelu
}
def _move_to_gpu_and_convert_conv_weights_to_fp16(l):
l.cuda()
if isinstance(l, Conv1D):
l.w.data = l.w.data.half()
def _convert_conv_weights_to_fp32(l):
if isinstance(l, Conv1D):
l.w.data = l.w.data.float()
def _convert_conv_weights_to_fp16(l):
if isinstance(l, Conv1D):
l.w.data = l.w.data.half()
def _convert_embedding_weights_to_fp16(l):
if isinstance(l, t.nn.Embedding):
l.weight.data = l.weight.data.half()
def _convert_embedding_weights_to_fp32(l):
if isinstance(l, t.nn.Embedding):
l.weight.data = l.weight.data.float()
class Conv1D(nn.Module):
def __init__(self, n_in, n_out, zero_out=False, init_scale=1.0):
super(Conv1D, self).__init__()
self.n_in = n_in
self.n_out = n_out
if zero_out:
w = t.zeros(n_in, n_out)
else:
w = t.empty(n_in, n_out)
nn.init.normal_(w, std=0.02 * init_scale)
b = t.zeros(n_out)
self.w = nn.Parameter(w)
self.b = nn.Parameter(b)
def forward(self, x):
size_out = (*x.size()[:-1], self.n_out)
x = t.addmm(self.b.type_as(x), x.view(-1, x.size(-1)), self.w.type_as(x)) # If x if float then float else half
x = x.view(*size_out)
return x
# For large contexts, mask's can take up memory, so you can make a single saved mask for all layers
class Mask(nn.Module):
def __init__(self, n_ctx):
super().__init__()
self.register_buffer('b', t.tril(t.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
def forward(self, w):
w = w * self.b + -1e9 * (1 - self.b) # For fp16 do w = w.float().masked_fill(self.b, float('-inf')
return w
def filter_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (vocabulary size)
top_k >0: keep only top k tokens with highest probability (top-k filtering).
top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
"""
#assert logits.dim() == 2 # batch size 1 for now - could be updated for more but the code would be less clear
logits = logits.clone()
top_k = min(top_k, logits.size(-1)) # Safety check
assert (top_k == 0) or (top_p == 0.0)
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < t.topk(logits, top_k, dim=-1)[0][..., -1:]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
sorted_logits, sorted_indices = t.sort(logits, descending=True, dim=-1)
cumulative_probs = t.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
#indices_to_remove = sorted_indices[sorted_indices_to_remove]
indices_to_remove = t.zeros_like(logits, dtype=t.uint8).scatter_(dim=-1, index=sorted_indices, src=sorted_indices_to_remove)
logits[indices_to_remove] = filter_value
return logits
|