Skip to content

Commit

Permalink
Fix issue HamaWhiteGG#6: Setting timeouts
Browse files Browse the repository at this point in the history
  • Loading branch information
HamaWhiteGG committed Jun 15, 2023
1 parent a8cf70c commit 424c8d8
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

/**
* Wrapper around OpenAI large language models.
*
* @author HamaWhite
*/
@SuperBuilder
Expand Down Expand Up @@ -117,9 +118,10 @@ public class BaseOpenAI extends BaseLLM {
protected int batchSize = 20;

/**
* Timeout for requests to OpenAI completion API. Default is 600 seconds.
* Timeout for requests to OpenAI completion API. Default is 10 seconds.
*/
protected float requestTimeout;
@Builder.Default
protected long requestTimeout = 10;

/**
* Adjust the probability of specific tokens being generated.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public OpenAI init() {
.openaiApiKey(openaiApiKey)
.openaiOrganization(openaiOrganization)
.openaiProxy(openaiProxy)
.requestTimeout(requestTimeout)
.build()
.init();
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,12 @@ public class OpenAIChat extends BaseLLM {
@Builder.Default
private List<Message> prefixMessages = new ArrayList<>();

/**
* Timeout for requests to OpenAI completion API. Default is 10 seconds.
*/
@Builder.Default
protected long requestTimeout = 10;

/**
* Adjust the probability of specific tokens being generated.
*/
Expand All @@ -138,6 +144,7 @@ public OpenAIChat init() {
.openaiApiKey(openaiApiKey)
.openaiOrganization(openaiOrganization)
.openaiProxy(openaiProxy)
.requestTimeout(requestTimeout)
.build()
.init();
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import com.hw.langchain.llms.openai.OpenAI;

import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -31,6 +32,7 @@
/**
* @author HamaWhite
*/
@Disabled("Test requires costly OpenAI calls, can be run manually.")
class ConversationChainTest {

private static final Logger LOG = LoggerFactory.getLogger(ConversationChainTest.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ class OpenAITest {
@Test
void testOpenAICall() {
OpenAI llm = OpenAI.builder()
.maxTokens(10)
.maxTokens(16)
.requestTimeout(15)
.build()
.init();

Expand Down
15 changes: 14 additions & 1 deletion openai-client/src/main/java/com/hw/openai/OpenAiClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory;
import retrofit2.converter.jackson.JacksonConverterFactory;

import java.util.concurrent.TimeUnit;

/**
* Represents a client for interacting with the OpenAI API.
*
Expand All @@ -60,6 +62,12 @@ public class OpenAiClient {

private String openaiProxy;

/**
* Timeout for requests to OpenAI completion API. Default is 10 seconds.
*/
@Builder.Default
protected long requestTimeout = 10;

private OpenAiService service;

private OkHttpClient httpClient;
Expand All @@ -73,7 +81,12 @@ public OpenAiClient init() {
openaiApiBase = getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE", "https://api.openai.com/v1/");
openaiProxy = getOrEnvOrDefault(openaiProxy, "OPENAI_PROXY");

OkHttpClient.Builder httpClientBuilder = new OkHttpClient.Builder();
OkHttpClient.Builder httpClientBuilder = new OkHttpClient.Builder()
.connectTimeout(requestTimeout, TimeUnit.SECONDS)
.readTimeout(requestTimeout, TimeUnit.SECONDS)
.writeTimeout(requestTimeout, TimeUnit.SECONDS)
.callTimeout(requestTimeout, TimeUnit.SECONDS);

httpClientBuilder.addInterceptor(chain -> {
// If openaiApiKey is not set, read the value of OPENAI_API_KEY from the environment.
openaiApiKey = getOrEnvOrDefault(openaiApiKey, "OPENAI_API_KEY");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.openai;

import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import static org.hibernate.validator.internal.util.Contracts.assertTrue;
import static org.junit.jupiter.api.Assertions.assertThrows;

/**
* @author HamaWhite
*/
@Disabled("Test is currently disabled due to network timeout issue")
class HttpClientTimeoutTest {

@Test
void testRequestTimeout() {
OpenAiClient client = OpenAiClient.builder()
.requestTimeout(15)
.build()
.init();

long startTime = System.currentTimeMillis();
/*
* Tests the request timeout functionality. Expects a RuntimeException to be thrown, which is the expected
* behavior when using Retrofit for synchronous requests. The actual network exception is wrapped in a
* RuntimeException by Retrofit for unified exception handling.
*/
assertThrows(RuntimeException.class, client::listModels);

long endTime = System.currentTimeMillis();
long executionTime = endTime - startTime;

assertTrue(executionTime >= 15_000, "Execution time should be greater than or equal to 30 seconds.");
}
}

0 comments on commit 424c8d8

Please sign in to comment.