|
import pandas as pd |
|
import time |
|
import random |
|
from sentence_transformers import SentenceTransformer |
|
from pymilvus import connections, DataType, FieldSchema, CollectionSchema, Collection, utility |
|
import configparser |
|
from tqdm import tqdm |
|
|
|
|
|
embedding_model = SentenceTransformer(model_name_or_path="bert-base-uncased") |
|
|
|
|
|
csv_path = 'molecules-small.csv' |
|
df = pd.read_csv(csv_path) |
|
max_name_length = 256 |
|
molecules = df['cmpdname'].tolist() |
|
|
|
for i, molecule in enumerate(molecules): |
|
if len(molecule) > max_name_length: |
|
molecules[i] = molecule[:max_name_length] |
|
|
|
cids = df['cid'].tolist() |
|
|
|
|
|
embeddings_list = [] |
|
for molecule in tqdm(molecules, desc="Generating Embeddings"): |
|
embeddings = embedding_model.encode(molecule) |
|
embeddings_list.append(embeddings) |
|
|
|
cfp = configparser.RawConfigParser() |
|
cfp.read('config.ini') |
|
milvus_uri = cfp.get('example', 'uri') |
|
token = cfp.get('example', 'token') |
|
connections.connect("default", |
|
uri=milvus_uri, |
|
token=token) |
|
print(f"Connecting to DB: {milvus_uri}") |
|
|
|
collection_name = 'molecule_embeddings' |
|
check_collection = utility.has_collection(collection_name) |
|
if check_collection: |
|
drop_result = utility.drop_collection(collection_name) |
|
print("Success!") |
|
dim = 768 |
|
|
|
|
|
molecule_cid = FieldSchema(name="molecule_cid", dtype=DataType.INT64, description="cid", is_primary = True) |
|
molecule_name = FieldSchema(name="molecule_name", dtype=DataType.VARCHAR, max_length=256, description="name") |
|
molecule_embeddings = FieldSchema(name="molecule_embedding", dtype=DataType.FLOAT_VECTOR, dim=dim) |
|
schema = CollectionSchema(fields=[molecule_cid, molecule_name, molecule_embeddings], |
|
auto_id=False, |
|
description="my first collection!") |
|
|
|
print(f"Creating example collection: {collection_name}") |
|
collection = Collection(name=collection_name, schema=schema) |
|
print(f"Schema: {schema}") |
|
print("Success!") |
|
|
|
batch_size = 1000 |
|
total_rt = 0 |
|
start = 0 |
|
print(f"Inserting {len(embeddings_list)} entities... ") |
|
for i in tqdm(range(0, len(embeddings_list), batch_size), desc="Inserting Embeddings"): |
|
batch_embeddings = embeddings_list[i:i + batch_size] |
|
batch_molecules = molecules[i:i + batch_size] |
|
batch_cids = cids[i:i + batch_size] |
|
entities = [batch_cids, batch_molecules, batch_embeddings] |
|
start += batch_size |
|
t0 = time.time() |
|
ins_resp = collection.insert(entities) |
|
ins_rt = time.time() - t0 |
|
total_rt += ins_rt |
|
|
|
|
|
print(f"Succeed in inserting {len(embeddings_list)} entities in {round(total_rt, 4)} seconds!") |
|
|
|
|
|
print("Flushing collection...") |
|
collection.flush() |
|
|
|
|
|
index_params = {"index_type": "AUTOINDEX", "metric_type": "L2", "params": {}} |
|
print("Building index...") |
|
collection.create_index(field_name='molecule_embedding', index_params=index_params) |
|
|
|
collection.load() |
|
|
|
|
|
nq = 1 |
|
search_params = {"metric_type": "L2"} |
|
topk = 5 |
|
search_vec = [[random.random() for _ in range(dim)] for _ in range(nq)] |
|
print(f"Searching vector: {search_vec}") |
|
results = collection.search(search_vec, anns_field='molecule_embedding', param=search_params, limit=topk) |
|
print(f"Search results: {results}") |
|
|
|
|
|
connections.disconnect("default") |
|
print("Disconnected from Milvus server.") |
|
|