Skip to content

Commit

Permalink
add pinecone getRelevantDocuments(100%)
Browse files Browse the repository at this point in the history
  • Loading branch information
HamaWhiteGG committed Jul 1, 2023
1 parent 60e1572 commit 987f42c
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,29 @@
public class MathUtils {

private MathUtils() {
}

/**
* Row-wise cosine similarity between two equal-width matrices.
*/
public static INDArray cosineSimilarity(List<List<Float>> X, INDArray yArray) {
return cosineSimilarity(Nd4j.createFromArray(listToArray(X)), yArray);
}

/**
* Row-wise cosine similarity between two equal-width matrices.
*/
public static INDArray cosineSimilarity(INDArray xArray, List<List<Float>> Y) {
if (xArray.isEmpty() || Y.isEmpty()) {
return cosineSimilarity(xArray, Nd4j.createFromArray(listToArray(Y)));
}

/**
* Row-wise cosine similarity between two equal-width matrices.
*/
public static INDArray cosineSimilarity(INDArray xArray, INDArray yArray) {
if (xArray.isEmpty() || yArray.isEmpty()) {
return Nd4j.create(new float[0][0]);
}
INDArray yArray = Nd4j.createFromArray(listToArray(Y));
if (xArray.shape()[1] != yArray.shape()[1]) {
throw new IllegalArgumentException(
String.format("Number of columns in X and Y must be the same. X has shape %s and Y has shape %s.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import com.hw.pinecone.entity.vector.Vector;

import org.apache.commons.lang3.tuple.Pair;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -39,7 +38,7 @@
import java.util.stream.Stream;

import static com.google.common.base.Preconditions.checkNotNull;
import static com.hw.langchain.vectorstores.utils.ArrayUtils.listToArray;
import static com.hw.langchain.vectorstores.utils.Nd4jUtils.createFromList;
import static com.hw.langchain.vectorstores.utils.Utils.maximalMarginalRelevance;

/**
Expand Down Expand Up @@ -170,7 +169,7 @@ public List<Document> maxMarginalRelevanceSearchByVector(List<Float> embedding,
QueryResponse results = index.query(queryRequest);

List<Integer> mmrSelected = maximalMarginalRelevance(
Nd4j.createFromArray(listToArray(List.of(embedding))),
createFromList(embedding),
results.getMatches().stream().map(ScoredVector::getValues).toList(),
k,
lambdaMult);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
*/
public class ArrayUtils {

private ArrayUtils() {
}

public static <T> List<List<T>> arrayToList(T[][] array) {
List<List<T>> result = Lists.newArrayListWithCapacity(array.length);
for (T[] subArray : array) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.vectorstores.utils;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

import java.util.List;

import static com.hw.langchain.vectorstores.utils.ArrayUtils.listToArray;

/**
* @author HamaWhite
*/
public class Nd4jUtils {

private Nd4jUtils() {
}

public static INDArray createFromList(List<Float> list) {
Float[][] array = listToArray(List.of(list));
return Nd4j.createFromArray(array);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,18 @@

package com.hw.langchain.vectorstores.utils;

import com.google.common.collect.Lists;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

import java.util.ArrayList;
import java.util.List;

import static com.hw.langchain.math.utils.MathUtils.cosineSimilarity;
import static com.hw.langchain.vectorstores.utils.Nd4jUtils.createFromList;
import static java.lang.Float.NEGATIVE_INFINITY;

/**
* Utility functions for working with vectors and vectorStores.
Expand All @@ -48,23 +53,28 @@ public static List<Integer> maximalMarginalRelevance(INDArray queryEmbedding, Li
queryEmbedding = Nd4j.expandDims(queryEmbedding, 0);
}
INDArray similarityToQuery = cosineSimilarity(queryEmbedding, embeddingList).getRow(0);
// INDArray XArray =Nd4j.create(queryEmbedding);
// List<INDArray> indArrayList= embeddingList.stream().map(e->Nd4j.create(e)).toList();
// INDArray YArray=Nd4j.create(indArrayList);
//
//
// INDArray XNorm = XArray.norm2(1);
// INDArray YNorm = YArray.norm2(1);
//
// INDArray similarity = XArray.mmul(YArray.transpose()).div(XNorm.reshape(XRows, 1).mmul(YNorm.reshape(1,
// YRows)));
//
// // Handle NaN and Inf values
// similarity.maskedReplace(Double.NaN, 0.0);
// similarity.maskedReplace(Double.POSITIVE_INFINITY, 0.0);
// similarity.maskedReplace(Double.NEGATIVE_INFINITY, 0.0);
//
// return similarity;
return null;
int mostSimilar = Nd4j.argMax(similarityToQuery).getInt(0);
List<Integer> idxs = Lists.newArrayList(mostSimilar);
INDArray selected = createFromList(embeddingList.get(mostSimilar));

while (idxs.size() < Math.min(k, embeddingList.size())) {
float bestScore = NEGATIVE_INFINITY;
int idxToAdd = -1;
INDArray similarityToSelected = cosineSimilarity(embeddingList, selected);
for (int i = 0; i < similarityToQuery.columns(); i++) {
if (idxs.contains(i)) {
continue;
}
float redundantScore = Transforms.max(similarityToSelected.getRow(i), 0).getFloat(0);
float equationScore = lambdaMult * similarityToQuery.getFloat(i) - (1 - lambdaMult) * redundantScore;
if (equationScore > bestScore) {
bestScore = equationScore;
idxToAdd = i;
}
}
idxs.add(idxToAdd);
selected = Nd4j.vstack(selected, createFromList(embeddingList.get(idxToAdd)));
}
return idxs;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,6 @@ void testGetRelevantDocuments() {
var retriever = pinecone.asRetriever(MMR);
var docs = retriever.getRelevantDocuments(query);

assertThat(docs).isNotNull();
assertThat(docs).isNotNull().hasSize(4);
}
}

0 comments on commit 987f42c

Please sign in to comment.