Skip to content

Instantly share code, notes, and snippets.

@flashburns
Last active January 15, 2025 18:49
Show Gist options
  • Save flashburns/c4df2e6da499e5c80ee08512988b7ece to your computer and use it in GitHub Desktop.
Save flashburns/c4df2e6da499e5c80ee08512988b7ece to your computer and use it in GitHub Desktop.
D lang bindings for rwkv.cpp with a D native tokenizer.

The bind was written for rwkv.cpp at commit hash 96ce1d9b3cb57b9fbec071602bd603a61757d74b (2024-07-12) You can generate rwkv_vocab_tokens.json by downloading rwkv_vocab_v20230424.txt and running vocab_txt_to_json.py. If you don't want to deal with python, you can download and unpack rwkv_vocab_tokens.json.xz

module cog.rwkv;
@safe:
alias Token = uint;
class RwkvTokenizer{
immutable(ubyte)[] raw_chars;
struct TokenSegment{
uint offset;
uint length;
}
TokenSegment[] tokens;
struct TokenPrefix{
Token[] tokens;
}
TokenPrefix[ubyte.max][ubyte.max] prefix_lookup;
uint[immutable(ubyte)[]] string_to_token_index;
this(string vocab_file){
import std.json : parseJSON;
import std.file : readText;
import std.algorithm : map, sort;
import std.conv : to;
import std.array : array;
auto token_bytes_from_json = vocab_file.readText.parseJSON
.arrayNoRef.map!(json_object => json_object.arrayNoRef)
.map!(json_byte_array => json_byte_array.map!( json_byte => json_byte.integer.to!ubyte ).array() );
foreach(immutable(ubyte)[] token_str; token_bytes_from_json){
const Token token_id = this.tokens.length.to!Token;
this.tokens ~= TokenSegment(raw_chars.length.to!uint, token_str.length.to!uint);
this.raw_chars ~= token_str;
this.string_to_token_index[token_str] = token_id;
if(token_str.length >= 2)
this.prefix_lookup[token_str[0]][token_str[1]].tokens ~= token_id;
}
size_t len(Token token){ return this.stringify_token(token).length; }
alias sort_token_len = (a, b) => len(a) > len(b);
foreach(ref second_byte_lookup; prefix_lookup) foreach(ref prefix; second_byte_lookup)
prefix.tokens.sort!sort_token_len;
}
string stringify_token(Token token) const {
const segment = this.tokens[token];
return cast(string)this.raw_chars[segment.offset .. segment.offset + segment.length];
}
private immutable(ubyte)[] stringify_token_as_bytes(Token token) const {
const segment = this.tokens[token];
return this.raw_chars[segment.offset .. segment.offset + segment.length];
}
auto consume_into_token(scope const(ubyte)[] str) const {
import std.algorithm : min, map, startsWith;
struct Ret{ Token token = -1; int consumed = -1; }
if(str.length == 0) return Ret();
else if(str.length >= 2){
ubyte first_byte = str[0];
ubyte second_byte = str[1];
const prefix = this.prefix_lookup[first_byte][second_byte];
if(prefix.tokens.length > 0){
const search_length = min(this.stringify_token(prefix.tokens[0]).length, str.length);
const(ubyte)[] match_string = str[0..search_length];
struct PotentialToken{ Token token; immutable(ubyte)[] token_str; }
// NOTE: the prefix.token array is sorted to be longest to shortest string lengths.
foreach(potential_match; prefix.tokens.map!( t => PotentialToken(t, this.stringify_token_as_bytes(t)) )){
if(startsWith(match_string, potential_match.token_str))
return Ret(this.string_to_token_index[potential_match.token_str], cast(int)potential_match.token_str.length);
}
}
}
return Ret(this.string_to_token_index[str[0..1]], 1);
}
auto consume_into_token(scope const(char)[] str) const { return this.consume_into_token(cast(const(ubyte)[])str); }
Token[] to_tokens(scope const(char)[] str) const {
Token[] ret;
size_t consumed;
do{
const new_token = this.consume_into_token(cast(const(ubyte)[])str[consumed..$]);
//writeln("new_token: ", new_token);
consumed += new_token.consumed;
assert(new_token.consumed > 0);
ret ~= new_token.token;
}while(consumed < str.length);
return ret;
}
}
/*import std.stdio : writeln;
import std.conv : text, to;
scope RwkvTokenizer tokenizer = new RwkvTokenizer("rwkv_vocab_tokens.json");
assert( tokenizer.to_tokens("Hello world") == [33155, 40213] );
assert( tokenizer.to_tokens("🍿") == [3319, 142, 192] );
import std.json : parseJSON;
import std.file : readText;
import std.algorithm : equal, map;
import std.conv : to;
foreach(test_index, test_set; readText("test.json").parseJSON.arrayNoRef){
assert(
tokenizer.to_tokens(test_set["text"].str).equal( test_set["as_bytes"].arrayNoRef.map!(json_num => json_num.integer.to!int) ),
i"Failed test $(test_index)".text
);
}*/
// ---------------------- RWKV.H ----------------------
@nogc nothrow @system:
extern(C): __gshared:
public import core.stdc.stddef;
public import core.stdc.stdint;
// 'ggmf' in hex.
enum RWKV_FILE_MAGIC = 0x67676d66;
enum RWKV_FILE_VERSION_0 = 100;
enum RWKV_FILE_VERSION_1 = 101;
enum RWKV_FILE_VERSION_MIN = RWKV_FILE_VERSION_0;
enum RWKV_FILE_VERSION_MAX = RWKV_FILE_VERSION_1;
// Default file version is the latest version.
enum RWKV_FILE_VERSION = RWKV_FILE_VERSION_MAX;
// Represents an error encountered during a function call.
// These are flags, so an actual value might contain multiple errors.
enum rwkv_error_flags {
RWKV_ERROR_NONE = 0,
RWKV_ERROR_ARGS = 1 << 8,
RWKV_ERROR_FILE = 2 << 8,
RWKV_ERROR_MODEL = 3 << 8,
RWKV_ERROR_MODEL_PARAMS = 4 << 8,
RWKV_ERROR_GRAPH = 5 << 8,
RWKV_ERROR_CTX = 6 << 8,
RWKV_ERROR_ALLOC = 1,
RWKV_ERROR_FILE_OPEN = 2,
RWKV_ERROR_FILE_STAT = 3,
RWKV_ERROR_FILE_READ = 4,
RWKV_ERROR_FILE_WRITE = 5,
RWKV_ERROR_FILE_MAGIC = 6,
RWKV_ERROR_FILE_VERSION = 7,
RWKV_ERROR_DATA_TYPE = 8,
RWKV_ERROR_UNSUPPORTED = 9,
RWKV_ERROR_SHAPE = 10,
RWKV_ERROR_DIMENSION = 11,
RWKV_ERROR_KEY = 12,
RWKV_ERROR_DATA = 13,
RWKV_ERROR_PARAM_MISSING = 14
}
alias RWKV_ERROR_NONE = rwkv_error_flags.RWKV_ERROR_NONE;
alias RWKV_ERROR_ARGS = rwkv_error_flags.RWKV_ERROR_ARGS;
alias RWKV_ERROR_FILE = rwkv_error_flags.RWKV_ERROR_FILE;
alias RWKV_ERROR_MODEL = rwkv_error_flags.RWKV_ERROR_MODEL;
alias RWKV_ERROR_MODEL_PARAMS = rwkv_error_flags.RWKV_ERROR_MODEL_PARAMS;
alias RWKV_ERROR_GRAPH = rwkv_error_flags.RWKV_ERROR_GRAPH;
alias RWKV_ERROR_CTX = rwkv_error_flags.RWKV_ERROR_CTX;
alias RWKV_ERROR_ALLOC = rwkv_error_flags.RWKV_ERROR_ALLOC;
alias RWKV_ERROR_FILE_OPEN = rwkv_error_flags.RWKV_ERROR_FILE_OPEN;
alias RWKV_ERROR_FILE_STAT = rwkv_error_flags.RWKV_ERROR_FILE_STAT;
alias RWKV_ERROR_FILE_READ = rwkv_error_flags.RWKV_ERROR_FILE_READ;
alias RWKV_ERROR_FILE_WRITE = rwkv_error_flags.RWKV_ERROR_FILE_WRITE;
alias RWKV_ERROR_FILE_MAGIC = rwkv_error_flags.RWKV_ERROR_FILE_MAGIC;
alias RWKV_ERROR_FILE_VERSION = rwkv_error_flags.RWKV_ERROR_FILE_VERSION;
alias RWKV_ERROR_DATA_TYPE = rwkv_error_flags.RWKV_ERROR_DATA_TYPE;
alias RWKV_ERROR_UNSUPPORTED = rwkv_error_flags.RWKV_ERROR_UNSUPPORTED;
alias RWKV_ERROR_SHAPE = rwkv_error_flags.RWKV_ERROR_SHAPE;
alias RWKV_ERROR_DIMENSION = rwkv_error_flags.RWKV_ERROR_DIMENSION;
alias RWKV_ERROR_KEY = rwkv_error_flags.RWKV_ERROR_KEY;
alias RWKV_ERROR_DATA = rwkv_error_flags.RWKV_ERROR_DATA;
alias RWKV_ERROR_PARAM_MISSING = rwkv_error_flags.RWKV_ERROR_PARAM_MISSING;
// RWKV context that can be used for inference.
// All functions that operate on rwkv_context are thread-safe.
// rwkv_context can be sent to different threads between calls to rwkv_eval.
// There is no requirement for rwkv_context to be freed on the creating thread.
struct rwkv_context;
// Sets whether errors are automatically printed to stderr.
// If this is set to false, you are responsible for calling rwkv_last_error manually if an operation fails.
// - ctx: the context to suppress error messages for.
// If NULL, affects model load (rwkv_init_from_file) and quantization (rwkv_quantize_model_file) errors,
// as well as the default for new context.
// - print_errors: whether error messages should be automatically printed.
void rwkv_set_print_errors(rwkv_context* ctx, const(bool) print_errors);
// Gets whether errors are automatically printed to stderr.
// - ctx: the context to retrieve the setting for, or NULL for the global setting.
bool rwkv_get_print_errors(const(rwkv_context)* ctx);
// Retrieves and clears the error flags.
// - ctx: the context the retrieve the error for, or NULL for the global error.
rwkv_error_flags rwkv_get_last_error(rwkv_context* ctx);
// Loads the model from a file and prepares it for inference.
// Returns NULL on any error.
// - model_file_path: path to model file in ggml format.
// - n_threads: count of threads to use, must be positive.
// - n_gpu_layer: count of layers need to load to gpu
rwkv_context* rwkv_init_from_file(const(char)* model_file_path, const(uint) n_threads, const(uint) n_gpu_layers);
// Creates a new context from an existing one.
// This can allow you to run multiple rwkv_eval's in parallel, without having to load a single model multiple times.
// Each rwkv_context can have one eval running at a time.
// Every rwkv_context must be freed using rwkv_free.
// - ctx: context to be cloned.
// - n_threads: count of threads to use, must be positive.
rwkv_context* rwkv_clone_context(rwkv_context* ctx, const(uint) n_threads);
// Evaluates the model for a single token.
// You can pass NULL to logits_out whenever logits are not needed. This can improve speed by ~10 ms per iteration, because logits are not calculated.
// Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
// Returns false on any error.
// - token: next token index, in range 0 <= token < n_vocab.
// - state_in: FP32 buffer of size rwkv_get_state_len(); or NULL, if this is a first pass.
// - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL.
// - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL.
bool rwkv_eval(rwkv_context* ctx, const(uint) token, const(float)* state_in, float* state_out, float* logits_out);
// Evaluates the model for a sequence of tokens.
// Uses a faster algorithm than `rwkv_eval` if you do not need the state and logits for every token. Best used with sequence lengths of 64 or so.
// Has to build a computation graph on the first call for a given sequence, but will use this cached graph for subsequent calls of the same sequence length.
//
// NOTE ON GGML NODE LIMIT
//
// ggml has a hard-coded limit on max amount of nodes in a computation graph. The sequence graph is built in a way that quickly exceedes
// this limit when using large models and/or large sequence lengths.
// Fortunately, rwkv.cpp's fork of ggml has increased limit which was tested to work for sequence lengths up to 64 for 14B models.
//
// If you get `GGML_ASSERT: ...\ggml.c:16941: cgraph->n_nodes < GGML_MAX_NODES`, this means you've exceeded the limit.
// To get rid of the assertion failure, reduce the model size and/or sequence length.
//
// TODO When Metal (MPS) support is implemented, check that large sequence lengths work
//
// You can pass NULL to logits_out whenever logits are not needed. This can improve speed by ~10 ms per iteration, because logits are not calculated.
// Not thread-safe. For parallel inference, call `rwkv_clone_context` to create one rwkv_context for each thread.
// Returns false on any error.
// - tokens: pointer to an array of tokens. If NULL, the graph will be built and cached, but not executed: this can be useful for initialization.
// - sequence_len: number of tokens to read from the array.
// - state_in: FP32 buffer of size rwkv_get_state_len(), or NULL if this is a first pass.
// - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL.
// - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL.
bool rwkv_eval_sequence(rwkv_context* ctx, const(uint)* tokens, const(size_t) sequence_len, const(float)* state_in, float* state_out, float* logits_out);
// Evaluates the model for a sequence of tokens using `rwkv_eval_sequence`, splitting a potentially long sequence into fixed-length chunks.
// This function is useful for processing complete prompts and user input in chat & role-playing use-cases.
// It is recommended to use this function instead of `rwkv_eval_sequence` to avoid mistakes and get maximum performance.
//
// Chunking allows processing sequences of thousands of tokens, while not reaching the ggml's node limit and not consuming too much memory.
// A reasonable and recommended value of chunk size is 16. If you want maximum performance, try different chunk sizes in range [2..64]
// and choose one that works the best in your use case.
//
// Not thread-safe. For parallel inference, call `rwkv_clone_context` to create one rwkv_context for each thread.
// Returns false on any error.
// - tokens: pointer to an array of tokens. If NULL, the graph will be built and cached, but not executed: this can be useful for initialization.
// - sequence_len: number of tokens to read from the array.
// - chunk_size: size of each chunk in tokens, must be positive.
// - state_in: FP32 buffer of size rwkv_get_state_len(), or NULL if this is a first pass.
// - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL.
// - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL.
bool rwkv_eval_sequence_in_chunks(rwkv_context* ctx, const(uint)* tokens, const(size_t) sequence_len, const(size_t) chunk_size, const(float)* state_in, float* state_out, float* logits_out);
// Returns the number of tokens in the given model's vocabulary.
// Useful for telling 20B_tokenizer models (n_vocab = 50277) apart from World models (n_vocab = 65536).
size_t rwkv_get_n_vocab(const(rwkv_context)* ctx);
// Returns the number of elements in the given model's embedding.
// Useful for reading individual fields of a model's hidden state.
size_t rwkv_get_n_embed(const(rwkv_context)* ctx);
// Returns the number of layers in the given model.
// A layer is a pair of RWKV and FFN operations, stacked multiple times throughout the model.
// Embedding matrix and model head (unembedding matrix) are NOT counted in `n_layer`.
// Useful for always offloading the entire model to GPU.
size_t rwkv_get_n_layer(const(rwkv_context)* ctx);
// Returns the number of float elements in a complete state for the given model.
// This is the number of elements you'll need to allocate for a call to rwkv_eval, rwkv_eval_sequence, or rwkv_init_state.
size_t rwkv_get_state_len(const(rwkv_context)* ctx);
// Returns the number of float elements in the logits output of a given model.
// This is currently always identical to n_vocab.
size_t rwkv_get_logits_len(const(rwkv_context)* ctx);
// Initializes the given state so that passing it to rwkv_eval or rwkv_eval_sequence would be identical to passing NULL.
// Useful in cases where tracking the first call to these functions may be annoying or expensive.
// State must be initialized for behavior to be defined, passing a zeroed state to rwkv.cpp functions will result in NaNs.
// - state: FP32 buffer of size rwkv_get_state_len() to initialize
void rwkv_init_state(const(rwkv_context)* ctx, float* state);
// Frees all allocated memory and the context.
// Does not need to be called on the same thread that created the rwkv_context.
void rwkv_free(rwkv_context* ctx);
// Quantizes FP32 or FP16 model to one of quantized formats.
// Returns false on any error. Error messages would be printed to stderr.
// - model_file_path_in: path to model file in ggml format, must be either FP32 or FP16.
// - model_file_path_out: quantized model will be written here.
// - format_name: must be one of available format names below.
// Available format names:
// - Q4_0
// - Q4_1
// - Q5_0
// - Q5_1
// - Q8_0
bool rwkv_quantize_model_file(const(char)* model_file_path_in, const(char)* model_file_path_out, const(char)* format_name);
// Returns system information string.
const(char)* rwkv_get_system_info_string();
import json
lines : list[str] = open("rwkv_vocab_v20230424.txt", "r", encoding="utf-8").readlines()
tokens : list[bytes] = []
for line in lines:
token_id = int(line[:line.index(' ')])
token_content_raw = eval(line[line.index(' '):line.rindex(' ')])
token_content : bytes = token_content_raw.encode("utf-8") if isinstance(token_content_raw, str) else token_content_raw
assert isinstance(token_content, bytes)
assert len(token_content) == int(line[line.rindex(' '):])
tokens += [token_content]
json_data = json.dumps([list(arr) for arr in tokens], indent='\t')
with open("rwkv_vocab_tokens.json", 'w') as file:
file.write(json_data)
@flashburns
Copy link
Author

flashburns commented Jan 12, 2025

The bind was written for rwkv.cpp at commit hash 96ce1d9b3cb57b9fbec071602bd603a61757d74b (2024-07-12)
You can generate rwkv_vocab_tokens.json by downloading rwkv_vocab_v20230424.txt and running vocab_txt_to_json.py. If you don't want to deal with python, you can download and unpack rwkv_vocab_tokens.json.xz

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment