Skip to content

Commit

Permalink
support stream response for openai-client, next will be langchain-core
Browse files Browse the repository at this point in the history
  • Loading branch information
HamaWhiteGG committed Sep 20, 2023
1 parent 464dea0 commit 095bd66
Show file tree
Hide file tree
Showing 28 changed files with 941 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ public ChatResult innerGenerate(List<BaseMessage> messages, List<String> stop) {
.stop(stop)
.build();

var response = retryWithExponentialBackoff(maxRetries, () -> client.create(chatCompletion));
var response = retryWithExponentialBackoff(maxRetries, () -> client.createChatCompletion(chatCompletion));
return createChatResult(response);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,6 @@ public EmbeddingResp embedWithRetry(List<?> input) {
.model(model)
.input(input)
.build();
return client.embedding(embedding);
return client.createEmbedding(embedding);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,8 @@ protected LLMResult innerGenerate(List<String> prompts, List<String> stop) {

for (var prompt : subPrompts) {
completion.setPrompt(prompt);
CompletionResp response = retryWithExponentialBackoff(maxRetries, () -> client.create(completion));
CompletionResp response =
retryWithExponentialBackoff(maxRetries, () -> client.createCompletion(completion));
choices.addAll(response.getChoices());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ protected LLMResult innerGenerate(List<String> prompts, List<String> stop) {
.stop(stop)
.build();

ChatCompletionResp response = client.create(chatCompletion);
ChatCompletionResp response = client.createChatCompletion(chatCompletion);

List<List<Generation>> generations = new ArrayList<>();
Generation generation = Generation.builder()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.openai.example.chat;

import com.hw.langchain.examples.runner.RunnableExample;
import com.hw.openai.OpenAiClient;
import com.hw.openai.entity.chat.ChatCompletion;
import com.hw.openai.entity.chat.ChatCompletionResp;
import com.hw.openai.entity.chat.Message;

import java.util.List;

import static com.hw.langchain.examples.utils.PrintUtils.println;

/**
* @author HamaWhite
*/
@RunnableExample
public class ChatCompletionExample {

public static void main(String[] args) {
OpenAiClient client = OpenAiClient.builder()
.requestTimeout(120)
.build()
.init();

Message message = Message.of("Introduce West Lake in Hangzhou, China.");
ChatCompletion chatCompletion = ChatCompletion.builder()
.model("gpt-4")
.temperature(0)
.messages(List.of(message))
.build();

ChatCompletionResp response = client.createChatCompletion(chatCompletion);
println(response.getChoices().get(0).getMessage().getContent());

client.close();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.openai.example.chat;

import com.hw.langchain.examples.runner.RunnableExample;
import com.hw.openai.OpenAiClient;
import com.hw.openai.entity.chat.ChatCompletion;
import com.hw.openai.entity.chat.Message;

import java.util.List;

/**
* @author HamaWhite
*/
@RunnableExample
public class StreamChatCompletionExample {

public static void main(String[] args) {
OpenAiClient client = OpenAiClient.builder()
.requestTimeout(120)
.build()
.init();

Message message = Message.of("Introduce West Lake in Hangzhou, China.");
ChatCompletion chatCompletion = ChatCompletion.builder()
.model("gpt-4")
.messages(List.of(message))
.temperature(0)
.stream(true)
.build();

client.streamChatCompletion(chatCompletion)
.doOnError(Throwable::printStackTrace)
.blockingForEach(e -> {
String content = e.getChoices().get(0).getMessage().getContent();
if (content != null) {
System.out.print(content);
}
});
client.close();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.openai.example.completions;

import com.hw.langchain.examples.runner.RunnableExample;
import com.hw.openai.OpenAiClient;
import com.hw.openai.entity.completions.Completion;
import com.hw.openai.entity.completions.CompletionResp;

import java.util.List;

import static com.hw.langchain.examples.utils.PrintUtils.println;

/**
* @author HamaWhite
*/
@RunnableExample
public class CompletionExample {

public static void main(String[] args) {
OpenAiClient client = OpenAiClient.builder()
.requestTimeout(120)
.build()
.init();

Completion completion = Completion.builder()
.model("gpt-3.5-turbo-instruct")
.prompt(List.of("Introduce West Lake in Hangzhou, China."))
.maxTokens(1000)
.temperature(0)
.build();

CompletionResp response = client.createCompletion(completion);
println(response.getChoices().get(0).getText());

client.close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,37 @@
* limitations under the License.
*/

package com.hw.openai;
package com.hw.openai.example.completions;

import com.hw.langchain.examples.runner.RunnableExample;
import com.hw.openai.OpenAiClient;
import com.hw.openai.entity.completions.Completion;

import java.util.List;

/**
* @author HamaWhite
*/
public class QuickStart {
@RunnableExample
public class StreamCompletionExample {

public static void main(String[] args) {
OpenAiClient client = OpenAiClient.builder()
.requestTimeout(120)
.build()
.init();

Completion completion = Completion.builder()
.model("text-davinci-003")
.prompt(List.of("Say this is a test"))
.maxTokens(700)
.model("gpt-3.5-turbo-instruct")
.prompt(List.of("Introduce West Lake in Hangzhou, China."))
.maxTokens(1000)
.temperature(0)
.stream(true)
.build();

System.out.println(client.completion(completion));
client.streamCompletion(completion)
.doOnError(Throwable::printStackTrace)
.blockingForEach(e -> System.out.print(e.getChoices().get(0).getText()));

client.close();
}
Expand Down
2 changes: 1 addition & 1 deletion langchain-examples/src/main/resources/log4j.properties
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# limitations under the License.
################################################################################

log4j.rootLogger=DEBUG, console
log4j.rootLogger=INFO, console
log4j.appender.console=org.apache.log4j.ConsoleAppender
log4j.appender.console.layout=org.apache.log4j.PatternLayout
log4j.appender.console.layout.ConversionPattern=%d %5p %t %-20c.%M:%L - %m%n
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
@Disabled("Test requires costly OpenAI and Pinecone calls, can be run manually.")
class ExampleRunnerTest {

private static final String PACKAGE_NAME = "com.hw.langchain.examples";
private static final String PACKAGE_NAME = "com.hw";

@Test
void testExecuteExamples() {
Expand Down
2 changes: 2 additions & 0 deletions openai-client/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
The code for the stream part is referenced from https://github.com/TheoKanning/openai-java.
Thank you for this.
5 changes: 5 additions & 0 deletions openai-client/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
<artifactId>adapter-rxjava2</artifactId>
</dependency>

<dependency>
<groupId>com.squareup.retrofit2</groupId>
<artifactId>retrofit-mock</artifactId>
</dependency>

<dependency>
<groupId>com.squareup.retrofit2</groupId>
<artifactId>converter-jackson</artifactId>
Expand Down
Loading

0 comments on commit 095bd66

Please sign in to comment.