# A Multi-task learning model with two prediction heads * First prediction head classifies between keyword sentences vs statements/questions * Second prediction head corresponds to classifier for statements vs questions ## Article [Medium article](https://medium.com/@shahrukhx01/multi-task-learning-with-transformers-part-1-multi-prediction-heads-b7001cf014bf) ## Demo Notebook [Colab Notebook Multi-task Query classifiers](https://colab.research.google.com/drive/1R7WcLHxDsVvZXPhr5HBgIWa3BlSZKY6p?usp=sharing) ## Clone the model repo ```bash git clone https://huggingface.co/shahrukhx01/bert-multitask-query-classifiers ``` ```python %cd bert-multitask-query-classifiers/ ``` ## Load models ```python from multitask_model import BertForSequenceClassification from transformers import AutoTokenizer import torch model = BertForSequenceClassification.from_pretrained( "shahrukhx01/bert-multitask-query-classifiers", task_labels_map={"quora_keyword_pairs": 2, "spaadia_squad_pairs": 2}, ) tokenizer = AutoTokenizer.from_pretrained("shahrukhx01/bert-multitask-query-classifiers") ``` ## Run inference on both Tasks ```python from multitask_model import BertForSequenceClassification from transformers import AutoTokenizer import torch model = BertForSequenceClassification.from_pretrained( "shahrukhx01/bert-multitask-query-classifiers", task_labels_map={"quora_keyword_pairs": 2, "spaadia_squad_pairs": 2}, ) tokenizer = AutoTokenizer.from_pretrained("shahrukhx01/bert-multitask-query-classifiers") ## Keyword vs Statement/Question Classifier input = ["keyword query", "is this a keyword query?"] task_name="quora_keyword_pairs" sequence = tokenizer(input, padding=True, return_tensors="pt")['input_ids'] logits = model(sequence, task_name=task_name)[0] predictions = torch.argmax(torch.softmax(logits, dim=1).detach().cpu(), axis=1) for input, prediction in zip(input, predictions): print(f"task: {task_name}, input: {input} \n prediction=> {prediction}") print() ```