Skip to content

Commit

Permalink
Add PromptTemplate and test case
Browse files Browse the repository at this point in the history
  • Loading branch information
HamaWhiteGG committed Jun 2, 2023
1 parent 1beaecb commit 09ccc19
Show file tree
Hide file tree
Showing 12 changed files with 190 additions and 15 deletions.
5 changes: 5 additions & 0 deletions langchain-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
<artifactId>guava</artifactId>
</dependency>

<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-text</artifactId>
</dependency>

<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ public OpenAI init() {
openaiOrganization = Utils.getOrEnvOrDefault(openaiOrganization, "OPENAI_ORGANIZATION", "");

this.client = OpenAiClient.builder()
.openaiApiBase(openaiApiBase)
.openaiApiKey(openaiApiKey)
.openaiOrganization(openaiOrganization)
.proxy(openaiProxy)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,12 @@ public class OpenAIChat extends BaseLLM {
protected boolean streaming;

public OpenAIChat init() {
openaiApiBase = Utils.getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE", "");
openaiApiBase = Utils.getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE");
openaiApiKey = Utils.getOrEnvOrDefault(openaiApiKey, "OPENAI_API_KEY");
openaiOrganization = Utils.getOrEnvOrDefault(openaiOrganization, "OPENAI_ORGANIZATION", "");

this.client = OpenAiClient.builder()
.openaiApiBase(openaiApiBase)
.openaiApiKey(openaiApiKey)
.openaiOrganization(openaiOrganization)
.proxy(openaiProxy)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,17 @@
import com.hw.langchain.schema.BaseOutputParser;
import com.hw.langchain.schema.PromptValue;

import lombok.Data;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
* @description: Base class for all prompt templates, returning a prompt.
* @author: HamaWhite
*/
@Data
public abstract class BasePromptTemplate {

/**
Expand All @@ -40,6 +44,8 @@ public abstract class BasePromptTemplate {
*/
protected BaseOutputParser outputParser;

private Map<String, Object> partialVariables = new HashMap<>();

public BasePromptTemplate(List<String> inputVariables) {
this.inputVariables = inputVariables;
}
Expand All @@ -54,6 +60,11 @@ public BasePromptTemplate(List<String> inputVariables, BaseOutputParser outputPa
*/
public abstract PromptValue formatPrompt(Map<String, Object> kwargs);

/**
* Format the prompt with the inputs.
* @param kwargs Any arguments to be passed to the prompt template.
* @return A formatted string.
*/
public abstract String format(Map<String, Object> kwargs);

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import com.hw.langchain.schema.HumanMessage;
import com.hw.langchain.schema.PromptValue;

import java.util.Collections;
import java.util.List;

/**
Expand All @@ -38,8 +37,8 @@ public StringPromptValue(String text) {
}

@Override
public List<BaseMessage> toMessageList() {
return Collections.singletonList(new HumanMessage(text));
public List<BaseMessage> toMessages() {
return List.of(new HumanMessage(text));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
import com.hw.langchain.prompts.base.StringPromptTemplate;
import com.hw.langchain.schema.BaseOutputParser;

import org.apache.commons.text.StringSubstitutor;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

Expand All @@ -33,7 +36,7 @@ public class PromptTemplate extends StringPromptTemplate {
/**
* The prompt template.
*/
private String template;
private final String template;

/**
* Whether or not to try validating the template.
Expand All @@ -52,6 +55,18 @@ public PromptTemplate(List<String> inputVariables, String template, BaseOutputPa

@Override
public String format(Map<String, Object> kwargs) {
return "";
return StringSubstitutor.replace(template, kwargs, "{", "}");
}

public static PromptTemplate fromTemplate(String template) {
List<String> variableNames = new ArrayList<>();
StringSubstitutor substitutor = new StringSubstitutor(variable -> {
variableNames.add(variable);
return null;
});
substitutor.setVariablePrefix("{");
substitutor.setVariableSuffix("}");
substitutor.replace(template);
return new PromptTemplate(variableNames, template);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
public interface PromptValue {

/**
* Returns the prompt as a list of messages.
* Return prompt as messages.
*/
List<BaseMessage> toMessageList();
List<BaseMessage> toMessages();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.llms.openai;

import org.junit.jupiter.api.Test;

/**
* @description: DemoTest
* @author: HamaWhite
*/
public class DemoTest {

@Test
public void test() {
System.out.println("hello world");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.prompts.prompt;

import org.junit.jupiter.api.Test;

import java.util.List;
import java.util.Map;

import static org.junit.jupiter.api.Assertions.*;

/**
* <a href="https://platform.openai.com/docs/api-reference/completions">OpenAI API reference</a>
*
* @description: PromptTemplateTest
* @author: HamaWhite
*/
class PromptTemplateTest {

@Test
void testPromptTemplate() {
String template = """
I want you to act as a naming consultant for new companies.
What is a good name for a company that makes {product}?
""";

PromptTemplate prompt = new PromptTemplate(List.of("product"), template);
String actual = prompt.format(Map.of("product", "colorful socks"));

String expected = """
I want you to act as a naming consultant for new companies.
What is a good name for a company that makes colorful socks?
""";
assertEquals(expected, actual);
}

/**
* An example prompt with no input variables.
*/
@Test
void testPromptWithNoInputVariables() {
PromptTemplate noInputPrompt = new PromptTemplate(List.of(), "Tell me a joke.");

String actual = noInputPrompt.format(Map.of());
String expected = "Tell me a joke.";
assertEquals(expected, actual);
}

/**
* An example prompt with one input variable.
*/
@Test
void testPromptWithOneInputVariables() {
PromptTemplate oneInputPrompt = new PromptTemplate(List.of("adjective"), "Tell me a {adjective} joke.");

String actual = oneInputPrompt.format(Map.of("adjective", "funny"));
String expected = "Tell me a funny joke.";
assertEquals(expected, actual);
}

/**
* An example prompt with multiple input variable.
*/
@Test
void testPromptWithMultipleInputVariables() {
PromptTemplate oneInputPrompt =
new PromptTemplate(List.of("adjective", "content"), "Tell me a {adjective} joke about {content}.");

String actual = oneInputPrompt.format(Map.of("adjective", "funny", "content", "chickens"));
String expected = "Tell me a funny joke about chickens.";
assertEquals(expected, actual);
}

@Test
void testInferInputVariablesFromTemplate() {
String template = "Tell me a {adjective} joke about {content}.";

PromptTemplate promptTemplate = PromptTemplate.fromTemplate(template);
assertEquals(List.of("adjective", "content"), promptTemplate.getInputVariables());

String actual = promptTemplate.format(Map.of("adjective", "funny", "content", "chickens"));
String expected = "Tell me a funny joke about chickens.";
assertEquals(expected, actual);
}
}
8 changes: 5 additions & 3 deletions openai-client/src/main/java/com/hw/openai/OpenAiClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ public class OpenAiClient {

private static final Logger LOG = LoggerFactory.getLogger(OpenAiClient.class);

private static final String BASE_URL = "https://api.openai.com/";
@Builder.Default
private String openaiApiBase;

private String openaiApiKey;

Expand All @@ -63,8 +64,9 @@ public class OpenAiClient {
private OpenAiService service;

public OpenAiClient init() {
OkHttpClient.Builder httpClientBuilder = new OkHttpClient.Builder();
openaiApiBase = getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE", "https://api.openai.com/v1/");

OkHttpClient.Builder httpClientBuilder = new OkHttpClient.Builder();
httpClientBuilder.addInterceptor(chain -> {
// If openaiApiKey is not set, read the value of OPENAI_API_KEY from the environment.
openaiApiKey = getOrEnvOrDefault(openaiApiKey, "OPENAI_API_KEY");
Expand Down Expand Up @@ -93,7 +95,7 @@ public OpenAiClient init() {
objectMapper.findAndRegisterModules();

Retrofit retrofit = new Retrofit.Builder()
.baseUrl(BASE_URL)
.baseUrl(openaiApiBase)
.addCallAdapterFactory(RxJava2CallAdapterFactory.create())
.addConverterFactory(JacksonConverterFactory.create(objectMapper))
.client(httpClientBuilder.build())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,24 +41,24 @@ public interface OpenAiService {
* Lists the currently available models, and provides basic information about each one
* such as the owner and availability.
*/
@GET("v1/models")
@GET("models")
Single<ModelResp> listModels();

/**
* Retrieves a model instance, providing basic information about the model such as the owner and permissions.
*/
@GET("v1/models/{model}")
@GET("models/{model}")
Single<Model> retrieveModel(@Path("model") String model);

/**
* Creates a completion for the provided prompt and parameters.
*/
@POST("v1/completions")
@POST("completions")
Single<CompletionResp> completion(@Body Completion completion);

/**
* Creates a model response for the given chat conversation.
*/
@POST("v1/chat/completions")
@POST("chat/completions")
Single<ChatCompletionResp> chatCompletion(@Body ChatCompletion chatCompletion);
}
7 changes: 7 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
<retrofit.version>2.9.0</retrofit.version>
<slf4j-api.version>1.7.32</slf4j-api.version>
<slf4j-log4j12.version>1.7.25</slf4j-log4j12.version>
<commons-text.version>1.10.0</commons-text.version>
<commons-lang3.version>3.12.0</commons-lang3.version>
<jackson-annotation.version>2.15.1</jackson-annotation.version>
<commons-collections4.version>4.4</commons-collections4.version>
Expand Down Expand Up @@ -82,6 +83,12 @@
<version>${guava.version}</version>
</dependency>

<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-text</artifactId>
<version>${commons-text.version}</version>
</dependency>

<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
Expand Down

0 comments on commit 09ccc19

Please sign in to comment.