Skip to content

Commit

Permalink
Merge pull request HamaWhiteGG#92 from wangtl/dev
Browse files Browse the repository at this point in the history
good, thanks
  • Loading branch information
HamaWhiteGG authored Sep 10, 2023
2 parents 1dc6423 + 8711254 commit c3553c1
Show file tree
Hide file tree
Showing 9 changed files with 287 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ public class ChatOpenAI extends BaseChatModel {

protected String openaiApiBase;

protected String openaiApiType;

protected String openaiApiVersion;

protected String openaiOrganization;

/**
Expand Down Expand Up @@ -124,10 +128,14 @@ public ChatOpenAI init() {
openaiOrganization = getOrEnvOrDefault(openaiOrganization, "OPENAI_ORGANIZATION", "");
openaiApiBase = getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE", "");
openaiProxy = getOrEnvOrDefault(openaiProxy, "OPENAI_PROXY", "");
openaiApiType = getOrEnvOrDefault(openaiApiType, "OPENAI_API_TYPE","");
openaiApiVersion = getOrEnvOrDefault(openaiApiVersion, "OPENAI_API_VERSION","");

this.client = OpenAiClient.builder()
.openaiApiBase(openaiApiBase)
.openaiApiKey(openaiApiKey)
.openaiApiVersion(openaiApiVersion)
.openaiApiType(openaiApiType)
.openaiOrganization(openaiOrganization)
.openaiProxy(openaiProxy)
.requestTimeout(requestTimeout)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ public class OpenAIEmbeddings implements Embeddings {

private String openaiApiKey;

private String openaiApiType;

private String openaiApiVersion;

protected String openaiOrganization;

/**
Expand Down Expand Up @@ -96,10 +100,14 @@ public OpenAIEmbeddings init() {
openaiApiBase = getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE", "");
openaiProxy = getOrEnvOrDefault(openaiProxy, "OPENAI_PROXY", "");
openaiOrganization = getOrEnvOrDefault(openaiOrganization, "OPENAI_ORGANIZATION", "");
openaiApiType = getOrEnvOrDefault(openaiApiType, "OPENAI_API_TYPE","");
openaiApiVersion = getOrEnvOrDefault(openaiApiVersion, "OPENAI_API_VERSION","");

this.client = OpenAiClient.builder()
.openaiApiBase(openaiApiBase)
.openaiApiKey(openaiApiKey)
.openaiApiVersion(openaiApiVersion)
.openaiApiType(openaiApiType)
.openaiOrganization(openaiOrganization)
.openaiProxy(openaiProxy)
.requestTimeout(requestTimeout)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,16 @@ public class BaseOpenAI extends BaseLLM {
*/
protected String openaiApiBase;

/**
* Api type for Azure OpenAI API.
*/
protected String openaiApiType;

/**
* Api version for Azure OpenAI API.
*/
protected String openaiApiVersion;

/**
* Organization ID for OpenAI.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,14 @@ public OpenAI init() {
openaiApiBase = Utils.getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE", "");
openaiOrganization = Utils.getOrEnvOrDefault(openaiOrganization, "OPENAI_ORGANIZATION", "");
openaiProxy = Utils.getOrEnvOrDefault(openaiProxy, "OPENAI_PROXY", "");
openaiApiType = Utils.getOrEnvOrDefault(openaiApiType, "OPENAI_API_TYPE","");
openaiApiVersion = Utils.getOrEnvOrDefault(openaiApiVersion, "OPENAI_API_VERSION","");

this.client = OpenAiClient.builder()
.openaiApiBase(openaiApiBase)
.openaiApiKey(openaiApiKey)
.openaiApiVersion(openaiApiVersion)
.openaiApiType(openaiApiType)
.openaiOrganization(openaiOrganization)
.openaiProxy(openaiProxy)
.proxyUsername(proxyUsername)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import java.util.*;

import static com.google.common.base.Preconditions.checkArgument;
import static com.hw.langchain.utils.Utils.getOrEnvOrDefault;

/**
* Wrapper around OpenAI Chat large language models.
Expand Down Expand Up @@ -94,6 +95,16 @@ public class OpenAIChat extends BaseLLM {
*/
protected String openaiApiBase;

/**
* Api type for Azure OpenAI API.
*/
protected String openaiApiType;

/**
* Api version for Azure OpenAI API.
*/
protected String openaiApiVersion;

/**
* Organization ID for OpenAI.
*/
Expand Down Expand Up @@ -137,10 +148,14 @@ public OpenAIChat init() {
openaiApiKey = Utils.getOrEnvOrDefault(openaiApiKey, "OPENAI_API_KEY");
openaiOrganization = Utils.getOrEnvOrDefault(openaiOrganization, "OPENAI_ORGANIZATION", "");
openaiProxy = Utils.getOrEnvOrDefault(openaiProxy, "OPENAI_PROXY", "");
openaiApiType = Utils.getOrEnvOrDefault(openaiApiType, "OPENAI_API_TYPE","");
openaiApiVersion = Utils.getOrEnvOrDefault(openaiApiVersion, "OPENAI_API_VERSION","");

this.client = OpenAiClient.builder()
.openaiApiBase(openaiApiBase)
.openaiApiKey(openaiApiKey)
.openaiApiVersion(openaiApiVersion)
.openaiApiType(openaiApiType)
.openaiOrganization(openaiOrganization)
.openaiProxy(openaiProxy)
.requestTimeout(requestTimeout)
Expand Down
70 changes: 57 additions & 13 deletions openai-client/src/main/java/com/hw/openai/OpenAiClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.hw.openai.entity.chat.ChatCompletion;
import com.hw.openai.entity.chat.ChatCompletionResp;
import com.hw.openai.entity.common.OpenaiApiType;
import com.hw.openai.entity.completions.Completion;
import com.hw.openai.entity.completions.CompletionResp;
import com.hw.openai.entity.embeddings.Embedding;
Expand All @@ -30,17 +31,15 @@
import com.hw.openai.entity.models.ModelResp;
import com.hw.openai.service.OpenAiService;
import com.hw.openai.utils.ProxyUtils;

import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import lombok.Builder;
import lombok.Data;
import okhttp3.Interceptor;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.logging.HttpLoggingInterceptor;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import retrofit2.Retrofit;
import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory;
import retrofit2.converter.jackson.JacksonConverterFactory;
Expand All @@ -63,6 +62,10 @@ public class OpenAiClient {

private String openaiApiKey;

private String openaiApiType;

private String openaiApiVersion;

private String openaiOrganization;

private String openaiProxy;
Expand Down Expand Up @@ -95,7 +98,24 @@ public class OpenAiClient {
* @return the initialized OpenAiClient instance
*/
public OpenAiClient init() {
openaiApiBase = getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE", "https://api.openai.com/v1/");
openaiApiType = getOrEnvOrDefault(openaiApiType, "OPENAI_API_TYPE","openai");
if(openaiApiType.equals(OpenaiApiType.AZURE.getValue())||openaiApiType.equals(OpenaiApiType.AZURE_AD.getValue())){
openaiApiBase = getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE");
if(openaiApiBase == null){
throw new IllegalArgumentException(
String.format(
"Did not find %s, please add an environment variable `%s` which contains it, or pass `%s` as a named parameter.",
"OPENAI_API_BASE", "OPENAI_API_BASE", "OPENAI_API_BASE"));
}
openaiApiBase += (openaiApiBase.endsWith("/")?"":"/") + "openai/deployments/";
}else if(openaiApiType.equals(OpenaiApiType.OPENAI.getValue())){
openaiApiBase = getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE", "https://api.openai.com/v1/");
}else {
throw new IllegalArgumentException(
String.format(
"The API type %s provided in invalid. Please select one of the supported API types: 'azure', 'azure_ad', 'openai'",
"OPENAI_API_TYPE"));
}
openaiProxy = getOrEnvOrDefault(openaiProxy, "OPENAI_PROXY");

OkHttpClient.Builder httpClientBuilder = new OkHttpClient.Builder()
Expand All @@ -109,11 +129,16 @@ public OpenAiClient init() {
openaiApiKey = getOrEnvOrDefault(openaiApiKey, "OPENAI_API_KEY");
openaiOrganization = getOrEnvOrDefault(openaiOrganization, "OPENAI_ORGANIZATION", "");

Request request = chain.request().newBuilder()
.header("Content-Type", "application/json")
.header("Authorization", "Bearer " + openaiApiKey)
.header("OpenAI-Organization", openaiOrganization)
.build();
Request.Builder requestBuilder = chain.request().newBuilder();
requestBuilder.header("Content-Type", "application/json");
if(openaiApiType.equals(OpenaiApiType.AZURE.getValue())||openaiApiType.equals(OpenaiApiType.AZURE_AD.getValue())){
requestBuilder.header("api-key", openaiApiKey);
}else {
requestBuilder.header("Authorization", "Bearer " + openaiApiKey);
requestBuilder.header("OpenAI-Organization", openaiOrganization);
}

Request request = requestBuilder.build();

return chain.proceed(request);
});
Expand Down Expand Up @@ -201,7 +226,12 @@ public Model retrieveModel(String model) {
* @return the generated completion text
*/
public String completion(Completion completion) {
CompletionResp response = service.completion(completion).blockingGet();
CompletionResp response;
if(openaiApiType.equals(OpenaiApiType.AZURE.getValue())||openaiApiType.equals(OpenaiApiType.AZURE_AD.getValue())){
response = service.completion(completion.getModel(),openaiApiVersion,completion).blockingGet();
}else {
response = service.completion(completion).blockingGet();
}

String text = response.getChoices().get(0).getText();
return StringUtils.trim(text);
Expand All @@ -214,6 +244,9 @@ public String completion(Completion completion) {
* @return the completion response
*/
public CompletionResp create(Completion completion) {
if(openaiApiType.equals(OpenaiApiType.AZURE.getValue())||openaiApiType.equals(OpenaiApiType.AZURE_AD.getValue())){
return service.completion(completion.getModel(),openaiApiVersion,completion).blockingGet();
}
return service.completion(completion).blockingGet();
}

Expand All @@ -224,7 +257,12 @@ public CompletionResp create(Completion completion) {
* @return the generated model response text
*/
public String chatCompletion(ChatCompletion chatCompletion) {
ChatCompletionResp response = service.chatCompletion(chatCompletion).blockingGet();
ChatCompletionResp response;
if(openaiApiType.equals(OpenaiApiType.AZURE.getValue())||openaiApiType.equals(OpenaiApiType.AZURE_AD.getValue())){
response = service.chatCompletion(chatCompletion.getModel(),openaiApiVersion,chatCompletion).blockingGet();
}else {
response = service.chatCompletion(chatCompletion).blockingGet();
}

String content = response.getChoices().get(0).getMessage().getContent();
return StringUtils.trim(content);
Expand All @@ -237,6 +275,9 @@ public String chatCompletion(ChatCompletion chatCompletion) {
* @return the chat completion response
*/
public ChatCompletionResp create(ChatCompletion chatCompletion) {
if(openaiApiType.equals(OpenaiApiType.AZURE.getValue())||openaiApiType.equals(OpenaiApiType.AZURE_AD.getValue())){
return service.chatCompletion(chatCompletion.getModel(),openaiApiVersion,chatCompletion).blockingGet();
}
return service.chatCompletion(chatCompletion).blockingGet();
}

Expand All @@ -247,6 +288,9 @@ public ChatCompletionResp create(ChatCompletion chatCompletion) {
* @return The embedding vector response.
*/
public EmbeddingResp embedding(Embedding embedding) {
if(openaiApiType.equals(OpenaiApiType.AZURE.getValue())||openaiApiType.equals(OpenaiApiType.AZURE_AD.getValue())){
return service.embedding(embedding.getModel(),openaiApiVersion,embedding).blockingGet();
}
return service.embedding(embedding).blockingGet();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.openai.entity.common;

/**
* OpenaiApiType
* @author Tingliang Wang
*/
public enum OpenaiApiType {

/**
* azure.
*/
AZURE("azure"),

/**
* azure_ad.
*/
AZURE_AD("azure_ad"),

/**
* openai.
*/
OPENAI("openai");

private final String value;

OpenaiApiType(String value) {
this.value = value;
}

public String getValue() {
return value;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,8 @@
import com.hw.openai.entity.embeddings.EmbeddingResp;
import com.hw.openai.entity.models.Model;
import com.hw.openai.entity.models.ModelResp;

import io.reactivex.Single;
import retrofit2.http.Body;
import retrofit2.http.GET;
import retrofit2.http.POST;
import retrofit2.http.Path;
import retrofit2.http.*;

/**
* Service interface for interacting with the OpenAI API.
Expand Down Expand Up @@ -66,6 +62,17 @@ public interface OpenAiService {
@POST("completions")
Single<CompletionResp> completion(@Body Completion completion);

/**
* Creates a completion for the provided prompt and parameters, using azure openai.
*
* @param deploymentId The deploymentId for azure openai url.
* @param apiVersion The apiVersion for azure openai url parameter 'api-version'.
* @param completion the completion request object containing the prompt and parameters
* @return a Single emitting the response containing the completion result
*/
@POST("{deploymentId}/completions")
Single<CompletionResp> completion(@Path("deploymentId") String deploymentId, @Query("api-version") String apiVersion, @Body Completion completion);

/**
* Creates a model response for the given chat conversation.
*
Expand All @@ -76,6 +83,18 @@ public interface OpenAiService {
@POST("chat/completions")
Single<ChatCompletionResp> chatCompletion(@Body ChatCompletion chatCompletion);

/**
* Creates a model response for the given chat conversation, using azure openai.
*
* @param deploymentId The deploymentId for azure openai url.
* @param apiVersion The apiVersion for azure openai url parameter 'api-version'.
* @param chatCompletion the chat completion request object containing the chat conversation
* @return a Single emitting the response containing the chat completion result
*/
@POST("{deploymentId}/chat/completions")
Single<ChatCompletionResp> chatCompletion(@Path("deploymentId") String deploymentId, @Query("api-version") String apiVersion, @Body ChatCompletion chatCompletion);

/**
* Creates an embedding vector representing the input text.
*
Expand All @@ -85,4 +104,15 @@ public interface OpenAiService {
@POST("embeddings")
Single<EmbeddingResp> embedding(@Body Embedding embedding);

/**
* Creates an embedding vector representing the input text, using azure openai.
*
* @param deploymentId The deploymentId for azure openai url.
* @param apiVersion The apiVersion for azure openai url parameter 'api-version'.
* @param embedding The Embedding object containing the input text.
* @return A Single object that emits an EmbeddingResp, representing the response containing the embedding vector.
*/
@POST("{deploymentId}/embeddings")
Single<EmbeddingResp> embedding(@Path("deploymentId") String deploymentId, @Query("api-version") String apiVersion, @Body Embedding embedding);

}
Loading

0 comments on commit c3553c1

Please sign in to comment.