Skip to content

Implement custom deserializer for LambdaRequest #666

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lambda-events/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "aws_lambda_events"
version = "0.10.0"
version = "0.11.0"
description = "AWS Lambda event definitions"
authors = [
"Christian Legnitto <christian@legnitto.com>",
Expand Down
1 change: 1 addition & 0 deletions lambda-events/src/event/alb/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use serde::{Deserialize, Serialize};
/// `AlbTargetGroupRequest` contains data originating from the ALB Lambda target group integration
#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)]
#[serde(rename_all = "camelCase")]
#[serde(deny_unknown_fields)]
pub struct AlbTargetGroupRequest {
#[serde(with = "http_method")]
pub http_method: Method,
Expand Down
23 changes: 22 additions & 1 deletion lambda-events/src/event/apigw/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use std::collections::HashMap;
/// `ApiGatewayProxyRequest` contains data coming from the API Gateway proxy
#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)]
#[serde(rename_all = "camelCase")]
#[serde(deny_unknown_fields)]
pub struct ApiGatewayProxyRequest<T1 = Value>
where
T1: DeserializeOwned,
Expand Down Expand Up @@ -118,12 +119,25 @@ where
/// `ApiGatewayV2httpRequest` contains data coming from the new HTTP API Gateway
#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)]
#[serde(rename_all = "camelCase")]
#[serde(deny_unknown_fields)]
pub struct ApiGatewayV2httpRequest {
#[serde(default, rename = "type")]
pub kind: Option<String>,
#[serde(default)]
pub method_arn: Option<String>,
#[serde(with = "http_method", default = "default_http_method")]
pub http_method: Method,
#[serde(default)]
pub identity_source: Option<String>,
#[serde(default)]
pub authorization_token: Option<String>,
#[serde(default)]
pub resource: Option<String>,
#[serde(default)]
pub version: Option<String>,
#[serde(default)]
pub route_key: Option<String>,
#[serde(default)]
#[serde(default, alias = "path")]
pub raw_path: Option<String>,
#[serde(default)]
pub raw_query_string: Option<String>,
Expand Down Expand Up @@ -319,6 +333,7 @@ pub struct ApiGatewayRequestIdentity {
/// `ApiGatewayWebsocketProxyRequest` contains data coming from the API Gateway proxy
#[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)]
#[serde(rename_all = "camelCase")]
#[serde(deny_unknown_fields)]
pub struct ApiGatewayWebsocketProxyRequest<T1 = Value, T2 = Value>
where
T1: DeserializeOwned,
Expand Down Expand Up @@ -747,6 +762,10 @@ pub struct IamPolicyStatement {
pub resource: Vec<String>,
}

fn default_http_method() -> Method {
Method::GET
}

#[cfg(test)]
mod test {
use super::*;
Expand Down Expand Up @@ -901,6 +920,8 @@ mod test {
let output: String = serde_json::to_string(&parsed).unwrap();
let reparsed: ApiGatewayV2httpRequest = serde_json::from_slice(output.as_bytes()).unwrap();
assert_eq!(parsed, reparsed);
assert_eq!("REQUEST", parsed.kind.unwrap());
assert_eq!(Method::GET, parsed.http_method);
}

#[test]
Expand Down
145 changes: 92 additions & 53 deletions lambda-events/src/fixtures/example-apigw-request.json
Original file line number Diff line number Diff line change
@@ -1,55 +1,95 @@
{
"resource": "/{proxy+}",
"path": "/hello/world",
"httpMethod": "POST",
"headers": {
"Accept": "*/*",
"Accept-Encoding": "gzip, deflate",
"cache-control": "no-cache",
"CloudFront-Forwarded-Proto": "https",
"CloudFront-Is-Desktop-Viewer": "true",
"CloudFront-Is-Mobile-Viewer": "false",
"CloudFront-Is-SmartTV-Viewer": "false",
"CloudFront-Is-Tablet-Viewer": "false",
"CloudFront-Viewer-Country": "US",
"Content-Type": "application/json",
"headerName": "headerValue",
"Host": "gy415nuibc.execute-api.us-east-1.amazonaws.com",
"Postman-Token": "9f583ef0-ed83-4a38-aef3-eb9ce3f7a57f",
"User-Agent": "PostmanRuntime/2.4.5",
"Via": "1.1 d98420743a69852491bbdea73f7680bd.cloudfront.net (CloudFront)",
"X-Amz-Cf-Id": "pn-PWIJc6thYnZm5P0NMgOUglL1DYtl0gdeJky8tqsg8iS_sgsKD1A==",
"X-Forwarded-For": "54.240.196.186, 54.182.214.83",
"X-Forwarded-Port": "443",
"X-Forwarded-Proto": "https"
},
"multiValueHeaders": {
"Accept": ["*/*"],
"Accept-Encoding": ["gzip, deflate"],
"cache-control": ["no-cache"],
"CloudFront-Forwarded-Proto": ["https"],
"CloudFront-Is-Desktop-Viewer": ["true"],
"CloudFront-Is-Mobile-Viewer": ["false"],
"CloudFront-Is-SmartTV-Viewer": ["false"],
"CloudFront-Is-Tablet-Viewer": ["false"],
"CloudFront-Viewer-Country": ["US"],
"Content-Type": ["application/json"],
"headerName": ["headerValue"],
"Host": ["gy415nuibc.execute-api.us-east-1.amazonaws.com"],
"Postman-Token": ["9f583ef0-ed83-4a38-aef3-eb9ce3f7a57f"],
"User-Agent": ["PostmanRuntime/2.4.5"],
"Via": ["1.1 d98420743a69852491bbdea73f7680bd.cloudfront.net (CloudFront)"],
"X-Amz-Cf-Id": ["pn-PWIJc6thYnZm5P0NMgOUglL1DYtl0gdeJky8tqsg8iS_sgsKD1A=="],
"X-Forwarded-For": ["54.240.196.186, 54.182.214.83"],
"X-Forwarded-Port": ["443"],
"X-Forwarded-Proto": ["https"]
},
"path": "/hello/world",
"httpMethod": "POST",
"headers": {
"Accept": "*/*",
"Accept-Encoding": "gzip, deflate",
"cache-control": "no-cache",
"CloudFront-Forwarded-Proto": "https",
"CloudFront-Is-Desktop-Viewer": "true",
"CloudFront-Is-Mobile-Viewer": "false",
"CloudFront-Is-SmartTV-Viewer": "false",
"CloudFront-Is-Tablet-Viewer": "false",
"CloudFront-Viewer-Country": "US",
"Content-Type": "application/json",
"headerName": "headerValue",
"Host": "gy415nuibc.execute-api.us-east-1.amazonaws.com",
"Postman-Token": "9f583ef0-ed83-4a38-aef3-eb9ce3f7a57f",
"User-Agent": "PostmanRuntime/2.4.5",
"Via": "1.1 d98420743a69852491bbdea73f7680bd.cloudfront.net (CloudFront)",
"X-Amz-Cf-Id": "pn-PWIJc6thYnZm5P0NMgOUglL1DYtl0gdeJky8tqsg8iS_sgsKD1A==",
"X-Forwarded-For": "54.240.196.186, 54.182.214.83",
"X-Forwarded-Port": "443",
"X-Forwarded-Proto": "https"
},
"multiValueHeaders": {
"Accept": [
"*/*"
],
"Accept-Encoding": [
"gzip, deflate"
],
"cache-control": [
"no-cache"
],
"CloudFront-Forwarded-Proto": [
"https"
],
"CloudFront-Is-Desktop-Viewer": [
"true"
],
"CloudFront-Is-Mobile-Viewer": [
"false"
],
"CloudFront-Is-SmartTV-Viewer": [
"false"
],
"CloudFront-Is-Tablet-Viewer": [
"false"
],
"CloudFront-Viewer-Country": [
"US"
],
"Content-Type": [
"application/json"
],
"headerName": [
"headerValue"
],
"Host": [
"gy415nuibc.execute-api.us-east-1.amazonaws.com"
],
"Postman-Token": [
"9f583ef0-ed83-4a38-aef3-eb9ce3f7a57f"
],
"User-Agent": [
"PostmanRuntime/2.4.5"
],
"Via": [
"1.1 d98420743a69852491bbdea73f7680bd.cloudfront.net (CloudFront)"
],
"X-Amz-Cf-Id": [
"pn-PWIJc6thYnZm5P0NMgOUglL1DYtl0gdeJky8tqsg8iS_sgsKD1A=="
],
"X-Forwarded-For": [
"54.240.196.186, 54.182.214.83"
],
"X-Forwarded-Port": [
"443"
],
"X-Forwarded-Proto": [
"https"
]
},
"queryStringParameters": {
"name": "me"
},
"multiValueQueryStringParameters": {
"name": ["me"]
},
},
"multiValueQueryStringParameters": {
"name": [
"me"
]
},
"pathParameters": {
"proxy": "hello/world"
},
Expand All @@ -70,9 +110,9 @@
"accountId": "theAccountId",
"cognitoIdentityId": "theCognitoIdentityId",
"caller": "theCaller",
"apiKey": "theApiKey",
"apiKeyId": "theApiKeyId",
"accessKey": "ANEXAMPLEOFACCESSKEY",
"apiKey": "theApiKey",
"apiKeyId": "theApiKeyId",
"accessKey": "ANEXAMPLEOFACCESSKEY",
"sourceIp": "192.168.196.186",
"cognitoAuthenticationType": "theCognitoAuthenticationType",
"cognitoAuthenticationProvider": "theCognitoAuthenticationProvider",
Expand All @@ -92,5 +132,4 @@
"apiId": "gy415nuibc"
},
"body": "{\r\n\t\"a\": 1\r\n}"
}

}
2 changes: 1 addition & 1 deletion lambda-http/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ percent-encoding = "2.2"

[dependencies.aws_lambda_events]
path = "../lambda-events"
version = "0.10.0"
version = "0.11.0"
default-features = false
features = ["alb", "apigw"]

Expand Down
117 changes: 117 additions & 0 deletions lambda-http/src/deserializer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
use crate::request::LambdaRequest;
use aws_lambda_events::{
alb::AlbTargetGroupRequest,
apigw::{ApiGatewayProxyRequest, ApiGatewayV2httpRequest, ApiGatewayWebsocketProxyRequest},
};
use serde::{de::Error, Deserialize};

const ERROR_CONTEXT: &str = "this function expects a JSON payload from Amazon API Gateway, Amazon Elastic Load Balancer, or AWS Lambda Function URLs, but the data doesn't match any of those services' events";

impl<'de> Deserialize<'de> for LambdaRequest {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let content = match serde::__private::de::Content::deserialize(deserializer) {
Ok(content) => content,
Err(err) => return Err(err),
};
#[cfg(feature = "apigw_rest")]
if let Ok(res) =
ApiGatewayProxyRequest::deserialize(serde::__private::de::ContentRefDeserializer::<D::Error>::new(&content))
{
return Ok(LambdaRequest::ApiGatewayV1(res));
}
#[cfg(feature = "apigw_http")]
if let Ok(res) = ApiGatewayV2httpRequest::deserialize(
serde::__private::de::ContentRefDeserializer::<D::Error>::new(&content),
) {
return Ok(LambdaRequest::ApiGatewayV2(res));
}
#[cfg(feature = "alb")]
if let Ok(res) =
AlbTargetGroupRequest::deserialize(serde::__private::de::ContentRefDeserializer::<D::Error>::new(&content))
{
return Ok(LambdaRequest::Alb(res));
}
#[cfg(feature = "apigw_websockets")]
if let Ok(res) = ApiGatewayWebsocketProxyRequest::deserialize(serde::__private::de::ContentRefDeserializer::<
D::Error,
>::new(&content))
{
return Ok(LambdaRequest::WebSocket(res));
}

Err(Error::custom(ERROR_CONTEXT))
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_deserialize_apigw_rest() {
let data = include_bytes!("../../lambda-events/src/fixtures/example-apigw-request.json");

let req: LambdaRequest = serde_json::from_slice(data).expect("failed to deserialze apigw rest data");
match req {
LambdaRequest::ApiGatewayV1(req) => {
assert_eq!("12345678912", req.request_context.account_id.unwrap());
}
other => panic!("unexpected request variant: {:?}", other),
}
}

#[test]
fn test_deserialize_apigw_http() {
let data = include_bytes!("../../lambda-events/src/fixtures/example-apigw-v2-request-iam.json");

let req: LambdaRequest = serde_json::from_slice(data).expect("failed to deserialze apigw http data");
match req {
LambdaRequest::ApiGatewayV2(req) => {
assert_eq!("123456789012", req.request_context.account_id.unwrap());
}
other => panic!("unexpected request variant: {:?}", other),
}
}

#[test]
fn test_deserialize_alb() {
let data = include_bytes!(
"../../lambda-events/src/fixtures/example-alb-lambda-target-request-multivalue-headers.json"
);

let req: LambdaRequest = serde_json::from_slice(data).expect("failed to deserialze alb rest data");
match req {
LambdaRequest::Alb(req) => {
assert_eq!(
"arn:aws:elasticloadbalancing:us-east-1:123456789012:targetgroup/lambda-target/abcdefgh",
req.request_context.elb.target_group_arn.unwrap()
);
}
other => panic!("unexpected request variant: {:?}", other),
}
}

#[test]
fn test_deserialize_apigw_websocket() {
let data =
include_bytes!("../../lambda-events/src/fixtures/example-apigw-websocket-request-without-method.json");

let req: LambdaRequest = serde_json::from_slice(data).expect("failed to deserialze apigw websocket data");
match req {
LambdaRequest::WebSocket(req) => {
assert_eq!("CONNECT", req.request_context.event_type.unwrap());
}
other => panic!("unexpected request variant: {:?}", other),
}
}

#[test]
fn test_deserialize_error() {
let err = serde_json::from_str::<LambdaRequest>("{\"command\": \"hi\"}").unwrap_err();

assert_eq!(ERROR_CONTEXT, err.to_string());
}
}
1 change: 1 addition & 0 deletions lambda-http/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ pub use lambda_runtime::{self, service_fn, tower, Context, Error, Service};
use request::RequestFuture;
use response::ResponseFuture;

mod deserializer;
pub mod ext;
pub mod request;
mod response;
Expand Down
5 changes: 3 additions & 2 deletions lambda-http/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ use aws_lambda_events::apigw::{ApiGatewayWebsocketProxyRequest, ApiGatewayWebsoc
use aws_lambda_events::{encodings::Body, query_map::QueryMap};
use http::header::HeaderName;
use http::{HeaderMap, HeaderValue};

use serde::{Deserialize, Serialize};
use serde_json::error::Error as JsonError;

use std::future::Future;
use std::pin::Pin;
use std::{env, io::Read, mem};
Expand All @@ -33,8 +35,7 @@ use url::Url;
/// This is not intended to be a type consumed by crate users directly. The order
/// of the variants are notable. Serde will try to deserialize in this order.
#[doc(hidden)]
#[derive(Deserialize, Debug)]
#[serde(untagged)]
#[derive(Debug)]
pub enum LambdaRequest {
#[cfg(feature = "apigw_rest")]
ApiGatewayV1(ApiGatewayProxyRequest),
Expand Down