Skip to content

Commit

Permalink
Merge pull request HamaWhiteGG#25 from HamaWhiteGG/dev
Browse files Browse the repository at this point in the history
Vector stores integrations pinecone
  • Loading branch information
HamaWhiteGG authored Jul 1, 2023
2 parents 1e5b48d + 4336f00 commit de2be50
Show file tree
Hide file tree
Showing 9 changed files with 395 additions and 59 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.math.utils;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

import java.util.Arrays;
import java.util.List;

import static com.hw.langchain.vectorstores.utils.ArrayUtils.listToArray;
import static org.nd4j.linalg.ops.transforms.Transforms.allCosineSimilarities;

/**
* Math utils.
*
* @author HamaWhite
*/
public class MathUtils {

private MathUtils() {
}

/**
* Row-wise cosine similarity between two equal-width matrices.
*/
public static INDArray cosineSimilarity(List<List<Float>> X, INDArray yArray) {
return cosineSimilarity(Nd4j.createFromArray(listToArray(X)), yArray);
}

/**
* Row-wise cosine similarity between two equal-width matrices.
*/
public static INDArray cosineSimilarity(INDArray xArray, List<List<Float>> Y) {
return cosineSimilarity(xArray, Nd4j.createFromArray(listToArray(Y)));
}

/**
* Row-wise cosine similarity between two equal-width matrices.
*/
public static INDArray cosineSimilarity(INDArray xArray, INDArray yArray) {
if (xArray.isEmpty() || yArray.isEmpty()) {
return Nd4j.create(new float[0][0]);
}
if (xArray.shape()[1] != yArray.shape()[1]) {
throw new IllegalArgumentException(
String.format("Number of columns in X and Y must be the same. X has shape %s and Y has shape %s.",
Arrays.toString(xArray.shape()), Arrays.toString(yArray.shape())));
}
return allCosineSimilarities(xArray, yArray, xArray.rank() - 1);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.vectorstores.base;

/**
* @author HamaWhite
*/
public enum SearchType {

/**
* It is used to retrieve vectors based on similarity.
*/
SIMILARITY("similarity"),

/**
* It is used to retrieve vectors based on a similarity score threshold.
*/
SIMILARITY_SCORE_THRESHOLD("similarity_score_threshold"),

/**
* It is used to retrieve vectors using the maximum marginal relevance (MMR) algorithm
*/
MMR("mmr");

private final String value;

SearchType(String value) {
this.value = value;
}

public String getValue() {
return value;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

import java.util.List;
import java.util.Map;
import java.util.function.Predicate;

/**
* @author HamaWhite
Expand Down Expand Up @@ -68,10 +67,10 @@ public List<String> addDocuments(List<Document> documents, Map<String, Object> k
return addTexts(texts, metadatas, kwargs);
}

public List<Document> search(String query, String searchType, Map<String, Object> kwargs) {
public List<Document> search(String query, SearchType searchType) {
return switch (searchType) {
case "similarity" -> similaritySearch(query);
case "mmr" -> maxMarginalRelevanceSearch(query, kwargs);
case SIMILARITY -> similaritySearch(query);
case MMR -> maxMarginalRelevanceSearch(query);
default -> throw new IllegalArgumentException(
"searchType of " + searchType + " not allowed. Expected searchType to be 'similarity' or 'mmr'.");
};
Expand All @@ -92,44 +91,31 @@ public List<Document> similaritySearch(String query) {
/**
* Return docs and relevance scores in the range [0, 1]. 0 is dissimilar, 1 is most similar.
*/
public List<Pair<Document, Float>> similaritySearchWithRelevanceScores(String query, Map<String, Object> kwargs) {
return similaritySearchWithRelevanceScores(query, 4, kwargs);
public List<Pair<Document, Float>> similaritySearchWithRelevanceScores(String query) {
return similaritySearchWithRelevanceScores(query, 4);
}

/**
* Return docs and relevance scores in the range [0, 1]. 0 is dissimilar, 1 is most similar.
*
* @param query input text
* @param k Number of Documents to return.
* @param kwargs kwargs to be passed to similarity search. Should include: score_threshold: Optional, a floating point value between 0 to 1 to filter the resulting set of retrieved docs
* @param query input text
* @param k Number of Documents to return.
* @return List of Tuples of (doc, similarity_score)
*/
public List<Pair<Document, Float>> similaritySearchWithRelevanceScores(String query, int k,
Map<String, Object> kwargs) {
List<Pair<Document, Float>> docsAndSimilarities = _similaritySearchWithRelevanceScores(query, k, kwargs);
public List<Pair<Document, Float>> similaritySearchWithRelevanceScores(String query, int k) {
List<Pair<Document, Float>> docsAndSimilarities = _similaritySearchWithRelevanceScores(query, k);

// Check relevance scores and filter by threshold
if (docsAndSimilarities.stream().anyMatch(pair -> pair.getRight() < 0.0f || pair.getRight() > 1.0f)) {
LOG.warn("Relevance scores must be between 0 and 1, got {} ", docsAndSimilarities);
}

if (kwargs.containsKey("score_threshold")) {
float scoreThreshold = (float) kwargs.get("score_threshold");
Predicate<Pair<Document, Float>> thresholdFilter = pair -> pair.getRight() >= scoreThreshold;
docsAndSimilarities = docsAndSimilarities.stream().filter(thresholdFilter).toList();

if (docsAndSimilarities.isEmpty()) {
LOG.warn("No relevant docs were retrieved using the relevance score threshold {}", scoreThreshold);
}
}
return docsAndSimilarities;
}

/**
* Return docs and relevance scores, normalized on a scale from 0 to 1. 0 is dissimilar, 1 is most similar.
*/
protected abstract List<Pair<Document, Float>> _similaritySearchWithRelevanceScores(String query, int k,
Map<String, Object> kwargs);
protected abstract List<Pair<Document, Float>> _similaritySearchWithRelevanceScores(String query, int k);

/**
* Return docs most similar to embedding vector.
Expand All @@ -141,8 +127,8 @@ protected abstract List<Pair<Document, Float>> _similaritySearchWithRelevanceSco
*/
public abstract List<Document> similarSearchByVector(List<Float> embedding, int k, Map<String, Object> kwargs);

public List<Document> maxMarginalRelevanceSearch(String query, Map<String, Object> kwargs) {
return maxMarginalRelevanceSearch(query, 4, 20, 0.5f, kwargs);
public List<Document> maxMarginalRelevanceSearch(String query) {
return maxMarginalRelevanceSearch(query, 4, 20, 0.5f);
}

/**
Expand All @@ -154,11 +140,27 @@ public List<Document> maxMarginalRelevanceSearch(String query, Map<String, Objec
* @param fetchK Number of Documents to fetch to pass to MMR algorithm.
* @param lambdaMult Number between 0 and 1 that determines the degree of diversity among the results with 0
* corresponding to maximum diversity and 1 to minimum diversity.
* @param kwargs kwargs
* @return List of Documents selected by maximal marginal relevance.
*/
public abstract List<Document> maxMarginalRelevanceSearch(String query, int k, int fetchK, float lambdaMult,
Map<String, Object> kwargs);
public abstract List<Document> maxMarginalRelevanceSearch(String query, int k, int fetchK, float lambdaMult);

public List<Document> maxMarginalRelevanceSearchByVector(List<Float> embedding) {
return maxMarginalRelevanceSearchByVector(embedding, 4, 20, 0.5f);
}

/**
* Return docs selected using the maximal marginal relevance.
* Maximal marginal relevance optimizes for similarity to query AND diversity among selected documents.
*
* @param embedding Embedding to look up documents similar to.
* @param k Number of Documents to return.
* @param fetchK Number of Documents to fetch to pass to MMR algorithm.
* @param lambdaMult Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding
* to maximum diversity and 1 to minimum diversity.
* @return List of Documents selected by maximal marginal relevance.
*/
public abstract List<Document> maxMarginalRelevanceSearchByVector(List<Float> embedding, int k, int fetchK,
float lambdaMult);

/**
* Return VectorStore initialized from documents and embeddings.
Expand All @@ -174,7 +176,7 @@ public int fromDocuments(List<Document> documents, Embeddings embedding) {
*/
public abstract int fromTexts(List<String> texts, Embeddings embedding, List<Map<String, Object>> metadatas);

public VectorStoreRetriever asRetriever(Map<String, Object> kwargs) {
return new VectorStoreRetriever(this, kwargs);
public VectorStoreRetriever asRetriever(SearchType searchType) {
return new VectorStoreRetriever(this, searchType);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,35 +26,33 @@
import java.util.List;
import java.util.Map;

import static com.hw.langchain.vectorstores.base.SearchType.SIMILARITY_SCORE_THRESHOLD;

/**
* @author HamaWhite
*/
public class VectorStoreRetriever implements BaseRetriever {

private static final List<String> ALLOWED_SEARCH_TYPES = List.of(
"similarity",
"similarity_score_threshold",
"mmr");
private final VectorStore vectorstore;

private VectorStore vectorstore;
private final SearchType searchType;

private String searchType = "similarity";
private final Map<String, Object> searchKwargs;

private Map<String, Object> searchKwargs;
public VectorStoreRetriever(VectorStore vectorstore, SearchType searchType) {
this(vectorstore, searchType, null);
}

public VectorStoreRetriever(VectorStore vectorstore, Map<String, Object> searchKwargs) {
public VectorStoreRetriever(VectorStore vectorstore, SearchType searchType, Map<String, Object> searchKwargs) {
this.vectorstore = vectorstore;
this.searchType = searchType;
this.searchKwargs = searchKwargs;

validateSearchType();
}

private void validateSearchType() {
if (!ALLOWED_SEARCH_TYPES.contains(searchType)) {
throw new IllegalArgumentException(
"searchType of " + searchType + " not allowed. Valid values are: " + ALLOWED_SEARCH_TYPES);
}
if ("similarity_score_threshold".equals(searchType)) {
if (SIMILARITY_SCORE_THRESHOLD.equals(searchType)) {
Object scoreThreshold = searchKwargs.get("score_threshold");
if (!(scoreThreshold instanceof Float)) {
throw new IllegalArgumentException(
Expand All @@ -66,13 +64,12 @@ private void validateSearchType() {
@Override
public List<Document> getRelevantDocuments(String query) {
return switch (searchType) {
case "similarity" -> vectorstore.similaritySearch(query);
case "similarity_score_threshold" -> vectorstore.similaritySearchWithRelevanceScores(query, searchKwargs)
case SIMILARITY -> vectorstore.similaritySearch(query);
case SIMILARITY_SCORE_THRESHOLD -> vectorstore.similaritySearchWithRelevanceScores(query)
.stream()
.map(Pair::getLeft)
.toList();
case "mmr" -> vectorstore.maxMarginalRelevanceSearch(query, searchKwargs);
default -> throw new IllegalArgumentException("searchType of " + searchType + " not allowed.");
case MMR -> vectorstore.maxMarginalRelevanceSearch(query);
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@
import java.util.stream.IntStream;
import java.util.stream.Stream;

import static com.google.common.base.Preconditions.checkNotNull;
import static com.hw.langchain.vectorstores.utils.Nd4jUtils.createFromList;
import static com.hw.langchain.vectorstores.utils.Utils.maximalMarginalRelevance;

/**
* @author HamaWhite
*/
Expand Down Expand Up @@ -137,8 +141,7 @@ public List<Document> similaritySearch(String query, int k) {
}

@Override
protected List<Pair<Document, Float>> _similaritySearchWithRelevanceScores(String query, int k,
Map<String, Object> kwargs) {
protected List<Pair<Document, Float>> _similaritySearchWithRelevanceScores(String query, int k) {
return null;
}

Expand All @@ -148,9 +151,37 @@ public List<Document> similarSearchByVector(List<Float> embedding, int k, Map<St
}

@Override
public List<Document> maxMarginalRelevanceSearch(String query, int k, int fetchK, float lambdaMult,
Map<String, Object> kwargs) {
return null;
public List<Document> maxMarginalRelevanceSearch(String query, int k, int fetchK, float lambdaMult) {
List<Float> embedding = embeddingFunction.apply(query);
return maxMarginalRelevanceSearchByVector(embedding, k, fetchK, lambdaMult);
}

@Override
public List<Document> maxMarginalRelevanceSearchByVector(List<Float> embedding, int k, int fetchK,
float lambdaMult) {
QueryRequest queryRequest = QueryRequest.builder()
.vector(embedding)
.topK(fetchK)
.namespace(namespace)
.includeValues(true)
.includeMetadata(true)
.build();
QueryResponse results = index.query(queryRequest);

List<Integer> mmrSelected = maximalMarginalRelevance(
createFromList(embedding),
results.getMatches().stream().map(ScoredVector::getValues).toList(),
k,
lambdaMult);

checkNotNull(mmrSelected, "mmrSelected must not be null");
List<Map<String, Object>> selected = mmrSelected.stream()
.map(i -> results.getMatches().get(i).getMetadata())
.toList();

return selected.stream()
.map(metadata -> new Document(metadata.remove(textKey).toString(), metadata))
.toList();
}

@Override
Expand Down
Loading

0 comments on commit de2be50

Please sign in to comment.