Skip to content

Commit

Permalink
Optimize Azure OpenAI
Browse files Browse the repository at this point in the history
  • Loading branch information
HamaWhiteGG committed Sep 10, 2023
1 parent c3553c1 commit caf5d79
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 37 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ The following example can view in the [langchain-example](langchain-examples/src
## 3. Integrations
### 3.1 LLMs
- [OpenAI](langchain-examples/src/main/java/com/hw/langchain/examples/llms/OpenAIExample.java)
- [Azure OpenAI](openai-client/src/test/java/com/hw/openai/AzureOpenAiClientTest.java)
- [ChatGLM2-6B](langchain-examples/src/main/java/com/hw/langchain/examples/llms/ChatGLMExample.java)
- [Ollama](langchain-examples/src/main/java/com/hw/langchain/examples/llms/OllamaExample.java)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ public ChatOpenAI init() {
openaiOrganization = getOrEnvOrDefault(openaiOrganization, "OPENAI_ORGANIZATION", "");
openaiApiBase = getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE", "");
openaiProxy = getOrEnvOrDefault(openaiProxy, "OPENAI_PROXY", "");
openaiApiType = getOrEnvOrDefault(openaiApiType, "OPENAI_API_TYPE","");
openaiApiVersion = getOrEnvOrDefault(openaiApiVersion, "OPENAI_API_VERSION","");
openaiApiType = getOrEnvOrDefault(openaiApiType, "OPENAI_API_TYPE", "");
openaiApiVersion = getOrEnvOrDefault(openaiApiVersion, "OPENAI_API_VERSION", "");

this.client = OpenAiClient.builder()
.openaiApiBase(openaiApiBase)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ public OpenAIEmbeddings init() {
openaiApiBase = getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE", "");
openaiProxy = getOrEnvOrDefault(openaiProxy, "OPENAI_PROXY", "");
openaiOrganization = getOrEnvOrDefault(openaiOrganization, "OPENAI_ORGANIZATION", "");
openaiApiType = getOrEnvOrDefault(openaiApiType, "OPENAI_API_TYPE","");
openaiApiVersion = getOrEnvOrDefault(openaiApiVersion, "OPENAI_API_VERSION","");
openaiApiType = getOrEnvOrDefault(openaiApiType, "OPENAI_API_TYPE", "");
openaiApiVersion = getOrEnvOrDefault(openaiApiVersion, "OPENAI_API_VERSION", "");

this.client = OpenAiClient.builder()
.openaiApiBase(openaiApiBase)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ public OpenAI init() {
openaiApiBase = Utils.getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE", "");
openaiOrganization = Utils.getOrEnvOrDefault(openaiOrganization, "OPENAI_ORGANIZATION", "");
openaiProxy = Utils.getOrEnvOrDefault(openaiProxy, "OPENAI_PROXY", "");
openaiApiType = Utils.getOrEnvOrDefault(openaiApiType, "OPENAI_API_TYPE","");
openaiApiVersion = Utils.getOrEnvOrDefault(openaiApiVersion, "OPENAI_API_VERSION","");
openaiApiType = Utils.getOrEnvOrDefault(openaiApiType, "OPENAI_API_TYPE", "");
openaiApiVersion = Utils.getOrEnvOrDefault(openaiApiVersion, "OPENAI_API_VERSION", "");

this.client = OpenAiClient.builder()
.openaiApiBase(openaiApiBase)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import java.util.*;

import static com.google.common.base.Preconditions.checkArgument;
import static com.hw.langchain.utils.Utils.getOrEnvOrDefault;

/**
* Wrapper around OpenAI Chat large language models.
Expand Down Expand Up @@ -148,8 +147,8 @@ public OpenAIChat init() {
openaiApiKey = Utils.getOrEnvOrDefault(openaiApiKey, "OPENAI_API_KEY");
openaiOrganization = Utils.getOrEnvOrDefault(openaiOrganization, "OPENAI_ORGANIZATION", "");
openaiProxy = Utils.getOrEnvOrDefault(openaiProxy, "OPENAI_PROXY", "");
openaiApiType = Utils.getOrEnvOrDefault(openaiApiType, "OPENAI_API_TYPE","");
openaiApiVersion = Utils.getOrEnvOrDefault(openaiApiVersion, "OPENAI_API_VERSION","");
openaiApiType = Utils.getOrEnvOrDefault(openaiApiType, "OPENAI_API_TYPE", "");
openaiApiVersion = Utils.getOrEnvOrDefault(openaiApiVersion, "OPENAI_API_VERSION", "");

this.client = OpenAiClient.builder()
.openaiApiBase(openaiApiBase)
Expand Down
56 changes: 33 additions & 23 deletions openai-client/src/main/java/com/hw/openai/OpenAiClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,17 @@
import com.hw.openai.entity.models.ModelResp;
import com.hw.openai.service.OpenAiService;
import com.hw.openai.utils.ProxyUtils;

import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import lombok.Builder;
import lombok.Data;
import okhttp3.Interceptor;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.logging.HttpLoggingInterceptor;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import retrofit2.Retrofit;
import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory;
import retrofit2.converter.jackson.JacksonConverterFactory;
Expand Down Expand Up @@ -98,19 +100,20 @@ public class OpenAiClient {
* @return the initialized OpenAiClient instance
*/
public OpenAiClient init() {
openaiApiType = getOrEnvOrDefault(openaiApiType, "OPENAI_API_TYPE","openai");
if(openaiApiType.equals(OpenaiApiType.AZURE.getValue())||openaiApiType.equals(OpenaiApiType.AZURE_AD.getValue())){
openaiApiType = getOrEnvOrDefault(openaiApiType, "OPENAI_API_TYPE", "openai");
if (openaiApiType.equals(OpenaiApiType.AZURE.getValue())
|| openaiApiType.equals(OpenaiApiType.AZURE_AD.getValue())) {
openaiApiBase = getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE");
if(openaiApiBase == null){
if (openaiApiBase == null) {
throw new IllegalArgumentException(
String.format(
"Did not find %s, please add an environment variable `%s` which contains it, or pass `%s` as a named parameter.",
"OPENAI_API_BASE", "OPENAI_API_BASE", "OPENAI_API_BASE"));
}
openaiApiBase += (openaiApiBase.endsWith("/")?"":"/") + "openai/deployments/";
}else if(openaiApiType.equals(OpenaiApiType.OPENAI.getValue())){
openaiApiBase += (openaiApiBase.endsWith("/") ? "" : "/") + "openai/deployments/";
} else if (openaiApiType.equals(OpenaiApiType.OPENAI.getValue())) {
openaiApiBase = getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE", "https://api.openai.com/v1/");
}else {
} else {
throw new IllegalArgumentException(
String.format(
"The API type %s provided in invalid. Please select one of the supported API types: 'azure', 'azure_ad', 'openai'",
Expand All @@ -131,9 +134,10 @@ public OpenAiClient init() {

Request.Builder requestBuilder = chain.request().newBuilder();
requestBuilder.header("Content-Type", "application/json");
if(openaiApiType.equals(OpenaiApiType.AZURE.getValue())||openaiApiType.equals(OpenaiApiType.AZURE_AD.getValue())){
if (openaiApiType.equals(OpenaiApiType.AZURE.getValue())
|| openaiApiType.equals(OpenaiApiType.AZURE_AD.getValue())) {
requestBuilder.header("api-key", openaiApiKey);
}else {
} else {
requestBuilder.header("Authorization", "Bearer " + openaiApiKey);
requestBuilder.header("OpenAI-Organization", openaiOrganization);
}
Expand Down Expand Up @@ -227,9 +231,10 @@ public Model retrieveModel(String model) {
*/
public String completion(Completion completion) {
CompletionResp response;
if(openaiApiType.equals(OpenaiApiType.AZURE.getValue())||openaiApiType.equals(OpenaiApiType.AZURE_AD.getValue())){
response = service.completion(completion.getModel(),openaiApiVersion,completion).blockingGet();
}else {
if (openaiApiType.equals(OpenaiApiType.AZURE.getValue())
|| openaiApiType.equals(OpenaiApiType.AZURE_AD.getValue())) {
response = service.completion(completion.getModel(), openaiApiVersion, completion).blockingGet();
} else {
response = service.completion(completion).blockingGet();
}

Expand All @@ -244,8 +249,9 @@ public String completion(Completion completion) {
* @return the completion response
*/
public CompletionResp create(Completion completion) {
if(openaiApiType.equals(OpenaiApiType.AZURE.getValue())||openaiApiType.equals(OpenaiApiType.AZURE_AD.getValue())){
return service.completion(completion.getModel(),openaiApiVersion,completion).blockingGet();
if (openaiApiType.equals(OpenaiApiType.AZURE.getValue())
|| openaiApiType.equals(OpenaiApiType.AZURE_AD.getValue())) {
return service.completion(completion.getModel(), openaiApiVersion, completion).blockingGet();
}
return service.completion(completion).blockingGet();
}
Expand All @@ -258,9 +264,11 @@ public CompletionResp create(Completion completion) {
*/
public String chatCompletion(ChatCompletion chatCompletion) {
ChatCompletionResp response;
if(openaiApiType.equals(OpenaiApiType.AZURE.getValue())||openaiApiType.equals(OpenaiApiType.AZURE_AD.getValue())){
response = service.chatCompletion(chatCompletion.getModel(),openaiApiVersion,chatCompletion).blockingGet();
}else {
if (openaiApiType.equals(OpenaiApiType.AZURE.getValue())
|| openaiApiType.equals(OpenaiApiType.AZURE_AD.getValue())) {
response =
service.chatCompletion(chatCompletion.getModel(), openaiApiVersion, chatCompletion).blockingGet();
} else {
response = service.chatCompletion(chatCompletion).blockingGet();
}

Expand All @@ -275,8 +283,9 @@ public String chatCompletion(ChatCompletion chatCompletion) {
* @return the chat completion response
*/
public ChatCompletionResp create(ChatCompletion chatCompletion) {
if(openaiApiType.equals(OpenaiApiType.AZURE.getValue())||openaiApiType.equals(OpenaiApiType.AZURE_AD.getValue())){
return service.chatCompletion(chatCompletion.getModel(),openaiApiVersion,chatCompletion).blockingGet();
if (openaiApiType.equals(OpenaiApiType.AZURE.getValue())
|| openaiApiType.equals(OpenaiApiType.AZURE_AD.getValue())) {
return service.chatCompletion(chatCompletion.getModel(), openaiApiVersion, chatCompletion).blockingGet();
}
return service.chatCompletion(chatCompletion).blockingGet();
}
Expand All @@ -288,8 +297,9 @@ public ChatCompletionResp create(ChatCompletion chatCompletion) {
* @return The embedding vector response.
*/
public EmbeddingResp embedding(Embedding embedding) {
if(openaiApiType.equals(OpenaiApiType.AZURE.getValue())||openaiApiType.equals(OpenaiApiType.AZURE_AD.getValue())){
return service.embedding(embedding.getModel(),openaiApiVersion,embedding).blockingGet();
if (openaiApiType.equals(OpenaiApiType.AZURE.getValue())
|| openaiApiType.equals(OpenaiApiType.AZURE_AD.getValue())) {
return service.embedding(embedding.getModel(), openaiApiVersion, embedding).blockingGet();
}
return service.embedding(embedding).blockingGet();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import com.hw.openai.entity.embeddings.EmbeddingResp;
import com.hw.openai.entity.models.Model;
import com.hw.openai.entity.models.ModelResp;

import io.reactivex.Single;
import retrofit2.http.*;

Expand Down Expand Up @@ -71,7 +72,8 @@ public interface OpenAiService {
* @return a Single emitting the response containing the completion result
*/
@POST("{deploymentId}/completions")
Single<CompletionResp> completion(@Path("deploymentId") String deploymentId, @Query("api-version") String apiVersion, @Body Completion completion);
Single<CompletionResp> completion(@Path("deploymentId") String deploymentId,
@Query("api-version") String apiVersion, @Body Completion completion);

/**
* Creates a model response for the given chat conversation.
Expand All @@ -90,10 +92,11 @@ public interface OpenAiService {
* @param apiVersion The apiVersion for azure openai url parameter 'api-version'.
* @param chatCompletion the chat completion request object containing the chat conversation
* @return a Single emitting the response containing the chat completion result
*/
@POST("{deploymentId}/chat/completions")
Single<ChatCompletionResp> chatCompletion(@Path("deploymentId") String deploymentId, @Query("api-version") String apiVersion, @Body ChatCompletion chatCompletion);
Single<ChatCompletionResp> chatCompletion(@Path("deploymentId") String deploymentId,
@Query("api-version") String apiVersion, @Body ChatCompletion chatCompletion);

/**
* Creates an embedding vector representing the input text.
Expand All @@ -113,6 +116,7 @@ public interface OpenAiService {
* @return A Single object that emits an EmbeddingResp, representing the response containing the embedding vector.
*/
@POST("{deploymentId}/embeddings")
Single<EmbeddingResp> embedding(@Path("deploymentId") String deploymentId, @Query("api-version") String apiVersion, @Body Embedding embedding);
Single<EmbeddingResp> embedding(@Path("deploymentId") String deploymentId, @Query("api-version") String apiVersion,
@Body Embedding embedding);

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.hw.openai.entity.chat.Message;
import com.hw.openai.entity.completions.Completion;
import com.hw.openai.entity.embeddings.Embedding;

import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Disabled;
Expand All @@ -36,7 +37,7 @@
*
* @author Tingliang Wang
*/
@Disabled("Test requires costly OpenAI calls, can be run manually.")
@Disabled("Test requires costly Azure OpenAI calls, can be run manually.")
class AzureOpenAiClientTest {

private static OpenAiClient client;
Expand Down

0 comments on commit caf5d79

Please sign in to comment.