forked from HamaWhiteGG/langchain-java
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8cc4fb9
commit 6c00408
Showing
8 changed files
with
314 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
57 changes: 57 additions & 0 deletions
57
langchain-core/src/main/java/com/hw/langchain/chat/models/base/LLM.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.chat.models.base; | ||
|
||
import com.hw.langchain.llms.base.BaseLLM; | ||
import com.hw.langchain.schema.Generation; | ||
import com.hw.langchain.schema.LLMResult; | ||
|
||
import lombok.experimental.SuperBuilder; | ||
|
||
import java.util.List; | ||
|
||
/** | ||
* Base LLM abstract class. | ||
* The purpose of this class is to expose a simpler interface for working | ||
* with LLMs, rather than expect the user to implement the full innerGenerate method. | ||
* | ||
* @author HamaWhite | ||
*/ | ||
@SuperBuilder | ||
public abstract class LLM extends BaseLLM { | ||
|
||
/** | ||
* Run the LLM on the given prompt and input. | ||
* | ||
* @param prompt The prompt to pass into the model. | ||
* @param stop list of stop words to use when generating. | ||
* @return The string generated by the model. | ||
*/ | ||
public abstract String innerCall(String prompt, List<String> stop); | ||
|
||
@Override | ||
protected LLMResult innerGenerate(List<String> prompts, List<String> stop) { | ||
List<List<Generation>> generations = prompts.stream().map(prompt -> { | ||
String text = innerCall(prompt, stop); | ||
return List.of(new Generation(text)); | ||
}).toList(); | ||
|
||
return new LLMResult(generations); | ||
} | ||
} |
44 changes: 44 additions & 0 deletions
44
langchain-core/src/main/java/com/hw/langchain/llms/Utils.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.llms; | ||
|
||
import java.util.List; | ||
|
||
/** | ||
* @author HamaWhite | ||
*/ | ||
public class Utils { | ||
|
||
private Utils() { | ||
// private constructor to hide the implicit public one | ||
throw new IllegalStateException("Utility class"); | ||
} | ||
|
||
/** | ||
* Cuts off the text as soon as any stop words occur. | ||
* | ||
* @param text The input text to be processed. | ||
* @param stop List of stop words to identify cut-off points. | ||
* @return The processed text after enforcing stop tokens. | ||
*/ | ||
public static String enforceStopTokens(String text, List<String> stop) { | ||
String[] parts = text.split(String.join("|", stop)); | ||
return parts[0]; | ||
} | ||
} |
122 changes: 122 additions & 0 deletions
122
langchain-core/src/main/java/com/hw/langchain/llms/chatglm/ChatGLM.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.llms.chatglm; | ||
|
||
import com.fasterxml.jackson.core.type.TypeReference; | ||
import com.hw.langchain.chains.query.constructor.JsonUtils; | ||
import com.hw.langchain.chat.models.base.LLM; | ||
import com.hw.langchain.requests.TextRequestsWrapper; | ||
|
||
import org.apache.commons.collections4.CollectionUtils; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
||
import lombok.Builder; | ||
import lombok.experimental.SuperBuilder; | ||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import static com.hw.langchain.llms.Utils.enforceStopTokens; | ||
|
||
/** | ||
* ChatGLM LLM service. | ||
* | ||
* @author HamaWhite | ||
*/ | ||
@SuperBuilder | ||
public class ChatGLM extends LLM { | ||
|
||
private static final Logger LOG = LoggerFactory.getLogger(ChatGLM.class); | ||
|
||
/** | ||
* Endpoint URL to use. | ||
*/ | ||
@Builder.Default | ||
private String endpointUrl = "http://127.0.0.1:8000/"; | ||
|
||
/** | ||
* Max token allowed to pass to the model. | ||
*/ | ||
@Builder.Default | ||
private int maxToken = 20000; | ||
|
||
/** | ||
* LLM model temperature from 0 to 10. | ||
*/ | ||
@Builder.Default | ||
private float temperature = 0.1f; | ||
|
||
/** | ||
* History of the conversation | ||
*/ | ||
@Builder.Default | ||
private List<List<?>> history = new ArrayList<>(); | ||
|
||
/** | ||
* Top P for nucleus sampling from 0 to 1 | ||
*/ | ||
@Builder.Default | ||
private float topP = 0.7f; | ||
|
||
/** | ||
* Whether to use history or not | ||
*/ | ||
private boolean withHistory; | ||
|
||
private TextRequestsWrapper requestsWrapper; | ||
|
||
public ChatGLM init() { | ||
Map<String, String> headers = Map.of("Content-Type", "application/json"); | ||
this.requestsWrapper = new TextRequestsWrapper(headers); | ||
return this; | ||
} | ||
|
||
@Override | ||
public String llmType() { | ||
return "chat_glm"; | ||
} | ||
|
||
@Override | ||
public String innerCall(String prompt, List<String> stop) { | ||
Map<String, Object> payload = Map.of( | ||
"prompt", prompt, | ||
"temperature", temperature, | ||
"history", history, | ||
"max_length", maxToken, | ||
"top_p", topP); | ||
|
||
LOG.debug("ChatGLM payload: {}", payload); | ||
String response = requestsWrapper.post(endpointUrl, payload); | ||
LOG.debug("ChatGLM response: {}", response); | ||
|
||
Map<String, Object> parsedResponse = JsonUtils.convertFromJsonStr(response, new TypeReference<>() { | ||
}); | ||
String text = parsedResponse.get("response").toString(); | ||
|
||
if (CollectionUtils.isNotEmpty(stop)) { | ||
text = enforceStopTokens(text, stop); | ||
} | ||
if (withHistory) { | ||
history.add(List.of(text)); | ||
} | ||
return text; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
43 changes: 43 additions & 0 deletions
43
langchain-core/src/test/java/com/hw/langchain/llms/chatglm/ChatGLMTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.llms.chatglm; | ||
|
||
import org.junit.jupiter.api.Disabled; | ||
import org.junit.jupiter.api.Test; | ||
|
||
import static org.junit.jupiter.api.Assertions.assertEquals; | ||
|
||
/** | ||
* <a href="https://github.com/THUDM/ChatGLM2-6B/blob/main/api.py">ChatGLM2-6B API reference</a> | ||
* | ||
* @author HamaWhite | ||
*/ | ||
@Disabled("Test requires ChatGLM environment, can be run manually.") | ||
class ChatGLMTest { | ||
|
||
@Test | ||
void testChatGLM() { | ||
var llm = ChatGLM.builder().build().init(); | ||
var actual = llm.predict("Say foo:"); | ||
|
||
var expected = "Foo."; | ||
assertEquals(expected, actual); | ||
} | ||
|
||
} |
42 changes: 42 additions & 0 deletions
42
langchain-examples/src/main/java/com/hw/langchain/examples/llms/ChatGLMExample.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.examples.llms; | ||
|
||
import com.hw.langchain.llms.chatglm.ChatGLM; | ||
|
||
import static com.hw.langchain.examples.utils.PrintUtils.println; | ||
|
||
/** | ||
* @author HamaWhite | ||
*/ | ||
public class ChatGLMExample { | ||
|
||
public static void main(String[] args) { | ||
var llm = ChatGLM.builder() | ||
.endpointUrl("http://127.0.0.1:8000/") | ||
.temperature(0.95f) | ||
.build() | ||
.init(); | ||
|
||
var result = llm.predict("What is the capital of China?"); | ||
|
||
// The capital of China is Beijing. | ||
println(result); | ||
} | ||
} |