File size: 3,563 Bytes
a1551fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import pandas as pd
import time
import random
from sentence_transformers import SentenceTransformer
from pymilvus import connections, DataType, FieldSchema, CollectionSchema, Collection, utility
import configparser
from tqdm import tqdm

# Initialize SentenceTransformer model for embeddings
embedding_model = SentenceTransformer(model_name_or_path="bert-base-uncased")

# Read molecule names from CSV
csv_path = 'molecules-small.csv'
df = pd.read_csv(csv_path)
max_name_length = 256
molecules = df['cmpdname'].tolist()

for i, molecule in enumerate(molecules):
    if len(molecule) > max_name_length:
        molecules[i] = molecule[:max_name_length]
        
cids = df['cid'].tolist()

# Encode embeddings for each molecule
embeddings_list = []
for molecule in tqdm(molecules, desc="Generating Embeddings"):
    embeddings = embedding_model.encode(molecule)
    embeddings_list.append(embeddings)

cfp = configparser.RawConfigParser()
cfp.read('config.ini')
milvus_uri = cfp.get('example', 'uri')
token = cfp.get('example', 'token')
connections.connect("default",
                    uri=milvus_uri,
                    token=token)
print(f"Connecting to DB: {milvus_uri}")
# Define collection name and dimensionality of embeddings
collection_name = 'molecule_embeddings'
check_collection = utility.has_collection(collection_name)
if check_collection:
     drop_result = utility.drop_collection(collection_name)
print("Success!")
dim = 768  # Adjust based on the dimensionality of your embeddings

# Define collection schema
molecule_cid = FieldSchema(name="molecule_cid", dtype=DataType.INT64, description="cid", is_primary = True)
molecule_name = FieldSchema(name="molecule_name", dtype=DataType.VARCHAR, max_length=256, description="name")
molecule_embeddings = FieldSchema(name="molecule_embedding", dtype=DataType.FLOAT_VECTOR, dim=dim)
schema = CollectionSchema(fields=[molecule_cid, molecule_name, molecule_embeddings],
                      auto_id=False,
                      description="my first collection!")

print(f"Creating example collection: {collection_name}")
collection = Collection(name=collection_name, schema=schema)
print(f"Schema: {schema}")
print("Success!")

batch_size = 1000
total_rt = 0
start = 0
print(f"Inserting {len(embeddings_list)} entities... ")
for i in tqdm(range(0, len(embeddings_list), batch_size), desc="Inserting Embeddings"):
    batch_embeddings = embeddings_list[i:i + batch_size]  
    batch_molecules = molecules[i:i + batch_size]
    batch_cids = cids[i:i + batch_size]
    entities = [batch_cids, batch_molecules, batch_embeddings]
    start += batch_size
    t0 = time.time()
    ins_resp = collection.insert(entities)
    ins_rt = time.time() - t0
    total_rt += ins_rt
    

print(f"Succeed in inserting {len(embeddings_list)} entities in {round(total_rt, 4)} seconds!")

# Flush collection
print("Flushing collection...")
collection.flush()

# Build index
index_params = {"index_type": "AUTOINDEX", "metric_type": "L2", "params": {}}
print("Building index...")
collection.create_index(field_name='molecule_embedding', index_params=index_params)

collection.load()
    
# Example search
nq = 1
search_params = {"metric_type": "L2"}
topk = 5
search_vec = [[random.random() for _ in range(dim)] for _ in range(nq)]
print(f"Searching vector: {search_vec}")
results = collection.search(search_vec, anns_field='molecule_embedding', param=search_params, limit=topk)
print(f"Search results: {results}")

# Disconnect from Milvus server
connections.disconnect("default")
print("Disconnected from Milvus server.")