为了实现从 Milvus 中获取数据,并基于嵌入向量重新排序的功能,你可以参考以下步骤对原代码进行完善和修改。
关键问题和修改:
- Milvus 数据获取:确保
query()
能获取插入的数据,尤其是向量。确保Collection
正确连接,并且output_fields
中字段名匹配 Milvus 中的定义。 - 加载数据:插入数据后,需要通过
load()
加载数据,确保数据在 Milvus 中可见。 - 优化向量查询逻辑:可以考虑使用
search()
方法,而不是直接从 Milvus 查询所有数据,再计算相似度。search()
可以直接根据查询文本的向量进行向量相似度搜索。
代码修改:
import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np
from pymilvus import Collection, connections
# 计算余弦相似度的函数
def cos_sim(a, b):
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
# 生成文本的嵌入向量
def embed_texts(texts, model, tokenizer):
"""
使用 Jina-embeddings-v2 模型将文本转化为嵌入向量
"""
inputs = tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
embeddings = outputs.last_hidden_state[:, 0, :] # 提取CLS token嵌入
return embeddings.numpy()
# 从 Milvus 获取候选文本和嵌入向量
def get_candidates_from_milvus(collection_name, limit=10):
"""
从 Milvus 中获取候选文本及其对应的嵌入
"""
# 连接到 Milvus
connections.connect(alias="default", host="localhost", port="19530")
collection = Collection(collection_name)
# 加载数据确保可用
collection.load()
# 简单查询所有候选文本和嵌入向量
results = collection.query(expr="", output_fields=["text", "embedding"], limit=limit)
# 提取文本和嵌入向量
candidate_texts = [result['text'] for result in results]
candidate_embeddings = np.array([result['embedding'] for result in results])
return candidate_texts, candidate_embeddings
def rerank_candidates(query_text, model, tokenizer, candidate_texts, candidate_embeddings):
"""
重新排序候选文本,基于与查询文本的相似度
"""
# 为查询文本生成嵌入向量
query_embedding = embed_texts([query_text], model, tokenizer)[0]
# 计算查询与候选文本的相似度
similarities = [cos_sim(query_embedding, embedding) for embedding in candidate_embeddings]
# 根据相似度进行重排序
sorted_candidates = sorted(zip(candidate_texts, similarities), key=lambda x: x[1], reverse=True)
# 返回重排序后的结果
return sorted_candidates
def main():
# 加载 Jina-embeddings-v2-base-zh 模型和对应的 tokenizer
model_name = 'jinaai/jina-embeddings-v2-base-zh'
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
# 获取候选文本和嵌入向量(从 Milvus)
candidate_texts, candidate_embeddings = get_candidates_from_milvus("weather", limit=10)
# 查询文本
query_text = "天气咋样?"
# 重新排序候选文本
sorted_candidates = rerank_candidates(query_text, model, tokenizer, candidate_texts, candidate_embeddings)
# 输出重排序结果
print("Ranked results based on similarity:")
for text, score in sorted_candidates:
print(f"Candidate: {text}, Similarity: {score}")
if __name__ == "__main__":
main()
关键改动:
collection.load()
:确保在查询前加载集合中的数据,使其可被查询。Milvus 使用延迟加载,所以需要显式调用load()
。query()
:获取集合中的text
和embedding
字段。根据你在 Milvus 集合中插入的字段来确定output_fields
。
运行步骤:
- 确保 Milvus 服务器已经启动并连接正常。
- 插入的数据已成功提交并在 Milvus 集合中可见。
- 执行脚本,通过从 Milvus 中查询文本和嵌入数据,然后对查询文本进行重排序。
小提示:
如果你希望直接在 Milvus 中进行相似度搜索(不手动计算余弦相似度),可以使用 search()
方法,Milvus 本身支持基于向量的相似度查询,会更加高效。