From 09ccc19fc989c3950b1c79ae4f5ba53a298830e8 Mon Sep 17 00:00:00 2001 From: HamaWhite Date: Fri, 2 Jun 2023 11:53:33 +0800 Subject: [PATCH] Add PromptTemplate and test case --- langchain-core/pom.xml | 5 + .../com/hw/langchain/llms/openai/OpenAI.java | 1 + .../hw/langchain/llms/openai/OpenAIChat.java | 3 +- .../prompts/base/BasePromptTemplate.java | 11 ++ .../prompts/base/StringPromptValue.java | 5 +- .../prompts/prompt/PromptTemplate.java | 19 +++- .../com/hw/langchain/schema/PromptValue.java | 4 +- .../hw/langchain/llms/openai/DemoTest.java | 33 ++++++ .../prompts/prompt/PromptTemplateTest.java | 101 ++++++++++++++++++ .../main/java/com/hw/openai/OpenAiClient.java | 8 +- .../com/hw/openai/service/OpenAiService.java | 8 +- pom.xml | 7 ++ 12 files changed, 190 insertions(+), 15 deletions(-) create mode 100644 langchain-core/src/test/java/com/hw/langchain/llms/openai/DemoTest.java create mode 100644 langchain-core/src/test/java/com/hw/langchain/prompts/prompt/PromptTemplateTest.java diff --git a/langchain-core/pom.xml b/langchain-core/pom.xml index e918c7620..9e5291dc8 100644 --- a/langchain-core/pom.xml +++ b/langchain-core/pom.xml @@ -22,6 +22,11 @@ guava + + org.apache.commons + commons-text + + org.apache.commons commons-lang3 diff --git a/langchain-core/src/main/java/com/hw/langchain/llms/openai/OpenAI.java b/langchain-core/src/main/java/com/hw/langchain/llms/openai/OpenAI.java index d26c4135b..95e37f343 100644 --- a/langchain-core/src/main/java/com/hw/langchain/llms/openai/OpenAI.java +++ b/langchain-core/src/main/java/com/hw/langchain/llms/openai/OpenAI.java @@ -44,6 +44,7 @@ public OpenAI init() { openaiOrganization = Utils.getOrEnvOrDefault(openaiOrganization, "OPENAI_ORGANIZATION", ""); this.client = OpenAiClient.builder() + .openaiApiBase(openaiApiBase) .openaiApiKey(openaiApiKey) .openaiOrganization(openaiOrganization) .proxy(openaiProxy) diff --git a/langchain-core/src/main/java/com/hw/langchain/llms/openai/OpenAIChat.java b/langchain-core/src/main/java/com/hw/langchain/llms/openai/OpenAIChat.java index b7bce9c72..dd5f1e3d5 100644 --- a/langchain-core/src/main/java/com/hw/langchain/llms/openai/OpenAIChat.java +++ b/langchain-core/src/main/java/com/hw/langchain/llms/openai/OpenAIChat.java @@ -130,11 +130,12 @@ public class OpenAIChat extends BaseLLM { protected boolean streaming; public OpenAIChat init() { - openaiApiBase = Utils.getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE", ""); + openaiApiBase = Utils.getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE"); openaiApiKey = Utils.getOrEnvOrDefault(openaiApiKey, "OPENAI_API_KEY"); openaiOrganization = Utils.getOrEnvOrDefault(openaiOrganization, "OPENAI_ORGANIZATION", ""); this.client = OpenAiClient.builder() + .openaiApiBase(openaiApiBase) .openaiApiKey(openaiApiKey) .openaiOrganization(openaiOrganization) .proxy(openaiProxy) diff --git a/langchain-core/src/main/java/com/hw/langchain/prompts/base/BasePromptTemplate.java b/langchain-core/src/main/java/com/hw/langchain/prompts/base/BasePromptTemplate.java index c6eee5cbe..150d3d600 100644 --- a/langchain-core/src/main/java/com/hw/langchain/prompts/base/BasePromptTemplate.java +++ b/langchain-core/src/main/java/com/hw/langchain/prompts/base/BasePromptTemplate.java @@ -21,6 +21,9 @@ import com.hw.langchain.schema.BaseOutputParser; import com.hw.langchain.schema.PromptValue; +import lombok.Data; + +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -28,6 +31,7 @@ * @description: Base class for all prompt templates, returning a prompt. * @author: HamaWhite */ +@Data public abstract class BasePromptTemplate { /** @@ -40,6 +44,8 @@ public abstract class BasePromptTemplate { */ protected BaseOutputParser outputParser; + private Map partialVariables = new HashMap<>(); + public BasePromptTemplate(List inputVariables) { this.inputVariables = inputVariables; } @@ -54,6 +60,11 @@ public BasePromptTemplate(List inputVariables, BaseOutputParser outputPa */ public abstract PromptValue formatPrompt(Map kwargs); + /** + * Format the prompt with the inputs. + * @param kwargs Any arguments to be passed to the prompt template. + * @return A formatted string. + */ public abstract String format(Map kwargs); } diff --git a/langchain-core/src/main/java/com/hw/langchain/prompts/base/StringPromptValue.java b/langchain-core/src/main/java/com/hw/langchain/prompts/base/StringPromptValue.java index 4802f9eb3..5fcd6172c 100644 --- a/langchain-core/src/main/java/com/hw/langchain/prompts/base/StringPromptValue.java +++ b/langchain-core/src/main/java/com/hw/langchain/prompts/base/StringPromptValue.java @@ -22,7 +22,6 @@ import com.hw.langchain.schema.HumanMessage; import com.hw.langchain.schema.PromptValue; -import java.util.Collections; import java.util.List; /** @@ -38,8 +37,8 @@ public StringPromptValue(String text) { } @Override - public List toMessageList() { - return Collections.singletonList(new HumanMessage(text)); + public List toMessages() { + return List.of(new HumanMessage(text)); } /** diff --git a/langchain-core/src/main/java/com/hw/langchain/prompts/prompt/PromptTemplate.java b/langchain-core/src/main/java/com/hw/langchain/prompts/prompt/PromptTemplate.java index 1f4d2c606..5d05227fa 100644 --- a/langchain-core/src/main/java/com/hw/langchain/prompts/prompt/PromptTemplate.java +++ b/langchain-core/src/main/java/com/hw/langchain/prompts/prompt/PromptTemplate.java @@ -21,6 +21,9 @@ import com.hw.langchain.prompts.base.StringPromptTemplate; import com.hw.langchain.schema.BaseOutputParser; +import org.apache.commons.text.StringSubstitutor; + +import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -33,7 +36,7 @@ public class PromptTemplate extends StringPromptTemplate { /** * The prompt template. */ - private String template; + private final String template; /** * Whether or not to try validating the template. @@ -52,6 +55,18 @@ public PromptTemplate(List inputVariables, String template, BaseOutputPa @Override public String format(Map kwargs) { - return ""; + return StringSubstitutor.replace(template, kwargs, "{", "}"); + } + + public static PromptTemplate fromTemplate(String template) { + List variableNames = new ArrayList<>(); + StringSubstitutor substitutor = new StringSubstitutor(variable -> { + variableNames.add(variable); + return null; + }); + substitutor.setVariablePrefix("{"); + substitutor.setVariableSuffix("}"); + substitutor.replace(template); + return new PromptTemplate(variableNames, template); } } diff --git a/langchain-core/src/main/java/com/hw/langchain/schema/PromptValue.java b/langchain-core/src/main/java/com/hw/langchain/schema/PromptValue.java index 516ef720e..73ce688e6 100644 --- a/langchain-core/src/main/java/com/hw/langchain/schema/PromptValue.java +++ b/langchain-core/src/main/java/com/hw/langchain/schema/PromptValue.java @@ -27,7 +27,7 @@ public interface PromptValue { /** - * Returns the prompt as a list of messages. + * Return prompt as messages. */ - List toMessageList(); + List toMessages(); } diff --git a/langchain-core/src/test/java/com/hw/langchain/llms/openai/DemoTest.java b/langchain-core/src/test/java/com/hw/langchain/llms/openai/DemoTest.java new file mode 100644 index 000000000..43a6a14a4 --- /dev/null +++ b/langchain-core/src/test/java/com/hw/langchain/llms/openai/DemoTest.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.llms.openai; + +import org.junit.jupiter.api.Test; + +/** + * @description: DemoTest + * @author: HamaWhite + */ +public class DemoTest { + + @Test + public void test() { + System.out.println("hello world"); + } +} diff --git a/langchain-core/src/test/java/com/hw/langchain/prompts/prompt/PromptTemplateTest.java b/langchain-core/src/test/java/com/hw/langchain/prompts/prompt/PromptTemplateTest.java new file mode 100644 index 000000000..cdc642120 --- /dev/null +++ b/langchain-core/src/test/java/com/hw/langchain/prompts/prompt/PromptTemplateTest.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.prompts.prompt; + +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * OpenAI API reference + * + * @description: PromptTemplateTest + * @author: HamaWhite + */ +class PromptTemplateTest { + + @Test + void testPromptTemplate() { + String template = """ + I want you to act as a naming consultant for new companies. + What is a good name for a company that makes {product}? + """; + + PromptTemplate prompt = new PromptTemplate(List.of("product"), template); + String actual = prompt.format(Map.of("product", "colorful socks")); + + String expected = """ + I want you to act as a naming consultant for new companies. + What is a good name for a company that makes colorful socks? + """; + assertEquals(expected, actual); + } + + /** + * An example prompt with no input variables. + */ + @Test + void testPromptWithNoInputVariables() { + PromptTemplate noInputPrompt = new PromptTemplate(List.of(), "Tell me a joke."); + + String actual = noInputPrompt.format(Map.of()); + String expected = "Tell me a joke."; + assertEquals(expected, actual); + } + + /** + * An example prompt with one input variable. + */ + @Test + void testPromptWithOneInputVariables() { + PromptTemplate oneInputPrompt = new PromptTemplate(List.of("adjective"), "Tell me a {adjective} joke."); + + String actual = oneInputPrompt.format(Map.of("adjective", "funny")); + String expected = "Tell me a funny joke."; + assertEquals(expected, actual); + } + + /** + * An example prompt with multiple input variable. + */ + @Test + void testPromptWithMultipleInputVariables() { + PromptTemplate oneInputPrompt = + new PromptTemplate(List.of("adjective", "content"), "Tell me a {adjective} joke about {content}."); + + String actual = oneInputPrompt.format(Map.of("adjective", "funny", "content", "chickens")); + String expected = "Tell me a funny joke about chickens."; + assertEquals(expected, actual); + } + + @Test + void testInferInputVariablesFromTemplate() { + String template = "Tell me a {adjective} joke about {content}."; + + PromptTemplate promptTemplate = PromptTemplate.fromTemplate(template); + assertEquals(List.of("adjective", "content"), promptTemplate.getInputVariables()); + + String actual = promptTemplate.format(Map.of("adjective", "funny", "content", "chickens")); + String expected = "Tell me a funny joke about chickens."; + assertEquals(expected, actual); + } +} \ No newline at end of file diff --git a/openai-client/src/main/java/com/hw/openai/OpenAiClient.java b/openai-client/src/main/java/com/hw/openai/OpenAiClient.java index b1a372a8d..5171655ad 100644 --- a/openai-client/src/main/java/com/hw/openai/OpenAiClient.java +++ b/openai-client/src/main/java/com/hw/openai/OpenAiClient.java @@ -52,7 +52,8 @@ public class OpenAiClient { private static final Logger LOG = LoggerFactory.getLogger(OpenAiClient.class); - private static final String BASE_URL = "https://api.openai.com/"; + @Builder.Default + private String openaiApiBase; private String openaiApiKey; @@ -63,8 +64,9 @@ public class OpenAiClient { private OpenAiService service; public OpenAiClient init() { - OkHttpClient.Builder httpClientBuilder = new OkHttpClient.Builder(); + openaiApiBase = getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE", "https://api.openai.com/v1/"); + OkHttpClient.Builder httpClientBuilder = new OkHttpClient.Builder(); httpClientBuilder.addInterceptor(chain -> { // If openaiApiKey is not set, read the value of OPENAI_API_KEY from the environment. openaiApiKey = getOrEnvOrDefault(openaiApiKey, "OPENAI_API_KEY"); @@ -93,7 +95,7 @@ public OpenAiClient init() { objectMapper.findAndRegisterModules(); Retrofit retrofit = new Retrofit.Builder() - .baseUrl(BASE_URL) + .baseUrl(openaiApiBase) .addCallAdapterFactory(RxJava2CallAdapterFactory.create()) .addConverterFactory(JacksonConverterFactory.create(objectMapper)) .client(httpClientBuilder.build()) diff --git a/openai-client/src/main/java/com/hw/openai/service/OpenAiService.java b/openai-client/src/main/java/com/hw/openai/service/OpenAiService.java index d9e4498aa..7181fe508 100644 --- a/openai-client/src/main/java/com/hw/openai/service/OpenAiService.java +++ b/openai-client/src/main/java/com/hw/openai/service/OpenAiService.java @@ -41,24 +41,24 @@ public interface OpenAiService { * Lists the currently available models, and provides basic information about each one * such as the owner and availability. */ - @GET("v1/models") + @GET("models") Single listModels(); /** * Retrieves a model instance, providing basic information about the model such as the owner and permissions. */ - @GET("v1/models/{model}") + @GET("models/{model}") Single retrieveModel(@Path("model") String model); /** * Creates a completion for the provided prompt and parameters. */ - @POST("v1/completions") + @POST("completions") Single completion(@Body Completion completion); /** * Creates a model response for the given chat conversation. */ - @POST("v1/chat/completions") + @POST("chat/completions") Single chatCompletion(@Body ChatCompletion chatCompletion); } diff --git a/pom.xml b/pom.xml index eb492d0f8..97a5dfc97 100644 --- a/pom.xml +++ b/pom.xml @@ -25,6 +25,7 @@ 2.9.0 1.7.32 1.7.25 + 1.10.0 3.12.0 2.15.1 4.4 @@ -82,6 +83,12 @@ ${guava.version} + + org.apache.commons + commons-text + ${commons-text.version} + + org.apache.commons commons-lang3