-
Notifications
You must be signed in to change notification settings - Fork 695
Add "load_by_name" API at wasi-nn #4267
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add "load_by_name" API at wasi-nn #4267
Conversation
@@ -30,7 +30,7 @@ load(graph_builder_array *builder, graph_encoding encoding, | |||
__attribute__((import_module("wasi_nn"))); | |||
|
|||
wasi_nn_error | |||
load_by_name(const char *name, graph *g) | |||
load_by_name(char *name, uint32_t name_len, graph *g) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this a bug fix?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wasi_nn.h is a header for WebAssembly applications written in the C language. Is there a specific reason that we need to change it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, its a bugfix,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we have two sets of APIs for historical reasons, we might remove one in another PR. For now, let's ensure both are functional.
- I suggest we use
WASM_ENABLE_WASI_EPHEMERAL_NN
on the wasm side. - With this flag, we declare two sets of APIs in wasi_nn.h. Please align with the content of
native_symbols_wasi_nn
in wasi_nn.c
@@ -697,6 +697,7 @@ static NativeSymbol native_symbols_wasi_nn[] = { | |||
REG_NATIVE_FUNC(get_output, "(ii*i*)i"), | |||
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */ | |||
REG_NATIVE_FUNC(load, "(*ii*)i"), | |||
REG_NATIVE_FUNC(load_by_name, "(*i*)i"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use -DWASM_ENABLE_WASI_EPHEMERAL_NN=1
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
at first, i just built with -DWASM_ENABLE_WASI_NN=1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If use -DWAMR_BUILD_WASI_EPHEMERAL_NN=1
during compilation, you will be able to use the set of APIs, including load_by_name()
. There is no need to change this line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If -DWAMR_BUILD_WASI_EPHEMERAL_NN=1
must be added , I think there is no need to add this line.
Currently, it is possible to use the default wasi-nn
instead of wasi_ephemeral_nn
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Please do it.
@@ -30,7 +30,7 @@ load(graph_builder_array *builder, graph_encoding encoding, | |||
__attribute__((import_module("wasi_nn"))); | |||
|
|||
wasi_nn_error | |||
load_by_name(const char *name, graph *g) | |||
load_by_name(char *name, uint32_t name_len, graph *g) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wasi_nn.h is a header for WebAssembly applications written in the C language. Is there a specific reason that we need to change it?
@@ -85,14 +85,11 @@ is_valid_graph(TFLiteContext *tfl_ctx, graph g) | |||
NN_ERR_PRINTF("Invalid graph: %d >= %d.", g, MAX_GRAPHS_PER_INST); | |||
return runtime_error; | |||
} | |||
if (tfl_ctx->models[g].model_pointer == NULL) { | |||
if (tfl_ctx->models[g].model_pointer == NULL | |||
&& tfl_ctx->models[g].model == NULL) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original version can output different information based on various invalid argument cases. Is there a specific reason we need to merge them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If call load_by_name(), there is no need to save the tflite buf to model_pointer.
https://github.com/HongxiaWangSSSS/wasm-micro-runtime/blob/c5414fd28baf973e3c95db1318de4d26f88007d3/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp#L141C49-L141C68
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure why not free model_point after below operation in load()
https://github.com/HongxiaWangSSSS/wasm-micro-runtime/blob/c5414fd28baf973e3c95db1318de4d26f88007d3/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp#L151
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is required to validate TFLitesContext.models[g]
for both cases, using load()
and load_by_name()
. It will not be acceptable if the change disables one of these cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whether it is load
or load_by_name
, the check of models[g].model_pointer
does not seem to be necessary, just make sure the models[g].model
is not NULL maybe is enough.
Do you have any idea?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does not seem to be necessary.
Why is that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
after this operation, the models[g].model_pointer
's connect has been saved in models[g].model
.
@@ -85,14 +85,11 @@ is_valid_graph(TFLiteContext *tfl_ctx, graph g) | |||
NN_ERR_PRINTF("Invalid graph: %d >= %d.", g, MAX_GRAPHS_PER_INST); | |||
return runtime_error; | |||
} | |||
if (tfl_ctx->models[g].model_pointer == NULL) { | |||
if (tfl_ctx->models[g].model_pointer == NULL | |||
&& tfl_ctx->models[g].model == NULL) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does not seem to be necessary.
Why is that?
@@ -697,6 +697,7 @@ static NativeSymbol native_symbols_wasi_nn[] = { | |||
REG_NATIVE_FUNC(get_output, "(ii*i*)i"), | |||
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */ | |||
REG_NATIVE_FUNC(load, "(*ii*)i"), | |||
REG_NATIVE_FUNC(load_by_name, "(*i*)i"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Please do it.
@@ -108,7 +109,12 @@ run_inference(execution_target target, float *input, uint32_t *input_size, | |||
uint32_t num_output_tensors) | |||
{ | |||
graph graph; | |||
|
|||
#if USE_WASM_LOAD_BY_NAME == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once we have two API set in wasi_nn.h separated by WASM_ENABLE_WASI_EPHEMERAL_NN
:
- Wrap
wasm_load_by_name()
withWASM_ENABLE_WASI_EPHEMERAL_NN
. - There is no need to introduce
USE_WASM_LOAD_BY_NAME
. - Use
WASM_ENABLE_WASI_EPHEMERAL_NN
to decide whether to usewasm_load()
orwasm_load_by_name()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated!
@@ -30,7 +30,7 @@ load(graph_builder_array *builder, graph_encoding encoding, | |||
__attribute__((import_module("wasi_nn"))); | |||
|
|||
wasi_nn_error | |||
load_by_name(const char *name, graph *g) | |||
load_by_name(char *name, uint32_t name_len, graph *g) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we have two sets of APIs for historical reasons, we might remove one in another PR. For now, let's ensure both are functional.
- I suggest we use
WASM_ENABLE_WASI_EPHEMERAL_NN
on the wasm side. - With this flag, we declare two sets of APIs in wasi_nn.h. Please align with the content of
native_symbols_wasi_nn
in wasi_nn.c
When using WASI-NN ,we want to reduce copying the AI model from the host to the WASM by using "load_by_name" API.
We also want to use it to improve the performance, keep the safety and WASI-NN also supports this method on different backends.
I test with 3 tflite models.(x86_64, Ubuntu 22.04)
Both coco_ssd_mobilenet_v1 and coco_ssd_mobilenet_v3 are for detection, its file size and input tensor size is different.
mobilenet_v2 is for classification and size is more bigger.
The time consumed does not show a linear growth as the file size increases.
For most cases, load_by_name will more faster whether the load tflite or the entire inference process (load+ set input +compute +get output).