Skip to content

Commit

Permalink
Support Milvus(70%)
Browse files Browse the repository at this point in the history
  • Loading branch information
HamaWhiteGG committed Aug 2, 2023
1 parent ad21a4f commit 84464e2
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,15 @@ public static String toJsonStringWithIndent(Object object) {
return toJsonStringWithIndent(object, 4);
}

public static String writeValueAsString(Object object) {
try {
ObjectMapper objectMapper = new ObjectMapper();
return objectMapper.writeValueAsString(object);
} catch (JsonProcessingException e) {
throw new LangChainException("Failed to format attribute info.", e);
}
}

public static <T> T convertFromJsonStr(String jsonStr, Class<T> clazz) {
try {
return OBJECT_MAPPER.readValue(jsonStr, clazz);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,20 @@
import io.milvus.client.MilvusClient;
import io.milvus.client.MilvusServiceClient;
import io.milvus.common.clientenum.ConsistencyLevelEnum;
import io.milvus.grpc.DataType;
import io.milvus.param.ConnectParam;
import io.milvus.param.Constant;
import io.milvus.param.collection.CreateCollectionParam;
import io.milvus.param.collection.DropCollectionParam;
import io.milvus.param.collection.FieldType;
import io.milvus.param.collection.HasCollectionParam;
import io.milvus.grpc.*;
import io.milvus.param.*;
import io.milvus.param.collection.*;
import io.milvus.param.dml.InsertParam;
import io.milvus.param.dml.SearchParam;
import io.milvus.param.index.CreateIndexParam;
import lombok.Builder;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import static com.hw.langchain.chains.query.constructor.JsonUtils.writeValueAsString;

/**
* Initialize wrapper around the milvus vector database.
*
Expand Down Expand Up @@ -76,12 +78,51 @@ public class Milvus extends VectorStore {

private boolean dropOld;

@Builder.Default
private int batchSize = 1000;

private MilvusClient milvusClient;

/**
* In order for a collection to be compatible, pk needs to be auto-id and int
*/
@Builder.Default
private String primaryField = "pk";

/**
* In order for compatibility, the text field will need to be called "text"
*/
@Builder.Default
private String textField = "text";

/**
* In order for compatibility, the vector field needs to be called "vector"
*/
@Builder.Default
private String vectorField = "vector";

@Builder.Default
private List<String> fields = new ArrayList<>();

private Map<String, Map<String, Object>> defaultSearchParams;

public Milvus init() {
milvusClient = new MilvusServiceClient(connectParam);

// If need to drop old, drop it
// default search params when one is not provided.
defaultSearchParams = Map.of(
"IVF_FLAT", Map.of("metric_type", "L2", "params", Map.of("nprobe", 10)),
"IVF_SQ8", Map.of("metric_type", "L2", "params", Map.of("nprobe", 10)),
"IVF_PQ", Map.of("metric_type", "L2", "params", Map.of("nprobe", 10)),
"HNSW", Map.of("metric_type", "L2", "params", Map.of("ef", 10)),
"RHNSW_FLAT", Map.of("metric_type", "L2", "params", Map.of("ef", 10)),
"RHNSW_SQ", Map.of("metric_type", "L2", "params", Map.of("ef", 10)),
"RHNSW_PQ", Map.of("metric_type", "L2", "params", Map.of("ef", 10)),
"IVF_HNSW", Map.of("metric_type", "L2", "params", Map.of("nprobe", 10, "ef", 10)),
"ANNOY", Map.of("metric_type", "L2", "params", Map.of("search_k", 10)),
"AUTOINDEX", Map.of("metric_type", "L2", "params", Map.of()));

// if need to drop old, drop it
if (hasCollection() && dropOld) {
milvusClient.dropCollection(
DropCollectionParam.newBuilder()
Expand All @@ -98,6 +139,15 @@ private boolean hasCollection() {
return milvusClient.hasCollection(requestParam).getData();
}

private void innerInit(List<List<Float>> embeddings, List<Map<String, Object>> metadatas) {
if (CollectionUtils.isNotEmpty(embeddings)) {
createCollection(embeddings, metadatas);
}
extractFields();
createIndex();
load();
}

public void createCollection(List<List<Float>> embeddings, List<Map<String, Object>> metadatas) {
CreateCollectionParam.Builder builder = CreateCollectionParam.newBuilder()
.withCollectionName(collectionName)
Expand Down Expand Up @@ -126,20 +176,20 @@ public void createCollection(List<List<Float>> embeddings, List<Map<String, Obje
}
// create the text field
builder.addFieldType(FieldType.newBuilder()
.withName("text")
.withName(textField)
.withDataType(DataType.VarChar)
.withTypeParams(Map.of(Constant.VARCHAR_MAX_LENGTH, "65535"))
.build());
// create the primary key field
builder.addFieldType(FieldType.newBuilder()
.withName("pk")
.withName(primaryField)
.withDataType(DataType.Int64)
.withPrimaryKey(true)
.withAutoID(true)
.build());
// create the vector field, supports binary or float vectors
builder.addFieldType(FieldType.newBuilder()
.withName("vector")
.withName(vectorField)
.withDataType(DataType.FloatVector)
.withDimension(dim)
.build());
Expand All @@ -153,28 +203,129 @@ private DataType inferDataTypeByData(Object value) {
return DataType.valueOf(value.toString());
}

/**
* Grab the existing fields from the Collection
*/
private void extractFields() {
R<DescribeCollectionResponse> response = milvusClient.describeCollection(
// Return the name and schema of the collection.
DescribeCollectionParam.newBuilder()
.withCollectionName(collectionName)
.build());

CollectionSchema schema = response.getData().getSchema();
for (FieldSchema x : schema.getFieldsList()) {
fields.add(x.getName());
}
// since primary field is auto-id, no need to track it
fields.remove(primaryField);
}

private Map<String, Object> getIndex() {
return null;
}

/**
* Create a index on the collection
*/
private void createIndex() {
Map<String, Object> extraParam = Map.of("M", 8, "efConstruction", 64);
CreateIndexParam requestParam = CreateIndexParam.newBuilder()
.withCollectionName(collectionName)
.withFieldName(vectorField)
.withIndexType(IndexType.HNSW)
.withMetricType(MetricType.L2)
.withExtraParam(writeValueAsString(extraParam))
.withSyncMode(false)
.build();
milvusClient.createIndex(requestParam);
LOG.info("Successfully created an index on collection: {}", collectionName);
}

/**
* Generate search params based on the current index type
*/
private void createSearchParams() {

}

/**
* Load the collection if available.
*/
private void load() {
LoadCollectionParam requestParam = LoadCollectionParam.newBuilder()
.withCollectionName(collectionName)
.build();

milvusClient.loadCollection(requestParam);
}

@Override
public List<String> addTexts(List<String> texts, List<Map<String, Object>> metadatas) {
List<List<Float>> embeddings = embeddingFunction.embedDocuments(texts);
if (embeddings.isEmpty()) {
LOG.warn("Nothing to insert, skipping.");
return List.of();
}
if (!hasCollection()) {
createCollection(embeddings, metadatas);
}

return List.of();
// if the collection hasn't been initialized yet, perform all steps to take so
innerInit(embeddings, metadatas);

// dict to hold all insert columns
Map<String, List<?>> insertDict = Map.of(textField, texts, vectorField, embeddings);
int totalCount = embeddings.size();
List<String> pks = new ArrayList<>();
for (int i = 0; i < totalCount; i += batchSize) {
// grab end index
int end = Math.min(i + batchSize, totalCount);
// convert map to batch list for insertion
List<InsertParam.Field> insertFields = new ArrayList<>();
for (String field : fields) {
insertFields.add(new InsertParam.Field(field, insertDict.get(field).subList(i, end)));
}
// insert into the collection.
InsertParam insertParam = InsertParam.newBuilder()
.withCollectionName(collectionName)
.withFields(insertFields)
.build();
var res = milvusClient.insert(insertParam);
pks.addAll(res.getData().getIDs().getStrId().getDataList());
}
return pks;
}

@Override
public boolean delete(List<String> ids) {
return false;
}

private List<Pair<Document, Float>> similaritySearchWithScore(String query, int k, Map<String, Object> filter) {
// embed the query text.
List<Float> embedding = embeddingFunction.embedQuery(query);

// determine result metadata fields.
List<String> outputFields = new ArrayList<>(fields);
outputFields.remove(vectorField);

// perform the search.
SearchParam searchParam = SearchParam.newBuilder()
.withCollectionName(collectionName)
.withConsistencyLevel(consistencyLevel)
.withMetricType(MetricType.L2)
.withOutFields(outputFields)
.withTopK(k)
.withVectors(List.of(embedding))
.withVectorFieldName(vectorField)
// .withParams(SEARCH_PARAM)
.build();
R<SearchResults> respSearch = milvusClient.search(searchParam);
return null;
}

@Override
public List<Document> similaritySearch(String query, int k, Map<String, Object> filter) {
return null;
List<Pair<Document, Float>> docsAndScores = similaritySearchWithScore(query, k, filter);
return docsAndScores.stream().map(Pair::getLeft).toList();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ private Milvus milvusFromTexts(List<Map<String, Object>> metadatas, boolean drop
@Test
void testMilvus() {
Milvus docSearch = milvusFromTexts(List.of(), true);
List<Document> output = docSearch.similaritySearch("foo", Map.of("k", 1));

// assertEquals(List.of(new Document("foo")), output);
List<Document> output = docSearch.similaritySearch("foo", 1, Map.of());
assertEquals(List.of(new Document("foo")), output);
}
}

0 comments on commit 84464e2

Please sign in to comment.