diff --git a/src/Middleware/RateLimiting/samples/RateLimitingSample/Program.cs b/src/Middleware/RateLimiting/samples/RateLimitingSample/Program.cs index d8a6fe95832e..2709f8df075b 100644 --- a/src/Middleware/RateLimiting/samples/RateLimitingSample/Program.cs +++ b/src/Middleware/RateLimiting/samples/RateLimitingSample/Program.cs @@ -16,35 +16,38 @@ // Inject an ILogger builder.Services.AddLogging(); -var app = builder.Build(); - var todoName = "todoPolicy"; var completeName = "completePolicy"; var helloName = "helloPolicy"; -// Define endpoint limiters and a global limiter. -var options = new RateLimiterOptions() - .AddTokenBucketLimiter(todoName, new TokenBucketRateLimiterOptions - { - TokenLimit = 1, - QueueProcessingOrder = QueueProcessingOrder.OldestFirst, - QueueLimit = 1, - ReplenishmentPeriod = TimeSpan.FromSeconds(10), - TokensPerPeriod = 1 - }) - .AddPolicy(completeName, new SampleRateLimiterPolicy(NullLogger.Instance)) - .AddPolicy(helloName); -// The global limiter will be a concurrency limiter with a max permit count of 10 and a queue depth of 5. -options.GlobalLimiter = PartitionedRateLimiter.Create(context => +builder.Services.AddRateLimiter(options => +{ + // Define endpoint limiters and a global limiter. + options.AddTokenBucketLimiter(todoName, options => + { + options.TokenLimit = 1; + options.QueueProcessingOrder = QueueProcessingOrder.OldestFirst; + options.QueueLimit = 1; + options.ReplenishmentPeriod = TimeSpan.FromSeconds(10); + options.TokensPerPeriod = 1; + }) + .AddPolicy(completeName, new SampleRateLimiterPolicy(NullLogger.Instance)) + .AddPolicy(helloName); + // The global limiter will be a concurrency limiter with a max permit count of 10 and a queue depth of 5. + options.GlobalLimiter = PartitionedRateLimiter.Create(context => + { + return RateLimitPartition.GetConcurrencyLimiter("globalLimiter", key => new ConcurrencyLimiterOptions { - return RateLimitPartition.GetConcurrencyLimiter("globalLimiter", key => new ConcurrencyLimiterOptions - { - PermitLimit = 10, - QueueProcessingOrder = QueueProcessingOrder.NewestFirst, - QueueLimit = 5 - }); + PermitLimit = 10, + QueueProcessingOrder = QueueProcessingOrder.NewestFirst, + QueueLimit = 5 }); -app.UseRateLimiter(options); + }); +}); + +var app = builder.Build(); + +app.UseRateLimiter(); // The limiter on this endpoint allows 1 request every 5 seconds app.MapGet("/", () => "Hello World!").RequireRateLimiting(helloName); diff --git a/src/Middleware/RateLimiting/src/DefaultKeyType.cs b/src/Middleware/RateLimiting/src/DefaultKeyType.cs index ea9b2b0d6c36..fd599833f537 100644 --- a/src/Middleware/RateLimiting/src/DefaultKeyType.cs +++ b/src/Middleware/RateLimiting/src/DefaultKeyType.cs @@ -5,14 +5,14 @@ namespace Microsoft.AspNetCore.RateLimiting; internal struct DefaultKeyType { - public DefaultKeyType(string policyName, object? key, object? factory = null) + public DefaultKeyType(string? policyName, object? key, object? factory = null) { PolicyName = policyName; Key = key; Factory = factory; } - public string PolicyName { get; } + public string? PolicyName { get; } public object? Key { get; } diff --git a/src/Middleware/RateLimiting/src/DisableRateLimitingAttribute.cs b/src/Middleware/RateLimiting/src/DisableRateLimitingAttribute.cs new file mode 100644 index 000000000000..55d888f05ca9 --- /dev/null +++ b/src/Middleware/RateLimiting/src/DisableRateLimitingAttribute.cs @@ -0,0 +1,16 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.AspNetCore.RateLimiting; + +/// +/// Metadata that disables request rate limiting on an endpoint. +/// +/// +/// Completely disables the rate limiting middleware from applying to this endpoint. +/// +[AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = false, Inherited = true)] +public sealed class DisableRateLimitingAttribute : Attribute +{ + internal static DisableRateLimitingAttribute Instance { get; } = new DisableRateLimitingAttribute(); +} diff --git a/src/Middleware/RateLimiting/src/EnableRateLimitingAttribute.cs b/src/Middleware/RateLimiting/src/EnableRateLimitingAttribute.cs new file mode 100644 index 000000000000..6486f6c89813 --- /dev/null +++ b/src/Middleware/RateLimiting/src/EnableRateLimitingAttribute.cs @@ -0,0 +1,41 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.AspNetCore.RateLimiting; + +/// +/// Metadata that provides endpoint-specific request rate limiting. +/// +/// +/// Replaces any policies currently applied to the endpoint. +/// The global limiter will still run on endpoints with this attribute applied. +/// +[AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = false, Inherited = true)] +public sealed class EnableRateLimitingAttribute : Attribute +{ + /// + /// Creates a new instance of using the specified policy. + /// + /// The name of the policy which needs to be applied. + public EnableRateLimitingAttribute(string policyName) + { + ArgumentNullException.ThrowIfNull(policyName); + + PolicyName = policyName; + } + + internal EnableRateLimitingAttribute(DefaultRateLimiterPolicy policy) + { + Policy = policy; + } + + /// + /// The name of the policy which needs to be applied. + /// + public string? PolicyName { get; } + + /// + /// The policy which needs to be applied, if present. + /// + internal DefaultRateLimiterPolicy? Policy { get; } +} diff --git a/src/Middleware/RateLimiting/src/IRateLimiterMetadata.cs b/src/Middleware/RateLimiting/src/IRateLimiterMetadata.cs deleted file mode 100644 index 54b8306cd5a7..000000000000 --- a/src/Middleware/RateLimiting/src/IRateLimiterMetadata.cs +++ /dev/null @@ -1,15 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -namespace Microsoft.AspNetCore.RateLimiting; - -/// -/// An interface which can be used to identify a type which provides metadata needed for enabling request rate limiting support. -/// -internal interface IRateLimiterMetadata -{ - /// - /// The name of the policy which needs to be applied. - /// - string PolicyName { get; } -} diff --git a/src/Middleware/RateLimiting/src/Microsoft.AspNetCore.RateLimiting.csproj b/src/Middleware/RateLimiting/src/Microsoft.AspNetCore.RateLimiting.csproj index b35039fe3ace..0c8628098e46 100644 --- a/src/Middleware/RateLimiting/src/Microsoft.AspNetCore.RateLimiting.csproj +++ b/src/Middleware/RateLimiting/src/Microsoft.AspNetCore.RateLimiting.csproj @@ -17,4 +17,7 @@ + + + diff --git a/src/Middleware/RateLimiting/src/PublicAPI.Unshipped.txt b/src/Middleware/RateLimiting/src/PublicAPI.Unshipped.txt index 3bbcc5cb4c20..9e18f7083b21 100644 --- a/src/Middleware/RateLimiting/src/PublicAPI.Unshipped.txt +++ b/src/Middleware/RateLimiting/src/PublicAPI.Unshipped.txt @@ -1,5 +1,11 @@ Microsoft.AspNetCore.Builder.RateLimiterApplicationBuilderExtensions Microsoft.AspNetCore.Builder.RateLimiterEndpointConventionBuilderExtensions +Microsoft.AspNetCore.Builder.RateLimiterServiceCollectionExtensions +Microsoft.AspNetCore.RateLimiting.DisableRateLimitingAttribute +Microsoft.AspNetCore.RateLimiting.DisableRateLimitingAttribute.DisableRateLimitingAttribute() -> void +Microsoft.AspNetCore.RateLimiting.EnableRateLimitingAttribute +Microsoft.AspNetCore.RateLimiting.EnableRateLimitingAttribute.EnableRateLimitingAttribute(string! policyName) -> void +Microsoft.AspNetCore.RateLimiting.EnableRateLimitingAttribute.PolicyName.get -> string? Microsoft.AspNetCore.RateLimiting.IRateLimiterPolicy Microsoft.AspNetCore.RateLimiting.IRateLimiterPolicy.GetPartition(Microsoft.AspNetCore.Http.HttpContext! httpContext) -> System.Threading.RateLimiting.RateLimitPartition Microsoft.AspNetCore.RateLimiting.IRateLimiterPolicy.OnRejected.get -> System.Func? @@ -23,9 +29,11 @@ Microsoft.AspNetCore.RateLimiting.RateLimiterOptions.RejectionStatusCode.set -> Microsoft.AspNetCore.RateLimiting.RateLimiterOptionsExtensions static Microsoft.AspNetCore.Builder.RateLimiterApplicationBuilderExtensions.UseRateLimiter(this Microsoft.AspNetCore.Builder.IApplicationBuilder! app) -> Microsoft.AspNetCore.Builder.IApplicationBuilder! static Microsoft.AspNetCore.Builder.RateLimiterApplicationBuilderExtensions.UseRateLimiter(this Microsoft.AspNetCore.Builder.IApplicationBuilder! app, Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! options) -> Microsoft.AspNetCore.Builder.IApplicationBuilder! +static Microsoft.AspNetCore.Builder.RateLimiterEndpointConventionBuilderExtensions.DisableRateLimiting(this TBuilder builder) -> TBuilder +static Microsoft.AspNetCore.Builder.RateLimiterEndpointConventionBuilderExtensions.RequireRateLimiting(this TBuilder builder, Microsoft.AspNetCore.RateLimiting.IRateLimiterPolicy! policy) -> TBuilder static Microsoft.AspNetCore.Builder.RateLimiterEndpointConventionBuilderExtensions.RequireRateLimiting(this TBuilder builder, string! policyName) -> TBuilder -static Microsoft.AspNetCore.RateLimiting.RateLimiterOptionsExtensions.AddConcurrencyLimiter(this Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! options, string! policyName, System.Threading.RateLimiting.ConcurrencyLimiterOptions! concurrencyLimiterOptions) -> Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! -static Microsoft.AspNetCore.RateLimiting.RateLimiterOptionsExtensions.AddFixedWindowLimiter(this Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! options, string! policyName, System.Threading.RateLimiting.FixedWindowRateLimiterOptions! fixedWindowRateLimiterOptions) -> Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! -static Microsoft.AspNetCore.RateLimiting.RateLimiterOptionsExtensions.AddNoLimiter(this Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! options, string! policyName) -> Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! -static Microsoft.AspNetCore.RateLimiting.RateLimiterOptionsExtensions.AddSlidingWindowLimiter(this Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! options, string! policyName, System.Threading.RateLimiting.SlidingWindowRateLimiterOptions! slidingWindowRateLimiterOptions) -> Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! -static Microsoft.AspNetCore.RateLimiting.RateLimiterOptionsExtensions.AddTokenBucketLimiter(this Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! options, string! policyName, System.Threading.RateLimiting.TokenBucketRateLimiterOptions! tokenBucketRateLimiterOptions) -> Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! +static Microsoft.AspNetCore.Builder.RateLimiterServiceCollectionExtensions.AddRateLimiter(this Microsoft.Extensions.DependencyInjection.IServiceCollection! services, System.Action! configureOptions) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! +static Microsoft.AspNetCore.RateLimiting.RateLimiterOptionsExtensions.AddConcurrencyLimiter(this Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! options, string! policyName, System.Action! configureOptions) -> Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! +static Microsoft.AspNetCore.RateLimiting.RateLimiterOptionsExtensions.AddFixedWindowLimiter(this Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! options, string! policyName, System.Action! configureOptions) -> Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! +static Microsoft.AspNetCore.RateLimiting.RateLimiterOptionsExtensions.AddSlidingWindowLimiter(this Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! options, string! policyName, System.Action! configureOptions) -> Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! +static Microsoft.AspNetCore.RateLimiting.RateLimiterOptionsExtensions.AddTokenBucketLimiter(this Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! options, string! policyName, System.Action! configureOptions) -> Microsoft.AspNetCore.RateLimiting.RateLimiterOptions! diff --git a/src/Middleware/RateLimiting/src/RateLimiterEndpointConventionBuilderExtensions.cs b/src/Middleware/RateLimiting/src/RateLimiterEndpointConventionBuilderExtensions.cs index 333499a7dea3..1a5fb0e180bf 100644 --- a/src/Middleware/RateLimiting/src/RateLimiterEndpointConventionBuilderExtensions.cs +++ b/src/Middleware/RateLimiting/src/RateLimiterEndpointConventionBuilderExtensions.cs @@ -24,7 +24,44 @@ public static TBuilder RequireRateLimiting(this TBuilder builder, stri builder.Add(endpointBuilder => { - endpointBuilder.Metadata.Add(new RateLimiterMetadata(policyName)); + endpointBuilder.Metadata.Add(new EnableRateLimitingAttribute(policyName)); + }); + + return builder; + } + + /// + /// Adds the specified rate limiting policy to the endpoint(s). + /// + /// The endpoint convention builder. + /// The rate limiting policy to add to the endpoint. + /// The original convention builder parameter. + public static TBuilder RequireRateLimiting(this TBuilder builder, IRateLimiterPolicy policy) where TBuilder : IEndpointConventionBuilder + { + ArgumentNullException.ThrowIfNull(builder); + + ArgumentNullException.ThrowIfNull(policy); + + builder.Add(endpointBuilder => + { + endpointBuilder.Metadata.Add(new EnableRateLimitingAttribute(new DefaultRateLimiterPolicy(RateLimiterOptions.ConvertPartitioner(null, policy.GetPartition), policy.OnRejected))); + }); + return builder; + } + + /// + /// Disables rate limiting on the endpoint(s). + /// + /// The endpoint convention builder. + /// The original convention builder parameter. + /// Will skip both the global limiter, and any endpoint-specific limiters that apply to the endpoint(s). + public static TBuilder DisableRateLimiting(this TBuilder builder) where TBuilder : IEndpointConventionBuilder + { + ArgumentNullException.ThrowIfNull(builder); + + builder.Add(endpointBuilder => + { + endpointBuilder.Metadata.Add(DisableRateLimitingAttribute.Instance); }); return builder; diff --git a/src/Middleware/RateLimiting/src/RateLimiterMetadata.cs b/src/Middleware/RateLimiting/src/RateLimiterMetadata.cs deleted file mode 100644 index 08498eb77199..000000000000 --- a/src/Middleware/RateLimiting/src/RateLimiterMetadata.cs +++ /dev/null @@ -1,24 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -namespace Microsoft.AspNetCore.RateLimiting; - -/// -/// Metadata that provides endpoint-specific request rate limiting. -/// -internal sealed class RateLimiterMetadata : IRateLimiterMetadata -{ - /// - /// Creates a new instance of using the specified policy. - /// - /// The name of the policy which needs to be applied. - public RateLimiterMetadata(string policyName) - { - PolicyName = policyName; - } - - /// - /// The name of the policy which needs to be applied. - /// - public string PolicyName { get; } -} diff --git a/src/Middleware/RateLimiting/src/RateLimiterOptions.cs b/src/Middleware/RateLimiting/src/RateLimiterOptions.cs index 7652613506ed..9872e1ca788a 100644 --- a/src/Middleware/RateLimiting/src/RateLimiterOptions.cs +++ b/src/Middleware/RateLimiting/src/RateLimiterOptions.cs @@ -107,7 +107,7 @@ public RateLimiterOptions AddPolicy(string policyName, IRateLimit } // Converts a Partition to a Partition> to prevent accidental collisions with the keys we create in the the RateLimiterOptionsExtensions. - private static Func> ConvertPartitioner(string policyName, Func> partitioner) + internal static Func> ConvertPartitioner(string? policyName, Func> partitioner) { return context => { diff --git a/src/Middleware/RateLimiting/src/RateLimiterOptionsExtensions.cs b/src/Middleware/RateLimiting/src/RateLimiterOptionsExtensions.cs index 2a6f40dc7058..55260072812c 100644 --- a/src/Middleware/RateLimiting/src/RateLimiterOptionsExtensions.cs +++ b/src/Middleware/RateLimiting/src/RateLimiterOptionsExtensions.cs @@ -15,11 +15,17 @@ public static class RateLimiterOptionsExtensions /// /// The to add a limiter to. /// The name that will be associated with the limiter. - /// The to be used for the limiter. + /// A callback to configure the to be used for the limiter. /// This . - public static RateLimiterOptions AddTokenBucketLimiter(this RateLimiterOptions options, string policyName, TokenBucketRateLimiterOptions tokenBucketRateLimiterOptions) + public static RateLimiterOptions AddTokenBucketLimiter(this RateLimiterOptions options, string policyName, Action configureOptions) { + ArgumentNullException.ThrowIfNull(configureOptions); + var key = new PolicyNameKey() { PolicyName = policyName }; + var tokenBucketRateLimiterOptions = new TokenBucketRateLimiterOptions(); + configureOptions.Invoke(tokenBucketRateLimiterOptions); + // Saves an allocation in GetTokenBucketLimiter, which would have created a new set of options if this was true. + tokenBucketRateLimiterOptions.AutoReplenishment = false; return options.AddPolicy(policyName, context => { return RateLimitPartition.GetTokenBucketLimiter(key, @@ -32,11 +38,17 @@ public static RateLimiterOptions AddTokenBucketLimiter(this RateLimiterOptions o /// /// The to add a limiter to. /// The name that will be associated with the limiter. - /// The to be used for the limiter. + /// A callback to configure the to be used for the limiter. /// This . - public static RateLimiterOptions AddFixedWindowLimiter(this RateLimiterOptions options, string policyName, FixedWindowRateLimiterOptions fixedWindowRateLimiterOptions) + public static RateLimiterOptions AddFixedWindowLimiter(this RateLimiterOptions options, string policyName, Action configureOptions) { + ArgumentNullException.ThrowIfNull(configureOptions); + var key = new PolicyNameKey() { PolicyName = policyName }; + var fixedWindowRateLimiterOptions = new FixedWindowRateLimiterOptions(); + configureOptions.Invoke(fixedWindowRateLimiterOptions); + // Saves an allocation in GetFixedWindowLimiter, which would have created a new set of options if this was true. + fixedWindowRateLimiterOptions.AutoReplenishment = false; return options.AddPolicy(policyName, context => { return RateLimitPartition.GetFixedWindowLimiter(key, @@ -49,11 +61,17 @@ public static RateLimiterOptions AddFixedWindowLimiter(this RateLimiterOptions o /// /// The to add a limiter to. /// The name that will be associated with the limiter. - /// The to be used for the limiter. + /// A callback to configure the to be used for the limiter. /// This . - public static RateLimiterOptions AddSlidingWindowLimiter(this RateLimiterOptions options, string policyName, SlidingWindowRateLimiterOptions slidingWindowRateLimiterOptions) + public static RateLimiterOptions AddSlidingWindowLimiter(this RateLimiterOptions options, string policyName, Action configureOptions) { + ArgumentNullException.ThrowIfNull(configureOptions); + var key = new PolicyNameKey() { PolicyName = policyName }; + var slidingWindowRateLimiterOptions = new SlidingWindowRateLimiterOptions(); + configureOptions.Invoke(slidingWindowRateLimiterOptions); + // Saves an allocation in GetSlidingWindowLimiter, which would have created a new set of options if this was true. + slidingWindowRateLimiterOptions.AutoReplenishment = false; return options.AddPolicy(policyName, context => { return RateLimitPartition.GetSlidingWindowLimiter(key, @@ -66,30 +84,19 @@ public static RateLimiterOptions AddSlidingWindowLimiter(this RateLimiterOptions /// /// The to add a limiter to. /// The name that will be associated with the limiter. - /// The to be used for the limiter. + /// A callback to configure the to be used for the limiter. /// This . - public static RateLimiterOptions AddConcurrencyLimiter(this RateLimiterOptions options, string policyName, ConcurrencyLimiterOptions concurrencyLimiterOptions) + public static RateLimiterOptions AddConcurrencyLimiter(this RateLimiterOptions options, string policyName, Action configureOptions) { + ArgumentNullException.ThrowIfNull(configureOptions); + var key = new PolicyNameKey() { PolicyName = policyName }; + var concurrencyLimiterOptions = new ConcurrencyLimiterOptions(); + configureOptions.Invoke(concurrencyLimiterOptions); return options.AddPolicy(policyName, context => { return RateLimitPartition.GetConcurrencyLimiter(key, _ => concurrencyLimiterOptions); }); } - - /// - /// Adds a new no-op to the . - /// - /// The to add a limiter to. - /// The name that will be associated with the limiter. - /// This . - public static RateLimiterOptions AddNoLimiter(this RateLimiterOptions options, string policyName) - { - var key = new PolicyNameKey() { PolicyName = policyName }; - return options.AddPolicy(policyName, context => - { - return RateLimitPartition.GetNoLimiter(key); - }); - } } diff --git a/src/Middleware/RateLimiting/src/RateLimiterServiceCollectionExtensions.cs b/src/Middleware/RateLimiting/src/RateLimiterServiceCollectionExtensions.cs new file mode 100644 index 000000000000..ac3c6718000f --- /dev/null +++ b/src/Middleware/RateLimiting/src/RateLimiterServiceCollectionExtensions.cs @@ -0,0 +1,28 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.AspNetCore.RateLimiting; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.AspNetCore.Builder; + +/// +/// Extension methods for the RateLimiting middleware. +/// +public static class RateLimiterServiceCollectionExtensions +{ + /// + /// Add rate limiting services and configure the related options. + /// + /// The for adding services. + /// A delegate to configure the . + /// + public static IServiceCollection AddRateLimiter(this IServiceCollection services, Action configureOptions) + { + ArgumentNullException.ThrowIfNull(services); + ArgumentNullException.ThrowIfNull(configureOptions); + + services.Configure(configureOptions); + return services; + } +} diff --git a/src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs b/src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs index 9764f3a49fd1..ccd9a8029449 100644 --- a/src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs +++ b/src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs @@ -58,7 +58,24 @@ public RateLimitingMiddleware(RequestDelegate next, ILogger /// The . /// A that completes when the request leaves. - public async Task Invoke(HttpContext context) + public Task Invoke(HttpContext context) + { + var endpoint = context.GetEndpoint(); + // If this endpoint has a DisableRateLimitingAttribute, don't apply any rate limits. + if (endpoint?.Metadata.GetMetadata() is not null) + { + return _next(context); + } + var enableRateLimitingAttribute = endpoint?.Metadata.GetMetadata(); + // If this endpoint has no EnableRateLimitingAttribute & there's no global limiter, don't apply any rate limits. + if (enableRateLimitingAttribute is null && _globalLimiter is null) + { + return _next(context); + } + return InvokeInternal(context, enableRateLimitingAttribute); + } + + private async Task InvokeInternal(HttpContext context, EnableRateLimitingAttribute? enableRateLimitingAttribute) { using var leaseContext = await TryAcquireAsync(context); if (leaseContext.Lease.IsAcquired) @@ -77,12 +94,20 @@ public async Task Invoke(HttpContext context) if (leaseContext.GlobalRejected == false) { DefaultRateLimiterPolicy? policy; - var policyName = context.GetEndpoint()?.Metadata.GetMetadata()?.PolicyName; // Use custom policy OnRejected if available, else use OnRejected from the Options if available. - if (policyName is not null && _policyMap.TryGetValue(policyName, out policy) && policy.OnRejected is not null) + policy = enableRateLimitingAttribute?.Policy; + if (policy is not null) { thisRequestOnRejected = policy.OnRejected; } + else + { + var policyName = enableRateLimitingAttribute?.PolicyName; + if (policyName is not null && _policyMap.TryGetValue(policyName, out policy) && policy.OnRejected is not null) + { + thisRequestOnRejected = policy.OnRejected; + } + } } if (thisRequestOnRejected is not null) { @@ -171,10 +196,21 @@ private PartitionedRateLimiter CreateEndpointLimiter() // If we have a policy for this endpoint, use its partitioner. Else use a NoLimiter. return PartitionedRateLimiter.Create(context => { - var name = context.GetEndpoint()?.Metadata.GetMetadata()?.PolicyName; + DefaultRateLimiterPolicy? policy; + var enableRateLimitingAttribute = context.GetEndpoint()?.Metadata.GetMetadata(); + if (enableRateLimitingAttribute is null) + { + return RateLimitPartition.GetNoLimiter(_defaultPolicyKey); + } + policy = enableRateLimitingAttribute.Policy; + if (policy is not null) + { + return policy.GetPartition(context); + } + var name = enableRateLimitingAttribute.PolicyName; if (name is not null) { - if (_policyMap.TryGetValue(name, out var policy)) + if (_policyMap.TryGetValue(name, out policy)) { return policy.GetPartition(context); } @@ -183,7 +219,11 @@ private PartitionedRateLimiter CreateEndpointLimiter() throw new InvalidOperationException($"This endpoint requires a rate limiting policy with name {name}, but no such policy exists."); } } - return RateLimitPartition.GetNoLimiter(_defaultPolicyKey); + // Should be impossible for both name & policy to be null, but throw in that scenario just in case. + else + { + throw new InvalidOperationException("This endpoint requested a rate limiting policy with a null name."); + } }, new DefaultKeyTypeEqualityComparer()); } diff --git a/src/Middleware/RateLimiting/test/RateLimiterEndpointConventionBuilderExtensionsTests.cs b/src/Middleware/RateLimiting/test/RateLimiterEndpointConventionBuilderExtensionsTests.cs new file mode 100644 index 000000000000..a4c600314454 --- /dev/null +++ b/src/Middleware/RateLimiting/test/RateLimiterEndpointConventionBuilderExtensionsTests.cs @@ -0,0 +1,94 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Threading.RateLimiting; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Testing; + +namespace Microsoft.AspNetCore.RateLimiting; + +public class RateLimiterEndpointConventionBuilderExtensionsTests : LoggedTest +{ + [Fact] + public void RequireRateLimiting_Name_MetadataAdded() + { + // Arrange + var testConventionBuilder = new TestEndpointConventionBuilder(); + + // Act + testConventionBuilder.RequireRateLimiting("TestPolicyName"); + + // Assert + var addEnableRateLimitingAttribute = Assert.Single(testConventionBuilder.Conventions); + + var endpointModel = new TestEndpointBuilder(); + addEnableRateLimitingAttribute(endpointModel); + var endpoint = endpointModel.Build(); + + var metadata = endpoint.Metadata.GetMetadata(); + Assert.NotNull(metadata); + Assert.Equal("TestPolicyName", metadata.PolicyName); + Assert.Null(metadata.Policy); + } + + [Fact] + public void RequireRateLimiting_Policy_MetadataAdded() + { + // Arrange + var testConventionBuilder = new TestEndpointConventionBuilder(); + + // Act + testConventionBuilder.RequireRateLimiting(new TestRateLimiterPolicy("myKey", 404, false)); + + // Assert + var addEnableRateLimitingAttribute = Assert.Single(testConventionBuilder.Conventions); + + var endpointBuilder = new TestEndpointBuilder(); + addEnableRateLimitingAttribute(endpointBuilder); + var endpoint = endpointBuilder.Build(); + + var metadata = endpoint.Metadata.GetMetadata(); + Assert.NotNull(metadata); + Assert.NotNull(metadata.Policy); + Assert.Null(metadata.PolicyName); + } + + [Fact] + public void DisableRateLimiting_MetadataAdded() + { + // Arrange + var testConventionBuilder = new TestEndpointConventionBuilder(); + + // Act + testConventionBuilder.DisableRateLimiting(); + + // Assert + var addDisableRateLimitingAttribute = Assert.Single(testConventionBuilder.Conventions); + + var endpointModel = new TestEndpointBuilder(); + addDisableRateLimitingAttribute(endpointModel); + var endpoint = endpointModel.Build(); + + var metadata = endpoint.Metadata.GetMetadata(); + Assert.NotNull(metadata); + } + + private class TestEndpointBuilder : EndpointBuilder + { + public override Endpoint Build() + { + return new Endpoint(RequestDelegate, new EndpointMetadataCollection(Metadata), DisplayName); + } + } + + private class TestEndpointConventionBuilder : IEndpointConventionBuilder + { + public IList> Conventions { get; } = new List>(); + + public void Add(Action convention) + { + Conventions.Add(convention); + } + } +} \ No newline at end of file diff --git a/src/Middleware/RateLimiting/test/RateLimitingMiddlewareTests.cs b/src/Middleware/RateLimiting/test/RateLimitingMiddlewareTests.cs index aece21337bb6..e095d2ffc185 100644 --- a/src/Middleware/RateLimiting/test/RateLimitingMiddlewareTests.cs +++ b/src/Middleware/RateLimiting/test/RateLimitingMiddlewareTests.cs @@ -148,7 +148,7 @@ public async Task EndpointLimiterRequested_NoPolicy_Throws() Mock.Of()); var context = new DefaultHttpContext(); - context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new RateLimiterMetadata(name)), "Test endpoint")); + context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableRateLimitingAttribute(name)), "Test endpoint")); await Assert.ThrowsAsync(() => middleware.Invoke(context)).DefaultTimeout(); } @@ -181,7 +181,7 @@ public async Task EndpointLimiter_Rejects() Mock.Of()); var context = new DefaultHttpContext(); - context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new RateLimiterMetadata(name)), "Test endpoint")); + context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableRateLimitingAttribute(name)), "Test endpoint")); await middleware.Invoke(context).DefaultTimeout(); Assert.True(onRejectedInvoked); Assert.Equal(StatusCodes.Status429TooManyRequests, context.Response.StatusCode); @@ -193,13 +193,13 @@ public async Task EndpointLimiterConvenienceMethod_Rejects() var onRejectedInvoked = false; var options = CreateOptionsAccessor(); var name = "myEndpoint"; - options.Value.AddFixedWindowLimiter(name, new FixedWindowRateLimiterOptions + options.Value.AddFixedWindowLimiter(name, options => { - PermitLimit = 1, - QueueProcessingOrder = QueueProcessingOrder.OldestFirst, - QueueLimit = 0, - Window = TimeSpan.Zero, - AutoReplenishment = false + options.PermitLimit = 1; + options.QueueProcessingOrder = QueueProcessingOrder.OldestFirst; + options.QueueLimit = 0; + options.Window = TimeSpan.Zero; + options.AutoReplenishment = false; }); options.Value.OnRejected = (context, token) => { @@ -217,7 +217,7 @@ public async Task EndpointLimiterConvenienceMethod_Rejects() Mock.Of()); var context = new DefaultHttpContext(); - context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new RateLimiterMetadata(name)), "Test endpoint")); + context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableRateLimitingAttribute(name)), "Test endpoint")); await middleware.Invoke(context).DefaultTimeout(); Assert.False(onRejectedInvoked); await middleware.Invoke(context).DefaultTimeout(); @@ -250,7 +250,7 @@ public async Task EndpointLimiterRejects_EndpointOnRejectedFires() Mock.Of()); var context = new DefaultHttpContext(); - context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new RateLimiterMetadata(name)), "Test endpoint")); + context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableRateLimitingAttribute(name)), "Test endpoint")); await middleware.Invoke(context).DefaultTimeout(); Assert.False(globalOnRejectedInvoked); @@ -283,7 +283,7 @@ public async Task GlobalAndEndpoint_GlobalRejects_GlobalWins() Mock.Of()); var context = new DefaultHttpContext(); - context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new RateLimiterMetadata(name)), "Test endpoint")); + context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableRateLimitingAttribute(name)), "Test endpoint")); await middleware.Invoke(context).DefaultTimeout(); Assert.True(globalOnRejectedInvoked); @@ -316,7 +316,7 @@ public async Task GlobalAndEndpoint_EndpointRejects_EndpointWins() Mock.Of()); var context = new DefaultHttpContext(); - context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new RateLimiterMetadata(name)), "Test endpoint")); + context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableRateLimitingAttribute(name)), "Test endpoint")); await middleware.Invoke(context).DefaultTimeout(); Assert.False(globalOnRejectedInvoked); @@ -349,7 +349,7 @@ public async Task GlobalAndEndpoint_BothReject_GlobalWins() Mock.Of()); var context = new DefaultHttpContext(); - context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new RateLimiterMetadata(name)), "Test endpoint")); + context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableRateLimitingAttribute(name)), "Test endpoint")); await middleware.Invoke(context).DefaultTimeout(); Assert.True(globalOnRejectedInvoked); @@ -393,7 +393,7 @@ public async Task EndpointLimiterRejects_EndpointOnRejectedFires_WithIRateLimite mockServiceProvider.Object); var context = new DefaultHttpContext(); - context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new RateLimiterMetadata(name)), "Test endpoint")); + context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableRateLimitingAttribute(name)), "Test endpoint")); await middleware.Invoke(context).DefaultTimeout(); Assert.False(globalOnRejectedInvoked); @@ -428,8 +428,8 @@ public async Task EndpointLimiter_DuplicatePartitionKey_NoCollision() Mock.Of()); var context = new DefaultHttpContext(); - var endpoint1 = new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new RateLimiterMetadata(endpointName1)), "Test endpoint 1"); - var endpoint2 = new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new RateLimiterMetadata(endpointName2)), "Test endpoint 2"); + var endpoint1 = new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableRateLimitingAttribute(endpointName1)), "Test endpoint 1"); + var endpoint2 = new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableRateLimitingAttribute(endpointName2)), "Test endpoint 2"); context.SetEndpoint(endpoint1); await middleware.Invoke(context).DefaultTimeout(); @@ -483,8 +483,8 @@ public async Task EndpointLimiter_DuplicatePartitionKey_Lambda_NoCollision() Mock.Of()); var context = new DefaultHttpContext(); - var endpoint1 = new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new RateLimiterMetadata(endpointName1)), "Test endpoint 1"); - var endpoint2 = new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new RateLimiterMetadata(endpointName2)), "Test endpoint 2"); + var endpoint1 = new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableRateLimitingAttribute(endpointName1)), "Test endpoint 1"); + var endpoint2 = new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableRateLimitingAttribute(endpointName2)), "Test endpoint 2"); context.SetEndpoint(endpoint1); await middleware.Invoke(context).DefaultTimeout(); @@ -499,6 +499,114 @@ public async Task EndpointLimiter_DuplicatePartitionKey_Lambda_NoCollision() Assert.False(globalOnRejectedInvoked); } + [Fact] + public async Task DisableRateLimitingAttribute_SkipsGlobalAndEndpoint() + { + var globalOnRejectedInvoked = false; + var options = CreateOptionsAccessor(); + var name = "myEndpoint"; + // Endpoint never allows + options.Value.AddPolicy(name, new TestRateLimiterPolicy("myKey", 404, false)); + // Global never allows + options.Value.GlobalLimiter = new TestPartitionedRateLimiter(new TestRateLimiter(false)); + options.Value.OnRejected = (context, token) => + { + globalOnRejectedInvoked = true; + context.HttpContext.Response.StatusCode = 429; + return ValueTask.CompletedTask; + }; + + var middleware = new RateLimitingMiddleware(c => + { + return Task.CompletedTask; + }, + new NullLoggerFactory().CreateLogger(), + options, + Mock.Of()); + + var context = new DefaultHttpContext(); + // DisableRateLimitingAttribute last + context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableRateLimitingAttribute(name), new DisableRateLimitingAttribute()), "Test endpoint")); + await middleware.Invoke(context).DefaultTimeout(); + Assert.False(globalOnRejectedInvoked); + + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + + // DisableRateLimitingAttribute first + context = new DefaultHttpContext(); + context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new DisableRateLimitingAttribute(), new EnableRateLimitingAttribute(name)), "Test endpoint")); + await middleware.Invoke(context).DefaultTimeout(); + Assert.False(globalOnRejectedInvoked); + + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + } + + [Fact] + public async Task PolicyDirectlyOnEndpoint_GetsUsed() + { + var globalOnRejectedInvoked = false; + var options = CreateOptionsAccessor(); + // Policy will disallow + var policy = new TestRateLimiterPolicy("myKey", 404, false); + var defaultRateLimiterPolicy = new DefaultRateLimiterPolicy(RateLimiterOptions.ConvertPartitioner(null, policy.GetPartition), policy.OnRejected); + options.Value.OnRejected = (context, token) => + { + globalOnRejectedInvoked = true; + context.HttpContext.Response.StatusCode = 429; + return ValueTask.CompletedTask; + }; + + var middleware = new RateLimitingMiddleware(c => + { + return Task.CompletedTask; + }, + new NullLoggerFactory().CreateLogger(), + options, + Mock.Of()); + + var context = new DefaultHttpContext(); + context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableRateLimitingAttribute(defaultRateLimiterPolicy)), "Test endpoint")); + await middleware.Invoke(context).DefaultTimeout(); + Assert.False(globalOnRejectedInvoked); + + Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode); + } + + [Fact] + public async Task MultipleEndpointPolicies_LastOneWins() + { + var globalOnRejectedInvoked = false; + var options = CreateOptionsAccessor(); + // Policy will disallow + var policy = new TestRateLimiterPolicy("myKey1", 404, false); + var defaultRateLimiterPolicy = new DefaultRateLimiterPolicy(RateLimiterOptions.ConvertPartitioner(null, policy.GetPartition), policy.OnRejected); + + var name = "myEndpoint"; + options.Value.AddPolicy(name, new TestRateLimiterPolicy("myKey2", 403, false)); + + options.Value.OnRejected = (context, token) => + { + globalOnRejectedInvoked = true; + context.HttpContext.Response.StatusCode = 429; + return ValueTask.CompletedTask; + }; + + var middleware = new RateLimitingMiddleware(c => + { + return Task.CompletedTask; + }, + new NullLoggerFactory().CreateLogger(), + options, + Mock.Of()); + + var context = new DefaultHttpContext(); + context.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableRateLimitingAttribute(defaultRateLimiterPolicy), new EnableRateLimitingAttribute(name)), "Test endpoint")); + await middleware.Invoke(context).DefaultTimeout(); + Assert.False(globalOnRejectedInvoked); + + Assert.Equal(StatusCodes.Status403Forbidden, context.Response.StatusCode); + } + private IOptions CreateOptionsAccessor() => Options.Create(new RateLimiterOptions()); }