Skip to content

Commit

Permalink
Support maxRetries for OpenAI and ChatOpenAI HamaWhiteGG#96
Browse files Browse the repository at this point in the history
  • Loading branch information
HamaWhiteGG committed Sep 10, 2023
1 parent ad62e7a commit 5560621
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 6 deletions.
5 changes: 5 additions & 0 deletions langchain-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@
<artifactId>jsoup</artifactId>
</dependency>

<dependency>
<groupId>io.github.resilience4j</groupId>
<artifactId>resilience4j-retry</artifactId>
</dependency>

<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public LLMResult generate(List<List<BaseMessage>> messages) {
*/
public LLMResult generate(List<List<BaseMessage>> messages, List<String> stop) {
List<ChatResult> results = messages.stream()
.map(message -> _generate(message, stop))
.map(message -> innerGenerate(message, stop))
.toList();

List<Map<String, Object>> llmOutputs = results.stream()
Expand All @@ -75,7 +75,7 @@ public LLMResult generatePrompt(List<PromptValue> prompts, List<String> stop) {
/**
* Top Level call
*/
public abstract ChatResult _generate(List<BaseMessage> messages, List<String> stop);
public abstract ChatResult innerGenerate(List<BaseMessage> messages, List<String> stop);

public BaseMessage call(List<BaseMessage> messages) {
return call(messages, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import java.util.Objects;

import static com.hw.langchain.chat.models.openai.OpenAI.convertOpenAiToLangChain;
import static com.hw.langchain.utils.Resilience4jRetryUtils.retryWithExponentialBackoff;
import static com.hw.langchain.utils.Utils.getOrEnvOrDefault;

/**
Expand Down Expand Up @@ -158,7 +159,7 @@ public Map<String, Object> combineLlmOutputs(List<Map<String, Object>> llmOutput
}

@Override
public ChatResult _generate(List<BaseMessage> messages, List<String> stop) {
public ChatResult innerGenerate(List<BaseMessage> messages, List<String> stop) {
var chatMessages = convertMessages(messages);

ChatCompletion chatCompletion = ChatCompletion.builder()
Expand All @@ -171,7 +172,7 @@ public ChatResult _generate(List<BaseMessage> messages, List<String> stop) {
.stop(stop)
.build();

var response = client.create(chatCompletion);
var response = retryWithExponentialBackoff(maxRetries, () -> client.create(chatCompletion));
return createChatResult(response);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import java.util.*;

import static com.google.common.base.Preconditions.checkArgument;
import static com.hw.langchain.utils.Resilience4jRetryUtils.retryWithExponentialBackoff;

/**
* Wrapper around OpenAI large language models.
Expand Down Expand Up @@ -194,7 +195,7 @@ protected LLMResult innerGenerate(List<String> prompts, List<String> stop) {

for (var prompt : subPrompts) {
completion.setPrompt(prompt);
CompletionResp response = client.create(completion);
CompletionResp response = retryWithExponentialBackoff(maxRetries, () -> client.create(completion));
choices.addAll(response.getChoices());
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.utils;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import io.github.resilience4j.core.IntervalFunction;
import io.github.resilience4j.retry.Retry;
import io.github.resilience4j.retry.RetryConfig;

import java.time.Duration;
import java.util.function.Supplier;

import static java.util.Objects.requireNonNull;

/**
* @author HamaWhite
*/
public class Resilience4jRetryUtils {

private static final Logger LOG = LoggerFactory.getLogger(Resilience4jRetryUtils.class);

public static <T> T retryWithExponentialBackoff(int maxRetries, Supplier<T> action) {
return retryWithExponentialBackoff(maxRetries, action, Duration.ofSeconds(4), 2, Duration.ofSeconds(16));
}

public static <T> T retryWithExponentialBackoff(int maxRetries, Supplier<T> action, Duration initialInterval,
double multiplier, Duration maxInterval) {
RetryConfig retryConfig = RetryConfig.custom()
.maxAttempts(maxRetries)
.intervalFunction(IntervalFunction.ofExponentialBackoff(initialInterval, multiplier, maxInterval))
.build();
Retry retry = Retry.of("retryWithExponentialBackoff", retryConfig);

retry.getEventPublisher().onRetry(event -> LOG.warn("Retry failed on attempt #{} with exception: {}",
event.getNumberOfRetryAttempts(), requireNonNull(event.getLastThrowable()).getMessage()));

return retry.executeSupplier(action);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ void testChatWithSingleMessage() {
var message = new HumanMessage("Translate this sentence from English to French. I love programming.");
var actual = chat.call(List.of(message));

var expected = new AIMessage("J'aime programmer.");
var expected = new AIMessage("J'adore la programmation.");
assertEquals(expected, actual);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.github.resilience4j.retry;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import io.github.resilience4j.core.IntervalFunction;

import java.time.Duration;
import java.util.function.Supplier;

import static java.util.Objects.requireNonNull;

/**
* @author HamaWhite
*/
public class Resilience4jRetryExample {

private static final Logger LOG = LoggerFactory.getLogger(Resilience4jRetryExample.class);

public static void main(String[] args) {
int maxRetries = 6;
String result = retryWithExponentialBackoff(maxRetries, () -> {
double value = Math.random();
LOG.info("Attempt: value is {}", value);
if (value < 0.7) {
throw new RuntimeException("Operation failed");
}
return "Operation succeeded";
});
LOG.info("Final result is {}", result);
}

public static <T> T retryWithExponentialBackoff(int maxRetries, Supplier<T> action) {
RetryConfig retryConfig = RetryConfig.custom()
.maxAttempts(maxRetries)
.intervalFunction(
IntervalFunction.ofExponentialBackoff(Duration.ofSeconds(4), 2, Duration.ofSeconds(16)))
.build();
Retry retry = Retry.of("retryWithExponential", retryConfig);

retry.getEventPublisher().onRetry(event -> LOG.warn("Retry failed on attempt #{} with exception: {}",
event.getNumberOfRetryAttempts(), requireNonNull(event.getLastThrowable()).getMessage()));

return retry.executeSupplier(action);
}
}
7 changes: 7 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
<slf4j-api.version>1.7.32</slf4j-api.version>
<awaitility.version>4.2.0</awaitility.version>
<reflections.version>0.10.2</reflections.version>
<resilience4j.version>2.1.0</resilience4j.version>
<slf4j-log4j12.version>1.7.25</slf4j-log4j12.version>
<commons-text.version>1.10.0</commons-text.version>
<commons-lang3.version>3.12.0</commons-lang3.version>
Expand Down Expand Up @@ -162,6 +163,12 @@
<version>${jsoup.version}</version>
</dependency>

<dependency>
<groupId>io.github.resilience4j</groupId>
<artifactId>resilience4j-retry</artifactId>
<version>${resilience4j.version}</version>
</dependency>

<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
Expand Down

0 comments on commit 5560621

Please sign in to comment.