Skip to content

Commit

Permalink
Support ChatGLM HamaWhiteGG#22
Browse files Browse the repository at this point in the history
  • Loading branch information
HamaWhiteGG committed Aug 22, 2023
1 parent 8cc4fb9 commit 6c00408
Show file tree
Hide file tree
Showing 8 changed files with 314 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ The following example can view in the [langchain-example](langchain-examples/src
## 3. Integrations
### 3.1 LLMs
- [OpenAI](langchain-examples/src/main/java/com/hw/langchain/examples/llms/OpenAIExample.java)
- [ChatGPT](langchain-examples/src/main/java/com/hw/langchain/examples/llms/ChatGLMExample.java)
- [Ollama](langchain-examples/src/main/java/com/hw/langchain/examples/llms/OllamaExample.java)

### 3.2 Vector stores
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.chat.models.base;

import com.hw.langchain.llms.base.BaseLLM;
import com.hw.langchain.schema.Generation;
import com.hw.langchain.schema.LLMResult;

import lombok.experimental.SuperBuilder;

import java.util.List;

/**
* Base LLM abstract class.
* The purpose of this class is to expose a simpler interface for working
* with LLMs, rather than expect the user to implement the full innerGenerate method.
*
* @author HamaWhite
*/
@SuperBuilder
public abstract class LLM extends BaseLLM {

/**
* Run the LLM on the given prompt and input.
*
* @param prompt The prompt to pass into the model.
* @param stop list of stop words to use when generating.
* @return The string generated by the model.
*/
public abstract String innerCall(String prompt, List<String> stop);

@Override
protected LLMResult innerGenerate(List<String> prompts, List<String> stop) {
List<List<Generation>> generations = prompts.stream().map(prompt -> {
String text = innerCall(prompt, stop);
return List.of(new Generation(text));
}).toList();

return new LLMResult(generations);
}
}
44 changes: 44 additions & 0 deletions langchain-core/src/main/java/com/hw/langchain/llms/Utils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.llms;

import java.util.List;

/**
* @author HamaWhite
*/
public class Utils {

private Utils() {
// private constructor to hide the implicit public one
throw new IllegalStateException("Utility class");
}

/**
* Cuts off the text as soon as any stop words occur.
*
* @param text The input text to be processed.
* @param stop List of stop words to identify cut-off points.
* @return The processed text after enforcing stop tokens.
*/
public static String enforceStopTokens(String text, List<String> stop) {
String[] parts = text.split(String.join("|", stop));
return parts[0];
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.llms.chatglm;

import com.fasterxml.jackson.core.type.TypeReference;
import com.hw.langchain.chains.query.constructor.JsonUtils;
import com.hw.langchain.chat.models.base.LLM;
import com.hw.langchain.requests.TextRequestsWrapper;

import org.apache.commons.collections4.CollectionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import lombok.Builder;
import lombok.experimental.SuperBuilder;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import static com.hw.langchain.llms.Utils.enforceStopTokens;

/**
* ChatGLM LLM service.
*
* @author HamaWhite
*/
@SuperBuilder
public class ChatGLM extends LLM {

private static final Logger LOG = LoggerFactory.getLogger(ChatGLM.class);

/**
* Endpoint URL to use.
*/
@Builder.Default
private String endpointUrl = "http://127.0.0.1:8000/";

/**
* Max token allowed to pass to the model.
*/
@Builder.Default
private int maxToken = 20000;

/**
* LLM model temperature from 0 to 10.
*/
@Builder.Default
private float temperature = 0.1f;

/**
* History of the conversation
*/
@Builder.Default
private List<List<?>> history = new ArrayList<>();

/**
* Top P for nucleus sampling from 0 to 1
*/
@Builder.Default
private float topP = 0.7f;

/**
* Whether to use history or not
*/
private boolean withHistory;

private TextRequestsWrapper requestsWrapper;

public ChatGLM init() {
Map<String, String> headers = Map.of("Content-Type", "application/json");
this.requestsWrapper = new TextRequestsWrapper(headers);
return this;
}

@Override
public String llmType() {
return "chat_glm";
}

@Override
public String innerCall(String prompt, List<String> stop) {
Map<String, Object> payload = Map.of(
"prompt", prompt,
"temperature", temperature,
"history", history,
"max_length", maxToken,
"top_p", topP);

LOG.debug("ChatGLM payload: {}", payload);
String response = requestsWrapper.post(endpointUrl, payload);
LOG.debug("ChatGLM response: {}", response);

Map<String, Object> parsedResponse = JsonUtils.convertFromJsonStr(response, new TypeReference<>() {
});
String text = parsedResponse.get("response").toString();

if (CollectionUtils.isNotEmpty(stop)) {
text = enforceStopTokens(text, stop);
}
if (withHistory) {
history.add(List.of(text));
}
return text;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ protected LLMResult innerGenerate(List<String> prompts, List<String> stop) {
}
generations.add(List.of(requireNonNull(finalChunk)));
}
return new LLMResult(generations, null);
return new LLMResult(generations);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ public class LLMResult {
*/
private Map<String, Object> llmOutput;

public LLMResult(List<? extends List<? extends Generation>> generations) {
this.generations = generations;
}

public LLMResult(List<? extends List<? extends Generation>> generations, Map<String, Object> llmOutput) {
this.generations = generations;
this.llmOutput = llmOutput;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.llms.chatglm;

import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertEquals;

/**
* <a href="https://github.com/THUDM/ChatGLM2-6B/blob/main/api.py">ChatGLM2-6B API reference</a>
*
* @author HamaWhite
*/
@Disabled("Test requires ChatGLM environment, can be run manually.")
class ChatGLMTest {

@Test
void testChatGLM() {
var llm = ChatGLM.builder().build().init();
var actual = llm.predict("Say foo:");

var expected = "Foo.";
assertEquals(expected, actual);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.examples.llms;

import com.hw.langchain.llms.chatglm.ChatGLM;

import static com.hw.langchain.examples.utils.PrintUtils.println;

/**
* @author HamaWhite
*/
public class ChatGLMExample {

public static void main(String[] args) {
var llm = ChatGLM.builder()
.endpointUrl("http://127.0.0.1:8000/")
.temperature(0.95f)
.build()
.init();

var result = llm.predict("What is the capital of China?");

// The capital of China is Beijing.
println(result);
}
}

0 comments on commit 6c00408

Please sign in to comment.