Skip to content

Commit

Permalink
Support Milvus(30%)
Browse files Browse the repository at this point in the history
  • Loading branch information
HamaWhiteGG committed Jul 31, 2023
1 parent c17b614 commit 45f68ce
Show file tree
Hide file tree
Showing 8 changed files with 362 additions and 5 deletions.
5 changes: 5 additions & 0 deletions langchain-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
<version>${project.version}</version>
</dependency>

<dependency>
<groupId>io.milvus</groupId>
<artifactId>milvus-sdk-java</artifactId>
</dependency>

<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,8 @@ public Document(String pageContent, Map<String, Object> metadata) {
this.pageContent = pageContent;
this.metadata = metadata;
}

public Document(String pageContent) {
this.pageContent = pageContent;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,9 @@ public abstract class VectorStore {
*
* @param texts Iterable of strings to add to the vectorStore.
* @param metadatas list of metadatas associated with the texts.
* @param kwargs vectorStore specific parameters
* @return List of ids from adding the texts into the vectorStore.
*/
public abstract List<String> addTexts(List<String> texts, List<Map<String, Object>> metadatas,
Map<String, Object> kwargs);
public abstract List<String> addTexts(List<String> texts, List<Map<String, Object>> metadatas);

/**
* Delete by vector ID.
Expand All @@ -66,7 +64,7 @@ public abstract List<String> addTexts(List<String> texts, List<Map<String, Objec
public List<String> addDocuments(List<Document> documents, Map<String, Object> kwargs) {
var texts = documents.stream().map(Document::getPageContent).toList();
var metadatas = documents.stream().map(Document::getMetadata).toList();
return addTexts(texts, metadatas, kwargs);
return addTexts(texts, metadatas);
}

public List<Document> search(String query, SearchType searchType, Map<String, Object> filter) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.vectorstores.milvus;

import com.hw.langchain.embeddings.base.Embeddings;
import com.hw.langchain.schema.Document;
import com.hw.langchain.vectorstores.base.VectorStore;

import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import io.milvus.client.MilvusClient;
import io.milvus.client.MilvusServiceClient;
import io.milvus.common.clientenum.ConsistencyLevelEnum;
import io.milvus.grpc.DataType;
import io.milvus.param.ConnectParam;
import io.milvus.param.Constant;
import io.milvus.param.collection.CreateCollectionParam;
import io.milvus.param.collection.DropCollectionParam;
import io.milvus.param.collection.FieldType;
import io.milvus.param.collection.HasCollectionParam;
import lombok.Builder;

import java.util.List;
import java.util.Map;

/**
* Initialize wrapper around the milvus vector database.
*
* @author HamaWhite
*/
@Builder
public class Milvus extends VectorStore {

private static final Logger LOG = LoggerFactory.getLogger(Milvus.class);

/**
* Function used to embed the text.
*/
private Embeddings embeddingFunction;

/**
* Parameters for client connection.
*/
private ConnectParam connectParam;

/**
* Which Milvus collection to use.
*/
@Builder.Default
private String collectionName = "LangChainCollection";

/**
* The consistency level to use for a collection.
*/
@Builder.Default
private ConsistencyLevelEnum consistencyLevel = ConsistencyLevelEnum.STRONG;

private boolean dropOld;

private MilvusClient milvusClient;

public Milvus init() {
milvusClient = new MilvusServiceClient(connectParam);

// If need to drop old, drop it
if (hasCollection() && dropOld) {
milvusClient.dropCollection(
DropCollectionParam.newBuilder()
.withCollectionName(collectionName)
.build());
}
return this;
}

private boolean hasCollection() {
HasCollectionParam requestParam = HasCollectionParam.newBuilder()
.withCollectionName(collectionName)
.build();
return milvusClient.hasCollection(requestParam).getData();
}

public void createCollection(List<List<Float>> embeddings, List<Map<String, Object>> metadatas) {
CreateCollectionParam.Builder builder = CreateCollectionParam.newBuilder()
.withCollectionName(collectionName)
.withEnableDynamicField(true);

// determine embedding dim
int dim = embeddings.get(0).size();
// determine metadata schema
if (CollectionUtils.isNotEmpty(metadatas)) {
// create FieldSchema for each entry in metadata.
metadatas.get(0).forEach((key, value) -> {
// infer the corresponding datatype of the metadata
DataType dataType = inferDataTypeByData(value);
// dataType isn't compatible
if (dataType == DataType.UNRECOGNIZED || dataType == DataType.None) {
LOG.error("Failure to create collection, unrecognized dataType for key: {}", key);
throw new IllegalArgumentException("Unrecognized datatype for " + key + ".");
} else {
FieldType fieldType = FieldType.newBuilder()
.withName(key)
.withDataType(dataType)
.build();
builder.addFieldType(fieldType);
}
});
}
// create the text field
builder.addFieldType(FieldType.newBuilder()
.withName("text")
.withDataType(DataType.VarChar)
.withTypeParams(Map.of(Constant.VARCHAR_MAX_LENGTH, "65535"))
.build());
// create the primary key field
builder.addFieldType(FieldType.newBuilder()
.withName("pk")
.withDataType(DataType.Int64)
.withPrimaryKey(true)
.withAutoID(true)
.build());
// create the vector field, supports binary or float vectors
builder.addFieldType(FieldType.newBuilder()
.withName("vector")
.withDataType(DataType.FloatVector)
.withDimension(dim)
.build());

// create the collection
milvusClient.createCollection(builder.build());
}

private DataType inferDataTypeByData(Object value) {
// TODO: Find corresponding method in Java
return DataType.valueOf(value.toString());
}

@Override
public List<String> addTexts(List<String> texts, List<Map<String, Object>> metadatas) {
List<List<Float>> embeddings = embeddingFunction.embedDocuments(texts);
if (embeddings.isEmpty()) {
LOG.warn("Nothing to insert, skipping.");
return List.of();
}
if (!hasCollection()) {
createCollection(embeddings, metadatas);
}

return List.of();
}

@Override
public boolean delete(List<String> ids) {
return false;
}

@Override
public List<Document> similaritySearch(String query, int k, Map<String, Object> filter) {
return null;
}

@Override
protected List<Pair<Document, Float>> _similaritySearchWithRelevanceScores(String query, int k) {
return null;
}

@Override
public List<Document> similarSearchByVector(List<Float> embedding, int k, Map<String, Object> kwargs) {
return null;
}

@Override
public List<Document> maxMarginalRelevanceSearch(String query, int k, int fetchK, float lambdaMult) {
return null;
}

@Override
public List<Document> maxMarginalRelevanceSearchByVector(List<Float> embedding, int k, int fetchK,
float lambdaMult) {
return null;
}

@Override
public int fromTexts(List<String> texts, Embeddings embedding, List<Map<String, Object>> metadatas) {
return addTexts(texts, metadatas).size();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public Pinecone init() {
}

@Override
public List<String> addTexts(List<String> texts, List<Map<String, Object>> metadatas, Map<String, Object> kwargs) {
public List<String> addTexts(List<String> texts, List<Map<String, Object>> metadatas) {
return null;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.vectorstores.fake.embeddings;

import com.hw.langchain.embeddings.base.Embeddings;

import java.util.ArrayList;
import java.util.List;

/**
* Fake embeddings functionality for testing.
*
* @author HamaWhite
*/
public class FakeEmbeddings implements Embeddings {

public static final List<String> FAKE_TEXTS = List.of("foo", "bar", "baz");

/**
* Return simple embeddings. Embeddings encode each text as its index.
*/
@Override
public List<List<Float>> embedDocuments(List<String> texts) {
List<List<Float>> embeddings = new ArrayList<>();
for (int i = 0; i < texts.size(); i++) {
List<Float> embedding = new ArrayList<>();
for (int j = 0; j < 9; j++) {
embedding.add(1.0f);
}
embedding.add((float) i);
embeddings.add(embedding);
}
return embeddings;
}

/**
* Return constant query embeddings. Embeddings are identical to embedDocuments(texts).get(0).
* Distance to each text will be that text's index, as it was passed to embedDocuments.
*/
@Override
public List<Float> embedQuery(String text) {
List<Float> embedding = new ArrayList<>();
for (int i = 0; i < 9; i++) {
embedding.add(1.0f);
}
embedding.add(0.0f);
return embedding;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.vectorstores.milvus;

import com.hw.langchain.embeddings.base.Embeddings;
import com.hw.langchain.schema.Document;
import com.hw.langchain.vectorstores.fake.embeddings.FakeEmbeddings;

import org.junit.jupiter.api.Test;

import io.milvus.param.ConnectParam;

import java.util.List;
import java.util.Map;

import static com.hw.langchain.vectorstores.fake.embeddings.FakeEmbeddings.FAKE_TEXTS;
import static org.junit.jupiter.api.Assertions.*;

/**
* Test Milvus functionality.
* <p>
* See the following documentation for how to run a Milvus instance:
* <a href="https://milvus.io/docs/install_standalone-docker.md">install_standalone-docker</a>
*
* @author HamaWhite
*/
class MilvusTest {

private Milvus milvusFromTexts(List<Map<String, Object>> metadatas, boolean dropOld) {
ConnectParam connectParam = ConnectParam.newBuilder()
.withHost("127.0.0.1")
.withPort(19530)
.build();

Embeddings embedding = new FakeEmbeddings();
Milvus milvus = Milvus.builder()
.embeddingFunction(embedding)
.connectParam(connectParam)
.collectionName("LangChainCollection_1")
.dropOld(dropOld)
.build()
.init();
milvus.fromTexts(FAKE_TEXTS, embedding, metadatas);
return milvus;
}

/**
* Test end to end construction and search.
*/
@Test
void testMilvus() {
Milvus docSearch = milvusFromTexts(List.of(), true);
List<Document> output = docSearch.similaritySearch("foo", Map.of("k", 1));

// assertEquals(List.of(new Document("foo")), output);
}
}
Loading

0 comments on commit 45f68ce

Please sign in to comment.