Skip to content

Commit

Permalink
rebase dev
Browse files Browse the repository at this point in the history
  • Loading branch information
lingjue@ubuntu committed Sep 11, 2023
1 parent 6fcf0f2 commit cc46f51
Show file tree
Hide file tree
Showing 15 changed files with 142 additions and 3 deletions.
1 change: 1 addition & 0 deletions .java-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
17
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.hw.langchain.schema.BaseMessage;
import com.hw.langchain.schema.LLMResult;
import com.hw.langchain.schema.PromptValue;

import reactor.core.publisher.Flux;

import java.util.List;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@

import com.google.common.collect.Maps;
import com.hw.langchain.schema.BaseMemory;

import org.apache.commons.lang3.StringUtils;

import reactor.core.publisher.Flux;

import java.util.*;
Expand Down Expand Up @@ -150,7 +152,7 @@ private Map<String, String> prepOutputs(Map<String, Object> inputs, Map<String,
* Validate and async prep outputs.
*/
private Flux<Map<String, String>> prepaOutputs(Map<String, Object> inputs, Flux<Map<String, String>> outputs,
boolean returnOnlyOutputs) {
boolean returnOnlyOutputs) {
Map<String, String> collector = Maps.newHashMap();
return outputs.doOnNext(this::validateOutputs)
.doOnNext(m -> m.forEach((k, v) -> collector.compute(k, (s, old) -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.hw.langchain.schema.Document;

import org.apache.commons.lang3.tuple.Pair;

import reactor.core.publisher.Flux;

import java.util.HashMap;
Expand Down Expand Up @@ -66,7 +67,8 @@ public Optional<Integer> promptLength(List<Document> docs, Map<String, Object> k
/**
* Combine documents into a single string async.
*/
public abstract Flux<Pair<String, Map<String, String>>> acombineDocs(List<Document> docs, Map<String, Object> kwargs);
public abstract Flux<Pair<String, Map<String, String>>> acombineDocs(List<Document> docs,
Map<String, Object> kwargs);

@Override
protected Map<String, String> innerCall(Map<String, Object> inputs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.hw.langchain.schema.Document;

import org.apache.commons.lang3.tuple.Pair;

import reactor.core.publisher.Flux;

import java.util.List;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import reactor.core.publisher.Flux;

import java.util.ArrayList;
Expand Down Expand Up @@ -171,7 +172,8 @@ private List<Map<String, String>> createOutputs(LLMResult llmResult) {
* Create outputs from response async.
*/
private Map<String, String> createAsyncOutputs(AsyncLLMResult llmResult) {
Map<String, String> result = Map.of(outputKey, outputParser.parseResult(llmResult.getGenerations()), "full_generation", llmResult.getGenerations().toString());
Map<String, String> result = Map.of(outputKey, outputParser.parseResult(llmResult.getGenerations()),
"full_generation", llmResult.getGenerations().toString());
if (returnFinalOnly) {
result = Map.of(outputKey, result.get(outputKey));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.hw.langchain.chains.combine.documents.base.BaseCombineDocumentsChain;
import com.hw.langchain.chains.query.constructor.JsonUtils;
import com.hw.langchain.schema.Document;

import reactor.core.publisher.Flux;

import java.util.List;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
package com.hw.langchain.chat.models.base;

import com.hw.langchain.llms.base.BaseLLM;
import com.hw.langchain.schema.AsyncLLMResult;
import com.hw.langchain.schema.Generation;
import com.hw.langchain.schema.LLMResult;

import lombok.experimental.SuperBuilder;
import reactor.core.publisher.Flux;

import java.util.List;

Expand Down Expand Up @@ -54,4 +56,9 @@ protected LLMResult innerGenerate(List<String> prompts, List<String> stop) {

return new LLMResult(generations);
}

@Override
protected Flux<AsyncLLMResult> _agenerate(List<String> prompts, List<String> stop) {
throw new UnsupportedOperationException("not supported yet.");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
package com.hw.langchain.llms.base;

import com.hw.langchain.base.language.BaseLanguageModel;
import com.hw.langchain.schema.AsyncLLMResult;
import com.hw.langchain.schema.BaseMessage;
import com.hw.langchain.schema.LLMResult;
import com.hw.langchain.schema.PromptValue;

import lombok.experimental.SuperBuilder;
import reactor.core.publisher.Flux;

import java.util.List;

Expand All @@ -43,6 +45,10 @@ public abstract class BaseLLM implements BaseLanguageModel {
* Run the LLM on the given prompts.
*/
protected abstract LLMResult innerGenerate(List<String> prompts, List<String> stop);
/**
* Run the LLM on the given prompts async.
*/
protected abstract Flux<AsyncLLMResult> _agenerate(List<String> prompts, List<String> stop);

/**
* Check Cache and run the LLM on the given prompt and input.
Expand Down Expand Up @@ -70,11 +76,24 @@ public LLMResult generatePrompt(List<PromptValue> prompts, List<String> stop) {
return generate(promptStrings, stop);
}

@Override
public List<Flux<AsyncLLMResult>> asyncGeneratePrompt(List<PromptValue> prompts, List<String> stop) {
List<String> promptStrings = prompts.stream()
.map(PromptValue::toString)
.toList();
return promptStrings.stream().map(s -> _agenerate(List.of(s), stop)).toList();
}

@Override
public String predict(String text, List<String> stop) {
return call(text, stop);
}

@Override
public Flux<String> apredict(String text, List<String> stop) {
return _agenerate(List.of(text), stop).map(result -> result.getGenerations().get(0).getText());
}

@Override
public BaseMessage predictMessages(List<BaseMessage> messages, List<String> stop) {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@
import com.hw.langchain.chains.query.constructor.JsonUtils;
import com.hw.langchain.llms.base.BaseLLM;
import com.hw.langchain.requests.TextRequestsWrapper;
import com.hw.langchain.schema.AsyncLLMResult;
import com.hw.langchain.schema.GenerationChunk;
import com.hw.langchain.schema.LLMResult;

import org.apache.commons.lang3.StringUtils;

import lombok.Builder;
import lombok.experimental.SuperBuilder;
import reactor.core.publisher.Flux;

import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -206,6 +208,11 @@ protected LLMResult innerGenerate(List<String> prompts, List<String> stop) {
return new LLMResult(generations);
}

@Override
protected Flux<AsyncLLMResult> _agenerate(List<String> prompts, List<String> stop) {
throw new UnsupportedOperationException("not supported yet.");
}

/**
* Convert a stream response to a generation chunk.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package com.hw.langchain.llms.openai;

import com.hw.langchain.llms.base.BaseLLM;
import com.hw.langchain.schema.AsyncLLMResult;
import com.hw.langchain.schema.Generation;
import com.hw.langchain.schema.LLMResult;
import com.hw.openai.OpenAiClient;
Expand All @@ -29,6 +30,7 @@
import lombok.Builder;
import lombok.experimental.SuperBuilder;
import okhttp3.Interceptor;
import reactor.core.publisher.Flux;

import java.util.*;

Expand Down Expand Up @@ -212,6 +214,11 @@ protected LLMResult innerGenerate(List<String> prompts, List<String> stop) {
return createLLMResult(choices, prompts, Map.of());
}

@Override
protected Flux<AsyncLLMResult> _agenerate(List<String> prompts, List<String> stop) {
throw new UnsupportedOperationException("not supported yet.");
}

/**
* Create the LLMResult from the choices and prompts.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package com.hw.langchain.llms.openai;

import com.hw.langchain.llms.base.BaseLLM;
import com.hw.langchain.schema.AsyncLLMResult;
import com.hw.langchain.schema.Generation;
import com.hw.langchain.schema.LLMResult;
import com.hw.langchain.utils.Utils;
Expand All @@ -29,6 +30,7 @@

import lombok.Builder;
import lombok.experimental.SuperBuilder;
import reactor.core.publisher.Flux;

import java.util.*;

Expand Down Expand Up @@ -201,6 +203,11 @@ protected LLMResult innerGenerate(List<String> prompts, List<String> stop) {
return new LLMResult(generations, llmOutput);
}

@Override
protected Flux<AsyncLLMResult> _agenerate(List<String> prompts, List<String> stop) {
throw new UnsupportedOperationException("not supported yet.");
}

private List<Message> getChatMessages(List<String> prompts) {
checkArgument(prompts.size() == 1, "OpenAIChat currently only supports single prompt, got %s", prompts);
List<Message> messages = new ArrayList<>(prefixMessages);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.schema;

import lombok.Data;
Expand Down
36 changes: 36 additions & 0 deletions langchain-server/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>io.github.hamawhitegg</groupId>
<artifactId>langchain-java</artifactId>
<version>0.1.11-SNAPSHOT</version>
</parent>

<artifactId>langchain-server</artifactId>

<properties>
<maven.deploy.skip>true</maven.deploy.skip>
</properties>

<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
</plugin>
<plugin>
<groupId>com.diffplug.spotless</groupId>
<artifactId>spotless-maven-plugin</artifactId>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-deploy-plugin</artifactId>
<configuration>
<skip>true</skip>
</configuration>
</plugin>
</plugins>
</build>
</project>
28 changes: 28 additions & 0 deletions langchain-web/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>io.github.hamawhitegg</groupId>
<artifactId>langchain-java</artifactId>
<version>0.1.11-SNAPSHOT</version>
</parent>

<artifactId>langchain-web</artifactId>

<properties>
<maven.deploy.skip>true</maven.deploy.skip>
</properties>

<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-deploy-plugin</artifactId>
<configuration>
<skip>true</skip>
</configuration>
</plugin>
</plugins>
</build>
</project>

0 comments on commit cc46f51

Please sign in to comment.