Skip to content

Commit

Permalink
Merge pull request HamaWhiteGG#99 from HamaWhiteGG/dev
Browse files Browse the repository at this point in the history
Azure OpenAI endpoint support
  • Loading branch information
HamaWhiteGG authored Sep 10, 2023
2 parents 1963421 + 3480728 commit 402dbd2
Show file tree
Hide file tree
Showing 30 changed files with 555 additions and 45 deletions.
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# 🦜️ LangChain Java

Java version of LangChain, bringing the capabilities of LLM to big data platforms like Flink and Spark.
Java version of LangChain, while empowering LLM for Big Data.

> If you are interested, you can add me on WeChat: HamaWhite, or send email to baisongxx@gmail.com
> If you are interested, you can add me on WeChat: HamaWhite, or send email to [me](mailto:baisongxx@gmail.com).
## 1. What is this?

Expand Down Expand Up @@ -30,6 +30,7 @@ The following example can view in the [langchain-example](langchain-examples/src
## 3. Integrations
### 3.1 LLMs
- [OpenAI](langchain-examples/src/main/java/com/hw/langchain/examples/llms/OpenAIExample.java)
- [Azure OpenAI](openai-client/src/test/java/com/hw/openai/AzureOpenAiClientTest.java)
- [ChatGLM2-6B](langchain-examples/src/main/java/com/hw/langchain/examples/llms/ChatGLMExample.java)
- [Ollama](langchain-examples/src/main/java/com/hw/langchain/examples/llms/OllamaExample.java)

Expand All @@ -52,7 +53,7 @@ Prerequisites for building:
<dependency>
<groupId>io.github.hamawhitegg</groupId>
<artifactId>langchain-core</artifactId>
<version>0.1.11</version>
<version>0.1.12</version>
</dependency>
```

Expand Down Expand Up @@ -380,5 +381,5 @@ Don’t hesitate to ask!
## 8. Fork and Contribute
This is an active open-source project. We are always open to people who want to use the system or contribute to it. Please note that pull requests should be merged into the **dev** branch.
Contact [me](baisongxx@gmail.com) if you are looking for implementation tasks that fit your skills.
Contact [me](mailto:baisongxx@gmail.com) if you are looking for implementation tasks that fit your skills.
2 changes: 1 addition & 1 deletion langchain-bigdata/langchain-flink/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<parent>
<groupId>io.github.hamawhitegg</groupId>
<artifactId>langchain-bigdata</artifactId>
<version>0.1.11</version>
<version>0.1.12</version>
</parent>

<artifactId>langchain-flink</artifactId>
Expand Down
2 changes: 1 addition & 1 deletion langchain-bigdata/langchain-spark/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<parent>
<groupId>io.github.hamawhitegg</groupId>
<artifactId>langchain-bigdata</artifactId>
<version>0.1.11</version>
<version>0.1.12</version>
</parent>

<artifactId>langchain-spark</artifactId>
Expand Down
2 changes: 1 addition & 1 deletion langchain-bigdata/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<parent>
<groupId>io.github.hamawhitegg</groupId>
<artifactId>langchain-java</artifactId>
<version>0.1.11</version>
<version>0.1.12</version>
</parent>

<artifactId>langchain-bigdata</artifactId>
Expand Down
7 changes: 6 additions & 1 deletion langchain-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<parent>
<groupId>io.github.hamawhitegg</groupId>
<artifactId>langchain-java</artifactId>
<version>0.1.11</version>
<version>0.1.12</version>
</parent>

<artifactId>langchain-core</artifactId>
Expand Down Expand Up @@ -80,6 +80,11 @@
<artifactId>jsoup</artifactId>
</dependency>

<dependency>
<groupId>io.github.resilience4j</groupId>
<artifactId>resilience4j-retry</artifactId>
</dependency>

<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public LLMResult generate(List<List<BaseMessage>> messages) {
*/
public LLMResult generate(List<List<BaseMessage>> messages, List<String> stop) {
List<ChatResult> results = messages.stream()
.map(message -> _generate(message, stop))
.map(message -> innerGenerate(message, stop))
.toList();

List<Map<String, Object>> llmOutputs = results.stream()
Expand All @@ -75,7 +75,7 @@ public LLMResult generatePrompt(List<PromptValue> prompts, List<String> stop) {
/**
* Top Level call
*/
public abstract ChatResult _generate(List<BaseMessage> messages, List<String> stop);
public abstract ChatResult innerGenerate(List<BaseMessage> messages, List<String> stop);

public BaseMessage call(List<BaseMessage> messages) {
return call(messages, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import java.util.Objects;

import static com.hw.langchain.chat.models.openai.OpenAI.convertOpenAiToLangChain;
import static com.hw.langchain.utils.Resilience4jRetryUtils.retryWithExponentialBackoff;
import static com.hw.langchain.utils.Utils.getOrEnvOrDefault;

/**
Expand Down Expand Up @@ -75,6 +76,10 @@ public class ChatOpenAI extends BaseChatModel {

protected String openaiApiBase;

protected String openaiApiType;

protected String openaiApiVersion;

protected String openaiOrganization;

/**
Expand Down Expand Up @@ -123,10 +128,14 @@ public ChatOpenAI init() {
openaiOrganization = getOrEnvOrDefault(openaiOrganization, "OPENAI_ORGANIZATION", "");
openaiApiBase = getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE", "");
openaiProxy = getOrEnvOrDefault(openaiProxy, "OPENAI_PROXY", "");
openaiApiType = getOrEnvOrDefault(openaiApiType, "OPENAI_API_TYPE", "");
openaiApiVersion = getOrEnvOrDefault(openaiApiVersion, "OPENAI_API_VERSION", "");

this.client = OpenAiClient.builder()
.openaiApiBase(openaiApiBase)
.openaiApiKey(openaiApiKey)
.openaiApiVersion(openaiApiVersion)
.openaiApiType(openaiApiType)
.openaiOrganization(openaiOrganization)
.openaiProxy(openaiProxy)
.requestTimeout(requestTimeout)
Expand Down Expand Up @@ -158,7 +167,7 @@ public Map<String, Object> combineLlmOutputs(List<Map<String, Object>> llmOutput
}

@Override
public ChatResult _generate(List<BaseMessage> messages, List<String> stop) {
public ChatResult innerGenerate(List<BaseMessage> messages, List<String> stop) {
var chatMessages = convertMessages(messages);

ChatCompletion chatCompletion = ChatCompletion.builder()
Expand All @@ -171,7 +180,7 @@ public ChatResult _generate(List<BaseMessage> messages, List<String> stop) {
.stop(stop)
.build();

var response = client.create(chatCompletion);
var response = retryWithExponentialBackoff(maxRetries, () -> client.create(chatCompletion));
return createChatResult(response);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ public class OpenAIEmbeddings implements Embeddings {

private String openaiApiKey;

private String openaiApiType;

private String openaiApiVersion;

protected String openaiOrganization;

/**
Expand Down Expand Up @@ -96,10 +100,14 @@ public OpenAIEmbeddings init() {
openaiApiBase = getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE", "");
openaiProxy = getOrEnvOrDefault(openaiProxy, "OPENAI_PROXY", "");
openaiOrganization = getOrEnvOrDefault(openaiOrganization, "OPENAI_ORGANIZATION", "");
openaiApiType = getOrEnvOrDefault(openaiApiType, "OPENAI_API_TYPE", "");
openaiApiVersion = getOrEnvOrDefault(openaiApiVersion, "OPENAI_API_VERSION", "");

this.client = OpenAiClient.builder()
.openaiApiBase(openaiApiBase)
.openaiApiKey(openaiApiKey)
.openaiApiVersion(openaiApiVersion)
.openaiApiType(openaiApiType)
.openaiOrganization(openaiOrganization)
.openaiProxy(openaiProxy)
.requestTimeout(requestTimeout)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import java.util.*;

import static com.google.common.base.Preconditions.checkArgument;
import static com.hw.langchain.utils.Resilience4jRetryUtils.retryWithExponentialBackoff;

/**
* Wrapper around OpenAI large language models.
Expand Down Expand Up @@ -101,6 +102,16 @@ public class BaseOpenAI extends BaseLLM {
*/
protected String openaiApiBase;

/**
* Api type for Azure OpenAI API.
*/
protected String openaiApiType;

/**
* Api version for Azure OpenAI API.
*/
protected String openaiApiVersion;

/**
* Organization ID for OpenAI.
*/
Expand Down Expand Up @@ -194,7 +205,7 @@ protected LLMResult innerGenerate(List<String> prompts, List<String> stop) {

for (var prompt : subPrompts) {
completion.setPrompt(prompt);
CompletionResp response = client.create(completion);
CompletionResp response = retryWithExponentialBackoff(maxRetries, () -> client.create(completion));
choices.addAll(response.getChoices());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,14 @@ public OpenAI init() {
openaiApiBase = Utils.getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE", "");
openaiOrganization = Utils.getOrEnvOrDefault(openaiOrganization, "OPENAI_ORGANIZATION", "");
openaiProxy = Utils.getOrEnvOrDefault(openaiProxy, "OPENAI_PROXY", "");
openaiApiType = Utils.getOrEnvOrDefault(openaiApiType, "OPENAI_API_TYPE", "");
openaiApiVersion = Utils.getOrEnvOrDefault(openaiApiVersion, "OPENAI_API_VERSION", "");

this.client = OpenAiClient.builder()
.openaiApiBase(openaiApiBase)
.openaiApiKey(openaiApiKey)
.openaiApiVersion(openaiApiVersion)
.openaiApiType(openaiApiType)
.openaiOrganization(openaiOrganization)
.openaiProxy(openaiProxy)
.proxyUsername(proxyUsername)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,16 @@ public class OpenAIChat extends BaseLLM {
*/
protected String openaiApiBase;

/**
* Api type for Azure OpenAI API.
*/
protected String openaiApiType;

/**
* Api version for Azure OpenAI API.
*/
protected String openaiApiVersion;

/**
* Organization ID for OpenAI.
*/
Expand Down Expand Up @@ -137,10 +147,14 @@ public OpenAIChat init() {
openaiApiKey = Utils.getOrEnvOrDefault(openaiApiKey, "OPENAI_API_KEY");
openaiOrganization = Utils.getOrEnvOrDefault(openaiOrganization, "OPENAI_ORGANIZATION", "");
openaiProxy = Utils.getOrEnvOrDefault(openaiProxy, "OPENAI_PROXY", "");
openaiApiType = Utils.getOrEnvOrDefault(openaiApiType, "OPENAI_API_TYPE", "");
openaiApiVersion = Utils.getOrEnvOrDefault(openaiApiVersion, "OPENAI_API_VERSION", "");

this.client = OpenAiClient.builder()
.openaiApiBase(openaiApiBase)
.openaiApiKey(openaiApiKey)
.openaiApiVersion(openaiApiVersion)
.openaiApiType(openaiApiType)
.openaiOrganization(openaiOrganization)
.openaiProxy(openaiProxy)
.requestTimeout(requestTimeout)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.utils;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import io.github.resilience4j.core.IntervalFunction;
import io.github.resilience4j.retry.Retry;
import io.github.resilience4j.retry.RetryConfig;

import java.time.Duration;
import java.util.function.Supplier;

import static java.util.Objects.requireNonNull;

/**
* @author HamaWhite
*/
public class Resilience4jRetryUtils {

private Resilience4jRetryUtils() {
}

private static final Logger LOG = LoggerFactory.getLogger(Resilience4jRetryUtils.class);

public static <T> T retryWithExponentialBackoff(int maxRetries, Supplier<T> action) {
return retryWithExponentialBackoff(maxRetries, action, Duration.ofSeconds(4), 2, Duration.ofSeconds(16));
}

public static <T> T retryWithExponentialBackoff(int maxRetries, Supplier<T> action, Duration initialInterval,
double multiplier, Duration maxInterval) {
RetryConfig retryConfig = RetryConfig.custom()
.maxAttempts(maxRetries)
.intervalFunction(IntervalFunction.ofExponentialBackoff(initialInterval, multiplier, maxInterval))
.build();
Retry retry = Retry.of("retryWithExponentialBackoff", retryConfig);

retry.getEventPublisher().onRetry(event -> LOG.warn("Retry failed on attempt #{} with exception: {}",
event.getNumberOfRetryAttempts(), requireNonNull(event.getLastThrowable()).getMessage()));

return retry.executeSupplier(action);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,8 @@ public abstract class VectorStore {
* Delete by vector ID.
*
* @param ids List of ids to delete.
* @return true if deletion is successful, false otherwise
*/
public abstract boolean delete(List<String> ids);
public abstract void delete(List<String> ids);

/**
* Run more documents through the embeddings and add to the vectorStore.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,8 +374,7 @@ public List<String> addTexts(List<String> texts, List<Map<String, Object>> metad
}

@Override
public boolean delete(List<String> ids) {
return false;
public void delete(List<String> ids) {
}

private List<Pair<Document, Float>> similaritySearchWithScore(String query, int k, Map<String, Object> filter) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.slf4j.LoggerFactory;

import lombok.Builder;
import lombok.Getter;

import java.util.*;
import java.util.function.Function;
Expand All @@ -52,6 +53,7 @@ public class Pinecone extends VectorStore {

private PineconeClient client;

@Getter
private IndexClient index;

private String indexName;
Expand Down Expand Up @@ -93,8 +95,12 @@ public List<String> addTexts(List<String> texts, List<Map<String, Object>> metad
}

@Override
public boolean delete(List<String> ids) {
return false;
public void delete(List<String> ids) {
DeleteRequest deleteRequest = DeleteRequest.builder()
.ids(ids)
.namespace(namespace)
.build();
index.delete(deleteRequest);
}

/**
Expand Down Expand Up @@ -239,8 +245,4 @@ private List<Vector> createVectors(List<String> idsBatch, List<List<Float>> embe
.mapToObj(k -> new Vector(idsBatch.get(k), embeds.get(k), metadata.get(k)))
.toList();
}

public IndexClient getIndex() {
return index;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ void testChatWithSingleMessage() {
var message = new HumanMessage("Translate this sentence from English to French. I love programming.");
var actual = chat.call(List.of(message));

var expected = new AIMessage("J'aime programmer.");
var expected = new AIMessage("J'adore la programmation.");
assertEquals(expected, actual);
}

Expand Down
Loading

0 comments on commit 402dbd2

Please sign in to comment.