diff --git a/src/SignalR/perf/Microbenchmarks/DefaultHubDispatcherBenchmark.cs b/src/SignalR/perf/Microbenchmarks/DefaultHubDispatcherBenchmark.cs index 744c8cc31183..4a0e67e0fb3a 100644 --- a/src/SignalR/perf/Microbenchmarks/DefaultHubDispatcherBenchmark.cs +++ b/src/SignalR/perf/Microbenchmarks/DefaultHubDispatcherBenchmark.cs @@ -38,7 +38,8 @@ public void GlobalSetup() serviceScopeFactory, new HubContext(new DefaultHubLifetimeManager(NullLogger>.Instance)), enableDetailedErrors: false, - new Logger>(NullLoggerFactory.Instance)); + new Logger>(NullLoggerFactory.Instance), + hubFilters: null); var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); var connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), pair.Application, pair.Transport); diff --git a/src/SignalR/samples/SignalRSamples/Startup.cs b/src/SignalR/samples/SignalRSamples/Startup.cs index b82743e678e3..60a34eef11fc 100644 --- a/src/SignalR/samples/SignalRSamples/Startup.cs +++ b/src/SignalR/samples/SignalRSamples/Startup.cs @@ -1,16 +1,12 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System; -using System.IO; using System.Reflection; -using System.Threading.Tasks; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; using System.Text.Json; -using System.Text.Json.Serialization; using SignalRSamples.ConnectionHandlers; using SignalRSamples.Hubs; @@ -18,7 +14,6 @@ namespace SignalRSamples { public class Startup { - private readonly JsonWriterOptions _jsonWriterOptions = new JsonWriterOptions { Indented = true }; // This method gets called by the runtime. Use this method to add services to the container. @@ -27,11 +22,7 @@ public void ConfigureServices(IServiceCollection services) { services.AddConnections(); - services.AddSignalR(options => - { - // Faster pings for testing - options.KeepAliveInterval = TimeSpan.FromSeconds(5); - }) + services.AddSignalR() .AddMessagePackProtocol(); //.AddStackExchangeRedis(); } diff --git a/src/SignalR/server/Core/ref/Microsoft.AspNetCore.SignalR.Core.netcoreapp.cs b/src/SignalR/server/Core/ref/Microsoft.AspNetCore.SignalR.Core.netcoreapp.cs index f23da84b5248..f68a44862e3c 100644 --- a/src/SignalR/server/Core/ref/Microsoft.AspNetCore.SignalR.Core.netcoreapp.cs +++ b/src/SignalR/server/Core/ref/Microsoft.AspNetCore.SignalR.Core.netcoreapp.cs @@ -181,12 +181,23 @@ public void Reset() { } } public partial class HubInvocationContext { + public HubInvocationContext(Microsoft.AspNetCore.SignalR.HubCallerContext context, System.IServiceProvider serviceProvider, Microsoft.AspNetCore.SignalR.Hub hub, System.Reflection.MethodInfo hubMethod, System.Collections.Generic.IReadOnlyList hubMethodArguments) { } + [System.ObsoleteAttribute("This constructor is obsolete and will be removed in a future version. The recommended alternative is to use the other constructor.")] public HubInvocationContext(Microsoft.AspNetCore.SignalR.HubCallerContext context, string hubMethodName, object[] hubMethodArguments) { } - public HubInvocationContext(Microsoft.AspNetCore.SignalR.HubCallerContext context, System.Type hubType, string hubMethodName, object[] hubMethodArguments) { } public Microsoft.AspNetCore.SignalR.HubCallerContext Context { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } } + public Microsoft.AspNetCore.SignalR.Hub Hub { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } } + public System.Reflection.MethodInfo HubMethod { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } } public System.Collections.Generic.IReadOnlyList HubMethodArguments { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } } + [System.ObsoleteAttribute("This property is obsolete and will be removed in a future version. Use HubMethod.Name instead.")] public string HubMethodName { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } } - public System.Type HubType { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } } + public System.IServiceProvider ServiceProvider { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } } + } + public sealed partial class HubLifetimeContext + { + public HubLifetimeContext(Microsoft.AspNetCore.SignalR.HubCallerContext context, System.IServiceProvider serviceProvider, Microsoft.AspNetCore.SignalR.Hub hub) { } + public Microsoft.AspNetCore.SignalR.HubCallerContext Context { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } } + public Microsoft.AspNetCore.SignalR.Hub Hub { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } } + public System.IServiceProvider ServiceProvider { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } } } public abstract partial class HubLifetimeManager where THub : Microsoft.AspNetCore.SignalR.Hub { @@ -227,6 +238,12 @@ public HubOptions() { } public int? StreamBufferCapacity { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute] set { } } public System.Collections.Generic.IList SupportedProtocols { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute] set { } } } + public static partial class HubOptionsExtensions + { + public static void AddFilter(this Microsoft.AspNetCore.SignalR.HubOptions options, Microsoft.AspNetCore.SignalR.IHubFilter hubFilter) { } + public static void AddFilter(this Microsoft.AspNetCore.SignalR.HubOptions options, System.Type filterType) { } + public static void AddFilter(this Microsoft.AspNetCore.SignalR.HubOptions options) where TFilter : Microsoft.AspNetCore.SignalR.IHubFilter { } + } public partial class HubOptionsSetup : Microsoft.Extensions.Options.IConfigureOptions { public HubOptionsSetup(System.Collections.Generic.IEnumerable protocols) { } @@ -294,6 +311,12 @@ public partial interface IHubContext where THub : Microsoft.AspNetCore. Microsoft.AspNetCore.SignalR.IHubClients Clients { get; } Microsoft.AspNetCore.SignalR.IGroupManager Groups { get; } } + public partial interface IHubFilter + { + System.Threading.Tasks.ValueTask InvokeMethodAsync(Microsoft.AspNetCore.SignalR.HubInvocationContext invocationContext, System.Func> next); + System.Threading.Tasks.Task OnConnectedAsync(Microsoft.AspNetCore.SignalR.HubLifetimeContext context, System.Func next) { throw null; } + System.Threading.Tasks.Task OnDisconnectedAsync(Microsoft.AspNetCore.SignalR.HubLifetimeContext context, System.Exception exception, System.Func next) { throw null; } + } public partial interface IHubProtocolResolver { System.Collections.Generic.IReadOnlyList AllProtocols { get; } diff --git a/src/SignalR/server/Core/src/HubConnectionHandler.cs b/src/SignalR/server/Core/src/HubConnectionHandler.cs index 666789049677..7aea15a5d9f0 100644 --- a/src/SignalR/server/Core/src/HubConnectionHandler.cs +++ b/src/SignalR/server/Core/src/HubConnectionHandler.cs @@ -64,22 +64,37 @@ IServiceScopeFactory serviceScopeFactory _userIdProvider = userIdProvider; _enableDetailedErrors = false; + + List hubFilters = null; if (_hubOptions.UserHasSetValues) { _maximumMessageSize = _hubOptions.MaximumReceiveMessageSize; _enableDetailedErrors = _hubOptions.EnableDetailedErrors ?? _enableDetailedErrors; + + if (_hubOptions.HubFilters != null) + { + hubFilters = new List(); + hubFilters.AddRange(_hubOptions.HubFilters); + } } else { _maximumMessageSize = _globalHubOptions.MaximumReceiveMessageSize; _enableDetailedErrors = _globalHubOptions.EnableDetailedErrors ?? _enableDetailedErrors; + + if (_globalHubOptions.HubFilters != null) + { + hubFilters = new List(); + hubFilters.AddRange(_globalHubOptions.HubFilters); + } } _dispatcher = new DefaultHubDispatcher( serviceScopeFactory, new HubContext(lifetimeManager), _enableDetailedErrors, - new Logger>(loggerFactory)); + new Logger>(loggerFactory), + hubFilters); } /// diff --git a/src/SignalR/server/Core/src/HubInvocationContext.cs b/src/SignalR/server/Core/src/HubInvocationContext.cs index 967db6ebb544..1161b030ce57 100644 --- a/src/SignalR/server/Core/src/HubInvocationContext.cs +++ b/src/SignalR/server/Core/src/HubInvocationContext.cs @@ -3,7 +3,9 @@ using System; using System.Collections.Generic; -using Microsoft.AspNetCore.Authorization; +using System.Linq; +using System.Reflection; +using Microsoft.Extensions.Internal; namespace Microsoft.AspNetCore.SignalR { @@ -12,16 +14,27 @@ namespace Microsoft.AspNetCore.SignalR /// public class HubInvocationContext { + internal ObjectMethodExecutor ObjectMethodExecutor { get; } + /// /// Instantiates a new instance of the class. /// /// Context for the active Hub connection and caller. - /// The type of the Hub. - /// The name of the Hub method being invoked. + /// The specific to the scope of this Hub method invocation. + /// The instance of the Hub. + /// The for the Hub method being invoked. /// The arguments provided by the client. - public HubInvocationContext(HubCallerContext context, Type hubType, string hubMethodName, object[] hubMethodArguments): this(context, hubMethodName, hubMethodArguments) + public HubInvocationContext(HubCallerContext context, IServiceProvider serviceProvider, Hub hub, MethodInfo hubMethod, IReadOnlyList hubMethodArguments) { - HubType = hubType; + Hub = hub; + ServiceProvider = serviceProvider; + HubMethod = hubMethod; + HubMethodArguments = hubMethodArguments; + Context = context; + +#pragma warning disable CS0618 // Type or member is obsolete + HubMethodName = HubMethod.Name; +#pragma warning restore CS0618 // Type or member is obsolete } /// @@ -30,11 +43,16 @@ public HubInvocationContext(HubCallerContext context, Type hubType, string hubMe /// Context for the active Hub connection and caller. /// The name of the Hub method being invoked. /// The arguments provided by the client. + [Obsolete("This constructor is obsolete and will be removed in a future version. The recommended alternative is to use the other constructor.")] public HubInvocationContext(HubCallerContext context, string hubMethodName, object[] hubMethodArguments) { - HubMethodName = hubMethodName; - HubMethodArguments = hubMethodArguments; - Context = context; + throw new NotSupportedException("This constructor no longer works. Use the other constructor."); + } + + internal HubInvocationContext(ObjectMethodExecutor objectMethodExecutor, HubCallerContext context, IServiceProvider serviceProvider, Hub hub, object[] hubMethodArguments) + : this(context, serviceProvider, hub, objectMethodExecutor.MethodInfo, hubMethodArguments) + { + ObjectMethodExecutor = objectMethodExecutor; } /// @@ -43,18 +61,29 @@ public HubInvocationContext(HubCallerContext context, string hubMethodName, obje public HubCallerContext Context { get; } /// - /// Gets the Hub type. + /// Gets the Hub instance. /// - public Type HubType { get; } + public Hub Hub { get; } /// /// Gets the name of the Hub method being invoked. /// + [Obsolete("This property is obsolete and will be removed in a future version. Use HubMethod.Name instead.")] public string HubMethodName { get; } /// /// Gets the arguments provided by the client. /// public IReadOnlyList HubMethodArguments { get; } + + /// + /// The specific to the scope of this Hub method invocation. + /// + public IServiceProvider ServiceProvider { get; } + + /// + /// The for the Hub method being invoked. + /// + public MethodInfo HubMethod { get; } } } diff --git a/src/SignalR/server/Core/src/HubLifetimeContext.cs b/src/SignalR/server/Core/src/HubLifetimeContext.cs new file mode 100644 index 000000000000..5bb2ed9f034c --- /dev/null +++ b/src/SignalR/server/Core/src/HubLifetimeContext.cs @@ -0,0 +1,43 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +#nullable enable + +using System; + +namespace Microsoft.AspNetCore.SignalR +{ + /// + /// Context for the hub lifetime events and . + /// + public sealed class HubLifetimeContext + { + /// + /// Instantiates a new instance of the class. + /// + /// Context for the active Hub connection and caller. + /// The specific to the scope of this Hub method invocation. + /// The instance of the Hub. + public HubLifetimeContext(HubCallerContext context, IServiceProvider serviceProvider, Hub hub) + { + Hub = hub; + ServiceProvider = serviceProvider; + Context = context; + } + + /// + /// Gets the context for the active Hub connection and caller. + /// + public HubCallerContext Context { get; } + + /// + /// Gets the Hub instance. + /// + public Hub Hub { get; } + + /// + /// The specific to the scope of this Hub method invocation. + /// + public IServiceProvider ServiceProvider { get; } + } +} diff --git a/src/SignalR/server/Core/src/HubOptions.cs b/src/SignalR/server/Core/src/HubOptions.cs index e92a3686ca46..5ff5df5045e8 100644 --- a/src/SignalR/server/Core/src/HubOptions.cs +++ b/src/SignalR/server/Core/src/HubOptions.cs @@ -51,5 +51,7 @@ public class HubOptions /// Gets or sets the max buffer size for client upload streams. The default size is 10. /// public int? StreamBufferCapacity { get; set; } = null; + + internal List HubFilters { get; set; } = null; } } diff --git a/src/SignalR/server/Core/src/HubOptionsExtensions.cs b/src/SignalR/server/Core/src/HubOptionsExtensions.cs new file mode 100644 index 000000000000..6db3e87e1f8c --- /dev/null +++ b/src/SignalR/server/Core/src/HubOptionsExtensions.cs @@ -0,0 +1,60 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +#nullable enable + +using System; +using System.Collections.Generic; +using Microsoft.AspNetCore.SignalR.Internal; + +namespace Microsoft.AspNetCore.SignalR +{ + /// + /// Methods to add 's to Hubs. + /// + public static class HubOptionsExtensions + { + /// + /// Adds an instance of an to the . + /// + /// The options to add a filter to. + /// The filter instance to add to the options. + public static void AddFilter(this HubOptions options, IHubFilter hubFilter) + { + _ = options ?? throw new ArgumentNullException(nameof(options)); + _ = hubFilter ?? throw new ArgumentNullException(nameof(hubFilter)); + + if (options.HubFilters == null) + { + options.HubFilters = new List(); + } + + options.HubFilters.Add(hubFilter); + } + + /// + /// Adds an type to the that will be resolved via DI or type activated. + /// + /// The type that will be added to the options. + /// The options to add a filter to. + public static void AddFilter(this HubOptions options) where TFilter : IHubFilter + { + _ = options ?? throw new ArgumentNullException(nameof(options)); + + options.AddFilter(typeof(TFilter)); + } + + /// + /// Adds an type to the that will be resolved via DI or type activated. + /// + /// The options to add a filter to. + /// The type that will be added to the options. + public static void AddFilter(this HubOptions options, Type filterType) + { + _ = options ?? throw new ArgumentNullException(nameof(options)); + _ = filterType ?? throw new ArgumentNullException(nameof(filterType)); + + options.AddFilter(new HubFilterFactory(filterType)); + } + } +} diff --git a/src/SignalR/server/Core/src/HubOptionsSetup`T.cs b/src/SignalR/server/Core/src/HubOptionsSetup`T.cs index 9f4fb17c0ad5..26d55750de45 100644 --- a/src/SignalR/server/Core/src/HubOptionsSetup`T.cs +++ b/src/SignalR/server/Core/src/HubOptionsSetup`T.cs @@ -17,11 +17,7 @@ public HubOptionsSetup(IOptions options) public void Configure(HubOptions options) { // Do a deep copy, otherwise users modifying the HubOptions list would be changing the global options list - options.SupportedProtocols = new List(_hubOptions.SupportedProtocols.Count); - foreach (var protocol in _hubOptions.SupportedProtocols) - { - options.SupportedProtocols.Add(protocol); - } + options.SupportedProtocols = new List(_hubOptions.SupportedProtocols); options.KeepAliveInterval = _hubOptions.KeepAliveInterval; options.HandshakeTimeout = _hubOptions.HandshakeTimeout; options.ClientTimeoutInterval = _hubOptions.ClientTimeoutInterval; @@ -30,6 +26,11 @@ public void Configure(HubOptions options) options.StreamBufferCapacity = _hubOptions.StreamBufferCapacity; options.UserHasSetValues = true; + + if (_hubOptions.HubFilters != null) + { + options.HubFilters = new List(_hubOptions.HubFilters); + } } } } diff --git a/src/SignalR/server/Core/src/IHubFilter.cs b/src/SignalR/server/Core/src/IHubFilter.cs new file mode 100644 index 000000000000..398d24a253a2 --- /dev/null +++ b/src/SignalR/server/Core/src/IHubFilter.cs @@ -0,0 +1,41 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +#nullable enable + +using System; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.SignalR +{ + /// + /// The filter abstraction for hub method invocations. + /// + public interface IHubFilter + { + /// + /// Allows handling of all Hub method invocations. + /// + /// The context for the method invocation that holds all the important information about the invoke. + /// The next filter to run, and for the final one, the Hub invocation. + /// Returns the result of the Hub method invoke. + ValueTask InvokeMethodAsync(HubInvocationContext invocationContext, Func> next); + + /// + /// Allows handling of the method. + /// + /// The context for OnConnectedAsync. + /// The next filter to run, and for the final one, the Hub invocation. + /// + Task OnConnectedAsync(HubLifetimeContext context, Func next) => next(context); + + /// + /// Allows handling of the method. + /// + /// The context for OnDisconnectedAsync. + /// The exception, if any, for the connection closing. + /// The next filter to run, and for the final one, the Hub invocation. + /// + Task OnDisconnectedAsync(HubLifetimeContext context, Exception exception, Func next) => next(context, exception); + } +} diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index e625acfff8ae..8a040b7d9260 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -16,7 +16,6 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Internal; using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Options; namespace Microsoft.AspNetCore.SignalR.Internal { @@ -27,14 +26,48 @@ internal partial class DefaultHubDispatcher : HubDispatcher where TH private readonly IHubContext _hubContext; private readonly ILogger> _logger; private readonly bool _enableDetailedErrors; + private readonly Func> _invokeMiddleware; + private readonly Func _onConnectedMiddleware; + private readonly Func _onDisconnectedMiddleware; - public DefaultHubDispatcher(IServiceScopeFactory serviceScopeFactory, IHubContext hubContext, bool enableDetailedErrors, ILogger> logger) + public DefaultHubDispatcher(IServiceScopeFactory serviceScopeFactory, IHubContext hubContext, bool enableDetailedErrors, + ILogger> logger, List hubFilters) { _serviceScopeFactory = serviceScopeFactory; _hubContext = hubContext; _enableDetailedErrors = enableDetailedErrors; _logger = logger; DiscoverHubMethods(); + + var count = hubFilters?.Count ?? 0; + if (count != 0) + { + _invokeMiddleware = (invocationContext) => + { + var arguments = invocationContext.HubMethodArguments as object[] ?? invocationContext.HubMethodArguments.ToArray(); + if (invocationContext.ObjectMethodExecutor != null) + { + return ExecuteMethod(invocationContext.ObjectMethodExecutor, invocationContext.Hub, arguments); + } + return ExecuteMethod(invocationContext.HubMethod.Name, invocationContext.Hub, arguments); + }; + + _onConnectedMiddleware = (context) => context.Hub.OnConnectedAsync(); + _onDisconnectedMiddleware = (context, exception) => context.Hub.OnDisconnectedAsync(exception); + + for (var i = count - 1; i > -1; i--) + { + var resolvedFilter = hubFilters[i]; + var nextFilter = _invokeMiddleware; + _invokeMiddleware = (context) => resolvedFilter.InvokeMethodAsync(context, nextFilter); + + var connectedFilter = _onConnectedMiddleware; + _onConnectedMiddleware = (context) => resolvedFilter.OnConnectedAsync(context, connectedFilter); + + var disconnectedFilter = _onDisconnectedMiddleware; + _onDisconnectedMiddleware = (context, exception) => resolvedFilter.OnDisconnectedAsync(context, exception, disconnectedFilter); + } + } } public override async Task OnConnectedAsync(HubConnectionContext connection) @@ -50,7 +83,16 @@ public override async Task OnConnectedAsync(HubConnectionContext connection) try { InitializeHub(hub, connection); - await hub.OnConnectedAsync(); + + if (_onConnectedMiddleware != null) + { + var context = new HubLifetimeContext(connection.HubCallerContext, scope.ServiceProvider, hub); + await _onConnectedMiddleware(context); + } + else + { + await hub.OnConnectedAsync(); + } } finally { @@ -76,7 +118,16 @@ public override async Task OnDisconnectedAsync(HubConnectionContext connection, try { InitializeHub(hub, connection); - await hub.OnDisconnectedAsync(exception); + + if (_onDisconnectedMiddleware != null) + { + var context = new HubLifetimeContext(connection.HubCallerContext, scope.ServiceProvider, hub); + await _onDisconnectedMiddleware(context, exception); + } + else + { + await hub.OnDisconnectedAsync(exception); + } } finally { @@ -220,7 +271,10 @@ private async Task Invoke(HubMethodDescriptor descriptor, HubConnectionContext c THub hub = null; try { - if (!await IsHubMethodAuthorized(scope.ServiceProvider, connection, descriptor.Policies, descriptor.MethodExecutor.MethodInfo.Name, hubMethodInvocationMessage.Arguments)) + hubActivator = scope.ServiceProvider.GetRequiredService>(); + hub = hubActivator.Create(); + + if (!await IsHubMethodAuthorized(scope.ServiceProvider, connection, descriptor, hubMethodInvocationMessage.Arguments, hub)) { Log.HubMethodNotAuthorized(_logger, hubMethodInvocationMessage.Target); await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, @@ -233,9 +287,6 @@ await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, return; } - hubActivator = scope.ServiceProvider.GetRequiredService>(); - hub = hubActivator.Create(); - try { var clientStreamLength = hubMethodInvocationMessage.StreamIds?.Length ?? 0; @@ -298,7 +349,7 @@ await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, if (isStreamResponse) { - var result = await ExecuteHubMethod(methodExecutor, hub, arguments); + var result = await ExecuteHubMethod(methodExecutor, hub, arguments, connection, scope.ServiceProvider); if (result == null) { @@ -315,7 +366,6 @@ await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, Log.StreamingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor); _ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerable, scope, hubActivator, hub, cts, hubMethodInvocationMessage); } - else { // Invoke or Send @@ -324,7 +374,7 @@ async Task ExecuteInvocation() object result; try { - result = await ExecuteHubMethod(methodExecutor, hub, arguments); + result = await ExecuteHubMethod(methodExecutor, hub, arguments, connection, scope.ServiceProvider); Log.SendingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor); } catch (Exception ex) @@ -447,13 +497,36 @@ private async Task StreamResultsAsync(string invocationId, HubConnectionContext } } - private static async Task ExecuteHubMethod(ObjectMethodExecutor methodExecutor, THub hub, object[] arguments) + private ValueTask ExecuteHubMethod(ObjectMethodExecutor methodExecutor, THub hub, object[] arguments, HubConnectionContext connection, IServiceProvider serviceProvider) + { + if (_invokeMiddleware != null) + { + var invocationContext = new HubInvocationContext(methodExecutor, connection.HubCallerContext, serviceProvider, hub, arguments); + return _invokeMiddleware(invocationContext); + } + + // If no Hub filters are registered + return ExecuteMethod(methodExecutor, hub, arguments); + } + + private ValueTask ExecuteMethod(string hubMethodName, Hub hub, object[] arguments) + { + if (!_methods.TryGetValue(hubMethodName, out var methodDescriptor)) + { + throw new HubException($"Unknown hub method '{hubMethodName}'"); + } + var methodExecutor = methodDescriptor.MethodExecutor; + return ExecuteMethod(methodExecutor, hub, arguments); + } + + private async ValueTask ExecuteMethod(ObjectMethodExecutor methodExecutor, Hub hub, object[] arguments) { if (methodExecutor.IsMethodAsync) { if (methodExecutor.MethodReturnType == typeof(Task)) { await (Task)methodExecutor.Execute(hub, arguments); + return null; } else { @@ -464,8 +537,6 @@ private static async Task ExecuteHubMethod(ObjectMethodExecutor methodEx { return methodExecutor.Execute(hub, arguments); } - - return null; } private async Task SendInvocationError(string invocationId, @@ -486,15 +557,15 @@ private void InitializeHub(THub hub, HubConnectionContext connection) hub.Groups = _hubContext.Groups; } - private Task IsHubMethodAuthorized(IServiceProvider provider, HubConnectionContext hubConnectionContext, IList policies, string hubMethodName, object[] hubMethodArguments) + private Task IsHubMethodAuthorized(IServiceProvider provider, HubConnectionContext hubConnectionContext, HubMethodDescriptor descriptor, object[] hubMethodArguments, Hub hub) { // If there are no policies we don't need to run auth - if (policies.Count == 0) + if (descriptor.Policies.Count == 0) { return TaskCache.True; } - return IsHubMethodAuthorizedSlow(provider, hubConnectionContext.User, policies, new HubInvocationContext(hubConnectionContext.HubCallerContext, typeof(THub), hubMethodName, hubMethodArguments)); + return IsHubMethodAuthorizedSlow(provider, hubConnectionContext.User, descriptor.Policies, new HubInvocationContext(hubConnectionContext.HubCallerContext, provider, hub, descriptor.MethodExecutor.MethodInfo, hubMethodArguments)); } private static async Task IsHubMethodAuthorizedSlow(IServiceProvider provider, ClaimsPrincipal principal, IList policies, HubInvocationContext resource) diff --git a/src/SignalR/server/Core/src/Internal/HubFilterFactory.cs b/src/SignalR/server/Core/src/Internal/HubFilterFactory.cs new file mode 100644 index 000000000000..da8b32e72f1f --- /dev/null +++ b/src/SignalR/server/Core/src/Internal/HubFilterFactory.cs @@ -0,0 +1,100 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +#nullable enable + +using System; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.AspNetCore.SignalR.Internal +{ + internal class HubFilterFactory : IHubFilter + { + private readonly ObjectFactory _objectFactory; + private readonly Type _filterType; + + public HubFilterFactory(Type filterType) + { + _objectFactory = ActivatorUtilities.CreateFactory(filterType, Array.Empty()); + _filterType = filterType; + } + + public async ValueTask InvokeMethodAsync(HubInvocationContext invocationContext, Func> next) + { + var (filter, owned) = GetFilter(invocationContext.ServiceProvider); + + try + { + return await filter.InvokeMethodAsync(invocationContext, next); + } + finally + { + if (owned) + { + await DisposeFilter(filter); + } + } + } + + public async Task OnConnectedAsync(HubLifetimeContext context, Func next) + { + var (filter, owned) = GetFilter(context.ServiceProvider); + + try + { + await filter.OnConnectedAsync(context, next); + } + finally + { + if (owned) + { + await DisposeFilter(filter); + } + } + } + + public async Task OnDisconnectedAsync(HubLifetimeContext context, Exception exception, Func next) + { + var (filter, owned) = GetFilter(context.ServiceProvider); + + try + { + await filter.OnDisconnectedAsync(context, exception, next); + } + finally + { + if (owned) + { + await DisposeFilter(filter); + } + } + } + + private ValueTask DisposeFilter(IHubFilter filter) + { + if (filter is IAsyncDisposable asyncDispsable) + { + return asyncDispsable.DisposeAsync(); + } + if (filter is IDisposable disposable) + { + disposable.Dispose(); + } + return default; + } + + private (IHubFilter, bool) GetFilter(IServiceProvider serviceProvider) + { + var owned = false; + var filter = (IHubFilter?)serviceProvider.GetService(_filterType); + if (filter == null) + { + filter = (IHubFilter)_objectFactory.Invoke(serviceProvider, null); + owned = true; + } + + return (filter, owned); + } + } +} diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs index 0de55bacef33..a33cb809d034 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs @@ -1048,8 +1048,19 @@ public async ValueTask MoveNextAsync() public class TcsService { - public TaskCompletionSource StartedMethod = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - public TaskCompletionSource EndMethod = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + public TaskCompletionSource StartedMethod; + public TaskCompletionSource EndMethod; + + public TcsService() + { + Reset(); + } + + public void Reset() + { + StartedMethod = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + EndMethod = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + } } public interface ITypedHubClient diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs index eb6d11a9dc0f..b511d2e6277a 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs @@ -2232,14 +2232,18 @@ public Task HandleAsync(AuthorizationHandlerContext context) { Assert.NotNull(context.Resource); var resource = Assert.IsType(context.Resource); - Assert.Equal(typeof(MethodHub), resource.HubType); + Assert.Equal(typeof(MethodHub), resource.Hub.GetType()); +#pragma warning disable CS0618 // Type or member is obsolete Assert.Equal(nameof(MethodHub.MultiParamAuthMethod), resource.HubMethodName); +#pragma warning restore CS0618 // Type or member is obsolete Assert.Equal(2, resource.HubMethodArguments?.Count); Assert.Equal("Hello", resource.HubMethodArguments[0]); Assert.Equal("World!", resource.HubMethodArguments[1]); Assert.NotNull(resource.Context); Assert.Equal(context.User, resource.Context.User); Assert.NotNull(resource.Context.GetHttpContext()); + Assert.NotNull(resource.ServiceProvider); + Assert.Equal(typeof(MethodHub).GetMethod(nameof(MethodHub.MultiParamAuthMethod)), resource.HubMethod); return Task.CompletedTask; } diff --git a/src/SignalR/server/SignalR/test/HubFilterTests.cs b/src/SignalR/server/SignalR/test/HubFilterTests.cs new file mode 100644 index 000000000000..91bc5d835c6c --- /dev/null +++ b/src/SignalR/server/SignalR/test/HubFilterTests.cs @@ -0,0 +1,867 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Diagnostics; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Internal; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging.Testing; +using Xunit; + +namespace Microsoft.AspNetCore.SignalR.Tests +{ + public class HubFilterTests : VerifiableLoggedTest + { + [Fact] + public async Task GlobalHubFilterByType_MethodsAreCalled() + { + using (StartVerifiableLog()) + { + var tcsService = new TcsService(); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => + { + services.AddSignalR(options => + { + options.AddFilter(); + }); + + services.AddSingleton(tcsService); + }, LoggerFactory); + + await AssertMethodsCalled(serviceProvider, tcsService); + } + } + + [Fact] + public async Task GlobalHubFilterByInstance_MethodsAreCalled() + { + using (StartVerifiableLog()) + { + var tcsService = new TcsService(); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => + { + services.AddSignalR(options => + { + options.AddFilter(new VerifyMethodFilter(tcsService)); + }); + }, LoggerFactory); + + await AssertMethodsCalled(serviceProvider, tcsService); + } + } + + [Fact] + public async Task PerHubFilterByInstance_MethodsAreCalled() + { + using (StartVerifiableLog()) + { + var tcsService = new TcsService(); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => + { + services.AddSignalR().AddHubOptions(options => + { + options.AddFilter(new VerifyMethodFilter(tcsService)); + }); + }, LoggerFactory); + + await AssertMethodsCalled(serviceProvider, tcsService); + } + } + + [Fact] + public async Task PerHubFilterByCompileTimeType_MethodsAreCalled() + { + using (StartVerifiableLog()) + { + var tcsService = new TcsService(); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => + { + services.AddSignalR().AddHubOptions(options => + { + options.AddFilter(); + }); + + services.AddSingleton(tcsService); + }, LoggerFactory); + + await AssertMethodsCalled(serviceProvider, tcsService); + } + } + + [Fact] + public async Task PerHubFilterByRuntimeType_MethodsAreCalled() + { + using (StartVerifiableLog()) + { + var tcsService = new TcsService(); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => + { + services.AddSignalR().AddHubOptions(options => + { + options.AddFilter(typeof(VerifyMethodFilter)); + }); + + services.AddSingleton(tcsService); + }, LoggerFactory); + + await AssertMethodsCalled(serviceProvider, tcsService); + } + } + + private async Task AssertMethodsCalled(IServiceProvider serviceProvider, TcsService tcsService) + { + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler); + + await tcsService.StartedMethod.Task.OrTimeout(); + await client.Connected.OrTimeout(); + await tcsService.EndMethod.Task.OrTimeout(); + + tcsService.Reset(); + var message = await client.InvokeAsync(nameof(MethodHub.Echo), "Hello world!").OrTimeout(); + await tcsService.EndMethod.Task.OrTimeout(); + tcsService.Reset(); + + Assert.Null(message.Error); + + client.Dispose(); + + await connectionHandlerTask.OrTimeout(); + + await tcsService.EndMethod.Task.OrTimeout(); + } + } + + [Fact] + public async Task MutlipleFilters_MethodsAreCalled() + { + using (StartVerifiableLog()) + { + var tcsService1 = new TcsService(); + var tcsService2 = new TcsService(); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => + { + services.AddSignalR(options => + { + options.AddFilter(new VerifyMethodFilter(tcsService1)); + options.AddFilter(new VerifyMethodFilter(tcsService2)); + }); + }, LoggerFactory); + + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler); + + await tcsService1.StartedMethod.Task.OrTimeout(); + await tcsService2.StartedMethod.Task.OrTimeout(); + await client.Connected.OrTimeout(); + await tcsService1.EndMethod.Task.OrTimeout(); + await tcsService2.EndMethod.Task.OrTimeout(); + + tcsService1.Reset(); + tcsService2.Reset(); + var message = await client.InvokeAsync(nameof(MethodHub.Echo), "Hello world!").OrTimeout(); + await tcsService1.EndMethod.Task.OrTimeout(); + await tcsService2.EndMethod.Task.OrTimeout(); + tcsService1.Reset(); + tcsService2.Reset(); + + Assert.Null(message.Error); + + client.Dispose(); + + await connectionHandlerTask.OrTimeout(); + + await tcsService1.EndMethod.Task.OrTimeout(); + await tcsService2.EndMethod.Task.OrTimeout(); + } + } + } + + [Fact] + public async Task MixingTypeAndInstanceGlobalFilters_MethodsAreCalled() + { + using (StartVerifiableLog()) + { + var tcsService1 = new TcsService(); + var tcsService2 = new TcsService(); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => + { + services.AddSignalR(options => + { + options.AddFilter(new VerifyMethodFilter(tcsService1)); + options.AddFilter(); + }); + + services.AddSingleton(tcsService2); + }, LoggerFactory); + + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler); + + await tcsService1.StartedMethod.Task.OrTimeout(); + await tcsService2.StartedMethod.Task.OrTimeout(); + await client.Connected.OrTimeout(); + await tcsService1.EndMethod.Task.OrTimeout(); + await tcsService2.EndMethod.Task.OrTimeout(); + + tcsService1.Reset(); + tcsService2.Reset(); + var message = await client.InvokeAsync(nameof(MethodHub.Echo), "Hello world!").OrTimeout(); + await tcsService1.EndMethod.Task.OrTimeout(); + await tcsService2.EndMethod.Task.OrTimeout(); + tcsService1.Reset(); + tcsService2.Reset(); + + Assert.Null(message.Error); + + client.Dispose(); + + await connectionHandlerTask.OrTimeout(); + + await tcsService1.EndMethod.Task.OrTimeout(); + await tcsService2.EndMethod.Task.OrTimeout(); + } + } + } + + [Fact] + public async Task MixingTypeAndInstanceHubSpecificFilters_MethodsAreCalled() + { + using (StartVerifiableLog()) + { + var tcsService1 = new TcsService(); + var tcsService2 = new TcsService(); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => + { + services.AddSignalR() + .AddHubOptions(options => + { + options.AddFilter(new VerifyMethodFilter(tcsService1)); + options.AddFilter(); + }); + + services.AddSingleton(tcsService2); + }, LoggerFactory); + + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler); + + await tcsService1.StartedMethod.Task.OrTimeout(); + await tcsService2.StartedMethod.Task.OrTimeout(); + await client.Connected.OrTimeout(); + await tcsService1.EndMethod.Task.OrTimeout(); + await tcsService2.EndMethod.Task.OrTimeout(); + + tcsService1.Reset(); + tcsService2.Reset(); + var message = await client.InvokeAsync(nameof(MethodHub.Echo), "Hello world!").OrTimeout(); + await tcsService1.EndMethod.Task.OrTimeout(); + await tcsService2.EndMethod.Task.OrTimeout(); + tcsService1.Reset(); + tcsService2.Reset(); + + Assert.Null(message.Error); + + client.Dispose(); + + await connectionHandlerTask.OrTimeout(); + + await tcsService1.EndMethod.Task.OrTimeout(); + await tcsService2.EndMethod.Task.OrTimeout(); + } + } + } + + [Fact] + public async Task GlobalFiltersRunInOrder() + { + using (StartVerifiableLog()) + { + var syncPoint1 = SyncPoint.Create(3, out var syncPoints1); + var syncPoint2 = SyncPoint.Create(3, out var syncPoints2); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => + { + services.AddSignalR(options => + { + options.AddFilter(new SyncPointFilter(syncPoints1)); + options.AddFilter(new SyncPointFilter(syncPoints2)); + }); + }, LoggerFactory); + + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler); + + await syncPoints1[0].WaitForSyncPoint().OrTimeout(); + // Second filter wont run yet because first filter is waiting on SyncPoint + Assert.False(syncPoints2[0].WaitForSyncPoint().IsCompleted); + syncPoints1[0].Continue(); + + await syncPoints2[0].WaitForSyncPoint().OrTimeout(); + syncPoints2[0].Continue(); + await client.Connected.OrTimeout(); + + var invokeTask = client.InvokeAsync(nameof(MethodHub.Echo), "Hello world!"); + + await syncPoints1[1].WaitForSyncPoint().OrTimeout(); + // Second filter wont run yet because first filter is waiting on SyncPoint + Assert.False(syncPoints2[1].WaitForSyncPoint().IsCompleted); + syncPoints1[1].Continue(); + + await syncPoints2[1].WaitForSyncPoint().OrTimeout(); + syncPoints2[1].Continue(); + var message = await invokeTask.OrTimeout(); + + Assert.Null(message.Error); + + client.Dispose(); + + await syncPoints1[2].WaitForSyncPoint().OrTimeout(); + // Second filter wont run yet because first filter is waiting on SyncPoint + Assert.False(syncPoints2[2].WaitForSyncPoint().IsCompleted); + syncPoints1[2].Continue(); + + await syncPoints2[2].WaitForSyncPoint().OrTimeout(); + syncPoints2[2].Continue(); + + await connectionHandlerTask.OrTimeout(); + } + } + } + + [Fact] + public async Task HubSpecificFiltersRunInOrder() + { + using (StartVerifiableLog()) + { + var syncPoint1 = SyncPoint.Create(3, out var syncPoints1); + var syncPoint2 = SyncPoint.Create(3, out var syncPoints2); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => + { + services.AddSignalR() + .AddHubOptions(options => + { + options.AddFilter(new SyncPointFilter(syncPoints1)); + options.AddFilter(new SyncPointFilter(syncPoints2)); + }); + }, LoggerFactory); + + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler); + + await syncPoints1[0].WaitForSyncPoint().OrTimeout(); + // Second filter wont run yet because first filter is waiting on SyncPoint + Assert.False(syncPoints2[0].WaitForSyncPoint().IsCompleted); + syncPoints1[0].Continue(); + + await syncPoints2[0].WaitForSyncPoint().OrTimeout(); + syncPoints2[0].Continue(); + await client.Connected.OrTimeout(); + + var invokeTask = client.InvokeAsync(nameof(MethodHub.Echo), "Hello world!"); + + await syncPoints1[1].WaitForSyncPoint().OrTimeout(); + // Second filter wont run yet because first filter is waiting on SyncPoint + Assert.False(syncPoints2[1].WaitForSyncPoint().IsCompleted); + syncPoints1[1].Continue(); + + await syncPoints2[1].WaitForSyncPoint().OrTimeout(); + syncPoints2[1].Continue(); + var message = await invokeTask.OrTimeout(); + + Assert.Null(message.Error); + + client.Dispose(); + + await syncPoints1[2].WaitForSyncPoint().OrTimeout(); + // Second filter wont run yet because first filter is waiting on SyncPoint + Assert.False(syncPoints2[2].WaitForSyncPoint().IsCompleted); + syncPoints1[2].Continue(); + + await syncPoints2[2].WaitForSyncPoint().OrTimeout(); + syncPoints2[2].Continue(); + + await connectionHandlerTask.OrTimeout(); + } + } + } + + [Fact] + public async Task GlobalFiltersRunBeforeHubSpecificFilters() + { + using (StartVerifiableLog()) + { + var syncPoint1 = SyncPoint.Create(3, out var syncPoints1); + var syncPoint2 = SyncPoint.Create(3, out var syncPoints2); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => + { + services.AddSignalR(options => + { + options.AddFilter(new SyncPointFilter(syncPoints1)); + }) + .AddHubOptions(options => + { + options.AddFilter(new SyncPointFilter(syncPoints2)); + }); + }, LoggerFactory); + + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler); + + await syncPoints1[0].WaitForSyncPoint().OrTimeout(); + // Second filter wont run yet because first filter is waiting on SyncPoint + Assert.False(syncPoints2[0].WaitForSyncPoint().IsCompleted); + syncPoints1[0].Continue(); + + await syncPoints2[0].WaitForSyncPoint().OrTimeout(); + syncPoints2[0].Continue(); + await client.Connected.OrTimeout(); + + var invokeTask = client.InvokeAsync(nameof(MethodHub.Echo), "Hello world!"); + + await syncPoints1[1].WaitForSyncPoint().OrTimeout(); + // Second filter wont run yet because first filter is waiting on SyncPoint + Assert.False(syncPoints2[1].WaitForSyncPoint().IsCompleted); + syncPoints1[1].Continue(); + + await syncPoints2[1].WaitForSyncPoint().OrTimeout(); + syncPoints2[1].Continue(); + var message = await invokeTask.OrTimeout(); + + Assert.Null(message.Error); + + client.Dispose(); + + await syncPoints1[2].WaitForSyncPoint().OrTimeout(); + // Second filter wont run yet because first filter is waiting on SyncPoint + Assert.False(syncPoints2[2].WaitForSyncPoint().IsCompleted); + syncPoints1[2].Continue(); + + await syncPoints2[2].WaitForSyncPoint().OrTimeout(); + syncPoints2[2].Continue(); + + await connectionHandlerTask.OrTimeout(); + } + } + } + + [Fact] + public async Task FilterCanBeResolvedFromDI() + { + using (StartVerifiableLog()) + { + var tcsService = new TcsService(); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => + { + services.AddSignalR(options => + { + options.AddFilter(); + }); + + // If this instance wasn't resolved, then the tcsService.StartedMethod waits would never trigger and fail the test + services.AddSingleton(new VerifyMethodFilter(tcsService)); + }, LoggerFactory); + + await AssertMethodsCalled(serviceProvider, tcsService); + } + } + + [Fact] + public async Task FiltersHaveTransientScopeByDefault() + { + using (StartVerifiableLog()) + { + var counter = new FilterCounter(); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => + { + services.AddSignalR(options => + { + options.AddFilter(); + }); + + services.AddSingleton(counter); + }, LoggerFactory); + + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler); + + await client.Connected.OrTimeout(); + // Filter is transient, so these counts are reset every time the filter is created + Assert.Equal(1, counter.OnConnectedAsyncCount); + Assert.Equal(0, counter.InvokeMethodAsyncCount); + Assert.Equal(0, counter.OnDisconnectedAsyncCount); + + var message = await client.InvokeAsync(nameof(MethodHub.Echo), "Hello world!").OrTimeout(); + // Filter is transient, so these counts are reset every time the filter is created + Assert.Equal(0, counter.OnConnectedAsyncCount); + Assert.Equal(1, counter.InvokeMethodAsyncCount); + Assert.Equal(0, counter.OnDisconnectedAsyncCount); + + Assert.Null(message.Error); + + client.Dispose(); + + await connectionHandlerTask.OrTimeout(); + + // Filter is transient, so these counts are reset every time the filter is created + Assert.Equal(0, counter.OnConnectedAsyncCount); + Assert.Equal(0, counter.InvokeMethodAsyncCount); + Assert.Equal(1, counter.OnDisconnectedAsyncCount); + } + } + } + + [Fact] + public async Task FiltersCanBeSingletonIfAddedToDI() + { + using (StartVerifiableLog()) + { + var counter = new FilterCounter(); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => + { + services.AddSignalR(options => + { + options.AddFilter(); + }); + + services.AddSingleton(); + services.AddSingleton(counter); + }, LoggerFactory); + + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler); + + await client.Connected.OrTimeout(); + Assert.Equal(1, counter.OnConnectedAsyncCount); + Assert.Equal(0, counter.InvokeMethodAsyncCount); + Assert.Equal(0, counter.OnDisconnectedAsyncCount); + + var message = await client.InvokeAsync(nameof(MethodHub.Echo), "Hello world!").OrTimeout(); + Assert.Equal(1, counter.OnConnectedAsyncCount); + Assert.Equal(1, counter.InvokeMethodAsyncCount); + Assert.Equal(0, counter.OnDisconnectedAsyncCount); + + Assert.Null(message.Error); + + client.Dispose(); + + await connectionHandlerTask.OrTimeout(); + + Assert.Equal(1, counter.OnConnectedAsyncCount); + Assert.Equal(1, counter.InvokeMethodAsyncCount); + Assert.Equal(1, counter.OnDisconnectedAsyncCount); + } + } + } + + [Fact] + public async Task ConnectionContinuesIfOnConnectedAsyncThrowsAndFilterDoesNot() + { + using (StartVerifiableLog()) + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => + { + services.AddSignalR(options => + { + options.EnableDetailedErrors = true; + options.AddFilter(); + }); + }, LoggerFactory); + + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler); + + // Verify connection still connected, can't invoke a method if the connection is disconnected + var message = await client.InvokeAsync("Method"); + Assert.Equal("Failed to invoke 'Method' due to an error on the server. HubException: Method does not exist.", message.Error); + + client.Dispose(); + + await connectionHandlerTask.OrTimeout(); + } + } + } + + [Fact] + public async Task ConnectionContinuesIfOnConnectedAsyncNotCalledByFilter() + { + using (StartVerifiableLog()) + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => + { + services.AddSignalR(options => + { + options.EnableDetailedErrors = true; + options.AddFilter(new SkipNextFilter(skipOnConnected: true)); + }); + }, LoggerFactory); + + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler); + + // Verify connection still connected, can't invoke a method if the connection is disconnected + var message = await client.InvokeAsync("Method"); + Assert.Equal("Failed to invoke 'Method' due to an error on the server. HubException: Method does not exist.", message.Error); + + client.Dispose(); + + await connectionHandlerTask.OrTimeout(); + } + } + } + + [Fact] + public async Task FilterCanSkipCallingHubMethod() + { + using (StartVerifiableLog()) + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => + { + services.AddSignalR(options => + { + options.AddFilter(new SkipNextFilter(skipInvoke: true)); + }); + }, LoggerFactory); + + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler); + + await client.Connected.OrTimeout(); + + var message = await client.InvokeAsync(nameof(MethodHub.Echo), "Hello world!").OrTimeout(); + + Assert.Null(message.Error); + Assert.Null(message.Result); + + client.Dispose(); + + await connectionHandlerTask.OrTimeout(); + } + } + } + + [Fact] + public async Task FiltersWithIDisposableAreDisposed() + { + using (StartVerifiableLog()) + { + var tcsService = new TcsService(); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => + { + services.AddSignalR(options => + { + options.EnableDetailedErrors = true; + options.AddFilter(); + }); + + services.AddSingleton(tcsService); + }, LoggerFactory); + + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler); + + // OnConnectedAsync creates and destroys the filter + await tcsService.StartedMethod.Task.OrTimeout(); + tcsService.Reset(); + + var message = await client.InvokeAsync("Echo", "Hello"); + Assert.Equal("Hello", message.Result); + await tcsService.StartedMethod.Task.OrTimeout(); + tcsService.Reset(); + + client.Dispose(); + + // OnDisconnectedAsync creates and destroys the filter + await tcsService.StartedMethod.Task.OrTimeout(); + await connectionHandlerTask.OrTimeout(); + } + } + } + + [Fact] + public async Task InstanceFiltersWithIDisposableAreNotDisposed() + { + using (StartVerifiableLog()) + { + var tcsService = new TcsService(); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => + { + services.AddSignalR(options => + { + options.EnableDetailedErrors = true; + options.AddFilter(new DisposableFilter(tcsService)); + }); + + services.AddSingleton(tcsService); + }, LoggerFactory); + + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler); + + var message = await client.InvokeAsync("Echo", "Hello"); + Assert.Equal("Hello", message.Result); + + client.Dispose(); + + await connectionHandlerTask.OrTimeout(); + + Assert.False(tcsService.StartedMethod.Task.IsCompleted); + } + } + } + + [Fact] + public async Task FiltersWithIAsyncDisposableAreDisposed() + { + using (StartVerifiableLog()) + { + var tcsService = new TcsService(); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => + { + services.AddSignalR(options => + { + options.EnableDetailedErrors = true; + options.AddFilter(); + }); + + services.AddSingleton(tcsService); + }, LoggerFactory); + + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler); + + // OnConnectedAsync creates and destroys the filter + await tcsService.StartedMethod.Task.OrTimeout(); + tcsService.Reset(); + + var message = await client.InvokeAsync("Echo", "Hello"); + Assert.Equal("Hello", message.Result); + await tcsService.StartedMethod.Task.OrTimeout(); + tcsService.Reset(); + + client.Dispose(); + + // OnDisconnectedAsync creates and destroys the filter + await tcsService.StartedMethod.Task.OrTimeout(); + await connectionHandlerTask.OrTimeout(); + } + } + } + + [Fact] + public async Task InstanceFiltersWithIAsyncDisposableAreNotDisposed() + { + using (StartVerifiableLog()) + { + var tcsService = new TcsService(); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => + { + services.AddSignalR(options => + { + options.EnableDetailedErrors = true; + options.AddFilter(new AsyncDisposableFilter(tcsService)); + }); + + services.AddSingleton(tcsService); + }, LoggerFactory); + + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler); + + var message = await client.InvokeAsync("Echo", "Hello"); + Assert.Equal("Hello", message.Result); + + client.Dispose(); + + await connectionHandlerTask.OrTimeout(); + + Assert.False(tcsService.StartedMethod.Task.IsCompleted); + } + } + } + + [Fact] + public async Task InvokeFailsWhenFilterCallsNonExistantMethod() + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == "Microsoft.AspNetCore.SignalR.Internal.DefaultHubDispatcher" && + writeContext.EventId.Name == "FailedInvokingHubMethod"; + } + + using (StartVerifiableLog(expectedErrorsFilter: ExpectedErrors)) + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => + { + services.AddSignalR(options => + { + options.EnableDetailedErrors = true; + options.AddFilter(); + }); + }, LoggerFactory); + + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler); + + var message = await client.InvokeAsync("Echo", "Hello"); + Assert.Equal("An unexpected error occurred invoking 'Echo' on the server. HubException: Unknown hub method 'BaseMethod'", message.Error); + + client.Dispose(); + + await connectionHandlerTask.OrTimeout(); + } + } + } + } +} diff --git a/src/SignalR/server/SignalR/test/TestFilters.cs b/src/SignalR/server/SignalR/test/TestFilters.cs new file mode 100644 index 000000000000..e32c48284a2a --- /dev/null +++ b/src/SignalR/server/SignalR/test/TestFilters.cs @@ -0,0 +1,236 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Diagnostics; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Internal; + +namespace Microsoft.AspNetCore.SignalR.Tests +{ + public class VerifyMethodFilter : IHubFilter + { + private readonly TcsService _service; + public VerifyMethodFilter(TcsService tcsService) + { + _service = tcsService; + } + + public async Task OnConnectedAsync(HubLifetimeContext context, Func next) + { + _service.StartedMethod.TrySetResult(null); + await next(context); + _service.EndMethod.TrySetResult(null); + } + + public async ValueTask InvokeMethodAsync(HubInvocationContext invocationContext, Func> next) + { + _service.StartedMethod.TrySetResult(null); + var result = await next(invocationContext); + _service.EndMethod.TrySetResult(null); + + return result; + } + + public async Task OnDisconnectedAsync(HubLifetimeContext context, Exception exception, Func next) + { + _service.StartedMethod.TrySetResult(null); + await next(context, exception); + _service.EndMethod.TrySetResult(null); + } + } + + public class SyncPointFilter : IHubFilter + { + private readonly SyncPoint[] _syncPoint; + public SyncPointFilter(SyncPoint[] syncPoints) + { + Debug.Assert(syncPoints.Length == 3); + _syncPoint = syncPoints; + } + + public async Task OnConnectedAsync(HubLifetimeContext context, Func next) + { + await _syncPoint[0].WaitToContinue(); + await next(context); + } + + public async ValueTask InvokeMethodAsync(HubInvocationContext invocationContext, Func> next) + { + await _syncPoint[1].WaitToContinue(); + var result = await next(invocationContext); + + return result; + } + + public async Task OnDisconnectedAsync(HubLifetimeContext context, Exception exception, Func next) + { + await _syncPoint[2].WaitToContinue(); + await next(context, exception); + } + } + + public class FilterCounter + { + public int OnConnectedAsyncCount; + public int InvokeMethodAsyncCount; + public int OnDisconnectedAsyncCount; + } + + public class CounterFilter : IHubFilter + { + private readonly FilterCounter _counter; + public CounterFilter(FilterCounter counter) + { + _counter = counter; + _counter.OnConnectedAsyncCount = 0; + _counter.InvokeMethodAsyncCount = 0; + _counter.OnDisconnectedAsyncCount = 0; + } + + public Task OnConnectedAsync(HubLifetimeContext context, Func next) + { + _counter.OnConnectedAsyncCount++; + return next(context); + } + + public Task OnDisconnectedAsync(HubLifetimeContext context, Exception exception, Func next) + { + _counter.OnDisconnectedAsyncCount++; + return next(context, exception); + } + + public ValueTask InvokeMethodAsync(HubInvocationContext invocationContext, Func> next) + { + _counter.InvokeMethodAsyncCount++; + return next(invocationContext); + } + } + + public class NoExceptionFilter : IHubFilter + { + public async Task OnConnectedAsync(HubLifetimeContext context, Func next) + { + try + { + await next(context); + } + catch { } + } + + public async Task OnDisconnectedAsync(HubLifetimeContext context, Exception exception, Func next) + { + try + { + await next(context, exception); + } + catch { } + } + + public async ValueTask InvokeMethodAsync(HubInvocationContext invocationContext, Func> next) + { + try + { + return await next(invocationContext); + } + catch { } + + return null; + } + } + + public class SkipNextFilter : IHubFilter + { + private readonly bool _skipOnConnected; + private readonly bool _skipInvoke; + private readonly bool _skipOnDisconnected; + + public SkipNextFilter(bool skipOnConnected = false, bool skipInvoke = false, bool skipOnDisconnected = false) + { + _skipOnConnected = skipOnConnected; + _skipInvoke = skipInvoke; + _skipOnDisconnected = skipOnDisconnected; + } + + public Task OnConnectedAsync(HubLifetimeContext context, Func next) + { + if (_skipOnConnected) + { + return Task.CompletedTask; + } + + return next(context); + } + + public Task OnDisconnectedAsync(HubLifetimeContext context, Exception exception, Func next) + { + if (_skipOnDisconnected) + { + return Task.CompletedTask; + } + + return next(context, exception); + } + + public ValueTask InvokeMethodAsync(HubInvocationContext invocationContext, Func> next) + { + if (_skipInvoke) + { + return new ValueTask(); + } + + return next(invocationContext); + } + } + + public class DisposableFilter : IHubFilter, IDisposable + { + private readonly TcsService _tcsService; + + public DisposableFilter(TcsService tcsService) + { + _tcsService = tcsService; + } + + public void Dispose() + { + _tcsService.StartedMethod.SetResult(null); + } + + public ValueTask InvokeMethodAsync(HubInvocationContext invocationContext, Func> next) + { + return next(invocationContext); + } + } + + public class AsyncDisposableFilter : IHubFilter, IAsyncDisposable + { + private readonly TcsService _tcsService; + + public AsyncDisposableFilter(TcsService tcsService) + { + _tcsService = tcsService; + } + + public ValueTask DisposeAsync() + { + _tcsService.StartedMethod.SetResult(null); + return default; + } + + public ValueTask InvokeMethodAsync(HubInvocationContext invocationContext, Func> next) + { + return next(invocationContext); + } + } + + public class ChangeMethodFilter : IHubFilter + { + public ValueTask InvokeMethodAsync(HubInvocationContext invocationContext, Func> next) + { + var methodInfo = typeof(BaseHub).GetMethod(nameof(BaseHub.BaseMethod)); + var context = new HubInvocationContext(invocationContext.Context, invocationContext.ServiceProvider, invocationContext.Hub, methodInfo, invocationContext.HubMethodArguments); + return next(context); + } + } +}