The bind was written for rwkv.cpp at commit hash 96ce1d9b3cb57b9fbec071602bd603a61757d74b
(2024-07-12)
You can generate rwkv_vocab_tokens.json
by downloading rwkv_vocab_v20230424.txt and running vocab_txt_to_json.py
. If you don't want to deal with python, you can download and unpack rwkv_vocab_tokens.json.xz
Last active
January 15, 2025 18:49
-
-
Save flashburns/c4df2e6da499e5c80ee08512988b7ece to your computer and use it in GitHub Desktop.
D lang bindings for rwkv.cpp with a D native tokenizer.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module cog.rwkv; | |
@safe: | |
alias Token = uint; | |
class RwkvTokenizer{ | |
immutable(ubyte)[] raw_chars; | |
struct TokenSegment{ | |
uint offset; | |
uint length; | |
} | |
TokenSegment[] tokens; | |
struct TokenPrefix{ | |
Token[] tokens; | |
} | |
TokenPrefix[ubyte.max][ubyte.max] prefix_lookup; | |
uint[immutable(ubyte)[]] string_to_token_index; | |
this(string vocab_file){ | |
import std.json : parseJSON; | |
import std.file : readText; | |
import std.algorithm : map, sort; | |
import std.conv : to; | |
import std.array : array; | |
auto token_bytes_from_json = vocab_file.readText.parseJSON | |
.arrayNoRef.map!(json_object => json_object.arrayNoRef) | |
.map!(json_byte_array => json_byte_array.map!( json_byte => json_byte.integer.to!ubyte ).array() ); | |
foreach(immutable(ubyte)[] token_str; token_bytes_from_json){ | |
const Token token_id = this.tokens.length.to!Token; | |
this.tokens ~= TokenSegment(raw_chars.length.to!uint, token_str.length.to!uint); | |
this.raw_chars ~= token_str; | |
this.string_to_token_index[token_str] = token_id; | |
if(token_str.length >= 2) | |
this.prefix_lookup[token_str[0]][token_str[1]].tokens ~= token_id; | |
} | |
size_t len(Token token){ return this.stringify_token(token).length; } | |
alias sort_token_len = (a, b) => len(a) > len(b); | |
foreach(ref second_byte_lookup; prefix_lookup) foreach(ref prefix; second_byte_lookup) | |
prefix.tokens.sort!sort_token_len; | |
} | |
string stringify_token(Token token) const { | |
const segment = this.tokens[token]; | |
return cast(string)this.raw_chars[segment.offset .. segment.offset + segment.length]; | |
} | |
private immutable(ubyte)[] stringify_token_as_bytes(Token token) const { | |
const segment = this.tokens[token]; | |
return this.raw_chars[segment.offset .. segment.offset + segment.length]; | |
} | |
auto consume_into_token(scope const(ubyte)[] str) const { | |
import std.algorithm : min, map, startsWith; | |
struct Ret{ Token token = -1; int consumed = -1; } | |
if(str.length == 0) return Ret(); | |
else if(str.length >= 2){ | |
ubyte first_byte = str[0]; | |
ubyte second_byte = str[1]; | |
const prefix = this.prefix_lookup[first_byte][second_byte]; | |
if(prefix.tokens.length > 0){ | |
const search_length = min(this.stringify_token(prefix.tokens[0]).length, str.length); | |
const(ubyte)[] match_string = str[0..search_length]; | |
struct PotentialToken{ Token token; immutable(ubyte)[] token_str; } | |
// NOTE: the prefix.token array is sorted to be longest to shortest string lengths. | |
foreach(potential_match; prefix.tokens.map!( t => PotentialToken(t, this.stringify_token_as_bytes(t)) )){ | |
if(startsWith(match_string, potential_match.token_str)) | |
return Ret(this.string_to_token_index[potential_match.token_str], cast(int)potential_match.token_str.length); | |
} | |
} | |
} | |
return Ret(this.string_to_token_index[str[0..1]], 1); | |
} | |
auto consume_into_token(scope const(char)[] str) const { return this.consume_into_token(cast(const(ubyte)[])str); } | |
Token[] to_tokens(scope const(char)[] str) const { | |
Token[] ret; | |
size_t consumed; | |
do{ | |
const new_token = this.consume_into_token(cast(const(ubyte)[])str[consumed..$]); | |
//writeln("new_token: ", new_token); | |
consumed += new_token.consumed; | |
assert(new_token.consumed > 0); | |
ret ~= new_token.token; | |
}while(consumed < str.length); | |
return ret; | |
} | |
} | |
/*import std.stdio : writeln; | |
import std.conv : text, to; | |
scope RwkvTokenizer tokenizer = new RwkvTokenizer("rwkv_vocab_tokens.json"); | |
assert( tokenizer.to_tokens("Hello world") == [33155, 40213] ); | |
assert( tokenizer.to_tokens("🍿") == [3319, 142, 192] ); | |
import std.json : parseJSON; | |
import std.file : readText; | |
import std.algorithm : equal, map; | |
import std.conv : to; | |
foreach(test_index, test_set; readText("test.json").parseJSON.arrayNoRef){ | |
assert( | |
tokenizer.to_tokens(test_set["text"].str).equal( test_set["as_bytes"].arrayNoRef.map!(json_num => json_num.integer.to!int) ), | |
i"Failed test $(test_index)".text | |
); | |
}*/ | |
// ---------------------- RWKV.H ---------------------- | |
@nogc nothrow @system: | |
extern(C): __gshared: | |
public import core.stdc.stddef; | |
public import core.stdc.stdint; | |
// 'ggmf' in hex. | |
enum RWKV_FILE_MAGIC = 0x67676d66; | |
enum RWKV_FILE_VERSION_0 = 100; | |
enum RWKV_FILE_VERSION_1 = 101; | |
enum RWKV_FILE_VERSION_MIN = RWKV_FILE_VERSION_0; | |
enum RWKV_FILE_VERSION_MAX = RWKV_FILE_VERSION_1; | |
// Default file version is the latest version. | |
enum RWKV_FILE_VERSION = RWKV_FILE_VERSION_MAX; | |
// Represents an error encountered during a function call. | |
// These are flags, so an actual value might contain multiple errors. | |
enum rwkv_error_flags { | |
RWKV_ERROR_NONE = 0, | |
RWKV_ERROR_ARGS = 1 << 8, | |
RWKV_ERROR_FILE = 2 << 8, | |
RWKV_ERROR_MODEL = 3 << 8, | |
RWKV_ERROR_MODEL_PARAMS = 4 << 8, | |
RWKV_ERROR_GRAPH = 5 << 8, | |
RWKV_ERROR_CTX = 6 << 8, | |
RWKV_ERROR_ALLOC = 1, | |
RWKV_ERROR_FILE_OPEN = 2, | |
RWKV_ERROR_FILE_STAT = 3, | |
RWKV_ERROR_FILE_READ = 4, | |
RWKV_ERROR_FILE_WRITE = 5, | |
RWKV_ERROR_FILE_MAGIC = 6, | |
RWKV_ERROR_FILE_VERSION = 7, | |
RWKV_ERROR_DATA_TYPE = 8, | |
RWKV_ERROR_UNSUPPORTED = 9, | |
RWKV_ERROR_SHAPE = 10, | |
RWKV_ERROR_DIMENSION = 11, | |
RWKV_ERROR_KEY = 12, | |
RWKV_ERROR_DATA = 13, | |
RWKV_ERROR_PARAM_MISSING = 14 | |
} | |
alias RWKV_ERROR_NONE = rwkv_error_flags.RWKV_ERROR_NONE; | |
alias RWKV_ERROR_ARGS = rwkv_error_flags.RWKV_ERROR_ARGS; | |
alias RWKV_ERROR_FILE = rwkv_error_flags.RWKV_ERROR_FILE; | |
alias RWKV_ERROR_MODEL = rwkv_error_flags.RWKV_ERROR_MODEL; | |
alias RWKV_ERROR_MODEL_PARAMS = rwkv_error_flags.RWKV_ERROR_MODEL_PARAMS; | |
alias RWKV_ERROR_GRAPH = rwkv_error_flags.RWKV_ERROR_GRAPH; | |
alias RWKV_ERROR_CTX = rwkv_error_flags.RWKV_ERROR_CTX; | |
alias RWKV_ERROR_ALLOC = rwkv_error_flags.RWKV_ERROR_ALLOC; | |
alias RWKV_ERROR_FILE_OPEN = rwkv_error_flags.RWKV_ERROR_FILE_OPEN; | |
alias RWKV_ERROR_FILE_STAT = rwkv_error_flags.RWKV_ERROR_FILE_STAT; | |
alias RWKV_ERROR_FILE_READ = rwkv_error_flags.RWKV_ERROR_FILE_READ; | |
alias RWKV_ERROR_FILE_WRITE = rwkv_error_flags.RWKV_ERROR_FILE_WRITE; | |
alias RWKV_ERROR_FILE_MAGIC = rwkv_error_flags.RWKV_ERROR_FILE_MAGIC; | |
alias RWKV_ERROR_FILE_VERSION = rwkv_error_flags.RWKV_ERROR_FILE_VERSION; | |
alias RWKV_ERROR_DATA_TYPE = rwkv_error_flags.RWKV_ERROR_DATA_TYPE; | |
alias RWKV_ERROR_UNSUPPORTED = rwkv_error_flags.RWKV_ERROR_UNSUPPORTED; | |
alias RWKV_ERROR_SHAPE = rwkv_error_flags.RWKV_ERROR_SHAPE; | |
alias RWKV_ERROR_DIMENSION = rwkv_error_flags.RWKV_ERROR_DIMENSION; | |
alias RWKV_ERROR_KEY = rwkv_error_flags.RWKV_ERROR_KEY; | |
alias RWKV_ERROR_DATA = rwkv_error_flags.RWKV_ERROR_DATA; | |
alias RWKV_ERROR_PARAM_MISSING = rwkv_error_flags.RWKV_ERROR_PARAM_MISSING; | |
// RWKV context that can be used for inference. | |
// All functions that operate on rwkv_context are thread-safe. | |
// rwkv_context can be sent to different threads between calls to rwkv_eval. | |
// There is no requirement for rwkv_context to be freed on the creating thread. | |
struct rwkv_context; | |
// Sets whether errors are automatically printed to stderr. | |
// If this is set to false, you are responsible for calling rwkv_last_error manually if an operation fails. | |
// - ctx: the context to suppress error messages for. | |
// If NULL, affects model load (rwkv_init_from_file) and quantization (rwkv_quantize_model_file) errors, | |
// as well as the default for new context. | |
// - print_errors: whether error messages should be automatically printed. | |
void rwkv_set_print_errors(rwkv_context* ctx, const(bool) print_errors); | |
// Gets whether errors are automatically printed to stderr. | |
// - ctx: the context to retrieve the setting for, or NULL for the global setting. | |
bool rwkv_get_print_errors(const(rwkv_context)* ctx); | |
// Retrieves and clears the error flags. | |
// - ctx: the context the retrieve the error for, or NULL for the global error. | |
rwkv_error_flags rwkv_get_last_error(rwkv_context* ctx); | |
// Loads the model from a file and prepares it for inference. | |
// Returns NULL on any error. | |
// - model_file_path: path to model file in ggml format. | |
// - n_threads: count of threads to use, must be positive. | |
// - n_gpu_layer: count of layers need to load to gpu | |
rwkv_context* rwkv_init_from_file(const(char)* model_file_path, const(uint) n_threads, const(uint) n_gpu_layers); | |
// Creates a new context from an existing one. | |
// This can allow you to run multiple rwkv_eval's in parallel, without having to load a single model multiple times. | |
// Each rwkv_context can have one eval running at a time. | |
// Every rwkv_context must be freed using rwkv_free. | |
// - ctx: context to be cloned. | |
// - n_threads: count of threads to use, must be positive. | |
rwkv_context* rwkv_clone_context(rwkv_context* ctx, const(uint) n_threads); | |
// Evaluates the model for a single token. | |
// You can pass NULL to logits_out whenever logits are not needed. This can improve speed by ~10 ms per iteration, because logits are not calculated. | |
// Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread. | |
// Returns false on any error. | |
// - token: next token index, in range 0 <= token < n_vocab. | |
// - state_in: FP32 buffer of size rwkv_get_state_len(); or NULL, if this is a first pass. | |
// - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL. | |
// - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL. | |
bool rwkv_eval(rwkv_context* ctx, const(uint) token, const(float)* state_in, float* state_out, float* logits_out); | |
// Evaluates the model for a sequence of tokens. | |
// Uses a faster algorithm than `rwkv_eval` if you do not need the state and logits for every token. Best used with sequence lengths of 64 or so. | |
// Has to build a computation graph on the first call for a given sequence, but will use this cached graph for subsequent calls of the same sequence length. | |
// | |
// NOTE ON GGML NODE LIMIT | |
// | |
// ggml has a hard-coded limit on max amount of nodes in a computation graph. The sequence graph is built in a way that quickly exceedes | |
// this limit when using large models and/or large sequence lengths. | |
// Fortunately, rwkv.cpp's fork of ggml has increased limit which was tested to work for sequence lengths up to 64 for 14B models. | |
// | |
// If you get `GGML_ASSERT: ...\ggml.c:16941: cgraph->n_nodes < GGML_MAX_NODES`, this means you've exceeded the limit. | |
// To get rid of the assertion failure, reduce the model size and/or sequence length. | |
// | |
// TODO When Metal (MPS) support is implemented, check that large sequence lengths work | |
// | |
// You can pass NULL to logits_out whenever logits are not needed. This can improve speed by ~10 ms per iteration, because logits are not calculated. | |
// Not thread-safe. For parallel inference, call `rwkv_clone_context` to create one rwkv_context for each thread. | |
// Returns false on any error. | |
// - tokens: pointer to an array of tokens. If NULL, the graph will be built and cached, but not executed: this can be useful for initialization. | |
// - sequence_len: number of tokens to read from the array. | |
// - state_in: FP32 buffer of size rwkv_get_state_len(), or NULL if this is a first pass. | |
// - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL. | |
// - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL. | |
bool rwkv_eval_sequence(rwkv_context* ctx, const(uint)* tokens, const(size_t) sequence_len, const(float)* state_in, float* state_out, float* logits_out); | |
// Evaluates the model for a sequence of tokens using `rwkv_eval_sequence`, splitting a potentially long sequence into fixed-length chunks. | |
// This function is useful for processing complete prompts and user input in chat & role-playing use-cases. | |
// It is recommended to use this function instead of `rwkv_eval_sequence` to avoid mistakes and get maximum performance. | |
// | |
// Chunking allows processing sequences of thousands of tokens, while not reaching the ggml's node limit and not consuming too much memory. | |
// A reasonable and recommended value of chunk size is 16. If you want maximum performance, try different chunk sizes in range [2..64] | |
// and choose one that works the best in your use case. | |
// | |
// Not thread-safe. For parallel inference, call `rwkv_clone_context` to create one rwkv_context for each thread. | |
// Returns false on any error. | |
// - tokens: pointer to an array of tokens. If NULL, the graph will be built and cached, but not executed: this can be useful for initialization. | |
// - sequence_len: number of tokens to read from the array. | |
// - chunk_size: size of each chunk in tokens, must be positive. | |
// - state_in: FP32 buffer of size rwkv_get_state_len(), or NULL if this is a first pass. | |
// - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL. | |
// - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL. | |
bool rwkv_eval_sequence_in_chunks(rwkv_context* ctx, const(uint)* tokens, const(size_t) sequence_len, const(size_t) chunk_size, const(float)* state_in, float* state_out, float* logits_out); | |
// Returns the number of tokens in the given model's vocabulary. | |
// Useful for telling 20B_tokenizer models (n_vocab = 50277) apart from World models (n_vocab = 65536). | |
size_t rwkv_get_n_vocab(const(rwkv_context)* ctx); | |
// Returns the number of elements in the given model's embedding. | |
// Useful for reading individual fields of a model's hidden state. | |
size_t rwkv_get_n_embed(const(rwkv_context)* ctx); | |
// Returns the number of layers in the given model. | |
// A layer is a pair of RWKV and FFN operations, stacked multiple times throughout the model. | |
// Embedding matrix and model head (unembedding matrix) are NOT counted in `n_layer`. | |
// Useful for always offloading the entire model to GPU. | |
size_t rwkv_get_n_layer(const(rwkv_context)* ctx); | |
// Returns the number of float elements in a complete state for the given model. | |
// This is the number of elements you'll need to allocate for a call to rwkv_eval, rwkv_eval_sequence, or rwkv_init_state. | |
size_t rwkv_get_state_len(const(rwkv_context)* ctx); | |
// Returns the number of float elements in the logits output of a given model. | |
// This is currently always identical to n_vocab. | |
size_t rwkv_get_logits_len(const(rwkv_context)* ctx); | |
// Initializes the given state so that passing it to rwkv_eval or rwkv_eval_sequence would be identical to passing NULL. | |
// Useful in cases where tracking the first call to these functions may be annoying or expensive. | |
// State must be initialized for behavior to be defined, passing a zeroed state to rwkv.cpp functions will result in NaNs. | |
// - state: FP32 buffer of size rwkv_get_state_len() to initialize | |
void rwkv_init_state(const(rwkv_context)* ctx, float* state); | |
// Frees all allocated memory and the context. | |
// Does not need to be called on the same thread that created the rwkv_context. | |
void rwkv_free(rwkv_context* ctx); | |
// Quantizes FP32 or FP16 model to one of quantized formats. | |
// Returns false on any error. Error messages would be printed to stderr. | |
// - model_file_path_in: path to model file in ggml format, must be either FP32 or FP16. | |
// - model_file_path_out: quantized model will be written here. | |
// - format_name: must be one of available format names below. | |
// Available format names: | |
// - Q4_0 | |
// - Q4_1 | |
// - Q5_0 | |
// - Q5_1 | |
// - Q8_0 | |
bool rwkv_quantize_model_file(const(char)* model_file_path_in, const(char)* model_file_path_out, const(char)* format_name); | |
// Returns system information string. | |
const(char)* rwkv_get_system_info_string(); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
lines : list[str] = open("rwkv_vocab_v20230424.txt", "r", encoding="utf-8").readlines() | |
tokens : list[bytes] = [] | |
for line in lines: | |
token_id = int(line[:line.index(' ')]) | |
token_content_raw = eval(line[line.index(' '):line.rindex(' ')]) | |
token_content : bytes = token_content_raw.encode("utf-8") if isinstance(token_content_raw, str) else token_content_raw | |
assert isinstance(token_content, bytes) | |
assert len(token_content) == int(line[line.rindex(' '):]) | |
tokens += [token_content] | |
json_data = json.dumps([list(arr) for arr in tokens], indent='\t') | |
with open("rwkv_vocab_tokens.json", 'w') as file: | |
file.write(json_data) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The bind was written for rwkv.cpp at commit hash
96ce1d9b3cb57b9fbec071602bd603a61757d74b
(2024-07-12)You can generate
rwkv_vocab_tokens.json
by downloading rwkv_vocab_v20230424.txt and runningvocab_txt_to_json.py
. If you don't want to deal with python, you can download and unpack rwkv_vocab_tokens.json.xz