Chem-210-Autograder / embeddings.py
anthony-chen's picture
kek
a1551fc
raw
history blame contribute delete
No virus
3.56 kB
import pandas as pd
import time
import random
from sentence_transformers import SentenceTransformer
from pymilvus import connections, DataType, FieldSchema, CollectionSchema, Collection, utility
import configparser
from tqdm import tqdm
# Initialize SentenceTransformer model for embeddings
embedding_model = SentenceTransformer(model_name_or_path="bert-base-uncased")
# Read molecule names from CSV
csv_path = 'molecules-small.csv'
df = pd.read_csv(csv_path)
max_name_length = 256
molecules = df['cmpdname'].tolist()
for i, molecule in enumerate(molecules):
if len(molecule) > max_name_length:
molecules[i] = molecule[:max_name_length]
cids = df['cid'].tolist()
# Encode embeddings for each molecule
embeddings_list = []
for molecule in tqdm(molecules, desc="Generating Embeddings"):
embeddings = embedding_model.encode(molecule)
embeddings_list.append(embeddings)
cfp = configparser.RawConfigParser()
cfp.read('config.ini')
milvus_uri = cfp.get('example', 'uri')
token = cfp.get('example', 'token')
connections.connect("default",
uri=milvus_uri,
token=token)
print(f"Connecting to DB: {milvus_uri}")
# Define collection name and dimensionality of embeddings
collection_name = 'molecule_embeddings'
check_collection = utility.has_collection(collection_name)
if check_collection:
drop_result = utility.drop_collection(collection_name)
print("Success!")
dim = 768 # Adjust based on the dimensionality of your embeddings
# Define collection schema
molecule_cid = FieldSchema(name="molecule_cid", dtype=DataType.INT64, description="cid", is_primary = True)
molecule_name = FieldSchema(name="molecule_name", dtype=DataType.VARCHAR, max_length=256, description="name")
molecule_embeddings = FieldSchema(name="molecule_embedding", dtype=DataType.FLOAT_VECTOR, dim=dim)
schema = CollectionSchema(fields=[molecule_cid, molecule_name, molecule_embeddings],
auto_id=False,
description="my first collection!")
print(f"Creating example collection: {collection_name}")
collection = Collection(name=collection_name, schema=schema)
print(f"Schema: {schema}")
print("Success!")
batch_size = 1000
total_rt = 0
start = 0
print(f"Inserting {len(embeddings_list)} entities... ")
for i in tqdm(range(0, len(embeddings_list), batch_size), desc="Inserting Embeddings"):
batch_embeddings = embeddings_list[i:i + batch_size]
batch_molecules = molecules[i:i + batch_size]
batch_cids = cids[i:i + batch_size]
entities = [batch_cids, batch_molecules, batch_embeddings]
start += batch_size
t0 = time.time()
ins_resp = collection.insert(entities)
ins_rt = time.time() - t0
total_rt += ins_rt
print(f"Succeed in inserting {len(embeddings_list)} entities in {round(total_rt, 4)} seconds!")
# Flush collection
print("Flushing collection...")
collection.flush()
# Build index
index_params = {"index_type": "AUTOINDEX", "metric_type": "L2", "params": {}}
print("Building index...")
collection.create_index(field_name='molecule_embedding', index_params=index_params)
collection.load()
# Example search
nq = 1
search_params = {"metric_type": "L2"}
topk = 5
search_vec = [[random.random() for _ in range(dim)] for _ in range(nq)]
print(f"Searching vector: {search_vec}")
results = collection.search(search_vec, anns_field='molecule_embedding', param=search_params, limit=topk)
print(f"Search results: {results}")
# Disconnect from Milvus server
connections.disconnect("default")
print("Disconnected from Milvus server.")