Skip to content

Commit

Permalink
support stream response HamaWhiteGG#60 HamaWhiteGG#104
Browse files Browse the repository at this point in the history
  • Loading branch information
HamaWhiteGG committed Sep 21, 2023
1 parent fae00d6 commit 0d6c47d
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 33 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ The following example can view in the [langchain-example](langchain-examples/src

## 3. Integrations
### 3.1 LLMs
- [OpenAI](langchain-examples/src/main/java/com/hw/langchain/examples/llms/OpenAIExample.java)
- [OpenAI](langchain-examples/src/main/java/com/hw/langchain/examples/llms/OpenAIExample.java), (support [stream](langchain-examples/src/main/java/com/hw/langchain/examples/llms/StreamOpenAIExample.java))
- [Azure OpenAI](openai-client/src/test/java/com/hw/openai/AzureOpenAiClientTest.java)
- [ChatGLM2](langchain-examples/src/main/java/com/hw/langchain/examples/llms/ChatGLMExample.java)
- [Ollama](langchain-examples/src/main/java/com/hw/langchain/examples/llms/OllamaExample.java)
Expand Down
10 changes: 10 additions & 0 deletions langchain-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,16 @@
<artifactId>reactor-core</artifactId>
</dependency>

<dependency>
<groupId>io.projectreactor.addons</groupId>
<artifactId>reactor-adapter</artifactId>
</dependency>

<dependency>
<groupId>io.reactivex.rxjava2</groupId>
<artifactId>rxjava</artifactId>
</dependency>

<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,14 @@
import com.hw.openai.common.OpenaiApiType;
import com.hw.openai.entity.completions.Choice;
import com.hw.openai.entity.completions.Completion;
import com.hw.openai.entity.completions.CompletionChunk;
import com.hw.openai.entity.completions.CompletionResp;

import io.reactivex.Flowable;
import lombok.Builder;
import lombok.experimental.SuperBuilder;
import okhttp3.Interceptor;
import reactor.adapter.rxjava.RxJava2Adapter;
import reactor.core.publisher.Flux;

import java.util.*;
Expand Down Expand Up @@ -195,17 +198,7 @@ public String llmType() {
protected LLMResult innerGenerate(List<String> prompts, List<String> stop) {
List<Choice> choices = new ArrayList<>();
List<List<String>> subPrompts = getSubPrompts(prompts);
Completion completion = Completion.builder()
.model(model)
.temperature(temperature)
.maxTokens(maxTokens)
.topP(topP)
.frequencyPenalty(frequencyPenalty)
.presencePenalty(presencePenalty)
.n(n)
.logitBias(logitBias)
.stop(stop)
.build();
Completion completion = buildCompletion(stop);

for (var prompt : subPrompts) {
completion.setPrompt(prompt);
Expand All @@ -214,18 +207,42 @@ protected LLMResult innerGenerate(List<String> prompts, List<String> stop) {
choices.addAll(response.getChoices());
}

return createLLMResult(choices, prompts, Map.of());
return createLlmResult(choices, prompts, Map.of());
}

@Override
protected Flux<AsyncLLMResult> asyncInnerGenerate(List<String> prompts, List<String> stop) {
throw new UnsupportedOperationException("not supported yet.");
Completion completion = buildCompletion(stop);

completion.setPrompt(prompts);
completion.setStream(true);
Flowable<CompletionChunk> response = retryWithExponentialBackoff(maxRetries,
() -> client.streamCompletion(completion));

return RxJava2Adapter.flowableToFlux(response).map(e -> {
Generation generation = new Generation(e.getChoices().get(0).getText());
return new AsyncLLMResult(List.of(generation), null);
});
}

private Completion buildCompletion(List<String> stop) {
return Completion.builder()
.model(model)
.temperature(temperature)
.maxTokens(maxTokens)
.topP(topP)
.frequencyPenalty(frequencyPenalty)
.presencePenalty(presencePenalty)
.n(n)
.logitBias(logitBias)
.stop(stop)
.build();
}

/**
* Create the LLMResult from the choices and prompts.
*/
private LLMResult createLLMResult(List<Choice> choices, List<String> prompts, Map<String, Integer> tokenUsage) {
private LLMResult createLlmResult(List<Choice> choices, List<String> prompts, Map<String, Integer> tokenUsage) {
List<List<Generation>> generations = new ArrayList<>();
for (int i = 0; i < prompts.size(); i++) {
List<Choice> subChoices = choices.subList(i * n, (i + 1) * n);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import reactor.core.publisher.Flux;

import java.sql.SQLException;
import java.util.List;
import java.util.Map;

import static org.junit.jupiter.api.Assertions.*;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;

/**
* LLMChainTest
Expand Down Expand Up @@ -74,6 +77,19 @@ void testLLMChainWithOneInputVariables() {
assertEquals(expected, actual);
}

@Test
void testAsyncRun() {
PromptTemplate prompt = new PromptTemplate(List.of("product"),
"What is a good name for a company that makes {product}?");

Chain chain = new LLMChain(llm, prompt);
Flux<String> actual = chain.asyncRun("colorful socks");

List<String> resultList = actual.collectList().block();
assertThat(resultList).isNotNull();
assertThat(String.join("", resultList)).isEqualTo("\n\nSocktastic!");
}

@Test
void testLLMChainWithMultipleInputVariables() {
PromptTemplate prompt = new PromptTemplate(List.of("company", "product"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,45 @@
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import reactor.core.publisher.Flux;

import java.util.List;

import static org.assertj.core.api.Assertions.assertThat;

/**
* OpenAITest
*
* @author HamaWhite
*/
@Disabled("Test requires costly OpenAI calls, can be run manually.")
class OpenAITest {

/**
* Test valid call to openai
*/
@Test
void testOpenAICall() {
OpenAI llm = OpenAI.builder()
void testPredict() {
var llm = OpenAI.builder()
.temperature(0)
.maxTokens(16)
.requestTimeout(16)
.build()
.init();

assertThat(llm.call("Say foo:")).isEqualTo("\n\nFoo!");
var result = llm.predict("Say this is a test");
assertThat(result).isEqualTo("\n\nThis is indeed a test.");
}

@Test
void testAsyncPredict() {
var llm = OpenAI.builder()
.temperature(0)
.maxTokens(1600)
.requestTimeout(120)
.build()
.init();

Flux<String> result = llm.asyncPredict("Say this is a test");

List<String> resultList = result.collectList().block();
assertThat(resultList).isEqualTo(List.of("\n", "\n", "This", " is", " indeed", " a", " test", ".", ""));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public static void main(String[] args) {
.build()
.init();

var result = llm.predict("What would be a good company name for a company that makes colorful socks?");
var result = llm.predict("Introduce West Lake in Hangzhou, China.");
println(result);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.examples.llms;

import com.hw.langchain.examples.runner.RunnableExample;
import com.hw.langchain.llms.openai.OpenAI;

/**
* @author HamaWhite
*/
@RunnableExample
public class StreamOpenAIExample {

public static void main(String[] args) {

var llm = OpenAI.builder()
.maxTokens(1000)
.temperature(0)
.requestTimeout(120)
.build()
.init();

var result = llm.asyncPredict("Introduce West Lake in Hangzhou, China.");
result.doOnNext(System.out::print).blockLast();
}
}
16 changes: 8 additions & 8 deletions openai-client/src/test/java/com/hw/openai/OpenAiClientTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,15 @@ void testStreamCompletion() {
.build();

// Call client.streamCompletion(completion) and verify the results
List<String> outputList = client.streamCompletion(completion)
List<String> resultList = client.streamCompletion(completion)
.doOnError(Throwable::printStackTrace)
.map(e -> e.getChoices().get(0).getText())
.toList()
.blockingGet();

assertThat(outputList).isNotNull();
assertThat(outputList).isNotEmpty();
assertThat(outputList).isEqualTo(List.of("\n\n", "This", " is", " a", " test", ".", ""));
assertThat(resultList).isNotNull();
assertThat(resultList).isNotEmpty();
assertThat(resultList).isEqualTo(List.of("\n\n", "This", " is", " a", " test", ".", ""));
}

@Test
Expand All @@ -136,7 +136,7 @@ void testStreamChatCompletion() {
.stream(true)
.build();

List<String> outputList = client.streamChatCompletion(chatCompletion)
List<String> resultList = client.streamChatCompletion(chatCompletion)
.doOnError(Throwable::printStackTrace)
.map(e -> {
String content = e.getChoices().get(0).getMessage().getContent();
Expand All @@ -145,9 +145,9 @@ void testStreamChatCompletion() {
.toList()
.blockingGet();

assertThat(outputList).isNotNull();
assertThat(outputList).isNotEmpty();
assertThat(outputList)
assertThat(resultList).isNotNull();
assertThat(resultList).isNotEmpty();
assertThat(resultList)
.isEqualTo(List.of("", "Hello", "!", " How", " can", " I", " assist", " you", " today", "?", ""));
}

Expand Down
18 changes: 16 additions & 2 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
<milvus.version>2.2.9</milvus.version>
<jython.version>2.7.3</jython.version>
<jsoup.version>1.16.1</jsoup.version>
<rxjava.version>2.2.21</rxjava.version>
<nd4j.version>1.0.0-M2.1</nd4j.version>
<lombok.version>1.18.28</lombok.version>
<jtokkit.version>0.5.0</jtokkit.version>
Expand All @@ -47,7 +48,8 @@
<hibernate-validator.version>8.0.0.Final</hibernate-validator.version>
<redisson.version>3.17.3</redisson.version>
<netty-resolver.version>4.1.43.Final</netty-resolver.version>
<reactor.version>3.5.8</reactor.version>
<reactor-core.version>3.5.8</reactor-core.version>
<reactor-adapter.version>3.5.1</reactor-adapter.version>
<mockito-core.version>3.12.4</mockito-core.version>

<scala.binary.version>2.12</scala.binary.version>
Expand Down Expand Up @@ -265,7 +267,19 @@
<dependency>
<groupId>io.projectreactor</groupId>
<artifactId>reactor-core</artifactId>
<version>${reactor.version}</version>
<version>${reactor-core.version}</version>
</dependency>

<dependency>
<groupId>io.projectreactor.addons</groupId>
<artifactId>reactor-adapter</artifactId>
<version>${reactor-adapter.version}</version>
</dependency>

<dependency>
<groupId>io.reactivex.rxjava2</groupId>
<artifactId>rxjava</artifactId>
<version>${rxjava.version}</version>
</dependency>

<dependency>
Expand Down

0 comments on commit 0d6c47d

Please sign in to comment.