|
|
|
import torch |
|
import torch.nn.functional as F |
|
from transformers.generation import TopKLogitsWarper, TopPLogitsWarper |
|
from ..utils.infer_utils import CustomRepetitionPenaltyLogitsProcessorRepeat |
|
|
|
def infer_code( |
|
models, |
|
text, |
|
spk_emb = None, |
|
top_P = 0.7, |
|
top_K = 20, |
|
temperature = 0.3, |
|
repetition_penalty = 1.05, |
|
max_new_token = 2048, |
|
**kwargs |
|
): |
|
|
|
device = next(models['gpt'].parameters()).device |
|
|
|
if not isinstance(text, list): |
|
text = [text] |
|
|
|
if not isinstance(temperature, list): |
|
temperature = [temperature] * models['gpt'].num_vq |
|
|
|
if spk_emb is not None: |
|
text = [f'[Stts][spk_emb]{i}[uv_break][Ptts]' for i in text] |
|
else: |
|
text = [f'[Stts][empty_spk]{i}[uv_break][Ptts]' for i in text] |
|
|
|
text_token = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True).to(device) |
|
input_ids = text_token['input_ids'][...,None].expand(-1, -1, models['gpt'].num_vq) |
|
text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device) |
|
|
|
inputs = { |
|
'input_ids': input_ids, |
|
'text_mask': text_mask, |
|
'attention_mask': text_token['attention_mask'], |
|
} |
|
|
|
emb = models['gpt'].get_emb(**inputs) |
|
if spk_emb is not None: |
|
emb[inputs['input_ids'][..., 0] == models['tokenizer'].convert_tokens_to_ids('[spk_emb]')] = \ |
|
F.normalize(spk_emb.to(device).to(emb.dtype)[None].expand(len(text), -1), p=2.0, dim=1, eps=1e-12) |
|
|
|
num_code = models['gpt'].emb_code[0].num_embeddings - 1 |
|
|
|
LogitsWarpers = [] |
|
if top_P is not None: |
|
LogitsWarpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3)) |
|
if top_K is not None: |
|
LogitsWarpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3)) |
|
|
|
LogitsProcessors = [] |
|
if repetition_penalty is not None and repetition_penalty != 1: |
|
LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(\ |
|
repetition_penalty, num_code, 16)) |
|
|
|
result = models['gpt'].generate( |
|
emb, inputs['input_ids'], |
|
temperature = torch.tensor(temperature, device=device), |
|
attention_mask = inputs['attention_mask'], |
|
LogitsWarpers = LogitsWarpers, |
|
LogitsProcessors = LogitsProcessors, |
|
eos_token = num_code, |
|
max_new_token = max_new_token, |
|
infer_text = False, |
|
**kwargs |
|
) |
|
|
|
return result |
|
|
|
|
|
def refine_text( |
|
models, |
|
text, |
|
top_P = 0.7, |
|
top_K = 20, |
|
temperature = 0.7, |
|
repetition_penalty = 1.0, |
|
max_new_token = 384, |
|
prompt = '', |
|
**kwargs |
|
): |
|
|
|
device = next(models['gpt'].parameters()).device |
|
|
|
if not isinstance(text, list): |
|
text = [text] |
|
|
|
assert len(text), 'text should not be empty' |
|
|
|
text = [f"[Sbreak]{i}[Pbreak]{prompt}" for i in text] |
|
text_token = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True).to(device) |
|
text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device) |
|
|
|
inputs = { |
|
'input_ids': text_token['input_ids'][...,None].expand(-1, -1, models['gpt'].num_vq), |
|
'text_mask': text_mask, |
|
'attention_mask': text_token['attention_mask'], |
|
} |
|
|
|
LogitsWarpers = [] |
|
if top_P is not None: |
|
LogitsWarpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3)) |
|
if top_K is not None: |
|
LogitsWarpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3)) |
|
|
|
LogitsProcessors = [] |
|
if repetition_penalty is not None and repetition_penalty != 1: |
|
LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(repetition_penalty, len(models['tokenizer']), 16)) |
|
|
|
result = models['gpt'].generate( |
|
models['gpt'].get_emb(**inputs), inputs['input_ids'], |
|
temperature = torch.tensor([temperature,], device=device), |
|
attention_mask = inputs['attention_mask'], |
|
LogitsWarpers = LogitsWarpers, |
|
LogitsProcessors = LogitsProcessors, |
|
eos_token = torch.tensor(models['tokenizer'].convert_tokens_to_ids('[Ebreak]'), device=device)[None], |
|
max_new_token = max_new_token, |
|
infer_text = True, |
|
**kwargs |
|
) |
|
return result |