From 87112542a0b3f4f148f692760371e4c62a1a8702 Mon Sep 17 00:00:00 2001 From: Tingliang Wang Date: Thu, 31 Aug 2023 15:29:26 +0800 Subject: [PATCH] Feature #33: Azure OpenAI endpoint support --- .../chat/models/openai/ChatOpenAI.java | 8 ++ .../embeddings/openai/OpenAIEmbeddings.java | 8 ++ .../hw/langchain/llms/openai/BaseOpenAI.java | 10 ++ .../com/hw/langchain/llms/openai/OpenAI.java | 4 + .../hw/langchain/llms/openai/OpenAIChat.java | 15 +++ .../main/java/com/hw/openai/OpenAiClient.java | 70 ++++++++++--- .../openai/entity/common/OpenaiApiType.java | 52 ++++++++++ .../com/hw/openai/service/OpenAiService.java | 40 +++++++- .../com/hw/openai/AzureOpenAiClientTest.java | 98 +++++++++++++++++++ 9 files changed, 287 insertions(+), 18 deletions(-) create mode 100644 openai-client/src/main/java/com/hw/openai/entity/common/OpenaiApiType.java create mode 100644 openai-client/src/test/java/com/hw/openai/AzureOpenAiClientTest.java diff --git a/langchain-core/src/main/java/com/hw/langchain/chat/models/openai/ChatOpenAI.java b/langchain-core/src/main/java/com/hw/langchain/chat/models/openai/ChatOpenAI.java index fc02709c1..21e9061dd 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chat/models/openai/ChatOpenAI.java +++ b/langchain-core/src/main/java/com/hw/langchain/chat/models/openai/ChatOpenAI.java @@ -75,6 +75,10 @@ public class ChatOpenAI extends BaseChatModel { protected String openaiApiBase; + protected String openaiApiType; + + protected String openaiApiVersion; + protected String openaiOrganization; /** @@ -123,10 +127,14 @@ public ChatOpenAI init() { openaiOrganization = getOrEnvOrDefault(openaiOrganization, "OPENAI_ORGANIZATION", ""); openaiApiBase = getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE", ""); openaiProxy = getOrEnvOrDefault(openaiProxy, "OPENAI_PROXY", ""); + openaiApiType = getOrEnvOrDefault(openaiApiType, "OPENAI_API_TYPE",""); + openaiApiVersion = getOrEnvOrDefault(openaiApiVersion, "OPENAI_API_VERSION",""); this.client = OpenAiClient.builder() .openaiApiBase(openaiApiBase) .openaiApiKey(openaiApiKey) + .openaiApiVersion(openaiApiVersion) + .openaiApiType(openaiApiType) .openaiOrganization(openaiOrganization) .openaiProxy(openaiProxy) .requestTimeout(requestTimeout) diff --git a/langchain-core/src/main/java/com/hw/langchain/embeddings/openai/OpenAIEmbeddings.java b/langchain-core/src/main/java/com/hw/langchain/embeddings/openai/OpenAIEmbeddings.java index 0d0d3e7eb..6b753a7c4 100644 --- a/langchain-core/src/main/java/com/hw/langchain/embeddings/openai/OpenAIEmbeddings.java +++ b/langchain-core/src/main/java/com/hw/langchain/embeddings/openai/OpenAIEmbeddings.java @@ -68,6 +68,10 @@ public class OpenAIEmbeddings implements Embeddings { private String openaiApiKey; + private String openaiApiType; + + private String openaiApiVersion; + protected String openaiOrganization; /** @@ -96,10 +100,14 @@ public OpenAIEmbeddings init() { openaiApiBase = getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE", ""); openaiProxy = getOrEnvOrDefault(openaiProxy, "OPENAI_PROXY", ""); openaiOrganization = getOrEnvOrDefault(openaiOrganization, "OPENAI_ORGANIZATION", ""); + openaiApiType = getOrEnvOrDefault(openaiApiType, "OPENAI_API_TYPE",""); + openaiApiVersion = getOrEnvOrDefault(openaiApiVersion, "OPENAI_API_VERSION",""); this.client = OpenAiClient.builder() .openaiApiBase(openaiApiBase) .openaiApiKey(openaiApiKey) + .openaiApiVersion(openaiApiVersion) + .openaiApiType(openaiApiType) .openaiOrganization(openaiOrganization) .openaiProxy(openaiProxy) .requestTimeout(requestTimeout) diff --git a/langchain-core/src/main/java/com/hw/langchain/llms/openai/BaseOpenAI.java b/langchain-core/src/main/java/com/hw/langchain/llms/openai/BaseOpenAI.java index f9020de48..4c7725781 100644 --- a/langchain-core/src/main/java/com/hw/langchain/llms/openai/BaseOpenAI.java +++ b/langchain-core/src/main/java/com/hw/langchain/llms/openai/BaseOpenAI.java @@ -101,6 +101,16 @@ public class BaseOpenAI extends BaseLLM { */ protected String openaiApiBase; + /** + * Api type for Azure OpenAI API. + */ + protected String openaiApiType; + + /** + * Api version for Azure OpenAI API. + */ + protected String openaiApiVersion; + /** * Organization ID for OpenAI. */ diff --git a/langchain-core/src/main/java/com/hw/langchain/llms/openai/OpenAI.java b/langchain-core/src/main/java/com/hw/langchain/llms/openai/OpenAI.java index f68095b79..65670414a 100644 --- a/langchain-core/src/main/java/com/hw/langchain/llms/openai/OpenAI.java +++ b/langchain-core/src/main/java/com/hw/langchain/llms/openai/OpenAI.java @@ -43,10 +43,14 @@ public OpenAI init() { openaiApiBase = Utils.getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE", ""); openaiOrganization = Utils.getOrEnvOrDefault(openaiOrganization, "OPENAI_ORGANIZATION", ""); openaiProxy = Utils.getOrEnvOrDefault(openaiProxy, "OPENAI_PROXY", ""); + openaiApiType = Utils.getOrEnvOrDefault(openaiApiType, "OPENAI_API_TYPE",""); + openaiApiVersion = Utils.getOrEnvOrDefault(openaiApiVersion, "OPENAI_API_VERSION",""); this.client = OpenAiClient.builder() .openaiApiBase(openaiApiBase) .openaiApiKey(openaiApiKey) + .openaiApiVersion(openaiApiVersion) + .openaiApiType(openaiApiType) .openaiOrganization(openaiOrganization) .openaiProxy(openaiProxy) .proxyUsername(proxyUsername) diff --git a/langchain-core/src/main/java/com/hw/langchain/llms/openai/OpenAIChat.java b/langchain-core/src/main/java/com/hw/langchain/llms/openai/OpenAIChat.java index 3e4a4c867..d66b4d514 100644 --- a/langchain-core/src/main/java/com/hw/langchain/llms/openai/OpenAIChat.java +++ b/langchain-core/src/main/java/com/hw/langchain/llms/openai/OpenAIChat.java @@ -33,6 +33,7 @@ import java.util.*; import static com.google.common.base.Preconditions.checkArgument; +import static com.hw.langchain.utils.Utils.getOrEnvOrDefault; /** * Wrapper around OpenAI Chat large language models. @@ -94,6 +95,16 @@ public class OpenAIChat extends BaseLLM { */ protected String openaiApiBase; + /** + * Api type for Azure OpenAI API. + */ + protected String openaiApiType; + + /** + * Api version for Azure OpenAI API. + */ + protected String openaiApiVersion; + /** * Organization ID for OpenAI. */ @@ -137,10 +148,14 @@ public OpenAIChat init() { openaiApiKey = Utils.getOrEnvOrDefault(openaiApiKey, "OPENAI_API_KEY"); openaiOrganization = Utils.getOrEnvOrDefault(openaiOrganization, "OPENAI_ORGANIZATION", ""); openaiProxy = Utils.getOrEnvOrDefault(openaiProxy, "OPENAI_PROXY", ""); + openaiApiType = Utils.getOrEnvOrDefault(openaiApiType, "OPENAI_API_TYPE",""); + openaiApiVersion = Utils.getOrEnvOrDefault(openaiApiVersion, "OPENAI_API_VERSION",""); this.client = OpenAiClient.builder() .openaiApiBase(openaiApiBase) .openaiApiKey(openaiApiKey) + .openaiApiVersion(openaiApiVersion) + .openaiApiType(openaiApiType) .openaiOrganization(openaiOrganization) .openaiProxy(openaiProxy) .requestTimeout(requestTimeout) diff --git a/openai-client/src/main/java/com/hw/openai/OpenAiClient.java b/openai-client/src/main/java/com/hw/openai/OpenAiClient.java index e4b73b76e..ee283aefe 100644 --- a/openai-client/src/main/java/com/hw/openai/OpenAiClient.java +++ b/openai-client/src/main/java/com/hw/openai/OpenAiClient.java @@ -22,6 +22,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.hw.openai.entity.chat.ChatCompletion; import com.hw.openai.entity.chat.ChatCompletionResp; +import com.hw.openai.entity.common.OpenaiApiType; import com.hw.openai.entity.completions.Completion; import com.hw.openai.entity.completions.CompletionResp; import com.hw.openai.entity.embeddings.Embedding; @@ -30,17 +31,15 @@ import com.hw.openai.entity.models.ModelResp; import com.hw.openai.service.OpenAiService; import com.hw.openai.utils.ProxyUtils; - -import org.apache.commons.lang3.StringUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import lombok.Builder; import lombok.Data; import okhttp3.Interceptor; import okhttp3.OkHttpClient; import okhttp3.Request; import okhttp3.logging.HttpLoggingInterceptor; +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import retrofit2.Retrofit; import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory; import retrofit2.converter.jackson.JacksonConverterFactory; @@ -63,6 +62,10 @@ public class OpenAiClient { private String openaiApiKey; + private String openaiApiType; + + private String openaiApiVersion; + private String openaiOrganization; private String openaiProxy; @@ -95,7 +98,24 @@ public class OpenAiClient { * @return the initialized OpenAiClient instance */ public OpenAiClient init() { - openaiApiBase = getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE", "https://api.openai.com/v1/"); + openaiApiType = getOrEnvOrDefault(openaiApiType, "OPENAI_API_TYPE","openai"); + if(openaiApiType.equals(OpenaiApiType.AZURE.getValue())||openaiApiType.equals(OpenaiApiType.AZURE_AD.getValue())){ + openaiApiBase = getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE"); + if(openaiApiBase == null){ + throw new IllegalArgumentException( + String.format( + "Did not find %s, please add an environment variable `%s` which contains it, or pass `%s` as a named parameter.", + "OPENAI_API_BASE", "OPENAI_API_BASE", "OPENAI_API_BASE")); + } + openaiApiBase += (openaiApiBase.endsWith("/")?"":"/") + "openai/deployments/"; + }else if(openaiApiType.equals(OpenaiApiType.OPENAI.getValue())){ + openaiApiBase = getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE", "https://api.openai.com/v1/"); + }else { + throw new IllegalArgumentException( + String.format( + "The API type %s provided in invalid. Please select one of the supported API types: 'azure', 'azure_ad', 'openai'", + "OPENAI_API_TYPE")); + } openaiProxy = getOrEnvOrDefault(openaiProxy, "OPENAI_PROXY"); OkHttpClient.Builder httpClientBuilder = new OkHttpClient.Builder() @@ -109,11 +129,16 @@ public OpenAiClient init() { openaiApiKey = getOrEnvOrDefault(openaiApiKey, "OPENAI_API_KEY"); openaiOrganization = getOrEnvOrDefault(openaiOrganization, "OPENAI_ORGANIZATION", ""); - Request request = chain.request().newBuilder() - .header("Content-Type", "application/json") - .header("Authorization", "Bearer " + openaiApiKey) - .header("OpenAI-Organization", openaiOrganization) - .build(); + Request.Builder requestBuilder = chain.request().newBuilder(); + requestBuilder.header("Content-Type", "application/json"); + if(openaiApiType.equals(OpenaiApiType.AZURE.getValue())||openaiApiType.equals(OpenaiApiType.AZURE_AD.getValue())){ + requestBuilder.header("api-key", openaiApiKey); + }else { + requestBuilder.header("Authorization", "Bearer " + openaiApiKey); + requestBuilder.header("OpenAI-Organization", openaiOrganization); + } + + Request request = requestBuilder.build(); return chain.proceed(request); }); @@ -201,7 +226,12 @@ public Model retrieveModel(String model) { * @return the generated completion text */ public String completion(Completion completion) { - CompletionResp response = service.completion(completion).blockingGet(); + CompletionResp response; + if(openaiApiType.equals(OpenaiApiType.AZURE.getValue())||openaiApiType.equals(OpenaiApiType.AZURE_AD.getValue())){ + response = service.completion(completion.getModel(),openaiApiVersion,completion).blockingGet(); + }else { + response = service.completion(completion).blockingGet(); + } String text = response.getChoices().get(0).getText(); return StringUtils.trim(text); @@ -214,6 +244,9 @@ public String completion(Completion completion) { * @return the completion response */ public CompletionResp create(Completion completion) { + if(openaiApiType.equals(OpenaiApiType.AZURE.getValue())||openaiApiType.equals(OpenaiApiType.AZURE_AD.getValue())){ + return service.completion(completion.getModel(),openaiApiVersion,completion).blockingGet(); + } return service.completion(completion).blockingGet(); } @@ -224,7 +257,12 @@ public CompletionResp create(Completion completion) { * @return the generated model response text */ public String chatCompletion(ChatCompletion chatCompletion) { - ChatCompletionResp response = service.chatCompletion(chatCompletion).blockingGet(); + ChatCompletionResp response; + if(openaiApiType.equals(OpenaiApiType.AZURE.getValue())||openaiApiType.equals(OpenaiApiType.AZURE_AD.getValue())){ + response = service.chatCompletion(chatCompletion.getModel(),openaiApiVersion,chatCompletion).blockingGet(); + }else { + response = service.chatCompletion(chatCompletion).blockingGet(); + } String content = response.getChoices().get(0).getMessage().getContent(); return StringUtils.trim(content); @@ -237,6 +275,9 @@ public String chatCompletion(ChatCompletion chatCompletion) { * @return the chat completion response */ public ChatCompletionResp create(ChatCompletion chatCompletion) { + if(openaiApiType.equals(OpenaiApiType.AZURE.getValue())||openaiApiType.equals(OpenaiApiType.AZURE_AD.getValue())){ + return service.chatCompletion(chatCompletion.getModel(),openaiApiVersion,chatCompletion).blockingGet(); + } return service.chatCompletion(chatCompletion).blockingGet(); } @@ -247,6 +288,9 @@ public ChatCompletionResp create(ChatCompletion chatCompletion) { * @return The embedding vector response. */ public EmbeddingResp embedding(Embedding embedding) { + if(openaiApiType.equals(OpenaiApiType.AZURE.getValue())||openaiApiType.equals(OpenaiApiType.AZURE_AD.getValue())){ + return service.embedding(embedding.getModel(),openaiApiVersion,embedding).blockingGet(); + } return service.embedding(embedding).blockingGet(); } } diff --git a/openai-client/src/main/java/com/hw/openai/entity/common/OpenaiApiType.java b/openai-client/src/main/java/com/hw/openai/entity/common/OpenaiApiType.java new file mode 100644 index 000000000..f1817aa99 --- /dev/null +++ b/openai-client/src/main/java/com/hw/openai/entity/common/OpenaiApiType.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.openai.entity.common; + +/** + * OpenaiApiType + * @author Tingliang Wang + */ +public enum OpenaiApiType { + + /** + * azure. + */ + AZURE("azure"), + + /** + * azure_ad. + */ + AZURE_AD("azure_ad"), + + /** + * openai. + */ + OPENAI("openai"); + + private final String value; + + OpenaiApiType(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + +} diff --git a/openai-client/src/main/java/com/hw/openai/service/OpenAiService.java b/openai-client/src/main/java/com/hw/openai/service/OpenAiService.java index 14e7488d3..a2e479ffb 100644 --- a/openai-client/src/main/java/com/hw/openai/service/OpenAiService.java +++ b/openai-client/src/main/java/com/hw/openai/service/OpenAiService.java @@ -26,12 +26,8 @@ import com.hw.openai.entity.embeddings.EmbeddingResp; import com.hw.openai.entity.models.Model; import com.hw.openai.entity.models.ModelResp; - import io.reactivex.Single; -import retrofit2.http.Body; -import retrofit2.http.GET; -import retrofit2.http.POST; -import retrofit2.http.Path; +import retrofit2.http.*; /** * Service interface for interacting with the OpenAI API. @@ -66,6 +62,17 @@ public interface OpenAiService { @POST("completions") Single completion(@Body Completion completion); + /** + * Creates a completion for the provided prompt and parameters, using azure openai. + * + * @param deploymentId The deploymentId for azure openai url. + * @param apiVersion The apiVersion for azure openai url parameter 'api-version'. + * @param completion the completion request object containing the prompt and parameters + * @return a Single emitting the response containing the completion result + */ + @POST("{deploymentId}/completions") + Single completion(@Path("deploymentId") String deploymentId, @Query("api-version") String apiVersion, @Body Completion completion); + /** * Creates a model response for the given chat conversation. * @@ -76,6 +83,18 @@ public interface OpenAiService { @POST("chat/completions") Single chatCompletion(@Body ChatCompletion chatCompletion); + /** + * Creates a model response for the given chat conversation, using azure openai. + * + * @param deploymentId The deploymentId for azure openai url. + * @param apiVersion The apiVersion for azure openai url parameter 'api-version'. + * @param chatCompletion the chat completion request object containing the chat conversation + * @return a Single emitting the response containing the chat completion result + + */ + @POST("{deploymentId}/chat/completions") + Single chatCompletion(@Path("deploymentId") String deploymentId, @Query("api-version") String apiVersion, @Body ChatCompletion chatCompletion); + /** * Creates an embedding vector representing the input text. * @@ -85,4 +104,15 @@ public interface OpenAiService { @POST("embeddings") Single embedding(@Body Embedding embedding); + /** + * Creates an embedding vector representing the input text, using azure openai. + * + * @param deploymentId The deploymentId for azure openai url. + * @param apiVersion The apiVersion for azure openai url parameter 'api-version'. + * @param embedding The Embedding object containing the input text. + * @return A Single object that emits an EmbeddingResp, representing the response containing the embedding vector. + */ + @POST("{deploymentId}/embeddings") + Single embedding(@Path("deploymentId") String deploymentId, @Query("api-version") String apiVersion, @Body Embedding embedding); + } diff --git a/openai-client/src/test/java/com/hw/openai/AzureOpenAiClientTest.java b/openai-client/src/test/java/com/hw/openai/AzureOpenAiClientTest.java new file mode 100644 index 000000000..e7db389f2 --- /dev/null +++ b/openai-client/src/test/java/com/hw/openai/AzureOpenAiClientTest.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.openai; + +import com.hw.openai.entity.chat.ChatCompletion; +import com.hw.openai.entity.chat.Message; +import com.hw.openai.entity.completions.Completion; +import com.hw.openai.entity.embeddings.Embedding; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * OpenAI API reference + * + * @author Tingliang Wang + */ +@Disabled("Test requires costly OpenAI calls, can be run manually.") +class AzureOpenAiClientTest { + + private static OpenAiClient client; + + @BeforeAll + static void setup() { + client = OpenAiClient.builder() + .openaiApiKey("xxx") + .openaiApiType("azure") + .openaiApiBase("https://xxx.openai.azure.com/") + .openaiApiVersion("2023-05-15") + .build() + .init(); + } + + @AfterAll + static void cleanup() { + client.close(); + } + + @Test + void testCompletion() { + Completion completion = Completion.builder() + .model("text-davinci-003") + .prompt(List.of("Say this is a test")) + .maxTokens(700) + .temperature(0) + .build(); + + assertThat(client.completion(completion)).isEqualTo("This is indeed a test."); + } + + @Test + void testChatCompletion() { + Message message = Message.of("Hello!"); + + ChatCompletion chatCompletion = ChatCompletion.builder() + .model("gpt-35-turbo") + .temperature(0) + .messages(List.of(message)) + .build(); + + assertThat(client.chatCompletion(chatCompletion)).isEqualTo("Hello there! How can I assist you today?"); + } + + @Test + void testEmbeddings() { + var embedding = Embedding.builder() + .model("text-embedding-ada-002") + .input(List.of("The food was delicious and the waiter...")) + .build(); + + var response = client.embedding(embedding); + + assertThat(response).as("Response should not be null").isNotNull(); + assertThat(response.getData()).as("Data list should have size 1").hasSize(1); + assertThat(response.getData().get(0).getEmbedding()).as("Embedding should have size 1536").hasSize(1536); + } +} \ No newline at end of file