From 55606210c31e9b3bef7520e214ea20878817e188 Mon Sep 17 00:00:00 2001 From: HamaWhiteGG Date: Sun, 10 Sep 2023 14:43:21 +0800 Subject: [PATCH] Support maxRetries for OpenAI and ChatOpenAI #96 --- langchain-core/pom.xml | 5 ++ .../chat/models/base/BaseChatModel.java | 4 +- .../chat/models/openai/ChatOpenAI.java | 5 +- .../hw/langchain/llms/openai/BaseOpenAI.java | 3 +- .../utils/Resilience4jRetryUtils.java | 57 +++++++++++++++++ .../chat/models/openai/ChatOpenAITest.java | 2 +- .../retry/Resilience4jRetryExample.java | 64 +++++++++++++++++++ pom.xml | 7 ++ 8 files changed, 141 insertions(+), 6 deletions(-) create mode 100644 langchain-core/src/main/java/com/hw/langchain/utils/Resilience4jRetryUtils.java create mode 100644 langchain-core/src/test/java/io/github/resilience4j/retry/Resilience4jRetryExample.java diff --git a/langchain-core/pom.xml b/langchain-core/pom.xml index cf5fb14cf..e1b9d9fbb 100644 --- a/langchain-core/pom.xml +++ b/langchain-core/pom.xml @@ -80,6 +80,11 @@ jsoup + + io.github.resilience4j + resilience4j-retry + + org.slf4j slf4j-api diff --git a/langchain-core/src/main/java/com/hw/langchain/chat/models/base/BaseChatModel.java b/langchain-core/src/main/java/com/hw/langchain/chat/models/base/BaseChatModel.java index 39007620b..8109d6e90 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chat/models/base/BaseChatModel.java +++ b/langchain-core/src/main/java/com/hw/langchain/chat/models/base/BaseChatModel.java @@ -50,7 +50,7 @@ public LLMResult generate(List> messages) { */ public LLMResult generate(List> messages, List stop) { List results = messages.stream() - .map(message -> _generate(message, stop)) + .map(message -> innerGenerate(message, stop)) .toList(); List> llmOutputs = results.stream() @@ -75,7 +75,7 @@ public LLMResult generatePrompt(List prompts, List stop) { /** * Top Level call */ - public abstract ChatResult _generate(List messages, List stop); + public abstract ChatResult innerGenerate(List messages, List stop); public BaseMessage call(List messages) { return call(messages, null); diff --git a/langchain-core/src/main/java/com/hw/langchain/chat/models/openai/ChatOpenAI.java b/langchain-core/src/main/java/com/hw/langchain/chat/models/openai/ChatOpenAI.java index fc02709c1..543602c4b 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chat/models/openai/ChatOpenAI.java +++ b/langchain-core/src/main/java/com/hw/langchain/chat/models/openai/ChatOpenAI.java @@ -38,6 +38,7 @@ import java.util.Objects; import static com.hw.langchain.chat.models.openai.OpenAI.convertOpenAiToLangChain; +import static com.hw.langchain.utils.Resilience4jRetryUtils.retryWithExponentialBackoff; import static com.hw.langchain.utils.Utils.getOrEnvOrDefault; /** @@ -158,7 +159,7 @@ public Map combineLlmOutputs(List> llmOutput } @Override - public ChatResult _generate(List messages, List stop) { + public ChatResult innerGenerate(List messages, List stop) { var chatMessages = convertMessages(messages); ChatCompletion chatCompletion = ChatCompletion.builder() @@ -171,7 +172,7 @@ public ChatResult _generate(List messages, List stop) { .stop(stop) .build(); - var response = client.create(chatCompletion); + var response = retryWithExponentialBackoff(maxRetries, () -> client.create(chatCompletion)); return createChatResult(response); } diff --git a/langchain-core/src/main/java/com/hw/langchain/llms/openai/BaseOpenAI.java b/langchain-core/src/main/java/com/hw/langchain/llms/openai/BaseOpenAI.java index f9020de48..b1733d1e8 100644 --- a/langchain-core/src/main/java/com/hw/langchain/llms/openai/BaseOpenAI.java +++ b/langchain-core/src/main/java/com/hw/langchain/llms/openai/BaseOpenAI.java @@ -33,6 +33,7 @@ import java.util.*; import static com.google.common.base.Preconditions.checkArgument; +import static com.hw.langchain.utils.Resilience4jRetryUtils.retryWithExponentialBackoff; /** * Wrapper around OpenAI large language models. @@ -194,7 +195,7 @@ protected LLMResult innerGenerate(List prompts, List stop) { for (var prompt : subPrompts) { completion.setPrompt(prompt); - CompletionResp response = client.create(completion); + CompletionResp response = retryWithExponentialBackoff(maxRetries, () -> client.create(completion)); choices.addAll(response.getChoices()); } diff --git a/langchain-core/src/main/java/com/hw/langchain/utils/Resilience4jRetryUtils.java b/langchain-core/src/main/java/com/hw/langchain/utils/Resilience4jRetryUtils.java new file mode 100644 index 000000000..56cc69309 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/utils/Resilience4jRetryUtils.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.utils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.github.resilience4j.core.IntervalFunction; +import io.github.resilience4j.retry.Retry; +import io.github.resilience4j.retry.RetryConfig; + +import java.time.Duration; +import java.util.function.Supplier; + +import static java.util.Objects.requireNonNull; + +/** + * @author HamaWhite + */ +public class Resilience4jRetryUtils { + + private static final Logger LOG = LoggerFactory.getLogger(Resilience4jRetryUtils.class); + + public static T retryWithExponentialBackoff(int maxRetries, Supplier action) { + return retryWithExponentialBackoff(maxRetries, action, Duration.ofSeconds(4), 2, Duration.ofSeconds(16)); + } + + public static T retryWithExponentialBackoff(int maxRetries, Supplier action, Duration initialInterval, + double multiplier, Duration maxInterval) { + RetryConfig retryConfig = RetryConfig.custom() + .maxAttempts(maxRetries) + .intervalFunction(IntervalFunction.ofExponentialBackoff(initialInterval, multiplier, maxInterval)) + .build(); + Retry retry = Retry.of("retryWithExponentialBackoff", retryConfig); + + retry.getEventPublisher().onRetry(event -> LOG.warn("Retry failed on attempt #{} with exception: {}", + event.getNumberOfRetryAttempts(), requireNonNull(event.getLastThrowable()).getMessage())); + + return retry.executeSupplier(action); + } +} diff --git a/langchain-core/src/test/java/com/hw/langchain/chat/models/openai/ChatOpenAITest.java b/langchain-core/src/test/java/com/hw/langchain/chat/models/openai/ChatOpenAITest.java index 872846c1a..d03cc27a5 100644 --- a/langchain-core/src/test/java/com/hw/langchain/chat/models/openai/ChatOpenAITest.java +++ b/langchain-core/src/test/java/com/hw/langchain/chat/models/openai/ChatOpenAITest.java @@ -60,7 +60,7 @@ void testChatWithSingleMessage() { var message = new HumanMessage("Translate this sentence from English to French. I love programming."); var actual = chat.call(List.of(message)); - var expected = new AIMessage("J'aime programmer."); + var expected = new AIMessage("J'adore la programmation."); assertEquals(expected, actual); } diff --git a/langchain-core/src/test/java/io/github/resilience4j/retry/Resilience4jRetryExample.java b/langchain-core/src/test/java/io/github/resilience4j/retry/Resilience4jRetryExample.java new file mode 100644 index 000000000..68279dafe --- /dev/null +++ b/langchain-core/src/test/java/io/github/resilience4j/retry/Resilience4jRetryExample.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.resilience4j.retry; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.github.resilience4j.core.IntervalFunction; + +import java.time.Duration; +import java.util.function.Supplier; + +import static java.util.Objects.requireNonNull; + +/** + * @author HamaWhite + */ +public class Resilience4jRetryExample { + + private static final Logger LOG = LoggerFactory.getLogger(Resilience4jRetryExample.class); + + public static void main(String[] args) { + int maxRetries = 6; + String result = retryWithExponentialBackoff(maxRetries, () -> { + double value = Math.random(); + LOG.info("Attempt: value is {}", value); + if (value < 0.7) { + throw new RuntimeException("Operation failed"); + } + return "Operation succeeded"; + }); + LOG.info("Final result is {}", result); + } + + public static T retryWithExponentialBackoff(int maxRetries, Supplier action) { + RetryConfig retryConfig = RetryConfig.custom() + .maxAttempts(maxRetries) + .intervalFunction( + IntervalFunction.ofExponentialBackoff(Duration.ofSeconds(4), 2, Duration.ofSeconds(16))) + .build(); + Retry retry = Retry.of("retryWithExponential", retryConfig); + + retry.getEventPublisher().onRetry(event -> LOG.warn("Retry failed on attempt #{} with exception: {}", + event.getNumberOfRetryAttempts(), requireNonNull(event.getLastThrowable()).getMessage())); + + return retry.executeSupplier(action); + } +} diff --git a/pom.xml b/pom.xml index 8aea52979..9cb073c3d 100644 --- a/pom.xml +++ b/pom.xml @@ -38,6 +38,7 @@ 1.7.32 4.2.0 0.10.2 + 2.1.0 1.7.25 1.10.0 3.12.0 @@ -162,6 +163,12 @@ ${jsoup.version} + + io.github.resilience4j + resilience4j-retry + ${resilience4j.version} + + org.slf4j slf4j-api