Skip to content

Add support for Lambda-Extesion-Accept-Feature header #887

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 52 additions & 8 deletions lambda-extension/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use hyper::service::service_fn;

use hyper_util::rt::tokio::TokioIo;
use lambda_runtime_api_client::Client;
use serde::Deserialize;
use std::{
convert::Infallible, fmt, future::ready, future::Future, net::SocketAddr, path::PathBuf, pin::Pin, sync::Arc,
};
Expand Down Expand Up @@ -230,8 +231,7 @@ where
pub async fn register(self) -> Result<RegisteredExtension<E>, Error> {
let client = &Client::builder().build()?;

let extension_id = register(client, self.extension_name, self.events).await?;
let extension_id = extension_id.to_str()?;
let register_res = register(client, self.extension_name, self.events).await?;

// Logs API subscriptions must be requested during the Lambda init phase (see
// https://docs.aws.amazon.com/lambda/latest/dg/runtimes-logs-api.html#runtimes-logs-api-subscribing).
Expand Down Expand Up @@ -266,7 +266,7 @@ where
// Call Logs API to start receiving events
let req = requests::subscribe_request(
Api::LogsApi,
extension_id,
&register_res.extension_id,
self.log_types,
self.log_buffering,
self.log_port_number,
Expand Down Expand Up @@ -312,7 +312,7 @@ where
// Call Telemetry API to start receiving events
let req = requests::subscribe_request(
Api::TelemetryApi,
extension_id,
&register_res.extension_id,
self.telemetry_types,
self.telemetry_buffering,
self.telemetry_port_number,
Expand All @@ -326,7 +326,11 @@ where
}

Ok(RegisteredExtension {
extension_id: extension_id.to_string(),
extension_id: register_res.extension_id,
function_name: register_res.function_name,
function_version: register_res.function_version,
handler: register_res.handler,
account_id: register_res.account_id,
events_processor: self.events_processor,
})
}
Expand All @@ -339,7 +343,17 @@ where

/// An extension registered by calling [`Extension::register`].
pub struct RegisteredExtension<E> {
extension_id: String,
/// The ID of the registered extension. This ID is unique per extension and remains constant
pub extension_id: String,
/// The ID of the account the extension was registered to.
/// This will be `None` if the register request doesn't send the Lambda-Extension-Accept-Feature header
pub account_id: Option<String>,
/// The name of the Lambda function that the extension is registered with
pub function_name: String,
/// The version of the Lambda function that the extension is registered with
pub function_version: String,
/// The Lambda function handler that AWS Lambda invokes
pub handler: String,
events_processor: E,
}

Expand Down Expand Up @@ -468,12 +482,30 @@ where
}
}

#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct RegisterResponseBody {
function_name: String,
function_version: String,
handler: String,
account_id: Option<String>,
}

#[derive(Debug)]
struct RegisterResponse {
extension_id: String,
function_name: String,
function_version: String,
handler: String,
account_id: Option<String>,
}

/// Initialize and register the extension in the Extensions API
async fn register<'a>(
client: &'a Client,
extension_name: Option<&'a str>,
events: Option<&'a [&'a str]>,
) -> Result<http::HeaderValue, Error> {
) -> Result<RegisterResponse, Error> {
let name = match extension_name {
Some(name) => name.into(),
None => {
Expand Down Expand Up @@ -501,5 +533,17 @@ async fn register<'a>(
.get(requests::EXTENSION_ID_HEADER)
.ok_or_else(|| ExtensionError::boxed("missing extension id header"))
.map_err(|e| ExtensionError::boxed(e.to_string()))?;
Ok(header.clone())
let extension_id = header.to_str()?.to_string();

let (_, body) = res.into_parts();
let body = body.collect().await?.to_bytes();
let response: RegisterResponseBody = serde_json::from_slice(&body)?;

Ok(RegisterResponse {
extension_id,
function_name: response.function_name,
function_version: response.function_version,
handler: response.handler,
account_id: response.account_id,
})
}
6 changes: 6 additions & 0 deletions lambda-extension/src/requests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ const EXTENSION_ERROR_TYPE_HEADER: &str = "Lambda-Extension-Function-Error-Type"
const CONTENT_TYPE_HEADER_NAME: &str = "Content-Type";
const CONTENT_TYPE_HEADER_VALUE: &str = "application/json";

// Comma separated list of features the extension supports.
// `accountId` is currently the only supported feature.
const EXTENSION_ACCEPT_FEATURE: &str = "Lambda-Extension-Accept-Feature";
const EXTENSION_ACCEPT_FEATURE_VALUE: &str = "accountId";

pub(crate) fn next_event_request(extension_id: &str) -> Result<Request<Body>, Error> {
let req = build_request()
.method(Method::GET)
Expand All @@ -25,6 +30,7 @@ pub(crate) fn register_request(extension_name: &str, events: &[&str]) -> Result<
.method(Method::POST)
.uri("/2020-01-01/extension/register")
.header(EXTENSION_NAME_HEADER, extension_name)
.header(EXTENSION_ACCEPT_FEATURE, EXTENSION_ACCEPT_FEATURE_VALUE)
.header(CONTENT_TYPE_HEADER_NAME, CONTENT_TYPE_HEADER_VALUE)
.body(Body::from(serde_json::to_string(&events)?))?;

Expand Down
Loading