diff --git a/langchain-core/src/main/java/com/hw/langchain/llms/openai/BaseOpenAI.java b/langchain-core/src/main/java/com/hw/langchain/llms/openai/BaseOpenAI.java index 1f24da13b..eccca51e9 100644 --- a/langchain-core/src/main/java/com/hw/langchain/llms/openai/BaseOpenAI.java +++ b/langchain-core/src/main/java/com/hw/langchain/llms/openai/BaseOpenAI.java @@ -36,6 +36,7 @@ /** * Wrapper around OpenAI large language models. + * * @author HamaWhite */ @SuperBuilder @@ -117,9 +118,10 @@ public class BaseOpenAI extends BaseLLM { protected int batchSize = 20; /** - * Timeout for requests to OpenAI completion API. Default is 600 seconds. + * Timeout for requests to OpenAI completion API. Default is 10 seconds. */ - protected float requestTimeout; + @Builder.Default + protected long requestTimeout = 10; /** * Adjust the probability of specific tokens being generated. diff --git a/langchain-core/src/main/java/com/hw/langchain/llms/openai/OpenAI.java b/langchain-core/src/main/java/com/hw/langchain/llms/openai/OpenAI.java index db4fbb30f..2a73f39ab 100644 --- a/langchain-core/src/main/java/com/hw/langchain/llms/openai/OpenAI.java +++ b/langchain-core/src/main/java/com/hw/langchain/llms/openai/OpenAI.java @@ -49,6 +49,7 @@ public OpenAI init() { .openaiApiKey(openaiApiKey) .openaiOrganization(openaiOrganization) .openaiProxy(openaiProxy) + .requestTimeout(requestTimeout) .build() .init(); return this; diff --git a/langchain-core/src/main/java/com/hw/langchain/llms/openai/OpenAIChat.java b/langchain-core/src/main/java/com/hw/langchain/llms/openai/OpenAIChat.java index c3c0b2437..b8686e880 100644 --- a/langchain-core/src/main/java/com/hw/langchain/llms/openai/OpenAIChat.java +++ b/langchain-core/src/main/java/com/hw/langchain/llms/openai/OpenAIChat.java @@ -117,6 +117,12 @@ public class OpenAIChat extends BaseLLM { @Builder.Default private List prefixMessages = new ArrayList<>(); + /** + * Timeout for requests to OpenAI completion API. Default is 10 seconds. + */ + @Builder.Default + protected long requestTimeout = 10; + /** * Adjust the probability of specific tokens being generated. */ @@ -138,6 +144,7 @@ public OpenAIChat init() { .openaiApiKey(openaiApiKey) .openaiOrganization(openaiOrganization) .openaiProxy(openaiProxy) + .requestTimeout(requestTimeout) .build() .init(); return this; diff --git a/langchain-core/src/test/java/com/hw/langchain/chains/conversation/base/ConversationChainTest.java b/langchain-core/src/test/java/com/hw/langchain/chains/conversation/base/ConversationChainTest.java index fb50e72f5..6ac21750b 100644 --- a/langchain-core/src/test/java/com/hw/langchain/chains/conversation/base/ConversationChainTest.java +++ b/langchain-core/src/test/java/com/hw/langchain/chains/conversation/base/ConversationChainTest.java @@ -20,6 +20,7 @@ import com.hw.langchain.llms.openai.OpenAI; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -31,6 +32,7 @@ /** * @author HamaWhite */ +@Disabled("Test requires costly OpenAI calls, can be run manually.") class ConversationChainTest { private static final Logger LOG = LoggerFactory.getLogger(ConversationChainTest.class); diff --git a/langchain-core/src/test/java/com/hw/langchain/llms/openai/OpenAITest.java b/langchain-core/src/test/java/com/hw/langchain/llms/openai/OpenAITest.java index 6f806274f..69eb0d36f 100644 --- a/langchain-core/src/test/java/com/hw/langchain/llms/openai/OpenAITest.java +++ b/langchain-core/src/test/java/com/hw/langchain/llms/openai/OpenAITest.java @@ -36,7 +36,8 @@ class OpenAITest { @Test void testOpenAICall() { OpenAI llm = OpenAI.builder() - .maxTokens(10) + .maxTokens(16) + .requestTimeout(15) .build() .init(); diff --git a/openai-client/src/main/java/com/hw/openai/OpenAiClient.java b/openai-client/src/main/java/com/hw/openai/OpenAiClient.java index c431619e2..86bc58b55 100644 --- a/openai-client/src/main/java/com/hw/openai/OpenAiClient.java +++ b/openai-client/src/main/java/com/hw/openai/OpenAiClient.java @@ -41,6 +41,8 @@ import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory; import retrofit2.converter.jackson.JacksonConverterFactory; +import java.util.concurrent.TimeUnit; + /** * Represents a client for interacting with the OpenAI API. * @@ -60,6 +62,12 @@ public class OpenAiClient { private String openaiProxy; + /** + * Timeout for requests to OpenAI completion API. Default is 10 seconds. + */ + @Builder.Default + protected long requestTimeout = 10; + private OpenAiService service; private OkHttpClient httpClient; @@ -73,7 +81,12 @@ public OpenAiClient init() { openaiApiBase = getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE", "https://api.openai.com/v1/"); openaiProxy = getOrEnvOrDefault(openaiProxy, "OPENAI_PROXY"); - OkHttpClient.Builder httpClientBuilder = new OkHttpClient.Builder(); + OkHttpClient.Builder httpClientBuilder = new OkHttpClient.Builder() + .connectTimeout(requestTimeout, TimeUnit.SECONDS) + .readTimeout(requestTimeout, TimeUnit.SECONDS) + .writeTimeout(requestTimeout, TimeUnit.SECONDS) + .callTimeout(requestTimeout, TimeUnit.SECONDS); + httpClientBuilder.addInterceptor(chain -> { // If openaiApiKey is not set, read the value of OPENAI_API_KEY from the environment. openaiApiKey = getOrEnvOrDefault(openaiApiKey, "OPENAI_API_KEY"); diff --git a/openai-client/src/test/java/com/hw/openai/HttpClientTimeoutTest.java b/openai-client/src/test/java/com/hw/openai/HttpClientTimeoutTest.java new file mode 100644 index 000000000..3a1edff5b --- /dev/null +++ b/openai-client/src/test/java/com/hw/openai/HttpClientTimeoutTest.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.openai; + +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +import static org.hibernate.validator.internal.util.Contracts.assertTrue; +import static org.junit.jupiter.api.Assertions.assertThrows; + +/** + * @author HamaWhite + */ +@Disabled("Test is currently disabled due to network timeout issue") +class HttpClientTimeoutTest { + + @Test + void testRequestTimeout() { + OpenAiClient client = OpenAiClient.builder() + .requestTimeout(15) + .build() + .init(); + + long startTime = System.currentTimeMillis(); + /* + * Tests the request timeout functionality. Expects a RuntimeException to be thrown, which is the expected + * behavior when using Retrofit for synchronous requests. The actual network exception is wrapped in a + * RuntimeException by Retrofit for unified exception handling. + */ + assertThrows(RuntimeException.class, client::listModels); + + long endTime = System.currentTimeMillis(); + long executionTime = endTime - startTime; + + assertTrue(executionTime >= 15_000, "Execution time should be greater than or equal to 30 seconds."); + } +}