Skip to content

Add "load_by_name" API at wasi-nn #4267

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

HongxiaWangSSSS
Copy link
Contributor

When using WASI-NN ,we want to reduce copying the AI model from the host to the WASM by using "load_by_name" API.
We also want to use it to improve the performance, keep the safety and WASI-NN also supports this method on different backends.

I test with 3 tflite models.(x86_64, Ubuntu 22.04)

image

Both coco_ssd_mobilenet_v1 and coco_ssd_mobilenet_v3 are for detection, its file size and input tensor size is different.
mobilenet_v2 is for classification and size is more bigger.

The time consumed does not show a linear growth as the file size increases.
For most cases, load_by_name will more faster whether the load tflite or the entire inference process (load+ set input +compute +get output).

@HongxiaWangSSSS HongxiaWangSSSS marked this pull request as ready for review May 12, 2025 05:24
@@ -30,7 +30,7 @@ load(graph_builder_array *builder, graph_encoding encoding,
__attribute__((import_module("wasi_nn")));

wasi_nn_error
load_by_name(const char *name, graph *g)
load_by_name(char *name, uint32_t name_len, graph *g)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a bug fix?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wasi_nn.h is a header for WebAssembly applications written in the C language. Is there a specific reason that we need to change it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, its a bugfix,

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we have two sets of APIs for historical reasons, we might remove one in another PR. For now, let's ensure both are functional.

  • I suggest we use WASM_ENABLE_WASI_EPHEMERAL_NN on the wasm side.
  • With this flag, we declare two sets of APIs in wasi_nn.h. Please align with the content of native_symbols_wasi_nn in wasi_nn.c

@@ -697,6 +697,7 @@ static NativeSymbol native_symbols_wasi_nn[] = {
REG_NATIVE_FUNC(get_output, "(ii*i*)i"),
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
REG_NATIVE_FUNC(load, "(*ii*)i"),
REG_NATIVE_FUNC(load_by_name, "(*i*)i"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use -DWASM_ENABLE_WASI_EPHEMERAL_NN=1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at first, i just built with -DWASM_ENABLE_WASI_NN=1.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If use -DWAMR_BUILD_WASI_EPHEMERAL_NN=1 during compilation, you will be able to use the set of APIs, including load_by_name(). There is no need to change this line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If -DWAMR_BUILD_WASI_EPHEMERAL_NN=1 must be added , I think there is no need to add this line.
Currently, it is possible to use the default wasi-nn instead of wasi_ephemeral_nn.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Please do it.

@@ -30,7 +30,7 @@ load(graph_builder_array *builder, graph_encoding encoding,
__attribute__((import_module("wasi_nn")));

wasi_nn_error
load_by_name(const char *name, graph *g)
load_by_name(char *name, uint32_t name_len, graph *g)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wasi_nn.h is a header for WebAssembly applications written in the C language. Is there a specific reason that we need to change it?

@@ -85,14 +85,11 @@ is_valid_graph(TFLiteContext *tfl_ctx, graph g)
NN_ERR_PRINTF("Invalid graph: %d >= %d.", g, MAX_GRAPHS_PER_INST);
return runtime_error;
}
if (tfl_ctx->models[g].model_pointer == NULL) {
if (tfl_ctx->models[g].model_pointer == NULL
&& tfl_ctx->models[g].model == NULL) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original version can output different information based on various invalid argument cases. Is there a specific reason we need to merge them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is required to validate TFLitesContext.models[g] for both cases, using load() and load_by_name(). It will not be acceptable if the change disables one of these cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whether it is load or load_by_name, the check of models[g].model_pointerdoes not seem to be necessary, just make sure the models[g].model is not NULL maybe is enough.
Do you have any idea?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does not seem to be necessary.

Why is that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after this operation, the models[g].model_pointer's connect has been saved in models[g].model.

@@ -85,14 +85,11 @@ is_valid_graph(TFLiteContext *tfl_ctx, graph g)
NN_ERR_PRINTF("Invalid graph: %d >= %d.", g, MAX_GRAPHS_PER_INST);
return runtime_error;
}
if (tfl_ctx->models[g].model_pointer == NULL) {
if (tfl_ctx->models[g].model_pointer == NULL
&& tfl_ctx->models[g].model == NULL) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does not seem to be necessary.

Why is that?

@@ -697,6 +697,7 @@ static NativeSymbol native_symbols_wasi_nn[] = {
REG_NATIVE_FUNC(get_output, "(ii*i*)i"),
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
REG_NATIVE_FUNC(load, "(*ii*)i"),
REG_NATIVE_FUNC(load_by_name, "(*i*)i"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Please do it.

@@ -108,7 +109,12 @@ run_inference(execution_target target, float *input, uint32_t *input_size,
uint32_t num_output_tensors)
{
graph graph;

#if USE_WASM_LOAD_BY_NAME == 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once we have two API set in wasi_nn.h separated by WASM_ENABLE_WASI_EPHEMERAL_NN:

  • Wrap wasm_load_by_name() with WASM_ENABLE_WASI_EPHEMERAL_NN.
  • There is no need to introduce USE_WASM_LOAD_BY_NAME.
  • Use WASM_ENABLE_WASI_EPHEMERAL_NN to decide whether to use wasm_load() or wasm_load_by_name().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated!

@@ -30,7 +30,7 @@ load(graph_builder_array *builder, graph_encoding encoding,
__attribute__((import_module("wasi_nn")));

wasi_nn_error
load_by_name(const char *name, graph *g)
load_by_name(char *name, uint32_t name_len, graph *g)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we have two sets of APIs for historical reasons, we might remove one in another PR. For now, let's ensure both are functional.

  • I suggest we use WASM_ENABLE_WASI_EPHEMERAL_NN on the wasm side.
  • With this flag, we declare two sets of APIs in wasi_nn.h. Please align with the content of native_symbols_wasi_nn in wasi_nn.c

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants