Skip to content

Use LLMs to generate valid JSON that conforms to a JSONSchema definition

Notifications You must be signed in to change notification settings

jerome3o/schemallm

Repository files navigation

SchemaLLM

This is an attempt at making an LLM completion API that allows for the submission of a JSONSchema (and eventually other schema languages) along side your prompt, and reliably returns a result that conforms to the schema.

See the accompanying blog post for more context :)

Overview

This repo contains:

  • Functions for guiding LLM completion to comply with:
    • JSONSchema
    • Context free grammars
    • Regular expressions
  • A server that exposes these functions over HTTP
  • A python (requests based) client that can be used to interact with the server
  • A LangChain client implementing the LLM object
  • Two vendored dependencies with some modifications:
    • ReLLM in schemallm/vendored/rellm
    • ParserLLM in schemallm/vendored/parserllm

Installation

TODO: bundle this up for pypi

Usage

These are pretty early days, so the API is likely to change a lot. The following is a rough overview of how it currently stands.

All model and tokenizer objects are from the huggingface transformers library, and can be loaded here using:

from schemallm.server.load_model import load_model, load_tokenizer

model = load_model("gpt2")
tokenizer = load_tokenizer("gpt2")

Completion functions

Regex

complete_with_regex is a function that takes a prompt and a regular expression, and returns a completion that conforms to that regex

Leverages ReLLM, but with some slight modifications, details in my blog post.

from schemallm.models.api import RegexCompletionRequest
from schemallm.server.server import complete_with_regex


result = complete_with_regex(
    model=model,
    tokenizer=tokenizer,
    completion_request=RegexCompletionRequest(
        prompt="An integer between 4 and 8: ",
        regex="[4-8]",
    ),
)

CFG

complete_with_cfg is a function that takes a prompt and a context free grammar, and returns a completion that conforms to the grammar

Leverages ParserLLM, but with some slight modifications, details in my blog post.

from schemallm.models.api import CfgCompletionRequest
from schemallm.server.server import complete_with_cfg

JSON_CFG = """
JSON := OBJECT | ARRAY
OBJECT := '{' PAIRLIST? '}'
PAIRLIST := PAIR (',' PAIR)*
PAIR := STRING ':' VALUE
VALUE := STRING | NUMBER | OBJECT | ARRAY | 'true' | 'false' | 'null'
ARRAY := '[' VALUELIST? ']'
VALUELIST := VALUE(',' VALUE)*
STRING := '"' [a-z, A-Z, 0-9]* '"'
NUMBER := [0-9]+
"""

result = complete_with_cfg(
    model=model,
    tokenizer=tokenizer,
    completion_request=CfgCompletionRequest(
        prompt="A JSON object with a key called 'foo' and a string value: ",
        cfg=JSON_CFG,
        max_tokens=300,
    ),
)

JSONSchema

complete_with_schema is a function that takes a prompt and a JSONSchema, and returns a completion that conforms to the schema.

We leverage pydantic to create the schema for us.

from pydantic import BaseModel
from schemallm.models.api import SchemaCompletionRequest
from schemallm.server.server import complete_with_schema


class PersonalDetails(BaseModel):
    name: str
    age: int
    location: str


schema = PersonalDetails.schema()

result = complete_with_schema(
    model=model,
    tokenizer=tokenizer,
    completion_request=SchemaCompletionRequest(
        prompt="A JSON object with a name, age and location: ",
        schema=schema_restriction,
        max_tokens=300,
    ),
)

Server

You can run the inference server with:

python -m schemallm.server.server

You will need some environment variables set:

export MODEL_PATH=gpt-2

This can also point to other huggingface paths.

Now navigate to http://localhost:8000/docs to see the openapi docs for the server.

HTTP Client

from schemallm.client import SchemaLlmClient
from pydantic import BaseModel

client = SchemaLlmClient()

class PersonalDetails(BaseModel):
    name: str
    age: int
    location: str

print(
    client.completion_with_pydantic(
        "Tell me about yourself, in JSON format!:\n",
        model=PersonalDetails,
    )
)

LangChain Client

Here is a quick example of how to use the JsonSchemaLLM LangChain client, but note there are also CfgLLM, ReLLM, and StandardLLM langchain implementations.

from schemallm.client.langchain_client import JsonSchemaLLM
from schemallm.models.jsonschema import parse_json_schema
from pydantic import BaseModel


# json schema
class PersonalDetails(BaseModel):
    name: str
    location: str


llm = JsonSchemaLLM(schema_restriction=parse_json_schema(PersonalDetails.schema()))
result = llm(
    "Tell me about yourself, in JSON format!:\n",
)
print(result)

Prior art:

About

Use LLMs to generate valid JSON that conforms to a JSONSchema definition

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published