Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Grammar] Integrate with XGrammar #635

Merged
merged 10 commits into from
Nov 22, 2024
Prev Previous commit
Update to new XGrammar pre-release APIs
  • Loading branch information
CharlieFRuan committed Nov 22, 2024
commit 2fed3e8b7b0cf1a6a8523e41bba733de55ec98c6
42 changes: 22 additions & 20 deletions src/llm_chat.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/* eslint-disable @typescript-eslint/no-non-null-assertion */
/* eslint-disable no-prototype-builtins */
import * as tvmjs from "@mlc-ai/web-runtime";
import * as xgrammar from "@mlc-ai/web-xgrammar";
import * as xgr from "@mlc-ai/web-xgrammar";
import log from "loglevel";
import { Tokenizer } from "@mlc-ai/web-tokenizers";
import { ChatConfig, GenerationConfig, Role } from "./config";
Expand Down Expand Up @@ -104,14 +104,17 @@ export class LLMChatPipeline {
// Grammar-related
// A grammar matcher for this current round if response_format is set. Reinitialized upon
// each step regardless of whether the chat is multi-round or not.
private grammarMatcher?: xgrammar.GrammarMatcher = undefined;
private grammarMatcher?: xgr.GrammarMatcher = undefined;
// The current schema or grammar string used for grammarMatcher; if undefined, grammarMatcher is
// simply using JSON mode. We use this field to determine whether we re-initiate a GrammarMatcher
// or simply reset the state during each round (i.e. during prefillStep).
private schemaOrGrammarStr?: string = undefined;
// A string list of tokens ordered by their token id, post-processed. Once initialized, will not
// be reinitialized since `this.tokenizer` does not change throughout the lifetime of LLMChatPipeline.
private xgTokenizerInfo?: xgrammar.TokenizerInfo = undefined;
private xgTokenizerInfo?: xgr.TokenizerInfo = undefined;
// Compiler for grammar. It is persistent since it specializes on xgTokenizerInfo.
private grammarCompiler?: xgr.GrammarCompiler = undefined;
// Size of the bitmask for grammar, determined by fullVocabSize
private bitmaskSize: number;
// `vocab_size` read from `config.json`. Can be different from the size of the tokenTable for some
// models due to dummy padded tokens.
Expand Down Expand Up @@ -302,6 +305,7 @@ export class LLMChatPipeline {
this.tvm.dispose();
this.tokenizer.dispose();
this.xgTokenizerInfo?.dispose();
this.grammarCompiler?.dispose();
}

/**
Expand Down Expand Up @@ -545,31 +549,29 @@ export class LLMChatPipeline {
log.info("Initialize token table.");
// Post process entire table
const rawTokenTable = getTokenTableFromTokenizer(this.tokenizer);
this.xgTokenizerInfo =
await xgrammar.TokenizerInfo.createTokenizerInfo(
rawTokenTable,
this.token_postproc_method,
this.prepend_space_in_encode,
this.xgTokenizerInfo = await xgr.TokenizerInfo.createTokenizerInfo(
rawTokenTable,
this.token_postproc_method,
this.prepend_space_in_encode,
this.fullVocabSize,
);
this.grammarCompiler =
await xgr.GrammarCompiler.createGrammarCompiler(
this.xgTokenizerInfo,
);
}
const grammar: xgrammar.BNFGrammar =
const grammar: xgr.CompiledGrammar =
curSchemaOrGrammarStr === undefined
? await xgrammar.BuiltinGrammar.json()
? await this.grammarCompiler!.compileBuiltinJSONGrammar()
: genConfig?.response_format?.type === "json_object"
? await xgrammar.BuiltinGrammar.jsonSchema(
? await this.grammarCompiler!.compileJSONSchema(
curSchemaOrGrammarStr,
)
: await xgrammar.BNFGrammar.createBNFGrammar(
: await this.grammarCompiler!.compileGrammar(
curSchemaOrGrammarStr,
);
this.grammarMatcher =
await xgrammar.GrammarMatcher.createGrammarMatcher(
grammar,
this.xgTokenizerInfo,
undefined,
undefined,
this.fullVocabSize,
);
await xgr.GrammarMatcher.createGrammarMatcher(grammar);
grammar.dispose();
this.schemaOrGrammarStr = curSchemaOrGrammarStr;
this.curRoundGrammarInitTotalTime =
Expand Down Expand Up @@ -1034,7 +1036,7 @@ export class LLMChatPipeline {

const tBitmaskStart = performance.now();
const bitMaskOnCPU: Int32Array =
await this.grammarMatcher.findNextTokenBitmask();
await this.grammarMatcher.getNextTokenBitmask();
this.curRoundGrammarPerTokenTotalTime +=
(performance.now() - tBitmaskStart) / 1e3;

Expand Down
Loading