@@ -27,8 +27,8 @@ func withMockServer<Result>(
27
27
_ body: ( _ port: Int ) async throws -> Result
28
28
) async throws -> Result {
29
29
let eventLoopGroup = NIOSingletons . posixEventLoopGroup
30
- let server = MockLambdaServer ( behavior: behaviour, port: port, keepAlive: keepAlive)
31
- let port = try await server. start ( ) . get ( )
30
+ let server = MockLambdaServer ( behavior: behaviour, port: port, keepAlive: keepAlive, eventLoopGroup : eventLoopGroup )
31
+ let port = try await server. start ( )
32
32
33
33
let result : Swift . Result < Result , any Error >
34
34
do {
@@ -37,13 +37,13 @@ func withMockServer<Result>(
37
37
result = . failure( error)
38
38
}
39
39
40
- try ? await server. stop ( ) . get ( )
40
+ try ? await server. stop ( )
41
41
return try result. get ( )
42
42
}
43
43
44
- final class MockLambdaServer {
44
+ final class MockLambdaServer < Behavior : LambdaServerBehavior > {
45
45
private let logger = Logger ( label: " MockLambdaServer " )
46
- private let behavior : LambdaServerBehavior
46
+ private let behavior : Behavior
47
47
private let host : String
48
48
private let port : Int
49
49
private let keepAlive : Bool
@@ -52,7 +52,13 @@ final class MockLambdaServer {
52
52
private var channel : Channel ?
53
53
private var shutdown = false
54
54
55
- init ( behavior: LambdaServerBehavior , host: String = " 127.0.0.1 " , port: Int = 7000 , keepAlive: Bool = true ) {
55
+ init (
56
+ behavior: Behavior ,
57
+ host: String = " 127.0.0.1 " ,
58
+ port: Int = 7000 ,
59
+ keepAlive: Bool = true ,
60
+ eventLoopGroup: MultiThreadedEventLoopGroup
61
+ ) {
56
62
self . group = NIOSingletons . posixEventLoopGroup
57
63
self . behavior = behavior
58
64
self . host = host
@@ -64,39 +70,41 @@ final class MockLambdaServer {
64
70
assert ( shutdown)
65
71
}
66
72
67
- func start( ) -> EventLoopFuture < Int > {
68
- let bootstrap = ServerBootstrap ( group: group)
73
+ fileprivate func start( ) async throws -> Int {
74
+ let logger = self . logger
75
+ let keepAlive = self . keepAlive
76
+ let behavior = self . behavior
77
+
78
+ let channel = try await ServerBootstrap ( group: group)
69
79
. serverChannelOption ( ChannelOptions . socket ( SocketOptionLevel ( SOL_SOCKET) , SO_REUSEADDR) , value: 1 )
70
80
. childChannelInitializer { channel in
71
81
do {
72
82
try channel. pipeline. syncOperations. configureHTTPServerPipeline ( withErrorHandling: true )
73
83
try channel. pipeline. syncOperations. addHandler (
74
- HTTPHandler ( logger: self . logger, keepAlive: self . keepAlive, behavior: self . behavior)
84
+ HTTPHandler ( logger: logger, keepAlive: keepAlive, behavior: behavior)
75
85
)
76
86
return channel. eventLoop. makeSucceededVoidFuture ( )
77
87
} catch {
78
88
return channel. eventLoop. makeFailedFuture ( error)
79
89
}
80
90
}
81
- return bootstrap. bind ( host: self . host, port: self . port) . flatMap { channel in
82
- self . channel = channel
83
- guard let localAddress = channel. localAddress else {
84
- return channel. eventLoop. makeFailedFuture ( ServerError . cantBind)
85
- }
86
- self . logger. info ( " \( self ) started and listening on \( localAddress) " )
87
- return channel. eventLoop. makeSucceededFuture ( localAddress. port!)
91
+ . bind ( host: self . host, port: self . port)
92
+ . get ( )
93
+
94
+ self . channel = channel
95
+ guard let localAddress = channel. localAddress else {
96
+ throw ServerError . cantBind
88
97
}
98
+ self . logger. info ( " \( self ) started and listening on \( localAddress) " )
99
+ return localAddress. port!
89
100
}
90
101
91
- func stop( ) -> EventLoopFuture < Void > {
102
+ fileprivate func stop( ) async throws {
92
103
self . logger. info ( " stopping \( self ) " )
93
- guard let channel = self . channel else {
94
- return self . group. next ( ) . makeFailedFuture ( ServerError . notReady)
95
- }
96
- return channel. close ( ) . always { _ in
97
- self . shutdown = true
98
- self . logger. info ( " \( self ) stopped " )
99
- }
104
+ let channel = self . channel!
105
+ try ? await channel. close ( ) . get ( )
106
+ self . shutdown = true
107
+ self . logger. info ( " \( self ) stopped " )
100
108
}
101
109
}
102
110
@@ -221,32 +229,37 @@ final class HTTPHandler: ChannelInboundHandler {
221
229
}
222
230
let head = HTTPResponseHead ( version: HTTPVersion ( major: 1 , minor: 1 ) , status: status, headers: headers)
223
231
232
+ let logger = self . logger
224
233
context. write ( wrapOutboundOut ( . head( head) ) ) . whenFailure { error in
225
- self . logger. error ( " \( self ) write error \( error) " )
234
+ logger. error ( " write error \( error) " )
226
235
}
227
236
228
237
if let b = body {
229
238
var buffer = context. channel. allocator. buffer ( capacity: b. utf8. count)
230
239
buffer. writeString ( b)
231
240
context. write ( wrapOutboundOut ( . body( . byteBuffer( buffer) ) ) ) . whenFailure { error in
232
- self . logger. error ( " \( self ) write error \( error) " )
241
+ logger. error ( " write error \( error) " )
233
242
}
234
243
}
235
244
245
+ let loopBoundContext = NIOLoopBound ( context, eventLoop: context. eventLoop)
246
+
247
+ let keepAlive = self . keepAlive
236
248
context. writeAndFlush ( wrapOutboundOut ( . end( nil ) ) ) . whenComplete { result in
237
249
if case . failure( let error) = result {
238
- self . logger. error ( " \( self ) write error \( error) " )
250
+ logger. error ( " write error \( error) " )
239
251
}
240
- if !self . keepAlive {
252
+ if !keepAlive {
253
+ let context = loopBoundContext. value
241
254
context. close ( ) . whenFailure { error in
242
- self . logger. error ( " \( self ) close error \( error) " )
255
+ logger. error ( " close error \( error) " )
243
256
}
244
257
}
245
258
}
246
259
}
247
260
}
248
261
249
- protocol LambdaServerBehavior {
262
+ protocol LambdaServerBehavior: Sendable {
250
263
func getInvocation( ) -> GetInvocationResult
251
264
func processResponse( requestId: String , response: String ? ) -> Result < Void , ProcessResponseError >
252
265
func processError( requestId: String , error: ErrorResponse ) -> Result < Void , ProcessErrorError >
0 commit comments