Skip to content

Commit

Permalink
optimize OpenAiClient
Browse files Browse the repository at this point in the history
  • Loading branch information
HamaWhiteGG committed Sep 12, 2023
1 parent 1307a64 commit 464dea0
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 50 deletions.
87 changes: 38 additions & 49 deletions openai-client/src/main/java/com/hw/openai/OpenAiClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory;
import retrofit2.converter.jackson.JacksonConverterFactory;

import java.io.Closeable;
import java.util.EnumSet;
import java.util.List;
import java.util.concurrent.TimeUnit;
Expand All @@ -57,7 +58,7 @@
*/
@Data
@Builder
public class OpenAiClient {
public class OpenAiClient implements Closeable {

private static final Logger LOG = LoggerFactory.getLogger(OpenAiClient.class);

Expand Down Expand Up @@ -101,18 +102,8 @@ public class OpenAiClient {
*
* @return the initialized OpenAiClient instance
*/

public OpenAiClient init() {
if (isAzureApiType()) {
openaiApiBase = getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE");
if (openaiApiBase == null) {
throw new NullPointerException(
"Did not find OPENAI_API_BASE, please add an environment variable `OPENAI_API_BASE` which contains it, or pass `OPENAI_API_BASE` as a named parameter.");
}
openaiApiBase = StringUtils.appendIfMissing(openaiApiBase, "/") + "openai/deployments/";
} else {
openaiApiBase = getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE", "https://api.openai.com/v1/");
}
initializeOpenaiApiBase();
openaiProxy = getOrEnvOrDefault(openaiProxy, "OPENAI_PROXY");

OkHttpClient.Builder httpClientBuilder = new OkHttpClient.Builder()
Expand Down Expand Up @@ -168,16 +159,17 @@ public OpenAiClient init() {
return this;
}

/**
* Closes the HttpClient connection pool.
*/
public void close() {
// Cancel all ongoing requests
httpClient.dispatcher().cancelAll();

// Shut down the connection pool (if any)
httpClient.connectionPool().evictAll();
httpClient.dispatcher().executorService().shutdown();
private void initializeOpenaiApiBase() {
if (isAzureApiType()) {
openaiApiBase = getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE");
if (openaiApiBase == null) {
throw new NullPointerException(
"Did not find OPENAI_API_BASE, please add an environment variable `OPENAI_API_BASE` which contains it, or pass `OPENAI_API_BASE` as a named parameter.");
}
openaiApiBase = StringUtils.appendIfMissing(openaiApiBase, "/") + "openai/deployments/";
} else {
openaiApiBase = getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE", "https://api.openai.com/v1/");
}
}

private String getOrEnvOrDefault(String originalValue, String envKey, String... defaultValue) {
Expand Down Expand Up @@ -221,13 +213,7 @@ public Model retrieveModel(String model) {
* @return the generated completion text
*/
public String completion(Completion completion) {
CompletionResp response;
if (isAzureApiType()) {
response = service.completion(completion.getModel(), openaiApiVersion, completion).blockingGet();
} else {
response = service.completion(completion).blockingGet();
}

CompletionResp response = create(completion);
String text = response.getChoices().get(0).getText();
return StringUtils.trim(text);
}
Expand All @@ -239,10 +225,9 @@ public String completion(Completion completion) {
* @return the completion response
*/
public CompletionResp create(Completion completion) {
if (isAzureApiType()) {
return service.completion(completion.getModel(), openaiApiVersion, completion).blockingGet();
}
return service.completion(completion).blockingGet();
return isAzureApiType()
? service.completion(completion.getModel(), openaiApiVersion, completion).blockingGet()
: service.completion(completion).blockingGet();
}

/**
Expand All @@ -252,14 +237,7 @@ public CompletionResp create(Completion completion) {
* @return the generated model response text
*/
public String chatCompletion(ChatCompletion chatCompletion) {
ChatCompletionResp response;
if (isAzureApiType()) {
response =
service.chatCompletion(chatCompletion.getModel(), openaiApiVersion, chatCompletion).blockingGet();
} else {
response = service.chatCompletion(chatCompletion).blockingGet();
}

ChatCompletionResp response = create(chatCompletion);
String content = response.getChoices().get(0).getMessage().getContent();
return StringUtils.trim(content);
}
Expand All @@ -271,10 +249,9 @@ public String chatCompletion(ChatCompletion chatCompletion) {
* @return the chat completion response
*/
public ChatCompletionResp create(ChatCompletion chatCompletion) {
if (isAzureApiType()) {
return service.chatCompletion(chatCompletion.getModel(), openaiApiVersion, chatCompletion).blockingGet();
}
return service.chatCompletion(chatCompletion).blockingGet();
return isAzureApiType()
? service.chatCompletion(chatCompletion.getModel(), openaiApiVersion, chatCompletion).blockingGet()
: service.chatCompletion(chatCompletion).blockingGet();
}

/**
Expand All @@ -284,10 +261,9 @@ public ChatCompletionResp create(ChatCompletion chatCompletion) {
* @return The embedding vector response.
*/
public EmbeddingResp embedding(Embedding embedding) {
if (isAzureApiType()) {
return service.embedding(embedding.getModel(), openaiApiVersion, embedding).blockingGet();
}
return service.embedding(embedding).blockingGet();
return isAzureApiType()
? service.embedding(embedding.getModel(), openaiApiVersion, embedding).blockingGet()
: service.embedding(embedding).blockingGet();
}

/**
Expand All @@ -298,4 +274,17 @@ public EmbeddingResp embedding(Embedding embedding) {
private boolean isAzureApiType() {
return EnumSet.of(OpenaiApiType.AZURE, OpenaiApiType.AZURE_AD).contains(openaiApiType);
}

/**
* Closes the HttpClient connection pool.
*/
@Override
public void close() {
// Cancel all ongoing requests
httpClient.dispatcher().cancelAll();

// Shut down the connection pool (if any)
httpClient.connectionPool().evictAll();
httpClient.dispatcher().executorService().shutdown();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory;
import retrofit2.converter.jackson.JacksonConverterFactory;

import java.io.Closeable;
import java.util.List;
import java.util.concurrent.TimeUnit;

Expand All @@ -50,7 +51,7 @@
*/
@Data
@Builder
public class PineconeClient {
public class PineconeClient implements Closeable {

private static final Logger LOG = LoggerFactory.getLogger(PineconeClient.class);

Expand Down Expand Up @@ -139,6 +140,7 @@ public Retrofit createRetrofit(String baseUrl) {
/**
* Closes the HttpClient connection pool.
*/
@Override
public void close() {
// Cancel all ongoing requests
httpClient.dispatcher().cancelAll();
Expand Down

0 comments on commit 464dea0

Please sign in to comment.