Skip to content

Commit

Permalink
sam : passing parameters and simple prompt (ggerganov#598)
Browse files Browse the repository at this point in the history
- most of the model hyperparameters can now be set on CLI
- user can define their own mask prefix
- user can define their own point prompt, although just one

Signed-off-by: Jiri Podivin <jpodivin@gmail.com>
  • Loading branch information
jpodivin authored Nov 2, 2023
1 parent 4058f5b commit d271d35
Showing 1 changed file with 86 additions and 24 deletions.
110 changes: 86 additions & 24 deletions examples/sam/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,22 @@ struct sam_image_f32 {
std::vector<float> data;
};

struct sam_params {
int32_t seed = -1; // RNG seed
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());

std::string model = "models/sam-vit-b/ggml-model-f16.bin"; // model path
std::string fname_inp = "img.jpg";
std::string fname_out = "img.out";
float mask_threshold = 0.f;
float iou_threshold = 0.88f;
float stability_score_threshold = 0.95f;
float stability_score_offset = 1.0f;
float eps = 1e-6f;
float eps_decoder_transformer = 1e-5f;
sam_point pt = { 414.375f, 162.796875f, };
};

void print_t_f32(const char* title, struct ggml_tensor * t, int n = 10) {
printf("%s\n", title);
float * data = (float *)t->data;
Expand Down Expand Up @@ -469,12 +485,12 @@ bool sam_image_preprocess(const sam_image_u8 & img, sam_image_f32 & res) {
}

// load the model's weights from a file
bool sam_model_load(const std::string & fname, sam_model & model) {
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
bool sam_model_load(const sam_params & params, sam_model & model) {
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, params.model.c_str());

auto fin = std::ifstream(fname, std::ios::binary);
auto fin = std::ifstream(params.model, std::ios::binary);
if (!fin) {
fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
fprintf(stderr, "%s: failed to open '%s'\n", __func__, params.model.c_str());
return false;
}

Expand All @@ -483,13 +499,21 @@ bool sam_model_load(const std::string & fname, sam_model & model) {
uint32_t magic;
fin.read((char *) &magic, sizeof(magic));
if (magic != 0x67676d6c) {
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, params.model.c_str());
return false;
}
}

// load hparams
{
// Override defaults with user choices
model.hparams.mask_threshold = params.mask_threshold;
model.hparams.iou_threshold = params.iou_threshold;
model.hparams.stability_score_threshold = params.stability_score_threshold;
model.hparams.stability_score_offset = params.stability_score_offset;
model.hparams.eps = params.eps;
model.hparams.eps_decoder_transformer = params.eps_decoder_transformer;

auto & hparams = model.hparams;

fin.read((char *) &hparams.n_enc_state, sizeof(hparams.n_enc_state));
Expand All @@ -510,14 +534,15 @@ bool sam_model_load(const std::string & fname, sam_model & model) {
printf("%s: qntvr = %d\n", __func__, qntvr);

hparams.ftype %= GGML_QNT_VERSION_FACTOR;

}

// for the big tensors, we have the option to store the data in 16-bit floats or quantized
// in order to save memory and also to speed up the computation
ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
if (wtype == GGML_TYPE_COUNT) {
fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n",
__func__, fname.c_str(), model.hparams.ftype);
__func__, params.model.c_str(), model.hparams.ftype);
return false;
}

Expand Down Expand Up @@ -1791,7 +1816,7 @@ bool sam_decode_mask(
return true;
}

