import torch from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline from sentence_transformers import SentenceTransformer, util import gradio as gr import json # Load the lightweight BERT-based QA model optimized for CPU model_name = "distilbert-base-uncased-distilled-squad" # Efficient for CPU model = AutoModelForQuestionAnswering.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) # Initialize pipeline for CPU usage device = -1 # Force CPU qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer, device=device) # Load Sentence-BERT for semantic search embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') # Load knowledge base and expanded QA dataset with open('knowledge_base.json', 'r') as f: knowledge_base = json.load(f) with open('expanded_qa_dataset.json', 'r') as f: expanded_qa_dataset = json.load(f) # Function to create embeddings for the expanded QA dataset def create_qa_dataset_embeddings(expanded_qa_dataset): qa_embeddings = [] questions = [] for item in expanded_qa_dataset: questions.append(item['question']) qa_embeddings.append(embedding_model.encode(item['question'], convert_to_tensor=True)) return qa_embeddings, questions # Create QA dataset embeddings qa_embeddings, qa_questions = create_qa_dataset_embeddings(expanded_qa_dataset) # Function to create embeddings for the knowledge base content def create_knowledge_base_embeddings(knowledge_base): embeddings = [] for entry in knowledge_base: if 'title' in entry: content = entry['title'] + ' '.join( [c.get('text', '') for c in entry.get('content', [])] + [' '.join(step['details']) for c in entry.get('content', []) if 'steps' in c for step in c['steps']] + [faq['question'] + ' ' + faq['answer'] for c in entry.get('content', []) if 'faq' in c for faq in c['faq']] ) embeddings.append(embedding_model.encode(content, convert_to_tensor=True)) return embeddings # Create knowledge base embeddings knowledge_base_embeddings = create_knowledge_base_embeddings(knowledge_base) # Semantic search on expanded QA dataset def search_expanded_qa(question): question_embedding = embedding_model.encode(question, convert_to_tensor=True) cosine_scores = util.pytorch_cos_sim(question_embedding, torch.stack(qa_embeddings)) best_match_idx = torch.argmax(cosine_scores).item() best_match_score = cosine_scores[0, best_match_idx].item() return expanded_qa_dataset[best_match_idx]['answer'], best_match_score # Semantic search on knowledge base def search_knowledge_base(question): question_embedding = embedding_model.encode(question, convert_to_tensor=True) cosine_scores = util.pytorch_cos_sim(question_embedding, torch.stack(knowledge_base_embeddings)) best_match_idx = torch.argmax(cosine_scores).item() best_match_score = cosine_scores[0, best_match_idx].item() # Retrieve content from best matched knowledge base entry best_match_entry = knowledge_base[best_match_idx] for content_item in best_match_entry['content']: if 'faq' in content_item: for faq in content_item['faq']: if faq['question'].lower() in question.lower(): return faq['answer'], best_match_score if 'steps' in content_item: step_details = [step['details'] for step in content_item['steps']] return "\n".join(step_details), best_match_score if 'text' in content_item: return content_item['text'], best_match_score return "Lo siento, no encontré una respuesta adecuada para tu pregunta.", best_match_score # Answer function: search both datasets and return the best match def answer_question(question): # Search expanded QA dataset qa_answer, qa_score = search_expanded_qa(question) # Search knowledge base kb_answer, kb_score = search_knowledge_base(question) # Compare scores and return the best answer if qa_score >= kb_score: return qa_answer else: return kb_answer # Gradio interface interface = gr.Interface( fn=answer_question, inputs="text", outputs="text", title="OCN Customer Support Chatbot", description="Ask questions and get answers from the OCN knowledge base and expanded QA dataset." ) # Launch the interface interface.launch(share=True)