Skip to content

Commit

Permalink
Merge pull request HamaWhiteGG#114 from HamaWhiteGG/dev
Browse files Browse the repository at this point in the history
optimize code
  • Loading branch information
HamaWhiteGG authored Oct 5, 2023
2 parents 2a9c5ee + e6cbcd0 commit 3ab54a0
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ void testRunQuery() {
.init();

var llm = ChatOpenAI.builder()
.model("gpt-4")
.temperature(0)
.build().init();

Expand All @@ -105,7 +106,7 @@ void testRunQuery() {
// SELECT SQRT(AVG(Age)) FROM titanic
var actual = agentExecutor.run("whats the square root of the average age?");
// sometimes it's 'The square root of the average age is approximately 5.07.'
assertEquals("5.07", actual);
assertEquals("The square root of the average age is approximately 5.07.", actual);

// TODO: It should be DESC here, not ASC.
// SELECT Name FROM titanic WHERE Survived = 1 ORDER BY Age ASC LIMIT 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,12 @@ static void setup() {
void testDescribeTable() {
var actual = agentExecutor.run("Describe the titanic table");

var expected = """
The titanic table has the following columns: PassengerId (INT), Survived (INT), Pclass (INT), Name (STRING), Sex (STRING), Age (DOUBLE), SibSp (INT), Parch (INT), Ticket (STRING), Fare (DOUBLE), Cabin (STRING), Embarked (STRING). Here are some sample rows from the table:
1. PassengerId: 1, Survived: 0, Pclass: 3, Name: Braund, Mr. Owen Harris, Sex: male, Age: 22.0, SibSp: 1, Parch: 0, Ticket: A/5 21171, Fare: 7.25, Cabin: null, Embarked: S
2. PassengerId: 2, Survived: 1, Pclass: 1, Name: Cumings, Mrs. John Bradley (Florence Briggs Thayer), Sex: female, Age: 38.0, SibSp: 1, Parch: 0, Ticket: PC 17599, Fare: 71.2833, Cabin: C85, Embarked: C
3. PassengerId: 3, Survived: 1, Pclass: 3, Name: Heikkinen, Miss. Laina, Sex: female, Age: 26.0, SibSp: 0, Parch: 0, Ticket: STON/O2. 3101282, Fare: 7.925, Cabin: null, Embarked: S""";
var expected =
"""
The titanic table has the following columns: PassengerId (INT), Survived (INT), Pclass (INT), Name (STRING), Sex (STRING), Age (DOUBLE), SibSp (INT), Parch (INT), Ticket (STRING), Fare (DOUBLE), Cabin (STRING), Embarked (STRING). Here are some sample rows from the table:
1. PassengerId: 1, Survived: 0, Pclass: 3, Name: Braund, Mr. Owen Harris, Sex: male, Age: 22.0, SibSp: 1, Parch: 0, Ticket: A/5 21171, Fare: 7.25, Cabin: null, Embarked: S
2. PassengerId: 2, Survived: 1, Pclass: 1, Name: Cumings, Mrs. John Bradley (Florence Briggs Thayer), Sex: female, Age: 38.0, SibSp: 1, Parch: 0, Ticket: PC 17599, Fare: 71.2833, Cabin: C85, Embarked: C
3. PassengerId: 3, Survived: 1, Pclass: 3, Name: Heikkinen, Miss. Laina, Sex: female, Age: 26.0, SibSp: 0, Parch: 0, Ticket: STON/O2. 3101282, Fare: 7.925, Cabin: null, Embarked: S""";
assertEquals(expected, actual);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ private static BasePromptTemplate getPrompt(String documentContents, List<Attrib
Map.of("i", examples.size() + 1, "content", documentContents, "attributes", attributeStr));

var outputParser = StructuredQueryOutputParser.fromComponents(allowedComparators, allowedOperators);
return new FewShotPromptTemplate(examples, EXAMPLE_PROMPT, prefix, suffix, List.of("query"), outputParser);
return new FewShotPromptTemplate(examples, EXAMPLE_PROMPT, prefix, suffix, List.of("query"), "\n\n",
outputParser);
}

private static Map<String, Object> createTemplateArguments(List<Comparator> allowedComparators,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,26 @@ public class FewShotPromptTemplate extends StringPromptTemplate {
/**
* String separator used to join the prefix, the examples, and suffix.
*/
private String exampleSeparator = "\n\n";
private String exampleSeparator;

public FewShotPromptTemplate(List<Map<String, Object>> examples, PromptTemplate examplePrompt, String prefix,
String suffix, List<String> inputVariables, BaseOutputParser<?> outputParser) {
String suffix, List<String> inputVariables) {
this(examples, examplePrompt, prefix, suffix, inputVariables, "\n\n");
}

public FewShotPromptTemplate(List<Map<String, Object>> examples, PromptTemplate examplePrompt, String prefix,
String suffix, List<String> inputVariables, String exampleSeparator) {
this(examples, examplePrompt, prefix, suffix, inputVariables, exampleSeparator, null);
}

public FewShotPromptTemplate(List<Map<String, Object>> examples, PromptTemplate examplePrompt, String prefix,
String suffix, List<String> inputVariables, String exampleSeparator, BaseOutputParser<?> outputParser) {
super(inputVariables, outputParser);
this.examples = examples;
this.examplePrompt = examplePrompt;
this.prefix = prefix;
this.suffix = suffix;
this.exampleSeparator = exampleSeparator;
}

public List<Map<String, Object>> getExamples() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ public List<Map<String, String>> results(String query, int num) {
}).toList();
}

public static Builder builder() {
return new Builder();
}

public static final class Builder {

/**
Expand Down Expand Up @@ -163,7 +167,7 @@ public static final class Builder {
*/
private int readTimeout = 20 * 1000;

public Builder() {
private Builder() {
}

public Builder googleApiKey(String googleApiKey) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class GoogleSearchAPIWrapperTest {
void testGoogleSearch() {
String query = "2022 USA national auto sales by brand";

GoogleSearchAPIWrapper searchWrapper = new GoogleSearchAPIWrapper.Builder()
GoogleSearchAPIWrapper searchWrapper = GoogleSearchAPIWrapper.builder()
.connectTimeout(10 * 1000)
.build();

Expand Down

0 comments on commit 3ab54a0

Please sign in to comment.