Skip to content

Commit

Permalink
optimize code
Browse files Browse the repository at this point in the history
  • Loading branch information
HamaWhiteGG committed Sep 11, 2023
1 parent bbc4302 commit 0c1fc73
Show file tree
Hide file tree
Showing 17 changed files with 46 additions and 112 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ logPath_IS_UNDEFINED
target

# other ignore
.java-version
*.log
*.tmp
Thumbs.db
Expand Down
1 change: 0 additions & 1 deletion .java-version

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -64,37 +64,37 @@ default BaseMessage predictMessages(List<BaseMessage> messages) {
BaseMessage predictMessages(List<BaseMessage> messages, List<String> stop);

/**
* Take in a list of prompt values and return an Flux&lt;AsyncLLMResult&gt; for every PromptValue.
* Take in a list of prompt values and return an Flux<AsyncLLMResult> for every PromptValue.
*/
default List<Flux<AsyncLLMResult>> asyncGeneratePrompt(List<PromptValue> prompts) {
return asyncGeneratePrompt(prompts, null);
}

/**
* Take in a list of prompt values and return an Flux&lt;AsyncLLMResult&gt; for every PromptValue.
* Take in a list of prompt values and return an Flux<AsyncLLMResult> for every PromptValue.
*/
default List<Flux<AsyncLLMResult>> asyncGeneratePrompt(List<PromptValue> prompts, List<String> stop) {
throw new UnsupportedOperationException("not supported yet.");
};
}