bool sam_write_masks(const sam_hparams& hparams, int nx, int ny, const sam_state & state) {
bool sam_write_masks(const sam_hparams& hparams, int nx, int ny, const sam_state & state, const std::string & fname) {
if (state.low_res_masks->ne[2] == 0) return true;
if (state.low_res_masks->ne[2] != state.iou_predictions->ne[0]) {
printf("Error: number of masks (%d) does not match number of iou predictions (%d)\n", (int)state.low_res_masks->ne[2], (int)state.iou_predictions->ne[0]);
Expand Down Expand Up @@ -1938,7 +1963,7 @@ bool sam_write_masks(const sam_hparams& hparams, int nx, int ny, const sam_state
printf("Mask %d: iou = %f, stability_score = %f, bbox (%d, %d), (%d, %d)\n",
i, iou_data[i], stability_score, min_ix, max_ix, min_iy, max_iy);

std::string filename = "mask_out_" + std::to_string(i) + ".png";
std::string filename = fname + std::to_string(i) + ".png";
if (!stbi_write_png(filename.c_str(), res.nx, res.ny, 1, res.data.data(), res.nx)) {
printf("%s: failed to write mask %s\n", __func__, filename.c_str());
return false;
Expand Down Expand Up @@ -1967,7 +1992,7 @@ struct ggml_cgraph * sam_build_fast_graph(

prompt_encoder_result enc_res = sam_encode_prompt(model, ctx0, gf, state, nx, ny, point);
if (!enc_res.embd_prompt_sparse || !enc_res.embd_prompt_dense) {
fprintf(stderr, "%s: failed to encode prompt\n", __func__);
fprintf(stderr, "%s: failed to encode prompt (%f, %f)\n", __func__, point.x, point.y);
return {};
}

Expand All @@ -1986,14 +2011,6 @@ struct ggml_cgraph * sam_build_fast_graph(

return gf;
}
struct sam_params {
int32_t seed = -1; // RNG seed
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());

std::string model = "models/sam-vit-b/ggml-model-f16.bin"; // model path
std::string fname_inp = "img.jpg";
std::string fname_out = "img.out";
};

void sam_print_usage(int argc, char ** argv, const sam_params & params) {
fprintf(stderr, "usage: %s [options]\n", argv[0]);
Expand All @@ -2007,7 +2024,23 @@ void sam_print_usage(int argc, char ** argv, const sam_params & params) {
fprintf(stderr, " -i FNAME, --inp FNAME\n");
fprintf(stderr, " input file (default: %s)\n", params.fname_inp.c_str());
fprintf(stderr, " -o FNAME, --out FNAME\n");
fprintf(stderr, " output file (default: %s)\n", params.fname_out.c_str());
fprintf(stderr, " mask file name prefix (default: %s)\n", params.fname_out.c_str());
fprintf(stderr, "SAM hyperparameters:\n");
fprintf(stderr, " -mt FLOAT, --mask-threshold\n");
fprintf(stderr, " mask threshold (default: %f)\n", params.mask_threshold);
fprintf(stderr, " -it FLOAT, --iou-threshold\n");
fprintf(stderr, " iou threshold (default: %f)\n", params.iou_threshold);
fprintf(stderr, " -st FLOAT, --score-threshold\n");
fprintf(stderr, " score threshold (default: %f)\n", params.stability_score_threshold);
fprintf(stderr, " -so FLOAT, --score-offset\n");
fprintf(stderr, " score offset (default: %f)\n", params.stability_score_offset);
fprintf(stderr, " -e FLOAT, --epsilon\n");
fprintf(stderr, " epsilon (default: %f)\n", params.eps);
fprintf(stderr, " -ed FLOAT, --epsilon-decoder-transformer\n");
fprintf(stderr, " epsilon decoder transformer (default: %f)\n", params.eps_decoder_transformer);
fprintf(stderr, "SAM prompt:\n");
fprintf(stderr, " -p TUPLE, --point-prompt\n");
fprintf(stderr, " point to be used as prompt for SAM (default: %f,%f). Must be in a format FLOAT,FLOAT \n", params.pt.x, params.pt.y);
fprintf(stderr, "\n");
}

Expand All @@ -2025,6 +2058,34 @@ bool sam_params_parse(int argc, char ** argv, sam_params & params) {
params.fname_inp = argv[++i];
} else if (arg == "-o" || arg == "--out") {
params.fname_out = argv[++i];
} else if (arg == "-mt" || arg == "--mask-threshold") {
params.mask_threshold = std::stof(argv[++i]);
} else if (arg == "-it" || arg == "--iou-threshold") {
params.iou_threshold = std::stof(argv[++i]);
} else if (arg == "-st" || arg == "--score-threshold") {
params.stability_score_threshold = std::stof(argv[++i]);
} else if (arg == "-so" || arg == "--score-offset") {
params.stability_score_offset = std::stof(argv[++i]);
} else if (arg == "-e" || arg == "--epsilon") {
params.eps = std::stof(argv[++i]);
} else if (arg == "-ed" || arg == "--epsilon-decoder-transformer") {
params.eps_decoder_transformer = std::stof(argv[++i]);
} else if (arg == "-p" || arg == "--point-prompt") {
// TODO multiple points per model invocation
char* point = argv[++i];

char* coord = strtok(point, ",");
if (!coord){
fprintf(stderr, "Error while parsing prompt!\n");
exit(1);
}
params.pt.x = std::stof(coord);
coord = strtok(NULL, ",");
if (!coord){
fprintf(stderr, "Error while parsing prompt!\n");
exit(1);
}
params.pt.y = std::stof(coord);
} else if (arg == "-h" || arg == "--help") {
sam_print_usage(argc, argv, params);
exit(0);
Expand Down Expand Up @@ -2078,7 +2139,7 @@ int main(int argc, char ** argv) {
{
const int64_t t_start_us = ggml_time_us();

if (!sam_model_load(params.model, model)) {
if (!sam_model_load(params, model)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
return 1;
}
Expand Down Expand Up @@ -2147,10 +2208,11 @@ int main(int argc, char ** argv) {
state.buf_compute_fast.resize(ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead());
state.allocr = ggml_allocr_new_measure(tensor_alignment);

// TODO: user input
const sam_point pt = { 414.375f, 162.796875f, };
// TODO: more varied prompts
fprintf(stderr, "prompt: (%f, %f)\n", params.pt.x, params.pt.y);

// measure memory requirements for the graph
struct ggml_cgraph * gf_measure = sam_build_fast_graph(model, state, img0.nx, img0.ny, pt);
struct ggml_cgraph * gf_measure = sam_build_fast_graph(model, state, img0.nx, img0.ny, params.pt);
if (!gf_measure) {
fprintf(stderr, "%s: failed to build fast graph to measure\n", __func__);
return 1;
Expand All @@ -2166,7 +2228,7 @@ int main(int argc, char ** argv) {
// compute the graph with the measured exact memory requirements from above
ggml_allocr_reset(state.allocr);

struct ggml_cgraph * gf = sam_build_fast_graph(model, state, img0.nx, img0.ny, pt);
struct ggml_cgraph * gf = sam_build_fast_graph(model, state, img0.nx, img0.ny, params.pt);
if (!gf) {
fprintf(stderr, "%s: failed to build fast graph\n", __func__);
return 1;
Expand All @@ -2182,7 +2244,7 @@ int main(int argc, char ** argv) {
state.allocr = NULL;
}

if (!sam_write_masks(model.hparams, img0.nx, img0.ny, state)) {
if (!sam_write_masks(model.hparams, img0.nx, img0.ny, state, params.fname_out)) {
fprintf(stderr, "%s: failed to write masks\n", __func__);
return 1;
}
Expand Down

0 comments on commit d271d35

Please sign in to comment.