From aa9953104fe8e3e727e7e4dc0a82f773460d7ff4 Mon Sep 17 00:00:00 2001 From: Will Godbe Date: Tue, 2 Aug 2022 14:07:06 -0700 Subject: [PATCH 01/10] RateLimitingMiddleware updates --- .../samples/RateLimitingSample/Program.cs | 49 +++++++------ .../src/DisableRateLimitingAttribute.cs | 12 ++++ ...data.cs => EnableRateLimitingAttribute.cs} | 7 +- .../RateLimiting/src/IRateLimiterMetadata.cs | 15 ---- .../RateLimiting/src/PublicAPI.Unshipped.txt | 17 +++-- ...iterEndpointConventionBuilderExtensions.cs | 20 +++++- .../src/RateLimiterOptionsExtensions.cs | 59 ++++++++++------ .../RateLimiterServiceCollectionExtensions.cs | 28 ++++++++ .../src/RateLimitingMiddleware.cs | 11 ++- .../test/RateLimitingMiddlewareTests.cs | 70 ++++++++++++++----- 10 files changed, 198 insertions(+), 90 deletions(-) create mode 100644 src/Middleware/RateLimiting/src/DisableRateLimitingAttribute.cs rename src/Middleware/RateLimiting/src/{RateLimiterMetadata.cs => EnableRateLimitingAttribute.cs} (64%) delete mode 100644 src/Middleware/RateLimiting/src/IRateLimiterMetadata.cs create mode 100644 src/Middleware/RateLimiting/src/RateLimiterServiceCollectionExtensions.cs diff --git a/src/Middleware/RateLimiting/samples/RateLimitingSample/Program.cs b/src/Middleware/RateLimiting/samples/RateLimitingSample/Program.cs index d8a6fe95832e..ec532bf7c5de 100644 --- a/src/Middleware/RateLimiting/samples/RateLimitingSample/Program.cs +++ b/src/Middleware/RateLimiting/samples/RateLimitingSample/Program.cs @@ -16,35 +16,38 @@ // Inject an ILogger builder.Services.AddLogging(); -var app = builder.Build(); - var todoName = "todoPolicy"; var completeName = "completePolicy"; var helloName = "helloPolicy"; -// Define endpoint limiters and a global limiter. -var options = new RateLimiterOptions() - .AddTokenBucketLimiter(todoName, new TokenBucketRateLimiterOptions - { - TokenLimit = 1, - QueueProcessingOrder = QueueProcessingOrder.OldestFirst, - QueueLimit = 1, - ReplenishmentPeriod = TimeSpan.FromSeconds(10), - TokensPerPeriod = 1 - }) - .AddPolicy(completeName, new SampleRateLimiterPolicy(NullLogger.Instance)) - .AddPolicy(helloName); -// The global limiter will be a concurrency limiter with a max permit count of 10 and a queue depth of 5. -options.GlobalLimiter = PartitionedRateLimiter.Create(context => +builder.Services.AddRateLimiter(options => +{ + // Define endpoint limiters and a global limiter. + options.AddTokenBucketLimiter(todoName, new TokenBucketRateLimiterOptions + { + TokenLimit = 1, + QueueProcessingOrder = QueueProcessingOrder.OldestFirst, + QueueLimit = 1, + ReplenishmentPeriod = TimeSpan.FromSeconds(10), + TokensPerPeriod = 1 + }) + .AddPolicy(completeName, new SampleRateLimiterPolicy(NullLogger.Instance)) + .AddPolicy(helloName); + // The global limiter will be a concurrency limiter with a max permit count of 10 and a queue depth of 5. + options.GlobalLimiter = PartitionedRateLimiter.Create(context => + { + return RateLimitPartition.GetConcurrencyLimiter("globalLimiter", key => new ConcurrencyLimiterOptions { - return RateLimitPartition.GetConcurrencyLimiter("globalLimiter", key => new ConcurrencyLimiterOptions - { - PermitLimit = 10, - QueueProcessingOrder = QueueProcessingOrder.NewestFirst, - QueueLimit = 5 - }); + PermitLimit = 10, + QueueProcessingOrder = QueueProcessingOrder.NewestFirst, + QueueLimit = 5 }); -app.UseRateLimiter(options); + }); +}); + +var app = builder.Build(); + +app.UseRateLimiter(); // The limiter on this endpoint allows 1 request every 5 seconds app.MapGet("/", () => "Hello World!").RequireRateLimiting(helloName); diff --git a/src/Middleware/RateLimiting/src/DisableRateLimitingAttribute.cs b/src/Middleware/RateLimiting/src/DisableRateLimitingAttribute.cs new file mode 100644 index 000000000000..5990c18ae251 --- /dev/null +++ b/src/Middleware/RateLimiting/src/DisableRateLimitingAttribute.cs @@ -0,0 +1,12 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.AspNetCore.RateLimiting; + +/// +/// Metadata that disables request rate limiting on an endpoint. +/// +[AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = false, Inherited = true)] +public sealed class DisableRateLimitingAttribute: Attribute +{ +} diff --git a/src/Middleware/RateLimiting/src/RateLimiterMetadata.cs b/src/Middleware/RateLimiting/src/EnableRateLimitingAttribute.cs similarity index 64% rename from src/Middleware/RateLimiting/src/RateLimiterMetadata.cs rename to src/Middleware/RateLimiting/src/EnableRateLimitingAttribute.cs index 08498eb77199..242c6ba69634 100644 --- a/src/Middleware/RateLimiting/src/RateLimiterMetadata.cs +++ b/src/Middleware/RateLimiting/src/EnableRateLimitingAttribute.cs @@ -6,13 +6,14 @@ namespace Microsoft.AspNetCore.RateLimiting; /// /// Metadata that provides endpoint-specific request rate limiting. /// -internal sealed class RateLimiterMetadata : IRateLimiterMetadata +[AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = false, Inherited = true)] +public sealed class EnableRateLimitingAttribute : Attribute { /// - /// Creates a new instance of using the specified policy. + /// Creates a new instance of using the specified policy. /// /// The name of the policy which needs to be applied. - public RateLimiterMetadata(string policyName) + public EnableRateLimitingAttribute(string policyName) { PolicyName = policyName; } diff --git a/src/Middleware/RateLimiting/src/IRateLimiterMetadata.cs b/src/Middleware/RateLimiting/src/IRateLimiterMetadata.cs deleted file mode 100644 index 54b8306cd5a7..000000000000 --- a/src/Middleware/RateLimiting/src/IRateLimiterMetadata.cs +++ /dev/null @@ -1,15 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -namespace Microsoft.AspNetCore.RateLimiting; - -/// -/// An interface which can be used to identify a type which provides metadata needed for enabling request rate limiting support. -/// -internal interface IRateLimiterMetadata -{ - /// - /// The name of the policy which needs to be applied. - /// - string PolicyName { get; } -} diff --git a/src/Middleware/RateLimiting/src/PublicAPI.Unshipped.txt b/src/Middleware/RateLimiting/src/PublicAPI.Unshipped.txt index 3bbcc5cb4c20..5adea6d5b898 100644 --- a/src/Middleware/RateLimiting/src/PublicAPI.Unshipped.txt +++ b/src/Middleware/RateLimiting/src/PublicAPI.Unshipped.txt @@ -1,5 +1,11 @@ Microsoft.AspNetCore.Builder.RateLimiterApplicationBuilderExtensions Microsoft.AspNetCore.Builder.RateLimiterEndpointConventionBuilderExtensions +Microsoft.AspNetCore.Builder.RateLimiterServiceCollectionExtensions +Microsoft.AspNetCore.RateLimiting.DisableRateLimitingAttribute +Microsoft.AspNetCore.RateLimiting.DisableRateLimitingAttribute.DisableRateLimitingAttribute() -> void +Microsoft.AspNetCore.RateLimiting.EnableRateLimitingAttribute +Microsoft.AspNetCore.RateLimiting.EnableRateLimitingAttribute.EnableRateLimitingAttribute(string! policyName) -> void +Microsoft.AspNetCore.RateLimiting.EnableRateLimitingAttribute.PolicyName.get -> string! Microsoft.AspNetCore.RateLimiting.IRateLimiterPolicy Microsoft.AspNetCore.RateLimiting.IRateLimiterPolicy.GetPartition(Microsoft.AspNetCore.Http.HttpContext! httpContext) -> System.Threading.RateLimiting.RateLimitPartition Microsoft.AspNetCore.RateLimiting.IRateLimiterPolicy.OnRejected.get -> System.Func? @@ -23,9 +29,10 @@ Microsoft.AspNetCore.RateLimiting.RateLimiterOptions.RejectionStatusCode.set -> Microsoft.AspNetCore.RateLimiting.RateLimiterOptionsExtensions static Microsoft.AspNetCore.Builder.RateLimiterApplicationBuilderExtensions.UseRateLimiter(this Microsoft.AspNetCore.Builder.IApplicationBuilder! app) -> Microsoft.AspNetCore.Builder.IApplicationBuilder! static Microsoft.AspNetCore.Builder.RateLimiterApplicationBuilderExtensions.UseRateLimiter(this Microsoft.AspNetCore.Builder.IApplicationBuilder! app, Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! options) -> Microsoft.AspNetCore.Builder.IApplicationBuilder! +static Microsoft.AspNetCore.Builder.RateLimiterEndpointConventionBuilderExtensions.DisableRateLimiting(this TBuilder builder) -> TBuilder static Microsoft.AspNetCore.Builder.RateLimiterEndpointConventionBuilderExtensions.RequireRateLimiting(this TBuilder builder, string! policyName) -> TBuilder -static Microsoft.AspNetCore.RateLimiting.RateLimiterOptionsExtensions.AddConcurrencyLimiter(this Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! options, string! policyName, System.Threading.RateLimiting.ConcurrencyLimiterOptions! concurrencyLimiterOptions) -> Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! -static Microsoft.AspNetCore.RateLimiting.RateLimiterOptionsExtensions.AddFixedWindowLimiter(this Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! options, string! policyName, System.Threading.RateLimiting.FixedWindowRateLimiterOptions! fixedWindowRateLimiterOptions) -> Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! -static Microsoft.AspNetCore.RateLimiting.RateLimiterOptionsExtensions.AddNoLimiter(this Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! options, string! policyName) -> Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! -static Microsoft.AspNetCore.RateLimiting.RateLimiterOptionsExtensions.AddSlidingWindowLimiter(this Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! options, string! policyName, System.Threading.RateLimiting.SlidingWindowRateLimiterOptions! slidingWindowRateLimiterOptions) -> Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! -static Microsoft.AspNetCore.RateLimiting.RateLimiterOptionsExtensions.AddTokenBucketLimiter(this Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! options, string! policyName, System.Threading.RateLimiting.TokenBucketRateLimiterOptions! tokenBucketRateLimiterOptions) -> Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! +static Microsoft.AspNetCore.Builder.RateLimiterServiceCollectionExtensions.AddRateLimiter(this Microsoft.Extensions.DependencyInjection.IServiceCollection! services, System.Action! configureOptions) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! +static Microsoft.AspNetCore.RateLimiting.RateLimiterOptionsExtensions.AddConcurrencyLimiter(this Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! options, string! policyName, System.Action! configureOptions) -> Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! +static Microsoft.AspNetCore.RateLimiting.RateLimiterOptionsExtensions.AddFixedWindowLimiter(this Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! options, string! policyName, System.Action! configureOptions) -> Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! +static Microsoft.AspNetCore.RateLimiting.RateLimiterOptionsExtensions.AddSlidingWindowLimiter(this Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! options, string! policyName, System.Action! configureOptions) -> Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! +static Microsoft.AspNetCore.RateLimiting.RateLimiterOptionsExtensions.AddTokenBucketLimiter(this Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! options, string! policyName, System.Action! configureOptions) -> Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! diff --git a/src/Middleware/RateLimiting/src/RateLimiterEndpointConventionBuilderExtensions.cs b/src/Middleware/RateLimiting/src/RateLimiterEndpointConventionBuilderExtensions.cs index 333499a7dea3..bf1aa933dcad 100644 --- a/src/Middleware/RateLimiting/src/RateLimiterEndpointConventionBuilderExtensions.cs +++ b/src/Middleware/RateLimiting/src/RateLimiterEndpointConventionBuilderExtensions.cs @@ -24,7 +24,25 @@ public static TBuilder RequireRateLimiting(this TBuilder builder, stri builder.Add(endpointBuilder => { - endpointBuilder.Metadata.Add(new RateLimiterMetadata(policyName)); + endpointBuilder.Metadata.Add(new EnableRateLimitingAttribute(policyName)); + }); + + return builder; + } + + /// + /// Disables rate limiting on the endpoint(s). + /// + /// The endpoint convention builder. + /// The original convention builder parameter. + /// Will skip both the global limiter, and any endpoint-specific limiters that apply to the endpoint(s). + public static TBuilder DisableRateLimiting(this TBuilder builder) where TBuilder : IEndpointConventionBuilder + { + ArgumentNullException.ThrowIfNull(builder); + + builder.Add(endpointBuilder => + { + endpointBuilder.Metadata.Add(new DisableRateLimitingAttribute()); }); return builder; diff --git a/src/Middleware/RateLimiting/src/RateLimiterOptionsExtensions.cs b/src/Middleware/RateLimiting/src/RateLimiterOptionsExtensions.cs index 2a6f40dc7058..acffe5f6017e 100644 --- a/src/Middleware/RateLimiting/src/RateLimiterOptionsExtensions.cs +++ b/src/Middleware/RateLimiting/src/RateLimiterOptionsExtensions.cs @@ -15,11 +15,18 @@ public static class RateLimiterOptionsExtensions /// /// The to add a limiter to. /// The name that will be associated with the limiter. - /// The to be used for the limiter. + /// A callback to configure the to be used for the limiter. /// This . - public static RateLimiterOptions AddTokenBucketLimiter(this RateLimiterOptions options, string policyName, TokenBucketRateLimiterOptions tokenBucketRateLimiterOptions) + public static RateLimiterOptions AddTokenBucketLimiter(this RateLimiterOptions options, string policyName, Action configureOptions) { + if (configureOptions is null) + { + throw new ArgumentNullException(nameof(configureOptions)); + } + var key = new PolicyNameKey() { PolicyName = policyName }; + var tokenBucketRateLimiterOptions = new TokenBucketRateLimiterOptions(); + configureOptions.Invoke(tokenBucketRateLimiterOptions); return options.AddPolicy(policyName, context => { return RateLimitPartition.GetTokenBucketLimiter(key, @@ -32,11 +39,18 @@ public static RateLimiterOptions AddTokenBucketLimiter(this RateLimiterOptions o /// /// The to add a limiter to. /// The name that will be associated with the limiter. - /// The to be used for the limiter. + /// A callback to configure the to be used for the limiter. /// This . - public static RateLimiterOptions AddFixedWindowLimiter(this RateLimiterOptions options, string policyName, FixedWindowRateLimiterOptions fixedWindowRateLimiterOptions) + public static RateLimiterOptions AddFixedWindowLimiter(this RateLimiterOptions options, string policyName, Action configureOptions) { + if (configureOptions is null) + { + throw new ArgumentNullException(nameof(configureOptions)); + } + var key = new PolicyNameKey() { PolicyName = policyName }; + var fixedWindowRateLimiterOptions = new FixedWindowRateLimiterOptions(); + configureOptions.Invoke(fixedWindowRateLimiterOptions); return options.AddPolicy(policyName, context => { return RateLimitPartition.GetFixedWindowLimiter(key, @@ -49,11 +63,18 @@ public static RateLimiterOptions AddFixedWindowLimiter(this RateLimiterOptions o /// /// The to add a limiter to. /// The name that will be associated with the limiter. - /// The to be used for the limiter. + /// A callback to configure the to be used for the limiter. /// This . - public static RateLimiterOptions AddSlidingWindowLimiter(this RateLimiterOptions options, string policyName, SlidingWindowRateLimiterOptions slidingWindowRateLimiterOptions) + public static RateLimiterOptions AddSlidingWindowLimiter(this RateLimiterOptions options, string policyName, Action configureOptions) { + if (configureOptions is null) + { + throw new ArgumentNullException(nameof(configureOptions)); + } + var key = new PolicyNameKey() { PolicyName = policyName }; + var slidingWindowRateLimiterOptions = new SlidingWindowRateLimiterOptions(); + configureOptions.Invoke(slidingWindowRateLimiterOptions); return options.AddPolicy(policyName, context => { return RateLimitPartition.GetSlidingWindowLimiter(key, @@ -66,30 +87,22 @@ public static RateLimiterOptions AddSlidingWindowLimiter(this RateLimiterOptions /// /// The to add a limiter to. /// The name that will be associated with the limiter. - /// The to be used for the limiter. + /// A callback to configure the to be used for the limiter. /// This . - public static RateLimiterOptions AddConcurrencyLimiter(this RateLimiterOptions options, string policyName, ConcurrencyLimiterOptions concurrencyLimiterOptions) + public static RateLimiterOptions AddConcurrencyLimiter(this RateLimiterOptions options, string policyName, Action configureOptions) { - var key = new PolicyNameKey() { PolicyName = policyName }; - return options.AddPolicy(policyName, context => + if (configureOptions is null) { - return RateLimitPartition.GetConcurrencyLimiter(key, - _ => concurrencyLimiterOptions); - }); - } + throw new ArgumentNullException(nameof(configureOptions)); + } - /// - /// Adds a new no-op to the . - /// - /// The to add a limiter to. - /// The name that will be associated with the limiter. - /// This . - public static RateLimiterOptions AddNoLimiter(this RateLimiterOptions options, string policyName) - { var key = new PolicyNameKey() { PolicyName = policyName }; + var concurrencyLimiterOptions = new ConcurrencyLimiterOptions(); + configureOptions.Invoke(concurrencyLimiterOptions); return options.AddPolicy(policyName, context => { - return RateLimitPartition.GetNoLimiter(key); + return RateLimitPartition.GetConcurrencyLimiter(key, + _ => concurrencyLimiterOptions); }); } } diff --git a/src/Middleware/RateLimiting/src/RateLimiterServiceCollectionExtensions.cs b/src/Middleware/RateLimiting/src/RateLimiterServiceCollectionExtensions.cs new file mode 100644 index 000000000000..ac3c6718000f --- /dev/null +++ b/src/Middleware/RateLimiting/src/RateLimiterServiceCollectionExtensions.cs @@ -0,0 +1,28 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.AspNetCore.RateLimiting; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.AspNetCore.Builder; + +/// +/// Extension methods for the RateLimiting middleware. +/// +public static class RateLimiterServiceCollectionExtensions +{ + /// + /// Add rate limiting services and configure the related options. + /// + /// The for adding services. + /// A delegate to configure the . + /// + public static IServiceCollection AddRateLimiter(this IServiceCollection services, Action configureOptions) + { + ArgumentNullException.ThrowIfNull(services); + ArgumentNullException.ThrowIfNull(configureOptions); + + services.Configure(configureOptions); + return services; + } +} diff --git a/src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs b/src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs index 0ffa8e1d485d..3ced3a507e15 100644 --- a/src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs +++ b/src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs @@ -60,6 +60,13 @@ public RateLimitingMiddleware(RequestDelegate next, ILoggerA that completes when the request leaves. public async Task Invoke(HttpContext context) { + var endpoint = context.GetEndpoint(); + // If this endpoint has a DisableRateLimitingAttribute, don't apply any rate limits. + if (endpoint?.Metadata.GetMetadata() is not null) + { + await _next(context); + return; + } using var leaseContext = await TryAcquireAsync(context); if (leaseContext.Lease.IsAcquired) { @@ -77,7 +84,7 @@ public async Task Invoke(HttpContext context) if (leaseContext.GlobalRejected == false) { DefaultRateLimiterPolicy? policy; - var policyName = context.GetEndpoint()?.Metadata.GetMetadata()?.PolicyName; + var policyName = endpoint?.Metadata.GetMetadata()?.PolicyName; // Use custom policy OnRejected if available, else use OnRejected from the Options if available. if (policyName is not null && _policyMap.TryGetValue(policyName, out policy) && policy.OnRejected is not null) { @@ -171,7 +178,7 @@ private PartitionedRateLimiter CreateEndpointLimiter() // If we have a policy for this endpoint, use its partitioner. Else use a NoLimiter. return PartitionedRateLimiter.Create(context => { - var name = context.GetEndpoint()?.Metadata.GetMetadata()?.PolicyName; + var name = context.GetEndpoint()?.Metadata.GetMetadata()?.PolicyName; if (name is not null) { if (_policyMap.TryGetValue(name, out var policy)) diff --git a/src/Middleware/RateLimiting/test/RateLimitingMiddlewareTests.cs b/src/Middleware/RateLimiting/test/RateLimitingMiddlewareTests.cs index aece21337bb6..e9140b9f530d 100644 --- a/src/Middleware/RateLimiting/test/RateLimitingMiddlewareTests.cs +++ b/src/Middleware/RateLimiting/test/RateLimitingMiddlewareTests.cs @@ -148,7 +148,7 @@ public async Task EndpointLimiterRequested_NoPolicy_Throws() Mock.Of()); var context = new DefaultHttpContext(); - context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new RateLimiterMetadata(name)), "Test endpoint")); + context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableRateLimitingAttribute(name)), "Test endpoint")); await Assert.ThrowsAsync(() => middleware.Invoke(context)).DefaultTimeout(); } @@ -181,7 +181,7 @@ public async Task EndpointLimiter_Rejects() Mock.Of()); var context = new DefaultHttpContext(); - context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new RateLimiterMetadata(name)), "Test endpoint")); + context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableRateLimitingAttribute(name)), "Test endpoint")); await middleware.Invoke(context).DefaultTimeout(); Assert.True(onRejectedInvoked); Assert.Equal(StatusCodes.Status429TooManyRequests, context.Response.StatusCode); @@ -193,13 +193,13 @@ public async Task EndpointLimiterConvenienceMethod_Rejects() var onRejectedInvoked = false; var options = CreateOptionsAccessor(); var name = "myEndpoint"; - options.Value.AddFixedWindowLimiter(name, new FixedWindowRateLimiterOptions + options.Value.AddFixedWindowLimiter(name, options => { - PermitLimit = 1, - QueueProcessingOrder = QueueProcessingOrder.OldestFirst, - QueueLimit = 0, - Window = TimeSpan.Zero, - AutoReplenishment = false + options.PermitLimit = 1; + options.QueueProcessingOrder = QueueProcessingOrder.OldestFirst; + options.QueueLimit = 0; + options.Window = TimeSpan.Zero; + options.AutoReplenishment = false; }); options.Value.OnRejected = (context, token) => { @@ -217,7 +217,7 @@ public async Task EndpointLimiterConvenienceMethod_Rejects() Mock.Of()); var context = new DefaultHttpContext(); - context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new RateLimiterMetadata(name)), "Test endpoint")); + context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableRateLimitingAttribute(name)), "Test endpoint")); await middleware.Invoke(context).DefaultTimeout(); Assert.False(onRejectedInvoked); await middleware.Invoke(context).DefaultTimeout(); @@ -250,7 +250,7 @@ public async Task EndpointLimiterRejects_EndpointOnRejectedFires() Mock.Of()); var context = new DefaultHttpContext(); - context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new RateLimiterMetadata(name)), "Test endpoint")); + context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableRateLimitingAttribute(name)), "Test endpoint")); await middleware.Invoke(context).DefaultTimeout(); Assert.False(globalOnRejectedInvoked); @@ -283,7 +283,7 @@ public async Task GlobalAndEndpoint_GlobalRejects_GlobalWins() Mock.Of()); var context = new DefaultHttpContext(); - context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new RateLimiterMetadata(name)), "Test endpoint")); + context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableRateLimitingAttribute(name)), "Test endpoint")); await middleware.Invoke(context).DefaultTimeout(); Assert.True(globalOnRejectedInvoked); @@ -316,7 +316,7 @@ public async Task GlobalAndEndpoint_EndpointRejects_EndpointWins() Mock.Of()); var context = new DefaultHttpContext(); - context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new RateLimiterMetadata(name)), "Test endpoint")); + context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableRateLimitingAttribute(name)), "Test endpoint")); await middleware.Invoke(context).DefaultTimeout(); Assert.False(globalOnRejectedInvoked); @@ -349,7 +349,7 @@ public async Task GlobalAndEndpoint_BothReject_GlobalWins() Mock.Of()); var context = new DefaultHttpContext(); - context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new RateLimiterMetadata(name)), "Test endpoint")); + context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableRateLimitingAttribute(name)), "Test endpoint")); await middleware.Invoke(context).DefaultTimeout(); Assert.True(globalOnRejectedInvoked); @@ -393,7 +393,7 @@ public async Task EndpointLimiterRejects_EndpointOnRejectedFires_WithIRateLimite mockServiceProvider.Object); var context = new DefaultHttpContext(); - context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new RateLimiterMetadata(name)), "Test endpoint")); + context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableRateLimitingAttribute(name)), "Test endpoint")); await middleware.Invoke(context).DefaultTimeout(); Assert.False(globalOnRejectedInvoked); @@ -428,8 +428,8 @@ public async Task EndpointLimiter_DuplicatePartitionKey_NoCollision() Mock.Of()); var context = new DefaultHttpContext(); - var endpoint1 = new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new RateLimiterMetadata(endpointName1)), "Test endpoint 1"); - var endpoint2 = new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new RateLimiterMetadata(endpointName2)), "Test endpoint 2"); + var endpoint1 = new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableRateLimitingAttribute(endpointName1)), "Test endpoint 1"); + var endpoint2 = new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableRateLimitingAttribute(endpointName2)), "Test endpoint 2"); context.SetEndpoint(endpoint1); await middleware.Invoke(context).DefaultTimeout(); @@ -483,8 +483,8 @@ public async Task EndpointLimiter_DuplicatePartitionKey_Lambda_NoCollision() Mock.Of()); var context = new DefaultHttpContext(); - var endpoint1 = new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new RateLimiterMetadata(endpointName1)), "Test endpoint 1"); - var endpoint2 = new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new RateLimiterMetadata(endpointName2)), "Test endpoint 2"); + var endpoint1 = new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableRateLimitingAttribute(endpointName1)), "Test endpoint 1"); + var endpoint2 = new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableRateLimitingAttribute(endpointName2)), "Test endpoint 2"); context.SetEndpoint(endpoint1); await middleware.Invoke(context).DefaultTimeout(); @@ -499,6 +499,40 @@ public async Task EndpointLimiter_DuplicatePartitionKey_Lambda_NoCollision() Assert.False(globalOnRejectedInvoked); } + [Fact] + public async Task DisableRateLimitingAttribute_SkipsGlobalAndEndpoint() + { + var globalOnRejectedInvoked = false; + var options = CreateOptionsAccessor(); + var name = "myEndpoint"; + // Endpoint never allows + options.Value.AddPolicy(name, new TestRateLimiterPolicy("myKey", 404, false)); + // Global never allows + options.Value.GlobalLimiter = new TestPartitionedRateLimiter(new TestRateLimiter(false)); + options.Value.OnRejected = (context, token) => + { + globalOnRejectedInvoked = true; + context.HttpContext.Response.StatusCode = 429; + return ValueTask.CompletedTask; + }; + + var middleware = new RateLimitingMiddleware(c => + { + return Task.CompletedTask; + }, + new NullLoggerFactory().CreateLogger(), + options, + Mock.Of()); + + var context = new DefaultHttpContext(); + context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableRateLimitingAttribute(name), new DisableRateLimitingAttribute()), "Test endpoint")); + await middleware.Invoke(context).DefaultTimeout(); + Assert.False(globalOnRejectedInvoked); + + Assert.NotEqual(StatusCodes.Status429TooManyRequests, context.Response.StatusCode); + Assert.NotEqual(StatusCodes.Status404NotFound, context.Response.StatusCode); + } + private IOptions CreateOptionsAccessor() => Options.Create(new RateLimiterOptions()); } From 9313b49e06d8e8e69c7cf038e2e05d54156e089c Mon Sep 17 00:00:00 2001 From: Will Godbe Date: Fri, 5 Aug 2022 12:38:32 -0700 Subject: [PATCH 02/10] Feedback --- .../src/DisableRateLimitingAttribute.cs | 3 ++- ...iterEndpointConventionBuilderExtensions.cs | 2 +- .../src/RateLimiterOptionsExtensions.cs | 26 +++++++------------ .../src/RateLimitingMiddleware.cs | 15 ++++++++--- .../test/RateLimitingMiddlewareTests.cs | 3 +-- 5 files changed, 26 insertions(+), 23 deletions(-) diff --git a/src/Middleware/RateLimiting/src/DisableRateLimitingAttribute.cs b/src/Middleware/RateLimiting/src/DisableRateLimitingAttribute.cs index 5990c18ae251..bdb40b8cd451 100644 --- a/src/Middleware/RateLimiting/src/DisableRateLimitingAttribute.cs +++ b/src/Middleware/RateLimiting/src/DisableRateLimitingAttribute.cs @@ -7,6 +7,7 @@ namespace Microsoft.AspNetCore.RateLimiting; /// Metadata that disables request rate limiting on an endpoint. /// [AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = false, Inherited = true)] -public sealed class DisableRateLimitingAttribute: Attribute +public sealed class DisableRateLimitingAttribute : Attribute { + internal static DisableRateLimitingAttribute Instance { get; } = new DisableRateLimitingAttribute(); } diff --git a/src/Middleware/RateLimiting/src/RateLimiterEndpointConventionBuilderExtensions.cs b/src/Middleware/RateLimiting/src/RateLimiterEndpointConventionBuilderExtensions.cs index bf1aa933dcad..9851689447e6 100644 --- a/src/Middleware/RateLimiting/src/RateLimiterEndpointConventionBuilderExtensions.cs +++ b/src/Middleware/RateLimiting/src/RateLimiterEndpointConventionBuilderExtensions.cs @@ -42,7 +42,7 @@ public static TBuilder DisableRateLimiting(this TBuilder builder) wher builder.Add(endpointBuilder => { - endpointBuilder.Metadata.Add(new DisableRateLimitingAttribute()); + endpointBuilder.Metadata.Add(DisableRateLimitingAttribute.Instance); }); return builder; diff --git a/src/Middleware/RateLimiting/src/RateLimiterOptionsExtensions.cs b/src/Middleware/RateLimiting/src/RateLimiterOptionsExtensions.cs index acffe5f6017e..55260072812c 100644 --- a/src/Middleware/RateLimiting/src/RateLimiterOptionsExtensions.cs +++ b/src/Middleware/RateLimiting/src/RateLimiterOptionsExtensions.cs @@ -19,14 +19,13 @@ public static class RateLimiterOptionsExtensions /// This . public static RateLimiterOptions AddTokenBucketLimiter(this RateLimiterOptions options, string policyName, Action configureOptions) { - if (configureOptions is null) - { - throw new ArgumentNullException(nameof(configureOptions)); - } + ArgumentNullException.ThrowIfNull(configureOptions); var key = new PolicyNameKey() { PolicyName = policyName }; var tokenBucketRateLimiterOptions = new TokenBucketRateLimiterOptions(); configureOptions.Invoke(tokenBucketRateLimiterOptions); + // Saves an allocation in GetTokenBucketLimiter, which would have created a new set of options if this was true. + tokenBucketRateLimiterOptions.AutoReplenishment = false; return options.AddPolicy(policyName, context => { return RateLimitPartition.GetTokenBucketLimiter(key, @@ -43,14 +42,13 @@ public static RateLimiterOptions AddTokenBucketLimiter(this RateLimiterOptions o /// This . public static RateLimiterOptions AddFixedWindowLimiter(this RateLimiterOptions options, string policyName, Action configureOptions) { - if (configureOptions is null) - { - throw new ArgumentNullException(nameof(configureOptions)); - } + ArgumentNullException.ThrowIfNull(configureOptions); var key = new PolicyNameKey() { PolicyName = policyName }; var fixedWindowRateLimiterOptions = new FixedWindowRateLimiterOptions(); configureOptions.Invoke(fixedWindowRateLimiterOptions); + // Saves an allocation in GetFixedWindowLimiter, which would have created a new set of options if this was true. + fixedWindowRateLimiterOptions.AutoReplenishment = false; return options.AddPolicy(policyName, context => { return RateLimitPartition.GetFixedWindowLimiter(key, @@ -67,14 +65,13 @@ public static RateLimiterOptions AddFixedWindowLimiter(this RateLimiterOptions o /// This . public static RateLimiterOptions AddSlidingWindowLimiter(this RateLimiterOptions options, string policyName, Action configureOptions) { - if (configureOptions is null) - { - throw new ArgumentNullException(nameof(configureOptions)); - } + ArgumentNullException.ThrowIfNull(configureOptions); var key = new PolicyNameKey() { PolicyName = policyName }; var slidingWindowRateLimiterOptions = new SlidingWindowRateLimiterOptions(); configureOptions.Invoke(slidingWindowRateLimiterOptions); + // Saves an allocation in GetSlidingWindowLimiter, which would have created a new set of options if this was true. + slidingWindowRateLimiterOptions.AutoReplenishment = false; return options.AddPolicy(policyName, context => { return RateLimitPartition.GetSlidingWindowLimiter(key, @@ -91,10 +88,7 @@ public static RateLimiterOptions AddSlidingWindowLimiter(this RateLimiterOptions /// This . public static RateLimiterOptions AddConcurrencyLimiter(this RateLimiterOptions options, string policyName, Action configureOptions) { - if (configureOptions is null) - { - throw new ArgumentNullException(nameof(configureOptions)); - } + ArgumentNullException.ThrowIfNull(configureOptions); var key = new PolicyNameKey() { PolicyName = policyName }; var concurrencyLimiterOptions = new ConcurrencyLimiterOptions(); diff --git a/src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs b/src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs index 963fd814d690..1f47668975a1 100644 --- a/src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs +++ b/src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs @@ -58,15 +58,24 @@ public RateLimitingMiddleware(RequestDelegate next, ILogger /// The . /// A that completes when the request leaves. - public async Task Invoke(HttpContext context) + public Task Invoke(HttpContext context) { var endpoint = context.GetEndpoint(); // If this endpoint has a DisableRateLimitingAttribute, don't apply any rate limits. if (endpoint?.Metadata.GetMetadata() is not null) { - await _next(context); - return; + return _next(context); + } + // If this endpoint has no EnableRateLimitingAttribute & there's no global limiter, don't apply any rate limits. + if (endpoint?.Metadata.GetMetadata() is null && _globalLimiter is null) + { + return _next(context); } + return InvokeInternal(context, endpoint); + } + + private async Task InvokeInternal(HttpContext context, Endpoint? endpoint) + { using var leaseContext = await TryAcquireAsync(context); if (leaseContext.Lease.IsAcquired) { diff --git a/src/Middleware/RateLimiting/test/RateLimitingMiddlewareTests.cs b/src/Middleware/RateLimiting/test/RateLimitingMiddlewareTests.cs index e9140b9f530d..49116dd9e1e5 100644 --- a/src/Middleware/RateLimiting/test/RateLimitingMiddlewareTests.cs +++ b/src/Middleware/RateLimiting/test/RateLimitingMiddlewareTests.cs @@ -529,8 +529,7 @@ public async Task DisableRateLimitingAttribute_SkipsGlobalAndEndpoint() await middleware.Invoke(context).DefaultTimeout(); Assert.False(globalOnRejectedInvoked); - Assert.NotEqual(StatusCodes.Status429TooManyRequests, context.Response.StatusCode); - Assert.NotEqual(StatusCodes.Status404NotFound, context.Response.StatusCode); + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); } private IOptions CreateOptionsAccessor() => Options.Create(new RateLimiterOptions()); From 99e4d659a9454808e98bfd3c57a56d58d6db5bf2 Mon Sep 17 00:00:00 2001 From: Will Godbe Date: Fri, 5 Aug 2022 14:38:12 -0700 Subject: [PATCH 03/10] Feedback, API --- .../RateLimiting/src/DefaultKeyType.cs | 4 +- .../src/EnableRateLimitingAttribute.cs | 14 +++- .../Microsoft.AspNetCore.RateLimiting.csproj | 3 + .../RateLimiting/src/PublicAPI.Unshipped.txt | 3 +- ...iterEndpointConventionBuilderExtensions.cs | 19 ++++++ .../RateLimiting/src/RateLimiterOptions.cs | 2 +- .../src/RateLimitingMiddleware.cs | 27 ++++++-- .../test/RateLimitingMiddlewareTests.cs | 66 +++++++++++++++++++ 8 files changed, 129 insertions(+), 9 deletions(-) diff --git a/src/Middleware/RateLimiting/src/DefaultKeyType.cs b/src/Middleware/RateLimiting/src/DefaultKeyType.cs index ea9b2b0d6c36..fd599833f537 100644 --- a/src/Middleware/RateLimiting/src/DefaultKeyType.cs +++ b/src/Middleware/RateLimiting/src/DefaultKeyType.cs @@ -5,14 +5,14 @@ namespace Microsoft.AspNetCore.RateLimiting; internal struct DefaultKeyType { - public DefaultKeyType(string policyName, object? key, object? factory = null) + public DefaultKeyType(string? policyName, object? key, object? factory = null) { PolicyName = policyName; Key = key; Factory = factory; } - public string PolicyName { get; } + public string? PolicyName { get; } public object? Key { get; } diff --git a/src/Middleware/RateLimiting/src/EnableRateLimitingAttribute.cs b/src/Middleware/RateLimiting/src/EnableRateLimitingAttribute.cs index 242c6ba69634..121373feb174 100644 --- a/src/Middleware/RateLimiting/src/EnableRateLimitingAttribute.cs +++ b/src/Middleware/RateLimiting/src/EnableRateLimitingAttribute.cs @@ -15,11 +15,23 @@ public sealed class EnableRateLimitingAttribute : Attribute /// The name of the policy which needs to be applied. public EnableRateLimitingAttribute(string policyName) { + ArgumentNullException.ThrowIfNull(policyName); + PolicyName = policyName; } + internal EnableRateLimitingAttribute(DefaultRateLimiterPolicy policy) + { + Policy = policy; + } + /// /// The name of the policy which needs to be applied. /// - public string PolicyName { get; } + public string? PolicyName { get; } + + /// + /// The policy which needs to be applied, if present. + /// + internal DefaultRateLimiterPolicy? Policy { get; } } diff --git a/src/Middleware/RateLimiting/src/Microsoft.AspNetCore.RateLimiting.csproj b/src/Middleware/RateLimiting/src/Microsoft.AspNetCore.RateLimiting.csproj index b35039fe3ace..0c8628098e46 100644 --- a/src/Middleware/RateLimiting/src/Microsoft.AspNetCore.RateLimiting.csproj +++ b/src/Middleware/RateLimiting/src/Microsoft.AspNetCore.RateLimiting.csproj @@ -17,4 +17,7 @@ + + + diff --git a/src/Middleware/RateLimiting/src/PublicAPI.Unshipped.txt b/src/Middleware/RateLimiting/src/PublicAPI.Unshipped.txt index 5adea6d5b898..9e18f7083b21 100644 --- a/src/Middleware/RateLimiting/src/PublicAPI.Unshipped.txt +++ b/src/Middleware/RateLimiting/src/PublicAPI.Unshipped.txt @@ -5,7 +5,7 @@ Microsoft.AspNetCore.RateLimiting.DisableRateLimitingAttribute Microsoft.AspNetCore.RateLimiting.DisableRateLimitingAttribute.DisableRateLimitingAttribute() -> void Microsoft.AspNetCore.RateLimiting.EnableRateLimitingAttribute Microsoft.AspNetCore.RateLimiting.EnableRateLimitingAttribute.EnableRateLimitingAttribute(string! policyName) -> void -Microsoft.AspNetCore.RateLimiting.EnableRateLimitingAttribute.PolicyName.get -> string! +Microsoft.AspNetCore.RateLimiting.EnableRateLimitingAttribute.PolicyName.get -> string? Microsoft.AspNetCore.RateLimiting.IRateLimiterPolicy Microsoft.AspNetCore.RateLimiting.IRateLimiterPolicy.GetPartition(Microsoft.AspNetCore.Http.HttpContext! httpContext) -> System.Threading.RateLimiting.RateLimitPartition Microsoft.AspNetCore.RateLimiting.IRateLimiterPolicy.OnRejected.get -> System.Func? @@ -30,6 +30,7 @@ Microsoft.AspNetCore.RateLimiting.RateLimiterOptionsExtensions static Microsoft.AspNetCore.Builder.RateLimiterApplicationBuilderExtensions.UseRateLimiter(this Microsoft.AspNetCore.Builder.IApplicationBuilder! app) -> Microsoft.AspNetCore.Builder.IApplicationBuilder! static Microsoft.AspNetCore.Builder.RateLimiterApplicationBuilderExtensions.UseRateLimiter(this Microsoft.AspNetCore.Builder.IApplicationBuilder! app, Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! options) -> Microsoft.AspNetCore.Builder.IApplicationBuilder! static Microsoft.AspNetCore.Builder.RateLimiterEndpointConventionBuilderExtensions.DisableRateLimiting(this TBuilder builder) -> TBuilder +static Microsoft.AspNetCore.Builder.RateLimiterEndpointConventionBuilderExtensions.RequireRateLimiting(this TBuilder builder, Microsoft.AspNetCore.RateLimiting.IRateLimiterPolicy! policy) -> TBuilder static Microsoft.AspNetCore.Builder.RateLimiterEndpointConventionBuilderExtensions.RequireRateLimiting(this TBuilder builder, string! policyName) -> TBuilder static Microsoft.AspNetCore.Builder.RateLimiterServiceCollectionExtensions.AddRateLimiter(this Microsoft.Extensions.DependencyInjection.IServiceCollection! services, System.Action! configureOptions) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! static Microsoft.AspNetCore.RateLimiting.RateLimiterOptionsExtensions.AddConcurrencyLimiter(this Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! options, string! policyName, System.Action! configureOptions) -> Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! diff --git a/src/Middleware/RateLimiting/src/RateLimiterEndpointConventionBuilderExtensions.cs b/src/Middleware/RateLimiting/src/RateLimiterEndpointConventionBuilderExtensions.cs index 9851689447e6..1a5fb0e180bf 100644 --- a/src/Middleware/RateLimiting/src/RateLimiterEndpointConventionBuilderExtensions.cs +++ b/src/Middleware/RateLimiting/src/RateLimiterEndpointConventionBuilderExtensions.cs @@ -30,6 +30,25 @@ public static TBuilder RequireRateLimiting(this TBuilder builder, stri return builder; } + /// + /// Adds the specified rate limiting policy to the endpoint(s). + /// + /// The endpoint convention builder. + /// The rate limiting policy to add to the endpoint. + /// The original convention builder parameter. + public static TBuilder RequireRateLimiting(this TBuilder builder, IRateLimiterPolicy policy) where TBuilder : IEndpointConventionBuilder + { + ArgumentNullException.ThrowIfNull(builder); + + ArgumentNullException.ThrowIfNull(policy); + + builder.Add(endpointBuilder => + { + endpointBuilder.Metadata.Add(new EnableRateLimitingAttribute(new DefaultRateLimiterPolicy(RateLimiterOptions.ConvertPartitioner(null, policy.GetPartition), policy.OnRejected))); + }); + return builder; + } + /// /// Disables rate limiting on the endpoint(s). /// diff --git a/src/Middleware/RateLimiting/src/RateLimiterOptions.cs b/src/Middleware/RateLimiting/src/RateLimiterOptions.cs index 7652613506ed..9872e1ca788a 100644 --- a/src/Middleware/RateLimiting/src/RateLimiterOptions.cs +++ b/src/Middleware/RateLimiting/src/RateLimiterOptions.cs @@ -107,7 +107,7 @@ public RateLimiterOptions AddPolicy(string policyName, IRateLimit } // Converts a Partition to a Partition> to prevent accidental collisions with the keys we create in the the RateLimiterOptionsExtensions. - private static Func> ConvertPartitioner(string policyName, Func> partitioner) + internal static Func> ConvertPartitioner(string? policyName, Func> partitioner) { return context => { diff --git a/src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs b/src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs index 1f47668975a1..45d4df979e81 100644 --- a/src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs +++ b/src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs @@ -93,12 +93,20 @@ private async Task InvokeInternal(HttpContext context, Endpoint? endpoint) if (leaseContext.GlobalRejected == false) { DefaultRateLimiterPolicy? policy; - var policyName = endpoint?.Metadata.GetMetadata()?.PolicyName; // Use custom policy OnRejected if available, else use OnRejected from the Options if available. - if (policyName is not null && _policyMap.TryGetValue(policyName, out policy) && policy.OnRejected is not null) + policy = endpoint?.Metadata.GetMetadata()?.Policy; + if (policy is not null) { thisRequestOnRejected = policy.OnRejected; } + else + { + var policyName = endpoint?.Metadata.GetMetadata()?.PolicyName; + if (policyName is not null && _policyMap.TryGetValue(policyName, out policy) && policy.OnRejected is not null) + { + thisRequestOnRejected = policy.OnRejected; + } + } } if (thisRequestOnRejected is not null) { @@ -187,10 +195,21 @@ private PartitionedRateLimiter CreateEndpointLimiter() // If we have a policy for this endpoint, use its partitioner. Else use a NoLimiter. return PartitionedRateLimiter.Create(context => { - var name = context.GetEndpoint()?.Metadata.GetMetadata()?.PolicyName; + DefaultRateLimiterPolicy? policy; + var enableRateLimitingAttribute = context.GetEndpoint()?.Metadata.GetMetadata(); + if (enableRateLimitingAttribute is null) + { + return RateLimitPartition.GetNoLimiter(_defaultPolicyKey); + } + policy = enableRateLimitingAttribute.Policy; + if (policy is not null) + { + return policy.GetPartition(context); + } + var name = enableRateLimitingAttribute.PolicyName; if (name is not null) { - if (_policyMap.TryGetValue(name, out var policy)) + if (_policyMap.TryGetValue(name, out policy)) { return policy.GetPartition(context); } diff --git a/src/Middleware/RateLimiting/test/RateLimitingMiddlewareTests.cs b/src/Middleware/RateLimiting/test/RateLimitingMiddlewareTests.cs index 49116dd9e1e5..e74e3d5161c0 100644 --- a/src/Middleware/RateLimiting/test/RateLimitingMiddlewareTests.cs +++ b/src/Middleware/RateLimiting/test/RateLimitingMiddlewareTests.cs @@ -532,6 +532,72 @@ public async Task DisableRateLimitingAttribute_SkipsGlobalAndEndpoint() Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); } + [Fact] + public async Task PolicyDirectlyOnEndpoint_GetsUsed() + { + var globalOnRejectedInvoked = false; + var options = CreateOptionsAccessor(); + // Policy will disallow + var policy = new TestRateLimiterPolicy("myKey", 404, false); + var defaultRateLimiterPolicy = new DefaultRateLimiterPolicy(RateLimiterOptions.ConvertPartitioner(null, policy.GetPartition), policy.OnRejected); + options.Value.OnRejected = (context, token) => + { + globalOnRejectedInvoked = true; + context.HttpContext.Response.StatusCode = 429; + return ValueTask.CompletedTask; + }; + + var middleware = new RateLimitingMiddleware(c => + { + return Task.CompletedTask; + }, + new NullLoggerFactory().CreateLogger(), + options, + Mock.Of()); + + var context = new DefaultHttpContext(); + context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableRateLimitingAttribute(defaultRateLimiterPolicy)), "Test endpoint")); + await middleware.Invoke(context).DefaultTimeout(); + Assert.False(globalOnRejectedInvoked); + + Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode); + } + + [Fact] + public async Task MultipleEndpointPolicies_LastOneWins() + { + var globalOnRejectedInvoked = false; + var options = CreateOptionsAccessor(); + // Policy will disallow + var policy = new TestRateLimiterPolicy("myKey1", 404, false); + var defaultRateLimiterPolicy = new DefaultRateLimiterPolicy(RateLimiterOptions.ConvertPartitioner(null, policy.GetPartition), policy.OnRejected); + + var name = "myEndpoint"; + options.Value.AddPolicy(name, new TestRateLimiterPolicy("myKey2", 403, false)); + + options.Value.OnRejected = (context, token) => + { + globalOnRejectedInvoked = true; + context.HttpContext.Response.StatusCode = 429; + return ValueTask.CompletedTask; + }; + + var middleware = new RateLimitingMiddleware(c => + { + return Task.CompletedTask; + }, + new NullLoggerFactory().CreateLogger(), + options, + Mock.Of()); + + var context = new DefaultHttpContext(); + context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableRateLimitingAttribute(defaultRateLimiterPolicy), new EnableRateLimitingAttribute(name)), "Test endpoint")); + await middleware.Invoke(context).DefaultTimeout(); + Assert.False(globalOnRejectedInvoked); + + Assert.Equal(StatusCodes.Status403Forbidden, context.Response.StatusCode); + } + private IOptions CreateOptionsAccessor() => Options.Create(new RateLimiterOptions()); } From 279b8b7edfb64e763f915f8c1d2b73767ad35e26 Mon Sep 17 00:00:00 2001 From: Will Godbe Date: Fri, 5 Aug 2022 15:03:59 -0700 Subject: [PATCH 04/10] Fix sample --- .../samples/RateLimitingSample/Program.cs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/Middleware/RateLimiting/samples/RateLimitingSample/Program.cs b/src/Middleware/RateLimiting/samples/RateLimitingSample/Program.cs index ec532bf7c5de..2709f8df075b 100644 --- a/src/Middleware/RateLimiting/samples/RateLimitingSample/Program.cs +++ b/src/Middleware/RateLimiting/samples/RateLimitingSample/Program.cs @@ -23,13 +23,13 @@ builder.Services.AddRateLimiter(options => { // Define endpoint limiters and a global limiter. - options.AddTokenBucketLimiter(todoName, new TokenBucketRateLimiterOptions + options.AddTokenBucketLimiter(todoName, options => { - TokenLimit = 1, - QueueProcessingOrder = QueueProcessingOrder.OldestFirst, - QueueLimit = 1, - ReplenishmentPeriod = TimeSpan.FromSeconds(10), - TokensPerPeriod = 1 + options.TokenLimit = 1; + options.QueueProcessingOrder = QueueProcessingOrder.OldestFirst; + options.QueueLimit = 1; + options.ReplenishmentPeriod = TimeSpan.FromSeconds(10); + options.TokensPerPeriod = 1; }) .AddPolicy(completeName, new SampleRateLimiterPolicy(NullLogger.Instance)) .AddPolicy(helloName); From 5a6c107556c87fcb1998a25d2ebb733ddb26ceee Mon Sep 17 00:00:00 2001 From: Will Godbe Date: Tue, 9 Aug 2022 14:54:16 -0700 Subject: [PATCH 05/10] Feedback --- .../RateLimiting/src/DisableRateLimitingAttribute.cs | 3 +++ .../RateLimiting/src/EnableRateLimitingAttribute.cs | 4 ++++ src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs | 6 +++++- 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/Middleware/RateLimiting/src/DisableRateLimitingAttribute.cs b/src/Middleware/RateLimiting/src/DisableRateLimitingAttribute.cs index bdb40b8cd451..55d888f05ca9 100644 --- a/src/Middleware/RateLimiting/src/DisableRateLimitingAttribute.cs +++ b/src/Middleware/RateLimiting/src/DisableRateLimitingAttribute.cs @@ -6,6 +6,9 @@ namespace Microsoft.AspNetCore.RateLimiting; /// /// Metadata that disables request rate limiting on an endpoint. /// +/// +/// Completely disables the rate limiting middleware from applying to this endpoint. +/// [AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = false, Inherited = true)] public sealed class DisableRateLimitingAttribute : Attribute { diff --git a/src/Middleware/RateLimiting/src/EnableRateLimitingAttribute.cs b/src/Middleware/RateLimiting/src/EnableRateLimitingAttribute.cs index 121373feb174..6486f6c89813 100644 --- a/src/Middleware/RateLimiting/src/EnableRateLimitingAttribute.cs +++ b/src/Middleware/RateLimiting/src/EnableRateLimitingAttribute.cs @@ -6,6 +6,10 @@ namespace Microsoft.AspNetCore.RateLimiting; /// /// Metadata that provides endpoint-specific request rate limiting. /// +/// +/// Replaces any policies currently applied to the endpoint. +/// The global limiter will still run on endpoints with this attribute applied. +/// [AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = false, Inherited = true)] public sealed class EnableRateLimitingAttribute : Attribute { diff --git a/src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs b/src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs index 45d4df979e81..2aa0b6e75332 100644 --- a/src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs +++ b/src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs @@ -218,7 +218,11 @@ private PartitionedRateLimiter CreateEndpointLimiter() throw new InvalidOperationException($"This endpoint requires a rate limiting policy with name {name}, but no such policy exists."); } } - return RateLimitPartition.GetNoLimiter(_defaultPolicyKey); + // Should be impossible for both name & policy to be null, but throw in that scenario just in case. + else + { + throw new InvalidOperationException("This endpoint requested a rate limiting policy with a null name."); + } }, new DefaultKeyTypeEqualityComparer()); } From df7ffbaf7e4a627c758e79605c9b9c27945ca2a0 Mon Sep 17 00:00:00 2001 From: wtgodbe Date: Wed, 10 Aug 2022 08:42:11 -0700 Subject: [PATCH 06/10] Add tests --- ...ndpointConventionBuilderExtensionsTests.cs | 93 +++++++++++++++++++ .../test/RateLimitingMiddlewareTests.cs | 9 ++ 2 files changed, 102 insertions(+) create mode 100644 src/Middleware/RateLimiting/test/RateLimiterEndpointConventionBuilderExtensionsTests.cs diff --git a/src/Middleware/RateLimiting/test/RateLimiterEndpointConventionBuilderExtensionsTests.cs b/src/Middleware/RateLimiting/test/RateLimiterEndpointConventionBuilderExtensionsTests.cs new file mode 100644 index 000000000000..a48307ebab33 --- /dev/null +++ b/src/Middleware/RateLimiting/test/RateLimiterEndpointConventionBuilderExtensionsTests.cs @@ -0,0 +1,93 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Threading.RateLimiting; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; + +namespace Microsoft.AspNetCore.RateLimiting; + +public class RateLimiterEndpointConventionBuilderExtensionsTests : LoggedTest +{ + [Fact] + public void RequireRateLimiting_Name_MetadataAdded() + { + // Arrange + var testConventionBuilder = new TestEndpointConventionBuilder(); + + // Act + testConventionBuilder.RequireRequireRateLimiting("TestPolicyName"); + + // Assert + var addEnableRateLimitingAttribute = Assert.Single(testConventionBuilder.Conventions); + + var endpointModel = new TestEndpointBuilder(); + addEnableRateLimitingAttribute(endpointModel); + var endpoint = endpointModel.Build(); + + var metadata = endpoint.Metadata.GetMetadata(); + Assert.NotNull(metadata); + Assert.Equal("TestPolicyName", metadata.PolicyName); + Assert.Null(metadata.Policy); + } + + [Fact] + public void RequireRateLimiting_Policy_MetadataAdded() + { + // Arrange + var testConventionBuilder = new TestEndpointConventionBuilder(); + + // Act + testConventionBuilder.RequireRequireRateLimiting(new TestRateLimiterPolicy("myKey", 404, false)); + + // Assert + var addEnableRateLimitingAttribute = Assert.Single(testConventionBuilder.Conventions); + + var endpointBuilder = new TestEndpointBuilder(); + addEnableRateLimitingAttribute(endpointBuilder); + var endpoint = endpointBuilder.Build(); + + var metadata = endpoint.Metadata.GetMetadata(); + Assert.NotNull(metadata); + Assert.NotNull(metadata.Policy); + Assert.Null(metadata.PolicyName); + } + + [Fact] + public void DisableRateLimiting_MetadataAdded() + { + // Arrange + var testConventionBuilder = new TestEndpointConventionBuilder(); + + // Act + testConventionBuilder.DisableRateLimiting(); + + // Assert + var addDisableRateLimitingAttribute = Assert.Single(testConventionBuilder.Conventions); + + var endpointModel = new TestEndpointBuilder(); + addDisableRateLimitingAttribute(endpointModel); + var endpoint = endpointModel.Build(); + + var metadata = endpoint.Metadata.GetMetadata(); + Assert.NotNull(metadata); + } + + private class TestEndpointBuilder : EndpointBuilder + { + public override Endpoint Build() + { + return new Endpoint(RequestDelegate, new EndpointMetadataCollection(Metadata), DisplayName); + } + } + + private class TestEndpointConventionBuilder : IEndpointConventionBuilder + { + public IList> Conventions { get; } = new List>(); + + public void Add(Action convention) + { + Conventions.Add(convention); + } + } +} \ No newline at end of file diff --git a/src/Middleware/RateLimiting/test/RateLimitingMiddlewareTests.cs b/src/Middleware/RateLimiting/test/RateLimitingMiddlewareTests.cs index e74e3d5161c0..fed64ecd240f 100644 --- a/src/Middleware/RateLimiting/test/RateLimitingMiddlewareTests.cs +++ b/src/Middleware/RateLimiting/test/RateLimitingMiddlewareTests.cs @@ -525,11 +525,20 @@ public async Task DisableRateLimitingAttribute_SkipsGlobalAndEndpoint() Mock.Of()); var context = new DefaultHttpContext(); + // DisableRateLimitingAttribute last context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableRateLimitingAttribute(name), new DisableRateLimitingAttribute()), "Test endpoint")); await middleware.Invoke(context).DefaultTimeout(); Assert.False(globalOnRejectedInvoked); Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + + // DisableRateLimitingAttribute first + context = new DefaultHttpContext(); + context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new DisableRateLimitingAttribute()), new EnableRateLimitingAttribute(name), "Test endpoint")); + await middleware.Invoke(context).DefaultTimeout(); + Assert.False(globalOnRejectedInvoked); + + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); } [Fact] From 9c52cd971d71a88898a120a96c325f4a7a62ce07 Mon Sep 17 00:00:00 2001 From: wtgodbe Date: Wed, 10 Aug 2022 09:07:07 -0700 Subject: [PATCH 07/10] Fixup --- .../test/RateLimiterEndpointConventionBuilderExtensionsTests.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Middleware/RateLimiting/test/RateLimiterEndpointConventionBuilderExtensionsTests.cs b/src/Middleware/RateLimiting/test/RateLimiterEndpointConventionBuilderExtensionsTests.cs index a48307ebab33..c48619d8b259 100644 --- a/src/Middleware/RateLimiting/test/RateLimiterEndpointConventionBuilderExtensionsTests.cs +++ b/src/Middleware/RateLimiting/test/RateLimiterEndpointConventionBuilderExtensionsTests.cs @@ -4,6 +4,7 @@ using System.Threading.RateLimiting; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Testing; namespace Microsoft.AspNetCore.RateLimiting; From c3ee13176efb5a5a603a022b8c56c2c2be213439 Mon Sep 17 00:00:00 2001 From: wtgodbe Date: Wed, 10 Aug 2022 09:50:25 -0700 Subject: [PATCH 08/10] Fixup2 --- .../RateLimiterEndpointConventionBuilderExtensionsTests.cs | 4 ++-- .../RateLimiting/test/RateLimitingMiddlewareTests.cs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Middleware/RateLimiting/test/RateLimiterEndpointConventionBuilderExtensionsTests.cs b/src/Middleware/RateLimiting/test/RateLimiterEndpointConventionBuilderExtensionsTests.cs index c48619d8b259..78606e100e4f 100644 --- a/src/Middleware/RateLimiting/test/RateLimiterEndpointConventionBuilderExtensionsTests.cs +++ b/src/Middleware/RateLimiting/test/RateLimiterEndpointConventionBuilderExtensionsTests.cs @@ -17,7 +17,7 @@ public void RequireRateLimiting_Name_MetadataAdded() var testConventionBuilder = new TestEndpointConventionBuilder(); // Act - testConventionBuilder.RequireRequireRateLimiting("TestPolicyName"); + testConventionBuilder.RequireRateLimiting("TestPolicyName"); // Assert var addEnableRateLimitingAttribute = Assert.Single(testConventionBuilder.Conventions); @@ -39,7 +39,7 @@ public void RequireRateLimiting_Policy_MetadataAdded() var testConventionBuilder = new TestEndpointConventionBuilder(); // Act - testConventionBuilder.RequireRequireRateLimiting(new TestRateLimiterPolicy("myKey", 404, false)); + testConventionBuilder.RequireRateLimiting(new TestRateLimiterPolicy("myKey", 404, false)); // Assert var addEnableRateLimitingAttribute = Assert.Single(testConventionBuilder.Conventions); diff --git a/src/Middleware/RateLimiting/test/RateLimitingMiddlewareTests.cs b/src/Middleware/RateLimiting/test/RateLimitingMiddlewareTests.cs index fed64ecd240f..e095d2ffc185 100644 --- a/src/Middleware/RateLimiting/test/RateLimitingMiddlewareTests.cs +++ b/src/Middleware/RateLimiting/test/RateLimitingMiddlewareTests.cs @@ -534,7 +534,7 @@ public async Task DisableRateLimitingAttribute_SkipsGlobalAndEndpoint() // DisableRateLimitingAttribute first context = new DefaultHttpContext(); - context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new DisableRateLimitingAttribute()), new EnableRateLimitingAttribute(name), "Test endpoint")); + context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new DisableRateLimitingAttribute(), new EnableRateLimitingAttribute(name)), "Test endpoint")); await middleware.Invoke(context).DefaultTimeout(); Assert.False(globalOnRejectedInvoked); From 1b535f1bd784f7809810376df71a7e667623fa7e Mon Sep 17 00:00:00 2001 From: wtgodbe Date: Wed, 10 Aug 2022 10:20:41 -0700 Subject: [PATCH 09/10] disambiguate --- .../test/RateLimiterEndpointConventionBuilderExtensionsTests.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Middleware/RateLimiting/test/RateLimiterEndpointConventionBuilderExtensionsTests.cs b/src/Middleware/RateLimiting/test/RateLimiterEndpointConventionBuilderExtensionsTests.cs index 78606e100e4f..a4c600314454 100644 --- a/src/Middleware/RateLimiting/test/RateLimiterEndpointConventionBuilderExtensionsTests.cs +++ b/src/Middleware/RateLimiting/test/RateLimiterEndpointConventionBuilderExtensionsTests.cs @@ -39,7 +39,7 @@ public void RequireRateLimiting_Policy_MetadataAdded() var testConventionBuilder = new TestEndpointConventionBuilder(); // Act - testConventionBuilder.RequireRateLimiting(new TestRateLimiterPolicy("myKey", 404, false)); + testConventionBuilder.RequireRateLimiting(new TestRateLimiterPolicy("myKey", 404, false)); // Assert var addEnableRateLimitingAttribute = Assert.Single(testConventionBuilder.Conventions); From e8339d47c370d240364b759adf348fa2c95ccc2f Mon Sep 17 00:00:00 2001 From: wtgodbe Date: Fri, 12 Aug 2022 10:05:17 -0700 Subject: [PATCH 10/10] Feedback --- .../RateLimiting/src/RateLimitingMiddleware.cs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs b/src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs index 2aa0b6e75332..ccd9a8029449 100644 --- a/src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs +++ b/src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs @@ -66,15 +66,16 @@ public Task Invoke(HttpContext context) { return _next(context); } + var enableRateLimitingAttribute = endpoint?.Metadata.GetMetadata(); // If this endpoint has no EnableRateLimitingAttribute & there's no global limiter, don't apply any rate limits. - if (endpoint?.Metadata.GetMetadata() is null && _globalLimiter is null) + if (enableRateLimitingAttribute is null && _globalLimiter is null) { return _next(context); } - return InvokeInternal(context, endpoint); + return InvokeInternal(context, enableRateLimitingAttribute); } - private async Task InvokeInternal(HttpContext context, Endpoint? endpoint) + private async Task InvokeInternal(HttpContext context, EnableRateLimitingAttribute? enableRateLimitingAttribute) { using var leaseContext = await TryAcquireAsync(context); if (leaseContext.Lease.IsAcquired) @@ -94,14 +95,14 @@ private async Task InvokeInternal(HttpContext context, Endpoint? endpoint) { DefaultRateLimiterPolicy? policy; // Use custom policy OnRejected if available, else use OnRejected from the Options if available. - policy = endpoint?.Metadata.GetMetadata()?.Policy; + policy = enableRateLimitingAttribute?.Policy; if (policy is not null) { thisRequestOnRejected = policy.OnRejected; } else { - var policyName = endpoint?.Metadata.GetMetadata()?.PolicyName; + var policyName = enableRateLimitingAttribute?.PolicyName; if (policyName is not null && _policyMap.TryGetValue(policyName, out policy) && policy.OnRejected is not null) { thisRequestOnRejected = policy.OnRejected;