from transformers import LlamaForSequenceClassification,Cache from transformers.modeling_outputs import SequenceClassifierOutputWithPast from typing import List, Optional, Tuple, Union import torch class Weights(torch.nn.Module): def __init__(self): super().__init__() self.fc=torch.nn.Sequential( torch.nn.Linear(4096,4096,dtype=torch.float16), torch.nn.SELU(), torch.nn.Linear(4096,4096,dtype=torch.float16), torch.nn.SELU(), torch.nn.Linear(4096,5,dtype=torch.float16) ) def forward(self,x): return self.fc(x.to(torch.float16)) class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassification): def __init__(self,config): super().__init__(config) self.weights=Weights() def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, SequenceClassifierOutputWithPast]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = transformer_outputs[0] logits = self.score(hidden_states).detach() weights=self.weights(hidden_states.detach()) if input_ids is not None: batch_size = input_ids.shape[0] else: batch_size = inputs_embeds.shape[0] if self.config.pad_token_id is None and batch_size != 1: raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: sequence_lengths = -1 else: if input_ids is not None: sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 sequence_lengths = sequence_lengths % input_ids.shape[-1] sequence_lengths = sequence_lengths.to(logits.device) else: sequence_lengths = -1 pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] pooled_weights= weights[torch.arange(batch_size, device=weights.device), sequence_lengths] loss = None if labels is not None: labels = labels.to(logits.device) if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = MSELoss() if self.num_labels == 1: loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) else: loss = loss_fct(pooled_logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(pooled_logits, labels) if not return_dict: return ((loss,) + output) if loss is not None else pooled_logits,pooled_weights rews=pooled_logits.view(-1,5,2)[:,:,0].view(-1,5) scores=(rews*pooled_weights).sum(dim=-1).view(-1,1) return SequenceClassifierOutputWithPast( loss=loss, logits=scores, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, )