Skip to content

Commit

Permalink
Support Milvus
Browse files Browse the repository at this point in the history
  • Loading branch information
HamaWhiteGG committed Aug 3, 2023
1 parent 84464e2 commit b49b1d4
Show file tree
Hide file tree
Showing 6 changed files with 203 additions and 44 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ The following example can view in the [langchain-example](langchain-examples/src
- [API Chains](langchain-examples/src/main/java/com/hw/langchain/examples/chains/ApiChainExample.java)
- [Spark SQL Agent](langchain-bigdata/langchain-spark/src/test/java/com/hw/langchain/agents/toolkits/spark/sql/toolkit/SparkSqlToolkitTest.java)
- [Flink SQL Agent](langchain-bigdata/langchain-flink/src/test/java/com/hw/langchain/agents/toolkits/flink/sql/toolkit/FlinkSqlToolkitTest.java)
- [QA-Milvus](langchain-examples/src/main/java/com/hw/langchain/examples/chains/MilvusExample.java)
- [QA-Pinecone](langchain-examples/src/main/java/com/hw/langchain/examples/chains/RetrievalQaExample.java)
- [QA-Pinecone-Markdown](langchain-examples/src/main/java/com/hw/langchain/examples/chains/RetrievalMarkdownExample.java)
- [Agent with Google Search](langchain-examples/src/main/java/com/hw/langchain/examples/agents/LlmAgentExample.java)
- [Question answering over documents](langchain-examples/src/main/java/com/hw/langchain/examples/chains/RetrievalQaExample.java)
- [Context aware text splitting and QA](langchain-examples/src/main/java/com/hw/langchain/examples/chains/RetrievalMarkdownExample.java)

## 3. Quickstart Guide
This tutorial gives you a quick walkthrough about building an end-to-end language model application with LangChain.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

package com.hw.langchain.schema;

import com.google.common.collect.Maps;

import lombok.Data;

import java.util.Map;
Expand All @@ -40,6 +42,6 @@ public Document(String pageContent, Map<String, Object> metadata) {
}

public Document(String pageContent) {
this.pageContent = pageContent;
this(pageContent, Maps.newHashMap());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ public List<Pair<Document, Float>> similaritySearchWithRelevanceScores(String qu
* @return List of Tuples of (doc, similarityScore)
*/
public List<Pair<Document, Float>> similaritySearchWithRelevanceScores(String query, int k) {
List<Pair<Document, Float>> docsAndSimilarities = _similaritySearchWithRelevanceScores(query, k);
List<Pair<Document, Float>> docsAndSimilarities = innerSimilaritySearchWithRelevanceScores(query, k);

// Check relevance scores and filter by threshold
if (docsAndSimilarities.stream().anyMatch(pair -> pair.getRight() < 0.0f || pair.getRight() > 1.0f)) {
Expand All @@ -139,7 +139,7 @@ public List<Pair<Document, Float>> similaritySearchWithRelevanceScores(String qu
* @param k Number of Documents to return.
* @return List of Tuples of (doc, similarityScore)
*/
protected abstract List<Pair<Document, Float>> _similaritySearchWithRelevanceScores(String query, int k);
protected abstract List<Pair<Document, Float>> innerSimilaritySearchWithRelevanceScores(String query, int k);

/**
* Return docs most similar to embedding vector.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package com.hw.langchain.vectorstores.milvus;

import com.google.common.collect.Maps;
import com.hw.langchain.embeddings.base.Embeddings;
import com.hw.langchain.schema.Document;
import com.hw.langchain.vectorstores.base.VectorStore;
Expand All @@ -36,11 +37,12 @@
import io.milvus.param.dml.InsertParam;
import io.milvus.param.dml.SearchParam;
import io.milvus.param.index.CreateIndexParam;
import io.milvus.param.index.DescribeIndexParam;
import io.milvus.response.DescIndexResponseWrapper;
import io.milvus.response.SearchResultsWrapper;
import lombok.Builder;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.*;

import static com.hw.langchain.chains.query.constructor.JsonUtils.writeValueAsString;

Expand All @@ -54,6 +56,8 @@ public class Milvus extends VectorStore {

private static final Logger LOG = LoggerFactory.getLogger(Milvus.class);

private static final String METRIC_TYPE = "metric_type";

/**
* Function used to embed the text.
*/
Expand All @@ -76,7 +80,8 @@ public class Milvus extends VectorStore {
@Builder.Default
private ConsistencyLevelEnum consistencyLevel = ConsistencyLevelEnum.STRONG;

private boolean dropOld;
@Builder.Default
private boolean dropOld = true;

@Builder.Default
private int batchSize = 1000;
Expand Down Expand Up @@ -106,23 +111,14 @@ public class Milvus extends VectorStore {

private Map<String, Map<String, Object>> defaultSearchParams;

private Map<String, Object> searchParams;

public Milvus init() {
milvusClient = new MilvusServiceClient(connectParam);

// default search params when one is not provided.
defaultSearchParams = Map.of(
"IVF_FLAT", Map.of("metric_type", "L2", "params", Map.of("nprobe", 10)),
"IVF_SQ8", Map.of("metric_type", "L2", "params", Map.of("nprobe", 10)),
"IVF_PQ", Map.of("metric_type", "L2", "params", Map.of("nprobe", 10)),
"HNSW", Map.of("metric_type", "L2", "params", Map.of("ef", 10)),
"RHNSW_FLAT", Map.of("metric_type", "L2", "params", Map.of("ef", 10)),
"RHNSW_SQ", Map.of("metric_type", "L2", "params", Map.of("ef", 10)),
"RHNSW_PQ", Map.of("metric_type", "L2", "params", Map.of("ef", 10)),
"IVF_HNSW", Map.of("metric_type", "L2", "params", Map.of("nprobe", 10, "ef", 10)),
"ANNOY", Map.of("metric_type", "L2", "params", Map.of("search_k", 10)),
"AUTOINDEX", Map.of("metric_type", "L2", "params", Map.of()));

// if need to drop old, drop it
initDefaultSearchParams();

// if you need to drop old, drop it
if (hasCollection() && dropOld) {
milvusClient.dropCollection(
DropCollectionParam.newBuilder()
Expand All @@ -132,6 +128,45 @@ public Milvus init() {
return this;
}

private void initDefaultSearchParams() {
// Initialize mutable maps
Map<String, Object> innerParams1 = Maps.newHashMap();
innerParams1.put("nprobe", 10);

Map<String, Object> innerParams2 = Maps.newHashMap();
innerParams2.put("ef", 10);

Map<String, Object> innerParams3 = Maps.newHashMap();
innerParams3.put("nprobe", 10);
innerParams3.put("ef", 10);

// Initialize the main map
defaultSearchParams = Maps.newHashMap();
defaultSearchParams.put("IVF_FLAT", createInnerMap("L2", innerParams1));
defaultSearchParams.put("IVF_SQ8", createInnerMap("L2", innerParams1));
defaultSearchParams.put("IVF_PQ", createInnerMap("L2", innerParams1));
defaultSearchParams.put("HNSW", createInnerMap("L2", innerParams2));
defaultSearchParams.put("RHNSW_FLAT", createInnerMap("L2", innerParams2));
defaultSearchParams.put("RHNSW_SQ", createInnerMap("L2", innerParams2));
defaultSearchParams.put("RHNSW_PQ", createInnerMap("L2", innerParams2));
defaultSearchParams.put("IVF_HNSW", createInnerMap("L2", innerParams3));
defaultSearchParams.put("ANNOY", createInnerMap("L2", createInnerParams("search_k", 10)));
defaultSearchParams.put("AUTOINDEX", createInnerMap("L2", Maps.newHashMap()));
}

private Map<String, Object> createInnerMap(String metricType, Map<String, Object> params) {
Map<String, Object> innerMap = Maps.newHashMap();
innerMap.put(METRIC_TYPE, metricType);
innerMap.put("params", params);
return innerMap;
}

private Map<String, Object> createInnerParams(String key, Object value) {
Map<String, Object> innerParams = Maps.newHashMap();
innerParams.put(key, value);
return innerParams;
}

private boolean hasCollection() {
HasCollectionParam requestParam = HasCollectionParam.newBuilder()
.withCollectionName(collectionName)
Expand All @@ -145,6 +180,7 @@ private void innerInit(List<List<Float>> embeddings, List<Map<String, Object>> m
}
extractFields();
createIndex();
createSearchParams();
load();
}

Expand All @@ -169,6 +205,7 @@ public void createCollection(List<List<Float>> embeddings, List<Map<String, Obje
FieldType fieldType = FieldType.newBuilder()
.withName(key)
.withDataType(dataType)
.withTypeParams(Map.of(Constant.VARCHAR_MAX_LENGTH, "65535"))
.build();
builder.addFieldType(fieldType);
}
Expand Down Expand Up @@ -199,8 +236,8 @@ public void createCollection(List<List<Float>> embeddings, List<Map<String, Obje
}

private DataType inferDataTypeByData(Object value) {
// TODO: Find corresponding method in Java
return DataType.valueOf(value.toString());
LOG.debug("meta value: {}", value);
return DataType.VarChar;
}

/**
Expand All @@ -221,32 +258,53 @@ private void extractFields() {
fields.remove(primaryField);
}

private Map<String, Object> getIndex() {
private DescIndexResponseWrapper.IndexDesc getIndex() {
DescribeIndexParam requestParam = DescribeIndexParam.newBuilder()
.withCollectionName(collectionName)
.build();

R<DescribeIndexResponse> response = milvusClient.describeIndex(requestParam);
if (response.getData() != null) {
DescIndexResponseWrapper wrapper = new DescIndexResponseWrapper(response.getData());
for (DescIndexResponseWrapper.IndexDesc desc : wrapper.getIndexDescriptions()) {
if (desc.getFieldName().equals(vectorField)) {
return desc;
}
}
}
return null;
}

/**
* Create a index on the collection
*/
private void createIndex() {
Map<String, Object> extraParam = Map.of("M", 8, "efConstruction", 64);
CreateIndexParam requestParam = CreateIndexParam.newBuilder()
.withCollectionName(collectionName)
.withFieldName(vectorField)
.withIndexType(IndexType.HNSW)
.withMetricType(MetricType.L2)
.withExtraParam(writeValueAsString(extraParam))
.withSyncMode(false)
.build();
milvusClient.createIndex(requestParam);
LOG.info("Successfully created an index on collection: {}", collectionName);
if (getIndex() == null) {
Map<String, Object> extraParam = Map.of("M", 8, "efConstruction", 64);
CreateIndexParam requestParam = CreateIndexParam.newBuilder()
.withCollectionName(collectionName)
.withFieldName(vectorField)
.withIndexType(IndexType.HNSW)
.withMetricType(MetricType.L2)
.withExtraParam(writeValueAsString(extraParam))
.withSyncMode(false)
.build();
milvusClient.createIndex(requestParam);
LOG.info("Successfully created an index on collection: {}", collectionName);
}
}

/**
* Generate search params based on the current index type
*/
private void createSearchParams() {

DescIndexResponseWrapper.IndexDesc index = getIndex();
if (index != null) {
String indexType = index.getParams().get("index_type");
String metricType = index.getParams().get(METRIC_TYPE);
searchParams = defaultSearchParams.get(indexType);
searchParams.put(METRIC_TYPE, metricType);
}
}

/**
Expand All @@ -272,7 +330,28 @@ public List<String> addTexts(List<String> texts, List<Map<String, Object>> metad
innerInit(embeddings, metadatas);

// dict to hold all insert columns
Map<String, List<?>> insertDict = Map.of(textField, texts, vectorField, embeddings);
Map<String, List<?>> insertDict = Maps.newHashMap();
insertDict.put(textField, texts);
insertDict.put(vectorField, embeddings);

// collect the metadata into the insert dict.
if (metadatas != null) {
for (var meta : metadatas) {
meta.forEach((key, value) -> {
if (fields.contains(key)) {
@SuppressWarnings("unchecked")
List<Object> dict = (List<Object>) insertDict.get(key);
if (dict == null) {
dict = new ArrayList<>();
insertDict.put(key, dict);
}
dict.add(value);
}
});
}
}

// total insert count
int totalCount = embeddings.size();
List<String> pks = new ArrayList<>();
for (int i = 0; i < totalCount; i += batchSize) {
Expand Down Expand Up @@ -311,15 +390,28 @@ private List<Pair<Document, Float>> similaritySearchWithScore(String query, int
SearchParam searchParam = SearchParam.newBuilder()
.withCollectionName(collectionName)
.withConsistencyLevel(consistencyLevel)
.withMetricType(MetricType.L2)
.withMetricType(MetricType.valueOf(searchParams.get(METRIC_TYPE).toString()))
.withOutFields(outputFields)
.withTopK(k)
.withVectors(List.of(embedding))
.withVectorFieldName(vectorField)
// .withParams(SEARCH_PARAM)
.withParams(writeValueAsString(searchParams.get("params")))
.build();
R<SearchResults> respSearch = milvusClient.search(searchParam);
return null;
SearchResultsWrapper wrapperSearch = new SearchResultsWrapper(respSearch.getData().getResults());

// organize results.
List<Pair<Document, Float>> ret = new ArrayList<>();
for (var result : wrapperSearch.getRowRecords()) {
Map<String, Object> meta = Maps.newHashMap();
for (String x : outputFields) {
meta.put(x, result.get(x));
}
Document doc = new Document((String) meta.remove(textField), meta);
Pair<Document, Float> pair = Pair.of(doc, (Float) result.get("distance"));
ret.add(pair);
}
return ret;
}

@Override
Expand All @@ -329,7 +421,7 @@ public List<Document> similaritySearch(String query, int k, Map<String, Object>
}

@Override
protected List<Pair<Document, Float>> _similaritySearchWithRelevanceScores(String query, int k) {
protected List<Pair<Document, Float>> innerSimilaritySearchWithRelevanceScores(String query, int k) {
return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ public List<Document> similaritySearch(String query, int k, Map<String, Object>
}

@Override
protected List<Pair<Document, Float>> _similaritySearchWithRelevanceScores(String query, int k) {
protected List<Pair<Document, Float>> innerSimilaritySearchWithRelevanceScores(String query, int k) {
return null;
}

Expand Down
Loading

0 comments on commit b49b1d4

Please sign in to comment.