Skip to content

Add "load_by_name" API at wasi-nn #4267

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/iwasm/libraries/wasi-nn/include/wasi_nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ load(graph_builder_array *builder, graph_encoding encoding,
__attribute__((import_module("wasi_nn")));

wasi_nn_error
load_by_name(const char *name, graph *g)
load_by_name(char *name, uint32_t name_len, graph *g)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a bug fix?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wasi_nn.h is a header for WebAssembly applications written in the C language. Is there a specific reason that we need to change it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, its a bugfix,

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we have two sets of APIs for historical reasons, we might remove one in another PR. For now, let's ensure both are functional.

  • I suggest we use WASM_ENABLE_WASI_EPHEMERAL_NN on the wasm side.
  • With this flag, we declare two sets of APIs in wasi_nn.h. Please align with the content of native_symbols_wasi_nn in wasi_nn.c

__attribute__((import_module("wasi_nn")));

/**
Expand Down
7 changes: 2 additions & 5 deletions core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,11 @@ is_valid_graph(TFLiteContext *tfl_ctx, graph g)
NN_ERR_PRINTF("Invalid graph: %d >= %d.", g, MAX_GRAPHS_PER_INST);
return runtime_error;
}
if (tfl_ctx->models[g].model_pointer == NULL) {
if (tfl_ctx->models[g].model_pointer == NULL
&& tfl_ctx->models[g].model == NULL) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original version can output different information based on various invalid argument cases. Is there a specific reason we need to merge them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is required to validate TFLitesContext.models[g] for both cases, using load() and load_by_name(). It will not be acceptable if the change disables one of these cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whether it is load or load_by_name, the check of models[g].model_pointerdoes not seem to be necessary, just make sure the models[g].model is not NULL maybe is enough.
Do you have any idea?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does not seem to be necessary.

Why is that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after this operation, the models[g].model_pointer's connect has been saved in models[g].model.

NN_ERR_PRINTF("Context (model) non-initialized.");
return runtime_error;
}
if (tfl_ctx->models[g].model == NULL) {
NN_ERR_PRINTF("Context (tflite model) non-initialized.");
return runtime_error;
}
return success;
}

Expand Down
7 changes: 6 additions & 1 deletion core/iwasm/libraries/wasi-nn/test/utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ wasm_load(char *model_name, graph *g, execution_target target)
wasi_nn_error
wasm_load_by_name(const char *model_name, graph *g)
{
wasi_nn_error res = load_by_name(model_name, g);
wasi_nn_error res = load_by_name(model_name, strlen(model_name), g);
return res;
}

Expand Down Expand Up @@ -108,7 +108,12 @@ run_inference(execution_target target, float *input, uint32_t *input_size,
uint32_t num_output_tensors)
{
graph graph;

#if WASM_ENABLE_WASI_EPHEMERAL_NN == 0
if (wasm_load(model_name, &graph, target) != success) {
#else
if (wasm_load_by_name(model_name, &graph) != success) {
#endif
NN_ERR_PRINTF("Error when loading model.");
exit(1);
}
Expand Down
Loading