Skip to content

Commit

Permalink
feat: add dtype info
Browse files Browse the repository at this point in the history
  • Loading branch information
LuTaoChen authored and Jeffwhen committed Dec 29, 2022
1 parent 16dfb67 commit dd76be4
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pipeline/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ blob_info_t *get_input_info(unsigned runner_id, unsigned *num)
blob.name = net_info->input_names[i];
bm_shape_t &s = net_info->stages[0].input_shapes[i];
blob.num_dims = s.num_dims;
blob.dtype = (net_info->input_dtypes)[i];
memcpy(blob.dims, s.dims, s.num_dims * sizeof(int));
blob.scale = net_info->input_scales[i];
}
Expand All @@ -277,6 +278,7 @@ blob_info_t *get_output_info(unsigned runner_id, unsigned *num)
blob.name = net_info->output_names[i];
bm_shape_t &s = net_info->stages[0].output_shapes[i];
blob.num_dims = s.num_dims;
blob.dtype = (net_info->output_dtypes)[i];
memcpy(blob.dims, s.dims, s.num_dims * sizeof(int));
blob.scale = (net_info->output_scales)[i];
}
Expand Down
1 change: 1 addition & 0 deletions pipeline/interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ unsigned int runner_release_output(unsigned int output_num, const tensor_data_t
struct blob_info_t {
const char *name;
int num_dims;
int dtype;
int dims[BM_MAX_DIMS_NUM];
float scale;
};
Expand Down
3 changes: 3 additions & 0 deletions python/tpu_perf/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class BlobInfo(ct.Structure):
_fields_ = [
("name", ct.c_char_p),
("dims_num", ct.c_int),
("dtype", ct.c_int),
("dims", ct.c_int * 8),
("scale", ct.c_float)]

Expand Down Expand Up @@ -93,6 +94,7 @@ def get_input_info(self):
for _, info in zip(range(num.value), infos):
result[info.name.decode()] = dict(
scale=info.scale,
dtype=info.dtype,
shape=[info.dims[i] for i in range(info.dims_num)])
self.__lib.release_input_info(self.runner_id, infos)
return result
Expand All @@ -105,6 +107,7 @@ def get_output_info(self):
for _, info in zip(range(num.value), infos):
result[info.name.decode()] = dict(
scale=info.scale,
dtype=info.dtype,
shape=[info.dims[i] for i in range(info.dims_num)])
self.__lib.release_input_info(self.runner_id, infos)
return result
Expand Down

0 comments on commit dd76be4

Please sign in to comment.