forked from HamaWhiteGG/langchain-java
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c17b614
commit 45f68ce
Showing
8 changed files
with
362 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
205 changes: 205 additions & 0 deletions
205
langchain-core/src/main/java/com/hw/langchain/vectorstores/milvus/Milvus.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.vectorstores.milvus; | ||
|
||
import com.hw.langchain.embeddings.base.Embeddings; | ||
import com.hw.langchain.schema.Document; | ||
import com.hw.langchain.vectorstores.base.VectorStore; | ||
|
||
import org.apache.commons.collections4.CollectionUtils; | ||
import org.apache.commons.lang3.tuple.Pair; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
||
import io.milvus.client.MilvusClient; | ||
import io.milvus.client.MilvusServiceClient; | ||
import io.milvus.common.clientenum.ConsistencyLevelEnum; | ||
import io.milvus.grpc.DataType; | ||
import io.milvus.param.ConnectParam; | ||
import io.milvus.param.Constant; | ||
import io.milvus.param.collection.CreateCollectionParam; | ||
import io.milvus.param.collection.DropCollectionParam; | ||
import io.milvus.param.collection.FieldType; | ||
import io.milvus.param.collection.HasCollectionParam; | ||
import lombok.Builder; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
|
||
/** | ||
* Initialize wrapper around the milvus vector database. | ||
* | ||
* @author HamaWhite | ||
*/ | ||
@Builder | ||
public class Milvus extends VectorStore { | ||
|
||
private static final Logger LOG = LoggerFactory.getLogger(Milvus.class); | ||
|
||
/** | ||
* Function used to embed the text. | ||
*/ | ||
private Embeddings embeddingFunction; | ||
|
||
/** | ||
* Parameters for client connection. | ||
*/ | ||
private ConnectParam connectParam; | ||
|
||
/** | ||
* Which Milvus collection to use. | ||
*/ | ||
@Builder.Default | ||
private String collectionName = "LangChainCollection"; | ||
|
||
/** | ||
* The consistency level to use for a collection. | ||
*/ | ||
@Builder.Default | ||
private ConsistencyLevelEnum consistencyLevel = ConsistencyLevelEnum.STRONG; | ||
|
||
private boolean dropOld; | ||
|
||
private MilvusClient milvusClient; | ||
|
||
public Milvus init() { | ||
milvusClient = new MilvusServiceClient(connectParam); | ||
|
||
// If need to drop old, drop it | ||
if (hasCollection() && dropOld) { | ||
milvusClient.dropCollection( | ||
DropCollectionParam.newBuilder() | ||
.withCollectionName(collectionName) | ||
.build()); | ||
} | ||
return this; | ||
} | ||
|
||
private boolean hasCollection() { | ||
HasCollectionParam requestParam = HasCollectionParam.newBuilder() | ||
.withCollectionName(collectionName) | ||
.build(); | ||
return milvusClient.hasCollection(requestParam).getData(); | ||
} | ||
|
||
public void createCollection(List<List<Float>> embeddings, List<Map<String, Object>> metadatas) { | ||
CreateCollectionParam.Builder builder = CreateCollectionParam.newBuilder() | ||
.withCollectionName(collectionName) | ||
.withEnableDynamicField(true); | ||
|
||
// determine embedding dim | ||
int dim = embeddings.get(0).size(); | ||
// determine metadata schema | ||
if (CollectionUtils.isNotEmpty(metadatas)) { | ||
// create FieldSchema for each entry in metadata. | ||
metadatas.get(0).forEach((key, value) -> { | ||
// infer the corresponding datatype of the metadata | ||
DataType dataType = inferDataTypeByData(value); | ||
// dataType isn't compatible | ||
if (dataType == DataType.UNRECOGNIZED || dataType == DataType.None) { | ||
LOG.error("Failure to create collection, unrecognized dataType for key: {}", key); | ||
throw new IllegalArgumentException("Unrecognized datatype for " + key + "."); | ||
} else { | ||
FieldType fieldType = FieldType.newBuilder() | ||
.withName(key) | ||
.withDataType(dataType) | ||
.build(); | ||
builder.addFieldType(fieldType); | ||
} | ||
}); | ||
} | ||
// create the text field | ||
builder.addFieldType(FieldType.newBuilder() | ||
.withName("text") | ||
.withDataType(DataType.VarChar) | ||
.withTypeParams(Map.of(Constant.VARCHAR_MAX_LENGTH, "65535")) | ||
.build()); | ||
// create the primary key field | ||
builder.addFieldType(FieldType.newBuilder() | ||
.withName("pk") | ||
.withDataType(DataType.Int64) | ||
.withPrimaryKey(true) | ||
.withAutoID(true) | ||
.build()); | ||
// create the vector field, supports binary or float vectors | ||
builder.addFieldType(FieldType.newBuilder() | ||
.withName("vector") | ||
.withDataType(DataType.FloatVector) | ||
.withDimension(dim) | ||
.build()); | ||
|
||
// create the collection | ||
milvusClient.createCollection(builder.build()); | ||
} | ||
|
||
private DataType inferDataTypeByData(Object value) { | ||
// TODO: Find corresponding method in Java | ||
return DataType.valueOf(value.toString()); | ||
} | ||
|
||
@Override | ||
public List<String> addTexts(List<String> texts, List<Map<String, Object>> metadatas) { | ||
List<List<Float>> embeddings = embeddingFunction.embedDocuments(texts); | ||
if (embeddings.isEmpty()) { | ||
LOG.warn("Nothing to insert, skipping."); | ||
return List.of(); | ||
} | ||
if (!hasCollection()) { | ||
createCollection(embeddings, metadatas); | ||
} | ||
|
||
return List.of(); | ||
} | ||
|
||
@Override | ||
public boolean delete(List<String> ids) { | ||
return false; | ||
} | ||
|
||
@Override | ||
public List<Document> similaritySearch(String query, int k, Map<String, Object> filter) { | ||
return null; | ||
} | ||
|
||
@Override | ||
protected List<Pair<Document, Float>> _similaritySearchWithRelevanceScores(String query, int k) { | ||
return null; | ||
} | ||
|
||
@Override | ||
public List<Document> similarSearchByVector(List<Float> embedding, int k, Map<String, Object> kwargs) { | ||
return null; | ||
} | ||
|
||
@Override | ||
public List<Document> maxMarginalRelevanceSearch(String query, int k, int fetchK, float lambdaMult) { | ||
return null; | ||
} | ||
|
||
@Override | ||
public List<Document> maxMarginalRelevanceSearchByVector(List<Float> embedding, int k, int fetchK, | ||
float lambdaMult) { | ||
return null; | ||
} | ||
|
||
@Override | ||
public int fromTexts(List<String> texts, Embeddings embedding, List<Map<String, Object>> metadatas) { | ||
return addTexts(texts, metadatas).size(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
65 changes: 65 additions & 0 deletions
65
...hain-core/src/test/java/com/hw/langchain/vectorstores/fake/embeddings/FakeEmbeddings.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.vectorstores.fake.embeddings; | ||
|
||
import com.hw.langchain.embeddings.base.Embeddings; | ||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
|
||
/** | ||
* Fake embeddings functionality for testing. | ||
* | ||
* @author HamaWhite | ||
*/ | ||
public class FakeEmbeddings implements Embeddings { | ||
|
||
public static final List<String> FAKE_TEXTS = List.of("foo", "bar", "baz"); | ||
|
||
/** | ||
* Return simple embeddings. Embeddings encode each text as its index. | ||
*/ | ||
@Override | ||
public List<List<Float>> embedDocuments(List<String> texts) { | ||
List<List<Float>> embeddings = new ArrayList<>(); | ||
for (int i = 0; i < texts.size(); i++) { | ||
List<Float> embedding = new ArrayList<>(); | ||
for (int j = 0; j < 9; j++) { | ||
embedding.add(1.0f); | ||
} | ||
embedding.add((float) i); | ||
embeddings.add(embedding); | ||
} | ||
return embeddings; | ||
} | ||
|
||
/** | ||
* Return constant query embeddings. Embeddings are identical to embedDocuments(texts).get(0). | ||
* Distance to each text will be that text's index, as it was passed to embedDocuments. | ||
*/ | ||
@Override | ||
public List<Float> embedQuery(String text) { | ||
List<Float> embedding = new ArrayList<>(); | ||
for (int i = 0; i < 9; i++) { | ||
embedding.add(1.0f); | ||
} | ||
embedding.add(0.0f); | ||
return embedding; | ||
} | ||
} |
73 changes: 73 additions & 0 deletions
73
langchain-core/src/test/java/com/hw/langchain/vectorstores/milvus/MilvusTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.vectorstores.milvus; | ||
|
||
import com.hw.langchain.embeddings.base.Embeddings; | ||
import com.hw.langchain.schema.Document; | ||
import com.hw.langchain.vectorstores.fake.embeddings.FakeEmbeddings; | ||
|
||
import org.junit.jupiter.api.Test; | ||
|
||
import io.milvus.param.ConnectParam; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import static com.hw.langchain.vectorstores.fake.embeddings.FakeEmbeddings.FAKE_TEXTS; | ||
import static org.junit.jupiter.api.Assertions.*; | ||
|
||
/** | ||
* Test Milvus functionality. | ||
* <p> | ||
* See the following documentation for how to run a Milvus instance: | ||
* <a href="https://milvus.io/docs/install_standalone-docker.md">install_standalone-docker</a> | ||
* | ||
* @author HamaWhite | ||
*/ | ||
class MilvusTest { | ||
|
||
private Milvus milvusFromTexts(List<Map<String, Object>> metadatas, boolean dropOld) { | ||
ConnectParam connectParam = ConnectParam.newBuilder() | ||
.withHost("127.0.0.1") | ||
.withPort(19530) | ||
.build(); | ||
|
||
Embeddings embedding = new FakeEmbeddings(); | ||
Milvus milvus = Milvus.builder() | ||
.embeddingFunction(embedding) | ||
.connectParam(connectParam) | ||
.collectionName("LangChainCollection_1") | ||
.dropOld(dropOld) | ||
.build() | ||
.init(); | ||
milvus.fromTexts(FAKE_TEXTS, embedding, metadatas); | ||
return milvus; | ||
} | ||
|
||
/** | ||
* Test end to end construction and search. | ||
*/ | ||
@Test | ||
void testMilvus() { | ||
Milvus docSearch = milvusFromTexts(List.of(), true); | ||
List<Document> output = docSearch.similaritySearch("foo", Map.of("k", 1)); | ||
|
||
// assertEquals(List.of(new Document("foo")), output); | ||
} | ||
} |
Oops, something went wrong.