Skip to content

Commit

Permalink
add StructuredOutputParser
Browse files Browse the repository at this point in the history
  • Loading branch information
HamaWhiteGG committed Jul 1, 2023
1 parent 8e4e0be commit d81c3af
Show file tree
Hide file tree
Showing 5 changed files with 298 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.output.parsers;

/**
* @author HamaWhite
*/
public class FormatInstructions {

public static String STRUCTURED_FORMAT_INSTRUCTIONS =
"""
The output should be a markdown code snippet formatted in the following schema, including the leading and trailing "```json" and "```":
```json
{{
{format}
}}
```""";

public static String PYDANTIC_FORMAT_INSTRUCTIONS =
"""
The output should be formatted as a JSON instance that conforms to the JSON schema below.
As an example, for the schema {{"properties": {{"foo": {{"title": "Foo", "description": "a list of strings", "type": "array", "items": {{"type": "string"}}}}}}, "required": ["foo"]}}}}
the object {{"foo": ["bar", "baz"]}} is a well-formatted instance of the schema. The object {{"properties": {{"foo": ["bar", "baz"]}}}} is not well-formatted.
Here is the output schema:
```
{schema}
```""";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.output.parsers.json;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.hw.langchain.schema.OutputParserException;

import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
* @author HamaWhite
*/
public class Json {

private Json() {

}

private static final Pattern PATTERN = Pattern.compile("```(json)?(.*?)```", Pattern.DOTALL);

/**
* Parse a JSON string from a Markdown string.
*
* @param jsonString The Markdown string.
* @return The parsed JSON object as a Python dictionary.
*/
public static JsonNode parseJsonMarkdown(String jsonString) {
// Try to find JSON string within triple backticks
Matcher matcher = PATTERN.matcher(jsonString);

// If match found, use the content within the backticks,otherwise assume the entire string is a JSON string
String jsonStr = matcher.find() ? matcher.group(2) : jsonString;

// Strip whitespace and newlines from the start and end
jsonStr = jsonStr.strip();
try {
// Parse the JSON string into a JsonNode
return new ObjectMapper().readTree(jsonStr);
} catch (JsonProcessingException e) {
throw new OutputParserException("Got invalid JSON object. Error: " + e.getMessage());
}
}

/**
* Parse a JSON string from a Markdown string and check that it contains the expected keys.
*
* @param markdown The Markdown string.
* @param expectedKeys The expected keys in the JSON string.
* @return The parsed JSON object as a JsonNode.
*/
public static JsonNode parseAndCheckJsonMarkdown(String markdown, List<String> expectedKeys) {
JsonNode jsonNode = parseJsonMarkdown(markdown);
for (String key : expectedKeys) {
if (!jsonNode.has(key)) {
throw new OutputParserException(String.format(
"Got invalid return object. Expected key `%s` to be present, but got %s", key, jsonNode));
}
}
return jsonNode;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.output.parsers.structured;

import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;

/**
* @author HamaWhite
*/
@Data
@NoArgsConstructor
@AllArgsConstructor
public class ResponseSchema {

private String name;

private String description;

private String type = "string";

public ResponseSchema(String name, String description) {
this.name = name;
this.description = description;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.output.parsers.structured;

import com.fasterxml.jackson.databind.JsonNode;
import com.hw.langchain.schema.BaseOutputParser;

import java.util.List;

import static com.hw.langchain.output.parsers.FormatInstructions.STRUCTURED_FORMAT_INSTRUCTIONS;
import static com.hw.langchain.output.parsers.json.Json.parseAndCheckJsonMarkdown;

/**
* @author HamaWhite
*/
public class StructuredOutputParser extends BaseOutputParser<JsonNode> {

private static final String LINE_TEMPLATE = "\t\"%s\": %s // %s";

private final List<ResponseSchema> responseSchemas;

public StructuredOutputParser(List<ResponseSchema> responseSchemas) {
this.responseSchemas = responseSchemas;
}

public static StructuredOutputParser fromResponseSchemas(List<ResponseSchema> responseSchemas) {
return new StructuredOutputParser(responseSchemas);
}

private String getSubString(ResponseSchema schema) {
return String.format(LINE_TEMPLATE, schema.getName(), schema.getType(), schema.getDescription());
}

@Override
public JsonNode parse(String text) {
var expectedKeys = responseSchemas.stream()
.map(ResponseSchema::getName)
.toList();
return parseAndCheckJsonMarkdown(text, expectedKeys);
}

@Override
public String getFormatInstructions() {
var schemaStr = String.join("\n", responseSchemas.stream().map(this::getSubString).toList());
return String.format(STRUCTURED_FORMAT_INSTRUCTIONS, schemaStr);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.output.parsers.structured;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.hw.langchain.schema.OutputParserException;

import org.junit.jupiter.api.Test;

import java.util.List;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

/**
* @author HamaWhite
*/
class StructuredOutputParserTest {

@Test
void testParse() {
var responseSchemas = List.of(
new ResponseSchema("name", "desc"),
new ResponseSchema("age", "desc"));
var parser = StructuredOutputParser.fromResponseSchemas(responseSchemas);

var text = "```json\n{\"name\": \"John\", \"age\": 30}\n```";
var result = parser.parse(text);

var expectedResult = new ObjectMapper().createObjectNode()
.put("name", "John")
.put("age", 30);

assertEquals(expectedResult, result);
}

@Test
void testInvalidJsonInput() {
var responseSchemas = List.of(
new ResponseSchema("name", "desc"),
new ResponseSchema("age", "desc"));
var parser = StructuredOutputParser.fromResponseSchemas(responseSchemas);

var text = "```json\n{\"name\": \"John\"}\n```";
assertThrows(OutputParserException.class, () -> parser.parse(text));
}

}

0 comments on commit d81c3af

Please sign in to comment.