Skip to content

Commit

Permalink
add pinecone-client(20%)
Browse files Browse the repository at this point in the history
  • Loading branch information
HamaWhiteGG committed Jun 27, 2023
1 parent da761ae commit 958b36e
Show file tree
Hide file tree
Showing 18 changed files with 1,095 additions and 17 deletions.
15 changes: 15 additions & 0 deletions langchain-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,21 @@
<artifactId>commons-collections4</artifactId>
</dependency>

<!-- &lt;!&ndash; https://mvnrepository.com/artifact/org.nd4j/nd4j-native &ndash;&gt;-->
<!-- <dependency>-->
<!-- <groupId>org.nd4j</groupId>-->
<!-- <artifactId>nd4j-native</artifactId>-->
<!-- <version>1.0.0-M2.1</version>-->
<!-- <scope>test</scope>-->
<!-- </dependency>-->

<!-- https://mvnrepository.com/artifact/org.nd4j/nd4j-native-platform -->
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native-platform</artifactId>
<version>1.0.0-M2.1</version>
</dependency>

<dependency>
<groupId>org.python</groupId>
<artifactId>jython-standalone</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,9 @@ public Map<String, String> _call(Map<String, Object> inputs) {
String result = database.run(sqlCmd, false);
LOG.info("SQLResult: \n{}", result);

/**
* If return direct, we just set the final result equal to the result of the sql query result,
* otherwise try to get a human readable final answer
/*
* If return direct, we just set the final result equal to the result of the sql query result, otherwise try to
* get a human readable final answer
*/
String finalResult;
if (returnDirect) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,26 @@

package com.hw.langchain.embeddings.openai;

import com.google.common.primitives.Doubles;
import com.google.common.primitives.Floats;
import com.hw.langchain.embeddings.base.Embeddings;
import com.hw.langchain.exception.LangChainException;
import com.hw.openai.OpenAiClient;
import com.hw.openai.entity.embeddings.Embedding;
import com.hw.openai.entity.embeddings.EmbeddingResp;
import com.knuddels.jtokkit.Encodings;
import com.knuddels.jtokkit.api.Encoding;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

import lombok.AllArgsConstructor;
import lombok.Builder;

import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.*;

import static com.hw.langchain.utils.Utils.getOrEnvOrDefault;
import static java.util.Collections.nCopies;

/**
* Wrapper around OpenAI embedding models.
Expand Down Expand Up @@ -61,12 +70,6 @@ public class OpenAIEmbeddings implements Embeddings {

protected String openaiOrganization;

@Builder.Default
private Set<String> allowedSpecial = new HashSet<>();

@Builder.Default
private Set<String> disallowedSpecial = Set.of("all");

/**
* Maximum number of texts to embed in each batch
*/
Expand Down Expand Up @@ -112,13 +115,108 @@ public OpenAIEmbeddings() {
init();
}

/**
* please refer to https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
*/
private List<List<Float>> getLenSafeEmbeddings(List<String> texts) {
List<List<Float>> embeddings = new ArrayList<>(texts.size());

List<List<Integer>> tokens = new ArrayList<>();
List<Integer> indices = new ArrayList<>();
Encoding encoding = Encodings.newDefaultEncodingRegistry()
.getEncodingForModel(model)
.orElseThrow(() -> new LangChainException("Encoding not found."));

for (int i = 0; i < texts.size(); i++) {
String text = texts.get(i);
if (model.endsWith("001")) {
// See https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
// replace newlines, which can negatively affect performance.
text = text.replace("\n", " ");
}
List<Integer> token = encoding.encode(text);
for (int j = 0; j < token.size(); j += embeddingCtxLength) {
tokens.add(token.subList(j, Math.min(j + embeddingCtxLength, token.size())));
indices.add(i);
}
}

List<List<Float>> batchedEmbeddings = new ArrayList<>();
for (int i = 0; i < tokens.size(); i += chunkSize) {
List<String> input = tokens.subList(i, Math.min(i + chunkSize, tokens.size()))
.stream()
.map(Object::toString)
.toList();
var response = embedWithRetry(input);
response.getData().forEach(result -> batchedEmbeddings.add(result.getEmbedding()));
}

List<List<List<Float>>> results = new ArrayList<>(nCopies(texts.size(), new ArrayList<>()));
List<List<Integer>> numTokensInBatch = new ArrayList<>(nCopies(texts.size(), new ArrayList<>()));
for (int i = 0; i < indices.size(); i++) {
int index = indices.get(i);
results.get(index).add(batchedEmbeddings.get(i));
numTokensInBatch.get(index).add(tokens.get(i).size());
}

for (int i = 0; i < texts.size(); i++) {
INDArray average;
try (INDArray resultArray =
Nd4j.create(results.get(i).stream().map(Floats::toArray).toArray(float[][]::new))) {
INDArray weightsArray = Nd4j.create(Doubles.toArray(numTokensInBatch.get(i)));
average = resultArray.mean(0).mulColumnVector(weightsArray).sum(0);
}
INDArray normalizedAverage = average.div(average.norm2Number());
embeddings.add(Floats.asList(normalizedAverage.toFloatVector()));
}
return embeddings;
}

/**
* Call out to OpenAI's embedding endpoint.
*/
public List<Float> embeddingFunc(String text) {
if (text.length() > embeddingCtxLength) {
return getLenSafeEmbeddings(List.of(text)).get(0);
} else {
if (model.endsWith("001")) {
// See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
// replace newlines, which can negatively affect performance.
text = text.replace("\n", " ");
}
return embedWithRetry(List.of(text)).getData().get(0).getEmbedding();
}
}

/**
* Call out to OpenAI's embedding endpoint for embedding search docs.
*
* @param texts The list of texts to embed.
* @return List of embeddings, one for each text.
*/
@Override
public List<List<Float>> embedDocuments(List<String> texts) {
return null;
// NOTE: to keep things simple, we assume the list may contain texts longer
// than the maximum context and use length-safe embedding function.
return this.getLenSafeEmbeddings(texts);
}

/**
* Call out to OpenAI's embedding endpoint for embedding query text.
*
* @param text The text to embed.
* @return Embedding for the text.
*/
@Override
public List<Float> embedQuery(String text) {
return null;
return embeddingFunc(text);
}

public EmbeddingResp embedWithRetry(List<String> input) {
var embedding = Embedding.builder()
.model(model)
.input(input)
.build();
return client.embedding(embedding);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.schema;

import java.util.List;

/**
* Base interface for retrievers.
*
* @author HamaWhite
*/
public interface BaseRetriever {

/**
* Get documents relevant for a query.
*
* @param query string to find relevant documents for
* @return List of relevant documents
*/
List<Document> getRelevantDocuments(String query);
}
Loading

0 comments on commit 958b36e

Please sign in to comment.