/**
* Predict text from text async.
*/
default Flux<String> apredict(String text) {
return apredict(text, null);
default Flux<String> asyncPredict(String text) {
return asyncPredict(text, null);
}

/**
* Predict text from text async.
*/
default Flux<String> apredict(String text, List<String> stop) {
default Flux<String> asyncPredict(String text, List<String> stop) {
throw new UnsupportedOperationException("not supported yet.");
}

/**
* Predict message from messages async.
*/
default Flux<BaseMessage> apredictMessages(List<BaseMessage> messages, List<String> stop) {
default Flux<BaseMessage> asyncPredictMessages(List<BaseMessage> messages, List<String> stop) {
throw new UnsupportedOperationException("not supported yet.");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ private void validateOutputs(Map<String, String> outputs) {
* @param inputs the inputs to be processed by the chain
* @return a map flux containing the output generated event by the chain
*/
protected Flux<Map<String, String>> ainnerCall(Map<String, Object> inputs) {
protected Flux<Map<String, String>> asyncInnerCall(Map<String, Object> inputs) {
throw new UnsupportedOperationException("Not supported yet.");
}

Expand Down Expand Up @@ -118,15 +118,15 @@ public Map<String, String> call(Map<String, Object> inputs, boolean returnOnlyOu
return prepOutputs(inputs, outputs, returnOnlyOutputs);
}

public Flux<Map<String, String>> acall(String input, boolean returnOnlyOutputs) {
public Flux<Map<String, String>> asyncCall(Object input, boolean returnOnlyOutputs) {
Map<String, Object> inputs = prepInputs(input);
return acall(inputs, returnOnlyOutputs);
return asyncCall(inputs, returnOnlyOutputs);
}

public Flux<Map<String, String>> acall(Map<String, Object> inputs, boolean returnOnlyOutputs) {
public Flux<Map<String, String>> asyncCall(Map<String, Object> inputs, boolean returnOnlyOutputs) {
inputs = prepInputs(inputs);
Flux<Map<String, String>> outputs = ainnerCall(inputs);
return prepaOutputs(inputs, outputs, returnOnlyOutputs);
Flux<Map<String, String>> outputs = asyncInnerCall(inputs);
return asyncPrepOutputs(inputs, outputs, returnOnlyOutputs);
}

/**
Expand All @@ -151,7 +151,7 @@ private Map<String, String> prepOutputs(Map<String, Object> inputs, Map<String,
/**
* Validate and async prep outputs.
*/
private Flux<Map<String, String>> prepaOutputs(Map<String, Object> inputs, Flux<Map<String, String>> outputs,
private Flux<Map<String, String>> asyncPrepOutputs(Map<String, Object> inputs, Flux<Map<String, String>> outputs,
boolean returnOnlyOutputs) {
Map<String, String> collector = Maps.newHashMap();
return outputs.doOnNext(this::validateOutputs)
Expand Down Expand Up @@ -240,25 +240,25 @@ public String run(Map<String, Object> args) {
/**
* Run the chain as text in, text out async
*/
public Flux<String> arun(String args) {
public Flux<String> asyncRun(Object args) {
if (outputKeys().size() != 1) {
throw new IllegalArgumentException(
"The `run` method is not supported when there is not exactly one output key. Got " + outputKeys()
+ ".");
}
return acall(args, false).map(m -> m.get(outputKeys().get(0)));
return asyncCall(args, false).map(m -> m.get(outputKeys().get(0)));
}

/**
* Run the chain as multiple variables, text out async.
*/
public Flux<String> arun(Map<String, Object> args) {
public Flux<String> asyncRun(Map<String, Object> args) {
if (outputKeys().size() != 1) {
throw new IllegalArgumentException(
"The `run` method is not supported when there is not exactly one output key. Got " + outputKeys()
+ ".");
}
return acall(args, false).map(m -> m.get(outputKeys().get(0)));
return asyncCall(args, false).map(m -> m.get(outputKeys().get(0)));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,11 @@ public Optional<Integer> promptLength(List<Document> docs, Map<String, Object> k
* Combine documents into a single string.
*/
public abstract Pair<String, Map<String, String>> combineDocs(List<Document> docs, Map<String, Object> kwargs);

/**
* Combine documents into a single string async.
*/
public abstract Flux<Pair<String, Map<String, String>>> acombineDocs(List<Document> docs,
public abstract Flux<Pair<String, Map<String, String>>> asyncCombineDocs(List<Document> docs,
Map<String, Object> kwargs);

@Override
Expand All @@ -84,12 +85,12 @@ protected Map<String, String> innerCall(Map<String, Object> inputs) {
}

@Override
protected Flux<Map<String, String>> ainnerCall(Map<String, Object> inputs) {
protected Flux<Map<String, String>> asyncInnerCall(Map<String, Object> inputs) {
@SuppressWarnings("unchecked")
var docs = (List<Document>) inputs.get(inputKey);

Map<String, Object> otherKeys = Maps.filterKeys(inputs, key -> !key.equals(inputKey));
var result = this.acombineDocs(docs, otherKeys);
var result = this.asyncCombineDocs(docs, otherKeys);

return result.map(pair -> {
var extraReturnDict = new HashMap<>(pair.getRight());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@ public Pair<String, Map<String, String>> combineDocs(List<Document> docs, Map<St
}

@Override
public Flux<Pair<String, Map<String, String>>> acombineDocs(List<Document> docs, Map<String, Object> kwargs) {
public Flux<Pair<String, Map<String, String>>> asyncCombineDocs(List<Document> docs, Map<String, Object> kwargs) {
var inputs = getInputs(docs, kwargs);
return llmChain.apredict(inputs).map(s -> Pair.of(s, Map.of()));
return llmChain.asyncPredict(inputs).map(s -> Pair.of(s, Map.of()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import lombok.Getter;
import reactor.core.publisher.Flux;

import java.util.ArrayList;
Expand All @@ -47,6 +48,7 @@ public class LLMChain extends Chain {
/**
* Prompt object to use.
*/
@Getter
protected BasePromptTemplate prompt;

protected String outputKey = "text";
Expand Down Expand Up @@ -102,8 +104,8 @@ protected Map<String, String> innerCall(Map<String, Object> inputs) {
}

@Override
protected Flux<Map<String, String>> ainnerCall(Map<String, Object> inputs) {
var response = agenerate(List.of(inputs));
protected Flux<Map<String, String>> asyncInnerCall(Map<String, Object> inputs) {
var response = asyncGenerate(List.of(inputs));
return response.get(0).map(this::createAsyncOutputs);
}

Expand All @@ -119,7 +121,7 @@ private LLMResult generate(List<Map<String, Object>> inputList) {
/**
* Generate LLM result from inputs async.
*/
private List<Flux<AsyncLLMResult>> agenerate(List<Map<String, Object>> inputList) {
private List<Flux<AsyncLLMResult>> asyncGenerate(List<Map<String, Object>> inputList) {
List<String> stop = prepStop(inputList);
List<PromptValue> prompts = prepPrompts(inputList);
return llm.asyncGeneratePrompt(prompts, stop);
Expand Down Expand Up @@ -197,14 +199,15 @@ public String predict(Map<String, Object> kwargs) {
* @param kwargs Keys to pass to prompt template.
* @return Completion from LLM.
*/
public Flux<String> apredict(Map<String, Object> kwargs) {
var flux = acall(kwargs, false);
public Flux<String> asyncPredict(Map<String, Object> kwargs) {
var flux = asyncCall(kwargs, false);
return flux.map(m -> m.get(outputKey));
}

/**
* Call predict and then parse the results.
*/
@SuppressWarnings("all")
public <T> T predictAndParse(Map<String, Object> kwargs) {
String result = predict(kwargs);
if (prompt.getOutputParser() != null) {
Expand All @@ -213,8 +216,4 @@ public <T> T predictAndParse(Map<String, Object> kwargs) {
return (T) result;
}

public BasePromptTemplate getPrompt() {
return prompt;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ public class LLMMathChain extends Chain {

private LLMChain llmChain;

private String inputKey = "question";
private final String inputKey = "question";

private String outputKey = "answer";
private final String outputKey = "answer";

public LLMMathChain() {
super();
Expand Down Expand Up @@ -99,7 +99,7 @@ public String evaluateExpression(String expression) {
var localDict = Map.of("pi", Math.PI, "e", Math.E);
// Set local variables in the interpreter
localDict.forEach(interpreter::set);
// Evaluate the expression using jython
// Evaluate the expression
result = interpreter.eval(expression.strip());
}
// Convert the result to a string
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,15 @@ protected Map<String, String> innerCall(Map<String, Object> inputs) {
}

@Override
protected Flux<Map<String, String>> ainnerCall(Map<String, Object> inputs) {
protected Flux<Map<String, String>> asyncInnerCall(Map<String, Object> inputs) {
var question = inputs.get(inputKey).toString();

List<Document> docs = getDocs(question);
inputs.put("input_documents", docs);
if (!inputs.containsKey("question")) {
inputs.put("question", question);
}
Flux<String> answer = combineDocumentsChain.arun(inputs);
Flux<String> answer = combineDocumentsChain.asyncRun(inputs);
return answer.map(s -> Map.of(outputKey, s));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ protected LLMResult innerGenerate(List<String> prompts, List<String> stop) {
}

@Override
protected Flux<AsyncLLMResult> _agenerate(List<String> prompts, List<String> stop) {
protected Flux<AsyncLLMResult> asyncInnerGenerate(List<String> prompts, List<String> stop) {
throw new UnsupportedOperationException("not supported yet.");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public abstract class BaseLLM implements BaseLanguageModel {
/**
* Run the LLM on the given prompts async.
*/
protected abstract Flux<AsyncLLMResult> _agenerate(List<String> prompts, List<String> stop);
protected abstract Flux<AsyncLLMResult> asyncInnerGenerate(List<String> prompts, List<String> stop);

/**
* Check Cache and run the LLM on the given prompt and input.
Expand Down Expand Up @@ -81,7 +81,7 @@ public List<Flux<AsyncLLMResult>> asyncGeneratePrompt(List<PromptValue> prompts,
List<String> promptStrings = prompts.stream()
.map(PromptValue::toString)
.toList();
return promptStrings.stream().map(s -> _agenerate(List.of(s), stop)).toList();
return promptStrings.stream().map(s -> asyncInnerGenerate(List.of(s), stop)).toList();
}

@Override
Expand All @@ -90,8 +90,8 @@ public String predict(String text, List<String> stop) {
}

@Override
public Flux<String> apredict(String text, List<String> stop) {
return _agenerate(List.of(text), stop).map(result -> result.getGenerations().get(0).getText());
public Flux<String> asyncPredict(String text, List<String> stop) {
return asyncInnerGenerate(List.of(text), stop).map(result -> result.getGenerations().get(0).getText());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ protected LLMResult innerGenerate(List<String> prompts, List<String> stop) {
}

@Override
protected Flux<AsyncLLMResult> _agenerate(List<String> prompts, List<String> stop) {
protected Flux<AsyncLLMResult> asyncInnerGenerate(List<String> prompts, List<String> stop) {
throw new UnsupportedOperationException("not supported yet.");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ protected LLMResult innerGenerate(List<String> prompts, List<String> stop) {
}

@Override
protected Flux<AsyncLLMResult> _agenerate(List<String> prompts, List<String> stop) {
protected Flux<AsyncLLMResult> asyncInnerGenerate(List<String> prompts, List<String> stop) {
throw new UnsupportedOperationException("not supported yet.");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ protected LLMResult innerGenerate(List<String> prompts, List<String> stop) {
}

@Override
protected Flux<AsyncLLMResult> _agenerate(List<String> prompts, List<String> stop) {
protected Flux<AsyncLLMResult> asyncInnerGenerate(List<String> prompts, List<String> stop) {
throw new UnsupportedOperationException("not supported yet.");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
public class AsyncLLMResult {

/**
* List of the things generated. This is List<List<Generation>> because each input could have multiple generations.
* List of the things generated.
*/
private List<? extends Generation> generations;

Expand Down
36 changes: 0 additions & 36 deletions langchain-server/pom.xml

This file was deleted.

Loading

0 comments on commit 0c1fc73

Please sign in to comment.