Skip to content

Commit

Permalink
Merge branch 'HamaWhiteGG:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
Jashinck authored Sep 12, 2023
2 parents 83708c6 + 85f2d23 commit 81a6d8e
Show file tree
Hide file tree
Showing 39 changed files with 794 additions and 123 deletions.
16 changes: 9 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# 🦜️ LangChain.Java

Java version of LangChain, bringing the capabilities of LLM to big data platforms like Flink and Spark.
Java version of LangChain, while empowering LLM for Big Data.

> If you are interested, you can add me on WeChat: HamaWhite, or send email to baisongxx@gmail.com
> If you are interested, you can add me on WeChat: HamaWhite, or send email to [me](mailto:baisongxx@gmail.com).
## 1. What is this?

Expand All @@ -19,17 +19,19 @@ The following example can view in the [langchain-example](langchain-examples/src

- [SQL Chains](langchain-examples/src/main/java/com/hw/langchain/examples/chains/SqlChainExample.java)
- [API Chains](langchain-examples/src/main/java/com/hw/langchain/examples/chains/ApiChainExample.java)
- [QA-Milvus](langchain-examples/src/main/java/com/hw/langchain/examples/chains/MilvusExample.java)
- [QA-Pinecone](langchain-examples/src/main/java/com/hw/langchain/examples/chains/RetrievalQaExample.java)
- [QA-Milvus-Text](langchain-examples/src/main/java/com/hw/langchain/examples/chains/MilvusExample.java)
- [QA-Pinecone-Text](langchain-examples/src/main/java/com/hw/langchain/examples/chains/RetrievalQaExample.java)
- [QA-Pinecone-Markdown](langchain-examples/src/main/java/com/hw/langchain/examples/chains/RetrievalMarkdownExample.java)
- [Summarization](langchain-examples/src/main/java/com/hw/langchain/examples/chains/SummarizationExample.java)
- [Agent with Google Search](langchain-examples/src/main/java/com/hw/langchain/examples/agents/LlmAgentExample.java)
- [Spark SQL AI](langchain-bigdata/langchain-spark/src/test/java/com/hw/langchain/agents/toolkits/spark/sql/toolkit/SparkSqlToolkitTest.java)
- [Flink SQL AI](langchain-bigdata/langchain-flink/src/test/java/com/hw/langchain/agents/toolkits/flink/sql/toolkit/FlinkSqlToolkitTest.java)

## 3. Integrations
### 3.1 LLMs
- [OpenAI](langchain-examples/src/main/java/com/hw/langchain/examples/llms/OpenAIExample.java)
- [ChatGLM2-6B](langchain-examples/src/main/java/com/hw/langchain/examples/llms/ChatGLMExample.java)
- [Azure OpenAI](openai-client/src/test/java/com/hw/openai/AzureOpenAiClientTest.java)
- [ChatGLM2](langchain-examples/src/main/java/com/hw/langchain/examples/llms/ChatGLMExample.java)
- [Ollama](langchain-examples/src/main/java/com/hw/langchain/examples/llms/OllamaExample.java)

### 3.2 Vector stores
Expand All @@ -51,7 +53,7 @@ Prerequisites for building:
<dependency>
<groupId>io.github.hamawhitegg</groupId>
<artifactId>langchain-core</artifactId>
<version>0.1.11</version>
<version>0.1.12</version>
</dependency>
```

Expand Down Expand Up @@ -379,5 +381,5 @@ Don’t hesitate to ask!
## 8. Fork and Contribute
This is an active open-source project. We are always open to people who want to use the system or contribute to it. Please note that pull requests should be merged into the **dev** branch.
Contact [me](baisongxx@gmail.com) if you are looking for implementation tasks that fit your skills.
Contact [me](mailto:baisongxx@gmail.com) if you are looking for implementation tasks that fit your skills.
2 changes: 1 addition & 1 deletion langchain-bigdata/langchain-flink/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<parent>
<groupId>io.github.hamawhitegg</groupId>
<artifactId>langchain-bigdata</artifactId>
<version>0.1.11</version>
<version>0.1.12</version>
</parent>

<artifactId>langchain-flink</artifactId>
Expand Down
2 changes: 1 addition & 1 deletion langchain-bigdata/langchain-spark/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<parent>
<groupId>io.github.hamawhitegg</groupId>
<artifactId>langchain-bigdata</artifactId>
<version>0.1.11</version>
<version>0.1.12</version>
</parent>

<artifactId>langchain-spark</artifactId>
Expand Down
2 changes: 1 addition & 1 deletion langchain-bigdata/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<parent>
<groupId>io.github.hamawhitegg</groupId>
<artifactId>langchain-java</artifactId>
<version>0.1.11</version>
<version>0.1.12</version>
</parent>

<artifactId>langchain-bigdata</artifactId>
Expand Down
12 changes: 11 additions & 1 deletion langchain-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<parent>
<groupId>io.github.hamawhitegg</groupId>
<artifactId>langchain-java</artifactId>
<version>0.1.11</version>
<version>0.1.12</version>
</parent>

<artifactId>langchain-core</artifactId>
Expand Down Expand Up @@ -75,6 +75,16 @@
<artifactId>jtokkit</artifactId>
</dependency>

<dependency>
<groupId>org.jsoup</groupId>
<artifactId>jsoup</artifactId>
</dependency>

<dependency>
<groupId>io.github.resilience4j</groupId>
<artifactId>resilience4j-retry</artifactId>
</dependency>

<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ private void validateOutputs(Map<String, String> outputs) {
* If False, both input keys and new keys generated by this chain will be returned.
* Defaults to False.
*/
public Map<String, String> call(String input, boolean returnOnlyOutputs) {
public Map<String, String> call(Object input, boolean returnOnlyOutputs) {
Map<String, Object> inputs = prepInputs(input);
return call(inputs, returnOnlyOutputs);
}
Expand Down Expand Up @@ -126,7 +126,7 @@ private Map<String, String> prepOutputs(Map<String, Object> inputs, Map<String,
/**
* Validate and prep inputs.
*/
private Map<String, Object> prepInputs(String input) {
private Map<String, Object> prepInputs(Object input) {
Set<String> inputKeys = new HashSet<>(inputKeys());
if (memory != null) {
// If there are multiple input keys, but some get set by memory so that only one is not set,
Expand Down Expand Up @@ -162,7 +162,7 @@ public Map<String, Object> prepInputs(Map<String, Object> inputs) {
/**
* Run the chain as text in, text out
*/
public String run(String args) {
public String run(Object args) {
if (outputKeys().size() != 1) {
throw new IllegalArgumentException(
"The `run` method is not supported when there is not exactly one output key. Got " + outputKeys()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@

package com.hw.langchain.chains.summarize;

import com.hw.langchain.base.language.BaseLanguageModel;
import com.hw.langchain.chains.combine.documents.stuff.StuffDocumentsChain;
import com.hw.langchain.chains.combine.documents.stuff.StuffUtils;
import com.hw.langchain.chains.llm.LLMChain;
import com.hw.langchain.prompts.base.BasePromptTemplate;

/**
* @author HamaWhite
*/
Expand All @@ -26,4 +32,15 @@ public class SummarizeUtils {
private SummarizeUtils() {
throw new IllegalStateException("Utility class");
}

public static StuffDocumentsChain loadStuffChain(BaseLanguageModel llm) {
return loadStuffChain(llm, StuffPrompt.PROMPT, "text", "\n\n");
}

public static StuffDocumentsChain loadStuffChain(BaseLanguageModel llm, BasePromptTemplate prompt,
String documentVariableName, String documentSeparator) {
LLMChain llmChain = new LLMChain(llm, prompt);
return new StuffDocumentsChain(llmChain, StuffUtils.getDefaultDocumentPrompt(), documentVariableName,
documentSeparator);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public LLMResult generate(List<List<BaseMessage>> messages) {
*/
public LLMResult generate(List<List<BaseMessage>> messages, List<String> stop) {
List<ChatResult> results = messages.stream()
.map(message -> _generate(message, stop))
.map(message -> innerGenerate(message, stop))
.toList();

List<Map<String, Object>> llmOutputs = results.stream()
Expand All @@ -75,7 +75,7 @@ public LLMResult generatePrompt(List<PromptValue> prompts, List<String> stop) {
/**
* Top Level call
*/
public abstract ChatResult _generate(List<BaseMessage> messages, List<String> stop);
public abstract ChatResult innerGenerate(List<BaseMessage> messages, List<String> stop);

public BaseMessage call(List<BaseMessage> messages) {
return call(messages, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import java.util.Objects;

import static com.hw.langchain.chat.models.openai.OpenAI.convertOpenAiToLangChain;
import static com.hw.langchain.utils.Resilience4jRetryUtils.retryWithExponentialBackoff;
import static com.hw.langchain.utils.Utils.getOrEnvOrDefault;

/**
Expand Down Expand Up @@ -75,6 +76,10 @@ public class ChatOpenAI extends BaseChatModel {

protected String openaiApiBase;

protected String openaiApiType;

protected String openaiApiVersion;

protected String openaiOrganization;

/**
Expand Down Expand Up @@ -123,10 +128,14 @@ public ChatOpenAI init() {
openaiOrganization = getOrEnvOrDefault(openaiOrganization, "OPENAI_ORGANIZATION", "");
openaiApiBase = getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE", "");
openaiProxy = getOrEnvOrDefault(openaiProxy, "OPENAI_PROXY", "");
openaiApiType = getOrEnvOrDefault(openaiApiType, "OPENAI_API_TYPE", "");
openaiApiVersion = getOrEnvOrDefault(openaiApiVersion, "OPENAI_API_VERSION", "");

this.client = OpenAiClient.builder()
.openaiApiBase(openaiApiBase)
.openaiApiKey(openaiApiKey)
.openaiApiVersion(openaiApiVersion)
.openaiApiType(openaiApiType)
.openaiOrganization(openaiOrganization)
.openaiProxy(openaiProxy)
.requestTimeout(requestTimeout)
Expand Down Expand Up @@ -158,7 +167,7 @@ public Map<String, Object> combineLlmOutputs(List<Map<String, Object>> llmOutput
}

@Override
public ChatResult _generate(List<BaseMessage> messages, List<String> stop) {
public ChatResult innerGenerate(List<BaseMessage> messages, List<String> stop) {
var chatMessages = convertMessages(messages);

ChatCompletion chatCompletion = ChatCompletion.builder()
Expand All @@ -171,7 +180,7 @@ public ChatResult _generate(List<BaseMessage> messages, List<String> stop) {
.stop(stop)
.build();

var response = client.create(chatCompletion);
var response = retryWithExponentialBackoff(maxRetries, () -> client.create(chatCompletion));
return createChatResult(response);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.document.loaders;

import com.google.common.collect.Maps;
import com.hw.langchain.document.loaders.base.BaseLoader;
import com.hw.langchain.exception.LangChainException;
import com.hw.langchain.schema.Document;

import org.jsoup.Jsoup;
import org.jsoup.nodes.Element;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

/**
* @author HamaWhite
*/
public class WebBaseLoader extends BaseLoader {

private final List<String> webUrls;

public WebBaseLoader(List<String> webUrls) {
this.webUrls = webUrls;
}

@Override
public List<Document> load() {
List<Document> documents = new ArrayList<>(webUrls.size());
for (String url : webUrls) {
try {
org.jsoup.nodes.Document doc = Jsoup.connect(url).get();
Map<String, Object> metadata = buildMetadata(doc, url);

documents.add(new Document(doc.wholeText(), metadata));
} catch (IOException e) {
throw new LangChainException(errorMessage(url), e);
}
}
return documents;
}

private Map<String, Object> buildMetadata(org.jsoup.nodes.Document doc, String url) {
Map<String, Object> metadata = Maps.newHashMap();
metadata.put("source", url);

Element title = doc.select("title").first();
if (title != null) {
metadata.put("title", title.text());
}
Element description = doc.select("meta[name=description]").first();
metadata.put("description", description != null ? description.attr("content") : "No description found.");

Element html = doc.select("html").first();
metadata.put("language", html != null ? html.attr("lang") : "No language found.");
return metadata;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ public class OpenAIEmbeddings implements Embeddings {

private String openaiApiKey;

private String openaiApiType;

private String openaiApiVersion;

protected String openaiOrganization;

/**
Expand Down Expand Up @@ -96,10 +100,14 @@ public OpenAIEmbeddings init() {
openaiApiBase = getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE", "");
openaiProxy = getOrEnvOrDefault(openaiProxy, "OPENAI_PROXY", "");
openaiOrganization = getOrEnvOrDefault(openaiOrganization, "OPENAI_ORGANIZATION", "");
openaiApiType = getOrEnvOrDefault(openaiApiType, "OPENAI_API_TYPE", "");
openaiApiVersion = getOrEnvOrDefault(openaiApiVersion, "OPENAI_API_VERSION", "");

this.client = OpenAiClient.builder()
.openaiApiBase(openaiApiBase)
.openaiApiKey(openaiApiKey)
.openaiApiVersion(openaiApiVersion)
.openaiApiType(openaiApiType)
.openaiOrganization(openaiOrganization)
.openaiProxy(openaiProxy)
.requestTimeout(requestTimeout)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import java.util.*;

import static com.google.common.base.Preconditions.checkArgument;
import static com.hw.langchain.utils.Resilience4jRetryUtils.retryWithExponentialBackoff;

/**
* Wrapper around OpenAI large language models.
Expand Down Expand Up @@ -101,6 +102,16 @@ public class BaseOpenAI extends BaseLLM {
*/
protected String openaiApiBase;

/**
* Api type for Azure OpenAI API.
*/
protected String openaiApiType;

/**
* Api version for Azure OpenAI API.
*/
protected String openaiApiVersion;

/**
* Organization ID for OpenAI.
*/
Expand Down Expand Up @@ -194,7 +205,7 @@ protected LLMResult innerGenerate(List<String> prompts, List<String> stop) {

for (var prompt : subPrompts) {
completion.setPrompt(prompt);
CompletionResp response = client.create(completion);
CompletionResp response = retryWithExponentialBackoff(maxRetries, () -> client.create(completion));
choices.addAll(response.getChoices());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,14 @@ public OpenAI init() {
openaiApiBase = Utils.getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE", "");
openaiOrganization = Utils.getOrEnvOrDefault(openaiOrganization, "OPENAI_ORGANIZATION", "");
openaiProxy = Utils.getOrEnvOrDefault(openaiProxy, "OPENAI_PROXY", "");
openaiApiType = Utils.getOrEnvOrDefault(openaiApiType, "OPENAI_API_TYPE", "");
openaiApiVersion = Utils.getOrEnvOrDefault(openaiApiVersion, "OPENAI_API_VERSION", "");

this.client = OpenAiClient.builder()
.openaiApiBase(openaiApiBase)
.openaiApiKey(openaiApiKey)
.openaiApiVersion(openaiApiVersion)
.openaiApiType(openaiApiType)
.openaiOrganization(openaiOrganization)
.openaiProxy(openaiProxy)
.proxyUsername(proxyUsername)
Expand Down
Loading

0 comments on commit 81a6d8e

Please sign in to comment.