Skip to content

Commit

Permalink
Fix Stable LM 3B build (mlc-ai#1061)
Browse files Browse the repository at this point in the history
* [stablelm 3b] Rename dynamic vocab size from "v" to "vocab_size"

* Add get_num_key_value_heads method to StableLM3bConfig
  • Loading branch information
jeethu authored Oct 14, 2023
1 parent d854105 commit c2b8cbc
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions mlc_llm/relax_model/stablelm_3b.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ def __init__(
self.num_shards = 1
self.kwargs = kwargs

def get_num_key_value_heads(self):
if self.num_key_value_heads is None:
return self.num_attention_heads
return self.num_key_value_heads


class LayerNorm(nn.Module):
def __init__(
Expand Down Expand Up @@ -579,7 +584,7 @@ def create_embed_func(
bsz = 1
seq_len = tvm.tir.Var("n", "int64")
with bb.function(func_name):
model = StableLM3bEmbedTokensWrapper(config, tvm.tir.Var("v", "int64"))
model = StableLM3bEmbedTokensWrapper(config, tvm.tir.Var("vocab_size", "int64"))
param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind)

input_ids = nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids")
Expand Down Expand Up @@ -608,7 +613,7 @@ def create_encoding_func(
all_seq_len = tvm.tir.Var("m", "int64")
hidden_size = config.hidden_size
with bb.function(func_name):
model = StableLM3bForCausalLM(config, tvm.tir.Var("v", "int64"), sep_embed)
model = StableLM3bForCausalLM(config, tvm.tir.Var("vocab_size", "int64"), sep_embed)
param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind)

inputs = (
Expand Down Expand Up @@ -652,7 +657,7 @@ def create_decoding_func(
all_seq_len = tvm.tir.Var("n", "int64")

with bb.function(func_name):
model = StableLM3bForCausalLM(config, tvm.tir.Var("v", "int64"))
model = StableLM3bForCausalLM(config, tvm.tir.Var("vocab_size", "int64"))
param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind)

input_ids = nn.Placeholder((bsz, 1), dtype="int32", name="input_ids")
Expand Down Expand Up @@ -714,7 +719,9 @@ def create_kv_cache_func(bb: relax.BlockBuilder, config: StableLM3bConfig) -> No

def create_softmax_func(bb: relax.BlockBuilder, config: StableLM3bConfig) -> None:
with bb.function("softmax_with_temperature"):
logits = nn.Placeholder((1, 1, tvm.tir.Var("v", "int64")), dtype="float32", name="logits")
logits = nn.Placeholder(
(1, 1, tvm.tir.Var("vocab_size", "int64")), dtype="float32", name="logits"
)
temperature = nn.Placeholder((), dtype="float32", name="temperature")
with bb.dataflow():
div = bb.emit(relax.op.divide(logits, temperature))
Expand Down

0 comments on commit c2b8cbc

Please sign in to comment.