File size: 2,421 Bytes
0d375ed
 
822daa6
 
 
 
 
 
 
0d375ed
 
 
 
 
2e16cc5
822daa6
 
 
 
5313aae
822daa6
5313aae
822daa6
5313aae
822daa6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d375ed
822daa6
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import sentencepiece
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
import os
import spacy
import spacy_transformers
import zipfile
from collections import defaultdict

class Models():
    def __init__(self) -> None:
        self.load_trained_models()
        
    def load_trained_models(self):
        tokenizer = AutoTokenizer.from_pretrained("Jean-Baptiste/camembert-ner-with-dates")
        model = AutoModelForTokenClassification.from_pretrained("Jean-Baptiste/camembert-ner-with-dates")
        self.ner = pipeline('ner', model=model, tokenizer=tokenizer, aggregation_strategy="simple")
        current_directory = os.path.dirname(os.path.realpath(__file__))
        custom_ner_path = os.path.join(current_directory, 'spacy_model_v2/output/model-best')
        destination_folder = "/spacy_model_v2"
        if not os.path.exists(custom_ner_path):
            with zipfile.ZipFile(r"./spacy_model_v2.zip", 'r') as zip_ref:
                # Extract all contents in the current working directory
                zip_ref.extractall(current_directory+destination_folder)
        self.custom_ner = spacy.load(custom_ner_path)

    def extract_ner(self, text):
        entities = self.ner(text)
        keys = ['DATE', 'ORG', 'LOC']
        sort_dict = defaultdict(list)
        for entity in entities:
            if entity['score'] > 0.75:
                sort_dict[entity['entity_group']].append(entity['word'])
        filtered_dict = {key: value for key, value in sort_dict.items() if key in keys}
        filtered_dict = defaultdict(list, filtered_dict)
        return filtered_dict['DATE'], filtered_dict['ORG'], filtered_dict['LOC']
    def get_ner(self, text, recover_text):
        dates, companies, locations = self.extract_ner(text)
        alternative_dates, alternative_companies, alternative_locations = self.extract_ner(recover_text)
        
        if dates == [] :
            dates = alternative_dates
        if companies == []:
            companies = alternative_companies
        if locations == []:
            locations = alternative_locations
        return dates, companies, locations
    def get_custom_ner(self, text):
        doc = self.custom_ner(text)
        entities = list(doc.ents)
        sort_dict = defaultdict(list)
        for entity in entities:
            sort_dict[entity.label_].append(entity.text)
        return sort_dict