diff --git a/.github/workflows/swift.yml b/.github/workflows/swift.yml index 94f15749..be1a6b3e 100644 --- a/.github/workflows/swift.yml +++ b/.github/workflows/swift.yml @@ -9,10 +9,11 @@ on: jobs: test-on-macOS-and-iOS: runs-on: macos-11 - steps: - uses: actions/checkout@v2 - + - uses: maxim-lobanov/setup-xcode@v1 + with: + xcode-version: 12.5.1 - name: Test on iOS Simulator run: > xcodebuild test @@ -31,12 +32,29 @@ jobs: -skip-testing:RSocketCorePerformanceTests -parallel-testing-enabled -destination 'platform=macOS' + + test-on-macOS-with-Xcode-13-Beta: + runs-on: macos-11 + steps: + - uses: actions/checkout@v2 + - uses: maxim-lobanov/setup-xcode@v1 + with: + xcode-version: 13.0 + + - name: Test on macOS + run: swift test performance-tests-on-macOS: runs-on: macos-11 + strategy: + matrix: + xcode: ['12.5.1', '13.0'] steps: - uses: actions/checkout@v2 + - uses: maxim-lobanov/setup-xcode@v1 + with: + xcode-version: ${{ matrix.xcode }} - name: Build & Run Performance Tests on macOS run: > swift test diff --git a/Package.swift b/Package.swift index d8119b14..81129546 100644 --- a/Package.swift +++ b/Package.swift @@ -19,11 +19,13 @@ let package = Package( // Transport protocol .library(name: "RSocketWSTransport", targets: ["RSocketWSTransport"]), .library(name: "RSocketTCPTransport", targets: ["RSocketTCPTransport"]), + .library(name: "RSocketAsync", targets: ["RSocketAsync"]), // Examples .executable(name: "timer-client-example", targets: ["TimerClientExample"]), .executable(name: "twitter-client-example", targets: ["TwitterClientExample"]), .executable(name: "vanilla-client-example", targets: ["VanillaClientExample"]), + .executable(name: "async-twitter-client-example", targets: ["AsyncTwitterClientExample"]), ], dependencies: [ .package(url: "https://github.com/ReactiveCocoa/ReactiveSwift.git", from: "6.6.0"), @@ -46,6 +48,11 @@ let package = Package( "RSocketCore", .product(name: "ReactiveSwift", package: "ReactiveSwift") ]), + .target(name: "RSocketAsync", dependencies: [ + "RSocketCore", + .product(name: "NIO", package: "swift-nio"), + .product(name: "_NIOConcurrency", package: "swift-nio"), + ]), // Channel .target(name: "RSocketTSChannel", dependencies: [ @@ -135,6 +142,19 @@ let package = Package( ], path: "Sources/Examples/VanillaClient" ), + .executableTarget( + name: "AsyncTwitterClientExample", + dependencies: [ + "RSocketCore", + "RSocketNIOChannel", + "RSocketWSTransport", + "RSocketAsync", + .product(name: "ArgumentParser", package: "swift-argument-parser"), + .product(name: "NIO", package: "swift-nio"), + .product(name: "_NIOConcurrency", package: "swift-nio"), + ], + path: "Sources/Examples/AsyncTwitterClient" + ), ], swiftLanguageVersions: [.v5] ) diff --git a/Sources/Examples/AsyncTwitterClient/main.swift b/Sources/Examples/AsyncTwitterClient/main.swift new file mode 100644 index 00000000..62368180 --- /dev/null +++ b/Sources/Examples/AsyncTwitterClient/main.swift @@ -0,0 +1,82 @@ +#if compiler(>=5.5) +import ArgumentParser +import Foundation +import NIO +import RSocketAsync +import RSocketCore +import RSocketNIOChannel +import RSocketWSTransport + +struct Tweet: Decodable { + struct User: Decodable { + let screen_name, name: String + let followers_count: Int + } + let user: User + let text: String + let reply_count, retweet_count, favorite_count: Int +} + +extension URL: ExpressibleByArgument { + public init?(argument: String) { + guard let url = URL(string: argument) else { return nil } + self = url + } + public var defaultValueDescription: String { description } +} + +/// the server-side code can be found here -> https://github.com/rsocket/rsocket-demo/tree/master/src/main/kotlin/io/rsocket/demo/twitter +@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +struct TwitterClientExample: ParsableCommand { + static var configuration = CommandConfiguration( + abstract: "connects to an RSocket endpoint using WebSocket transport, requests a stream at the route `searchTweets` to search for tweets that match the `searchString` and logs all events." + ) + + @Argument(help: "used to find tweets that match the given string") + var searchString = "swift" + + @Option + var url = URL(string: "wss://demo.rsocket.io/rsocket")! + + @Option(help: "maximum number of tweets that are taken before it cancels the stream") + var limit = 1000 + + func run() throws { + let eventLoop = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { try! eventLoop.syncShutdownGracefully() } + let promise = eventLoop.next().makePromise(of: Void.self) + promise.completeWithAsync { + try await self.runAsync() + } + try promise.futureResult.wait() + } + func runAsync() async throws { + let bootstrap = ClientBootstrap( + transport: WSTransport(), + config: .mobileToServer + .set(\.encoding.metadata, to: .messageXRSocketRoutingV0) + .set(\.encoding.data, to: .applicationJson), + timeout: .seconds(30) + ) + let client = try await bootstrap.connect(to: .init(url: url), payload: .empty) + + let stream = try client.requester( + RequestStream { + Encoder() + .encodeStaticMetadata("searchTweets", using: .routing) + .mapData { (string: String) in Data(string.utf8) } + Decoder() + .decodeData(using: JSONDataDecoder(type: Tweet.self)) + }, + request: searchString + ) + + for try await tweet in stream.prefix(limit) { + dump(tweet) + } + } +} +if #available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) { + TwitterClientExample.main() +} +#endif diff --git a/Sources/Examples/TimerClient/main.swift b/Sources/Examples/TimerClient/main.swift index 038364ca..e8a3eeac 100644 --- a/Sources/Examples/TimerClient/main.swift +++ b/Sources/Examples/TimerClient/main.swift @@ -6,14 +6,6 @@ import RSocketNIOChannel import RSocketReactiveSwift import RSocketWSTransport -func route(_ route: String) -> Data { - let encodedRoute = Data(route.utf8) - precondition(encodedRoute.count <= Int(UInt8.max), "route is to long to be encoded") - let encodedRouteLength = Data([UInt8(encodedRoute.count)]) - - return encodedRouteLength + encodedRoute -} - extension URL: ExpressibleByArgument { public init?(argument: String) { guard let url = URL(string: argument) else { return nil } @@ -43,12 +35,12 @@ struct TimerClientExample: ParsableCommand { ) let client = try bootstrap.connect(to: .init(url: url)).first()!.get() - - try client.requester.requestStream(payload: Payload( - metadata: route("timer"), - data: Data() - )) - .map() { String.init(decoding: $0.data, as: UTF8.self) } + try client.requester(RequestStream { + Encoder() + .encodeStaticMetadata("timer", using: RoutingEncoder()) + Decoder() + .mapData { String(decoding: $0, as: UTF8.self) } + }, request: Data()) .logEvents(identifier: "route.timer") .take(first: limit) .wait() diff --git a/Sources/Examples/TwitterClient/main.swift b/Sources/Examples/TwitterClient/main.swift index 58fd340f..8fa0bc3e 100644 --- a/Sources/Examples/TwitterClient/main.swift +++ b/Sources/Examples/TwitterClient/main.swift @@ -6,14 +6,6 @@ import RSocketNIOChannel import RSocketReactiveSwift import RSocketWSTransport -func route(_ route: String) -> Data { - let encodedRoute = Data(route.utf8) - precondition(encodedRoute.count <= Int(UInt8.max), "route is to long to be encoded") - let encodedRouteLength = Data([UInt8(encodedRoute.count)]) - - return encodedRouteLength + encodedRoute -} - extension URL: ExpressibleByArgument { public init?(argument: String) { guard let url = URL(string: argument) else { return nil } @@ -40,23 +32,26 @@ struct TwitterClientExample: ParsableCommand { func run() throws { let bootstrap = ClientBootstrap( transport: WSTransport(), - config: ClientConfiguration.mobileToServer + config: .mobileToServer .set(\.encoding.metadata, to: .messageXRSocketRoutingV0) .set(\.encoding.data, to: .applicationJson) ) let client = try bootstrap.connect(to: .init(url: url)).first()!.get() - - try client.requester.requestStream(payload: Payload( - metadata: route("searchTweets"), - data: Data(searchString.utf8) - )) - .attemptMap { payload -> String in - // pretty print json - let json = try JSONSerialization.jsonObject(with: payload.data, options: []) - let data = try JSONSerialization.data(withJSONObject: json, options: [.prettyPrinted]) - return String(decoding: data, as: UTF8.self) - } + try client.requester(RequestStream { + Encoder() + .encodeStaticMetadata("searchTweets", using: RoutingEncoder()) + .mapData { (string: String) in + Data(string.utf8) + } + Decoder() + .mapData { data -> String in + // pretty print json + let json = try JSONSerialization.jsonObject(with: data, options: []) + let data = try JSONSerialization.data(withJSONObject: json, options: [.prettyPrinted]) + return String(decoding: data, as: UTF8.self) + } + }, request: searchString) .logEvents(identifier: "route.searchTweets") .take(first: limit) .wait() diff --git a/Sources/Examples/VanillaClient/main.swift b/Sources/Examples/VanillaClient/main.swift index 5563c42e..352456b1 100644 --- a/Sources/Examples/VanillaClient/main.swift +++ b/Sources/Examples/VanillaClient/main.swift @@ -22,8 +22,11 @@ struct VanillaClientExample: ParsableCommand { let client = try bootstrap.connect(to: .init(host: host, port: port)).first()!.get() - let streamProducer = client.requester.requestStream(payload: .empty) - let requestProducer = client.requester.requestResponse(payload: Payload(data: Data("HelloWorld".utf8))) + let streamProducer = client.requester( + RequestStream(), + request: Data() + ) + let requestProducer = client.requester(RequestResponse(), request: Data("HelloWorld".utf8)) streamProducer.logEvents(identifier: "stream1").take(first: 1).start() streamProducer.logEvents(identifier: "stream3").take(first: 10).start() diff --git a/Sources/RSocketAsync/Client.swift b/Sources/RSocketAsync/Client.swift new file mode 100644 index 00000000..65ed51e7 --- /dev/null +++ b/Sources/RSocketAsync/Client.swift @@ -0,0 +1,38 @@ +/* + * Copyright 2015-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if compiler(>=5.5) +import RSocketCore + +@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +public struct AsyncClient { + private let coreClient: RSocketCore.CoreClient + + public var requester: RequesterRSocket { RequesterRSocket(requester: coreClient.requester) } + + public init(_ coreClient: RSocketCore.CoreClient) { + self.coreClient = coreClient + } +} + +@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension RSocketCore.ClientBootstrap where Client == CoreClient, Responder == RSocketCore.RSocket { + public func connect(to endpoint: Transport.Endpoint, payload: Payload) async throws -> AsyncClient { + AsyncClient(try await connect(to: endpoint, payload: payload, responder: nil).get()) + } +} + +#endif diff --git a/Sources/RSocketAsync/RSocket.swift b/Sources/RSocketAsync/RSocket.swift new file mode 100644 index 00000000..01894796 --- /dev/null +++ b/Sources/RSocketAsync/RSocket.swift @@ -0,0 +1,34 @@ +/* + * Copyright 2015-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if compiler(>=5.5) +import Foundation +import RSocketCore + +@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +public protocol RSocket { + func metadataPush(metadata: Data) + func fireAndForget(payload: Payload) + func requestResponse(payload: Payload) async throws -> Payload + func requestStream(payload: Payload) -> AsyncThrowingStream + func requestChannel( + initialPayload: Payload, + payloadStream: PayloadSequence? + ) -> AsyncThrowingStream + where PayloadSequence: AsyncSequence, PayloadSequence.Element == Payload +} + +#endif diff --git a/Sources/RSocketAsync/Requester.swift b/Sources/RSocketAsync/Requester.swift new file mode 100644 index 00000000..04b58e52 --- /dev/null +++ b/Sources/RSocketAsync/Requester.swift @@ -0,0 +1,219 @@ +/* + * Copyright 2015-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if compiler(>=5.5) +import Foundation +import NIO +import RSocketCore +import _NIOConcurrency + +@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +extension RequesterRSocket { + public func callAsFunction(_ metadataPush: MetadataPush, metadata: Metadata) throws { + self.metadataPush(metadata: try metadataPush.encoder.encode(metadata)) + } + + public func callAsFunction(_ fireAndForget: FireAndForget, request: Request) throws { + var encoder = fireAndForget.encoder + self.fireAndForget(payload: try encoder.encode(request, encoding: encoding)) + } + + public func callAsFunction( + _ requestResponse: RequestResponse, + request: Request + ) async throws -> Response { + var encoder = requestResponse.encoder + let response = try await self.requestResponse(payload: encoder.encode(request, encoding: encoding)) + var decoder = requestResponse.decoder + return try decoder.decode(response, encoding: encoding) + } + + public func callAsFunction( + _ requestStream: RequestStream, + request: Request + ) throws -> AsyncThrowingMapSequence, Response> { + /// TODO: this method should not throw but rather the async sequence should throw an error + /// TODO: result type of this method should be an opaque result type with where clause (e.g. `some AsyncSequence where _.Element == Response`) once they are available in Swift + var encoder = requestStream.encoder + var decoder = requestStream.decoder + let a = self.requestStream(payload: try encoder.encode(request, encoding: encoding)).map { response throws -> Response in + try decoder.decode(response, encoding: encoding) + } + return a + } + + public func callAsFunction( + _ requestChannel: RequestChannel, + initialRequest: Request, + producer: Producer? + ) throws -> AsyncThrowingMapSequence, Response> + where Producer: AsyncSequence, Producer.Element == Request { + /// TODO: this method should not throw but rather the async sequence should throw an error + /// TODO: result type of this method should be an opaque result type with where clause (e.g. `some AsyncSequence where _.Element == Response`) once they are available in Swift + var encoder = requestChannel.encoder + var decoder = requestChannel.decoder + + return self.requestChannel( + initialPayload: try encoder.encode(initialRequest, encoding: encoding), + payloadStream: producer?.map { try encoder.encode($0, encoding: encoding) } + ).map { + try decoder.decode($0, encoding: encoding) + } + } +} + +@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +public struct RequesterRSocket { + private let requester: RSocketCore.RSocket + + internal var encoding: ConnectionEncoding { requester.encoding } + public init(requester: RSocketCore.RSocket) { + self.requester = requester + } + internal func metadataPush(metadata: Data) { + requester.metadataPush(metadata: metadata) + } + internal func fireAndForget(payload: Payload) { + requester.fireAndForget(payload: payload) + } + internal func requestResponse(payload: Payload) async throws -> Payload { + struct RequestResponseOperator: UnidirectionalStream { + var continuation: CheckedContinuation + func onNext(_ payload: Payload, isCompletion: Bool) { + assert(isCompletion) + continuation.resume(returning: payload) + } + + func onComplete() { + assertionFailure("request response does not support \(#function)") + } + + func onRequestN(_ requestN: Int32) { + assertionFailure("request response does not support \(#function)") + } + + func onCancel() { + continuation.resume(throwing: Error.canceled(message: "onCancel")) + } + + func onError(_ error: Error) { + continuation.resume(throwing: error) + } + + func onExtension(extendedType: Int32, payload: Payload, canBeIgnored: Bool) { + assertionFailure("request response does not support \(#function)") + } + } + var cancelable: Cancellable? + defer { cancelable?.onCancel() } + return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + let stream = RequestResponseOperator(continuation: continuation) + cancelable = requester.requestResponse(payload: payload, responderStream: stream) + } + } + + internal func requestStream(payload: Payload) -> AsyncThrowingStream { + AsyncThrowingStream(Payload.self, bufferingPolicy: .unbounded) { continuation in + let adapter = AsyncStreamAdapter(continuation: continuation) + let subscription = requester.stream(payload: payload, initialRequestN: .max, responderStream: adapter) + continuation.onTermination = { @Sendable (reason: AsyncThrowingStream.Continuation.Termination) -> Void in + switch reason { + case .cancelled: + subscription.onCancel() + case .finished: break + // TODO: `Termination` should probably be @frozen so we do not have to deal with the @unknown default case + @unknown default: break + } + } + } + } + + internal func requestChannel( + initialPayload: Payload, + payloadStream: PayloadSequence? + ) -> AsyncThrowingStream where PayloadSequence: AsyncSequence, PayloadSequence.Element == Payload { + AsyncThrowingStream(Payload.self, bufferingPolicy: .unbounded) { continuation in + let adapter = AsyncStreamAdapter(continuation: continuation) + let channel = requester.channel( + payload: initialPayload, + initialRequestN: .max, + isCompleted: payloadStream == nil, + responderStream: adapter + ) + + let task = Task.detached { + guard let payloadStream = payloadStream else { return } + do { + for try await payload in payloadStream { + channel.onNext(payload, isCompletion: false) + } + channel.onComplete() + } catch is CancellationError { + channel.onCancel() + } catch { + channel.onError(Error.applicationError(message: error.localizedDescription)) + } + } + + continuation.onTermination = { @Sendable (reason: AsyncThrowingStream.Continuation.Termination) -> Void in + switch reason { + case .cancelled: + channel.onCancel() + task.cancel() + case .finished: break + // TODO: `Termination` should probably be @frozen so we do not have to deal with the @unknown default case + @unknown default: break + } + } + } + } +} + +@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +internal final class AsyncStreamAdapter: UnidirectionalStream { + private var continuation: AsyncThrowingStream.Continuation + init(continuation: AsyncThrowingStream.Continuation) { + self.continuation = continuation + } + internal func onNext(_ payload: Payload, isCompletion: Bool) { + continuation.yield(payload) + if isCompletion { + continuation.finish() + } + } + + internal func onComplete() { + continuation.finish() + } + + internal func onRequestN(_ requestN: Int32) { + assertionFailure("request stream does not support \(#function)") + } + + internal func onCancel() { + continuation.finish() + } + + internal func onError(_ error: Error) { + continuation.yield(with: .failure(error)) + } + + internal func onExtension(extendedType: Int32, payload: Payload, canBeIgnored: Bool) { + assertionFailure("request stream does not support \(#function)") + } +} + +#endif diff --git a/Sources/RSocketAsync/Responder.swift b/Sources/RSocketAsync/Responder.swift new file mode 100644 index 00000000..fa724ee2 --- /dev/null +++ b/Sources/RSocketAsync/Responder.swift @@ -0,0 +1,220 @@ +/* + * Copyright 2015-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if compiler(>=5.5) +import RSocketCore +import Foundation + +@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +struct ResponderAdapter: RSocketCore.RSocket { + var responder: RSocket + let encoding: ConnectionEncoding + + func metadataPush(metadata: Data) { + responder.metadataPush(metadata: metadata) + } + + func fireAndForget(payload: Payload) { + responder.fireAndForget(payload: payload) + } + + func requestResponse( + payload: Payload, + responderStream: UnidirectionalStream + ) -> Cancellable { + let task = Task.init(priority: nil) { + do { + let response = try await responder.requestResponse(payload: payload) + responderStream.onNext(response, isCompletion: true) + } catch { + responderStream.onError(Error.applicationError(message: error.localizedDescription)) + } + } + return RequestResponseResponder(task: task) + } + + func stream( + payload: Payload, + initialRequestN: Int32, + responderStream: UnidirectionalStream + ) -> Subscription { + let task = Task.init(priority: nil) { + do { + let stream = responder.requestStream(payload: payload) + for try await responderPayload in stream { + responderStream.onNext(responderPayload, isCompletion: false) + } + responderStream.onComplete() + } catch is CancellationError { + responderStream.onCancel() + } catch { + responderStream.onError(Error.applicationError(message: error.localizedDescription)) + } + } + + return RequestStreamResponder(task: task) + } + + func channel(payload: Payload, initialRequestN: Int32, isCompleted: Bool, responderStream: UnidirectionalStream) -> UnidirectionalStream { + let requesterStream = RequestChannelAsyncSequence() + + let task = Task.init(priority: nil) { + do { + let responderPayloads = responder.requestChannel(initialPayload: payload, payloadStream: requesterStream) + for try await responderPayload in responderPayloads { + responderStream.onNext(responderPayload, isCompletion: false) + } + responderStream.onComplete() + } catch is CancellationError { + responderStream.onCancel() + } catch { + responderStream.onError(Error.applicationError(message: error.localizedDescription)) + } + } + requesterStream.task = task + return requesterStream + } +} + +@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +fileprivate class RequestResponseResponder: Cancellable { + private let task: Task + + internal init(task: Task) { + self.task = task + } + + deinit { + task.cancel() + } + + func onCancel() { + task.cancel() + } + + func onError(_ error: Error) { + // TODO: Can a request actually send an error? If yes, we should probably do something with the error + task.cancel() + } + + func onExtension(extendedType: Int32, payload: Payload, canBeIgnored: Bool) { + assertionFailure("\(Self.self) does not support \(#function)") + } +} + +@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +fileprivate class RequestStreamResponder: Subscription { + private let task: Task + + internal init(task: Task) { + self.task = task + } + + deinit { + task.cancel() + } + + func onCancel() { + task.cancel() + } + + func onError(_ error: Error) { + // TODO: Can a stream actually send an error? If yes, we should probably do something with the error + task.cancel() + } + + func onExtension(extendedType: Int32, payload: Payload, canBeIgnored: Bool) { + assertionFailure("\(Self.self) does not support \(#function)") + } + + func onRequestN(_ requestN: Int32) { + assertionFailure("\(Self.self) does not support \(#function)") + } +} + +@available(macOS 9999, iOS 9999, watchOS 9999, tvOS 9999, *) +fileprivate class RequestChannelAsyncSequence: AsyncSequence, UnidirectionalStream { + typealias AsyncIterator = AsyncThrowingStream.AsyncIterator + typealias Element = Payload + + private var iterator: AsyncThrowingStream.AsyncIterator! + private var continuation: AsyncThrowingStream.Continuation! + + internal var task: Task? + + internal init() { + let sequence = AsyncThrowingStream(Payload.self, bufferingPolicy: .unbounded) { continuation in + self.continuation = continuation + continuation.onTermination = { @Sendable [weak self] (reason) in + // TODO: `task` is not safe to access here because we set it late, after we give `self` to user code + // but just adding a lock is not enough because we could be terminated before `task` is even set and we then need to cancel `task` after it is set + // I hope we find a better solution. Maybe we can access the current task from within the `Task.init` closure which would solve both problems mentioned above + // UPDATE: Looks like it will be possible. The documentation of `withUnsafeCurrentTask(body:)` says that `UnsafeCurrentTask` has get a `task` property. But it does currently (Xcode 12 Beta 3) not have it. + switch reason { + case let .finished(.some(error)): + if error is CancellationError { + return + } + // only in the error case we cancel task + self?.task?.cancel() + case .finished(nil): break + case .cancelled: break + @unknown default: break + } + } + } + iterator = sequence.makeAsyncIterator() + } + + deinit { + task?.cancel() + continuation.finish(throwing: CancellationError()) + } + + func makeAsyncIterator() -> AsyncIterator { + iterator + } + + func onNext(_ payload: Payload, isCompletion: Bool) { + continuation.yield(payload) + if isCompletion { + continuation.finish() + } + } + + func onComplete() { + continuation.finish() + } + + func onCancel() { + continuation.finish(throwing: CancellationError()) + } + + func onError(_ error: Error) { + continuation.finish(throwing: error) + task?.cancel() + } + + func onExtension(extendedType: Int32, payload: Payload, canBeIgnored: Bool) { + assertionFailure("\(Self.self) does not support \(#function)") + } + + func onRequestN(_ requestN: Int32) { + assertionFailure("\(Self.self) does not support \(#function)") + } +} + +#endif diff --git a/Sources/RSocketCore/Channel Handler/ConnectionEstablishment.swift b/Sources/RSocketCore/Channel Handler/ConnectionEstablishment.swift index c5be7722..4ffd6455 100644 --- a/Sources/RSocketCore/Channel Handler/ConnectionEstablishment.swift +++ b/Sources/RSocketCore/Channel Handler/ConnectionEstablishment.swift @@ -44,27 +44,7 @@ public struct SetupInfo { /// Token used for client resume identification public let resumeIdentificationToken: Data? - /** - MIME Type for encoding of Metadata - - This SHOULD be a US-ASCII string that includes the Internet media type specified in RFC 2045. - Many are registered with IANA such as CBOR. - Suffix rules MAY be used for handling layout. - For example, `application/x.netflix+cbor` or `application/x.reactivesocket+cbor` or `application/x.netflix+json`. - The string MUST NOT be null terminated. - */ - public let metadataEncodingMimeType: String - - /** - MIME Type for encoding of Data - - This SHOULD be a US-ASCII string that includes the Internet media type specified in RFC 2045. - Many are registered with IANA such as CBOR. - Suffix rules MAY be used for handling layout. - For example, `application/x.netflix+cbor` or `application/x.reactivesocket+cbor` or `application/x.netflix+json`. - The string MUST NOT be null terminated. - */ - public let dataEncodingMimeType: String + public let encoding: ConnectionEncoding /// Payload of this frame describing connection capabilities of the endpoint sending the Setup header public let payload: Payload @@ -114,8 +94,10 @@ extension SetupInfo { self.timeBetweenKeepaliveFrames = setup.timeBetweenKeepaliveFrames self.maxLifetime = setup.maxLifetime self.resumeIdentificationToken = setup.resumeIdentificationToken - self.metadataEncodingMimeType = setup.metadataEncodingMimeType - self.dataEncodingMimeType = setup.dataEncodingMimeType + self.encoding = .init( + metadata: .init(rawValue: setup.metadataEncodingMimeType), + data: .init(rawValue: setup.dataEncodingMimeType) + ) self.payload = setup.payload } } diff --git a/Sources/RSocketCore/ChannelPipeline.swift b/Sources/RSocketCore/ChannelPipeline.swift index cda37057..21e20845 100644 --- a/Sources/RSocketCore/ChannelPipeline.swift +++ b/Sources/RSocketCore/ChannelPipeline.swift @@ -46,7 +46,13 @@ extension ChannelPipeline { self?.writeAndFlush(NIOAny(frame), promise: nil) } let promise = eventLoop.makePromise(of: Void.self) - let requester = Requester(streamIdGenerator: .client, eventLoop: eventLoop, sendFrame: sendFrame) + let requester = Requester( + streamIdGenerator: .client, + encoding: config.encoding, + eventLoop: eventLoop, + sendFrame: sendFrame + ) + let responder = responder ?? DefaultRSocket(encoding: config.encoding) promise.futureResult.map { requester as RSocket }.cascade(to: connectedPromise) let (timeBetweenKeepaliveFrames, maxLifetime): (Int32, Int32) do { @@ -108,14 +114,14 @@ extension ChannelPipeline { FrameEncoderHandler(maximumFrameSize: maximumFrameSize), ConnectionStateHandler(), ConnectionEstablishmentHandler(initializeConnection: { [unowned self] (info, channel) in - let responder = makeResponder?(info) + let responder = makeResponder?(info) ?? DefaultRSocket(encoding: info.encoding) let sendFrame: (Frame) -> () = { [weak self] frame in self?.writeAndFlush(NIOAny(frame), promise: nil) } return channel.pipeline.addHandlers([ DemultiplexerHandler( connectionSide: .server, - requester: Requester(streamIdGenerator: .server, eventLoop: eventLoop, sendFrame: sendFrame), + requester: Requester(streamIdGenerator: .server, encoding: info.encoding, eventLoop: eventLoop, sendFrame: sendFrame), responder: Responder(responderSocket: responder, eventLoop: eventLoop, sendFrame: sendFrame) ), KeepaliveHandler(timeBetweenKeepaliveFrames: info.timeBetweenKeepaliveFrames, maxLifetime: info.maxLifetime, connectionSide: ConnectionRole.server), diff --git a/Sources/RSocketCore/Client/ClientBootstrap.swift b/Sources/RSocketCore/Client/ClientBootstrap.swift index 8e9af665..1704333c 100644 --- a/Sources/RSocketCore/Client/ClientBootstrap.swift +++ b/Sources/RSocketCore/Client/ClientBootstrap.swift @@ -21,6 +21,8 @@ public protocol ClientBootstrap { associatedtype Responder associatedtype Transport: TransportChannelHandler + var config: ClientConfiguration { get } + /// Creates a new connection to the given `endpoint`. /// - Parameters: /// - endpoint: endpoint to connect to diff --git a/Sources/RSocketCore/Client/ClientConfiguration.swift b/Sources/RSocketCore/Client/ClientConfiguration.swift index eca3e704..85f46d21 100644 --- a/Sources/RSocketCore/Client/ClientConfiguration.swift +++ b/Sources/RSocketCore/Client/ClientConfiguration.swift @@ -57,27 +57,6 @@ public struct ClientConfiguration { } } - /// encoding configuration of metadata and data which is send to the server during setup - public struct Encoding { - - /// default encoding uses `.octetStream` for metadata and data - public static let `default` = Encoding() - - /// MIME Type for encoding of Metadata - public var metadata: MIMEType - - /// MIME Type for encoding of Data - public var data: MIMEType - - public init( - metadata: MIMEType = .default, - data: MIMEType = .default - ) { - self.metadata = metadata - self.data = data - } - } - /// local fragmentation configuration which are **not** send to the server public struct Fragmentation { @@ -109,14 +88,14 @@ public struct ClientConfiguration { public var timeout: Timeout /// encoding configuration of metadata and data which is send to the server during setup - public var encoding: Encoding + public var encoding: ConnectionEncoding /// local fragmentation configuration which are **not** send to the server public var fragmentation: Fragmentation public init( timeout: Timeout, - encoding: Encoding = .default, + encoding: ConnectionEncoding = .default, fragmentation: Fragmentation = .default ) { self.timeout = timeout diff --git a/Sources/RSocketCore/DefaultRSocket.swift b/Sources/RSocketCore/DefaultRSocket.swift index 94200882..87a6235d 100644 --- a/Sources/RSocketCore/DefaultRSocket.swift +++ b/Sources/RSocketCore/DefaultRSocket.swift @@ -29,6 +29,7 @@ fileprivate final class NoOpStream: UnidirectionalStream { /// An RSocket which rejects all incoming requests (requestResponse, stream and channel) and ignores metadataPush and fireAndForget events. internal struct DefaultRSocket: RSocket { + let encoding: ConnectionEncoding func metadataPush(metadata: Data) {} func fireAndForget(payload: Payload) {} func requestResponse(payload: Payload, responderStream: UnidirectionalStream) -> Cancellable { diff --git a/Sources/RSocketCore/Extensions/Coder/Decoder/Decoder.swift b/Sources/RSocketCore/Extensions/Coder/Decoder/Decoder.swift index fb577d97..4e4a1d11 100644 --- a/Sources/RSocketCore/Extensions/Coder/Decoder/Decoder.swift +++ b/Sources/RSocketCore/Extensions/Coder/Decoder/Decoder.swift @@ -38,5 +38,15 @@ public struct Decoder: DecoderProtocol { } } +extension DecoderProtocol where Metadata == Void { + @inlinable + public mutating func decode( + _ payload: Payload, + encoding: ConnectionEncoding + ) throws -> Data { + try decode(payload, encoding: encoding).1 + } +} + /// Namespace for types conforming to the ``DecoderProtocol`` protocol public enum Decoders {} diff --git a/Sources/RSocketCore/Extensions/Coder/Decoder/MultiDataDecoder.swift b/Sources/RSocketCore/Extensions/Coder/Decoder/MultiDataDecoder.swift index b3755278..321f3fe1 100644 --- a/Sources/RSocketCore/Extensions/Coder/Decoder/MultiDataDecoder.swift +++ b/Sources/RSocketCore/Extensions/Coder/Decoder/MultiDataDecoder.swift @@ -38,11 +38,7 @@ extension MultiDataDecoderProtocol { @inlinable internal func decodeMIMEType(_ mimeType: MIMEType, from data: Foundation.Data) throws -> Data { var buffer = ByteBuffer(data: data) - let data = try self.decodeMIMEType(mimeType, from: &buffer) - guard buffer.readableBytes == 0 else { - throw Error.invalid(message: "\(Decoder.self) did not read all bytes") - } - return data + return try self.decodeMIMEType(mimeType, from: &buffer) } } diff --git a/Sources/RSocketCore/Extensions/Coder/Encoder/DataEncoder.swift b/Sources/RSocketCore/Extensions/Coder/Encoder/DataEncoder.swift index 78dc0c58..6ee2da2b 100644 --- a/Sources/RSocketCore/Extensions/Coder/Encoder/DataEncoder.swift +++ b/Sources/RSocketCore/Extensions/Coder/Encoder/DataEncoder.swift @@ -39,11 +39,7 @@ extension DataDecoderProtocol { @inlinable internal func decode(from data: Foundation.Data) throws -> Data { var buffer = ByteBuffer(data: data) - let data = try self.decode(from: &buffer) - guard buffer.readableBytes == 0 else { - throw Error.invalid(message: "\(Decoder.self) did not read all bytes") - } - return data + return try self.decode(from: &buffer) } } diff --git a/Sources/RSocketCore/Extensions/Coder/Encoder/Encoder.swift b/Sources/RSocketCore/Extensions/Coder/Encoder/Encoder.swift index ada5ddb1..22e50819 100644 --- a/Sources/RSocketCore/Extensions/Coder/Encoder/Encoder.swift +++ b/Sources/RSocketCore/Extensions/Coder/Encoder/Encoder.swift @@ -36,5 +36,15 @@ public struct Encoder: EncoderProtocol { } } +extension EncoderProtocol where Metadata == Void { + @inlinable + public mutating func encode( + _ data: Data, + encoding: ConnectionEncoding + ) throws -> Payload { + try encode(metadata: (), data: data, encoding: encoding) + } +} + /// Namespace for types conforming to the ``EncoderProtocol`` protocol public enum Encoders {} diff --git a/Sources/RSocketCore/Extensions/Coder/Encoder/MetadataEncoder.swift b/Sources/RSocketCore/Extensions/Coder/Encoder/MetadataEncoder.swift index 98347198..9b6e2844 100644 --- a/Sources/RSocketCore/Extensions/Coder/Encoder/MetadataEncoder.swift +++ b/Sources/RSocketCore/Extensions/Coder/Encoder/MetadataEncoder.swift @@ -29,7 +29,7 @@ extension MetadataDecoder { var buffer = ByteBuffer(data: data) let metadata = try self.decode(from: &buffer) guard buffer.readableBytes == 0 else { - throw Error.invalid(message: "\(Decoder.self) did not read all bytes") + throw Error.invalid(message: "\(Self.self) did not read all bytes") } return metadata } diff --git a/Sources/RSocketCore/Extensions/Coder/Encoder/OctetStreamMetadataEncoder.swift b/Sources/RSocketCore/Extensions/Coder/Encoder/OctetStreamMetadataEncoder.swift new file mode 100644 index 00000000..68a18937 --- /dev/null +++ b/Sources/RSocketCore/Extensions/Coder/Encoder/OctetStreamMetadataEncoder.swift @@ -0,0 +1,39 @@ +/* + * Copyright 2015-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import Foundation +import NIO + +public struct OctetStreamMetadataEncoder: MetadataEncoder { + public typealias Metadata = Data? + + @inlinable + public init() {} + + @inlinable + public var mimeType: MIMEType { .applicationOctetStream } + + @inlinable + public func encode(_ metadata: Data?, into buffer: inout ByteBuffer) throws { + guard let metadata = metadata else { return } + buffer.writeData(metadata) + } +} + +extension MetadataEncoder where Self == OctetStreamMetadataEncoder { + @inlinable + public static var octetStream: Self { .init() } +} diff --git a/Sources/RSocketCore/Extensions/RootCompositeMetadataDecoder.swift b/Sources/RSocketCore/Extensions/RootCompositeMetadataDecoder.swift index 5fb5e1e7..944ca267 100644 --- a/Sources/RSocketCore/Extensions/RootCompositeMetadataDecoder.swift +++ b/Sources/RSocketCore/Extensions/RootCompositeMetadataDecoder.swift @@ -25,10 +25,10 @@ public struct RootCompositeMetadataDecoder: MetadataDecoder { public var mimeType: MIMEType { .messageXRSocketCompositeMetadataV0 } @usableFromInline - internal let mimeTypeDecoder: MIMETypeEncoder + internal let mimeTypeDecoder: MIMETypeDecoder @inlinable - public init(mimeTypeDecoder: MIMETypeEncoder = MIMETypeEncoder()) { + public init(mimeTypeDecoder: MIMETypeDecoder = MIMETypeDecoder()) { self.mimeTypeDecoder = mimeTypeDecoder } diff --git a/Sources/RSocketCore/Extensions/RoutingDecoder.swift b/Sources/RSocketCore/Extensions/RoutingDecoder.swift index a6421e9f..0a54336b 100644 --- a/Sources/RSocketCore/Extensions/RoutingDecoder.swift +++ b/Sources/RSocketCore/Extensions/RoutingDecoder.swift @@ -18,7 +18,10 @@ import NIO public struct RoutingDecoder: MetadataDecoder { public typealias Metadata = RouteMetadata - + + @inlinable + public init() {} + @inlinable public var mimeType: MIMEType { .messageXRSocketRoutingV0 } diff --git a/Sources/RSocketCore/Extensions/RoutingEncoder.swift b/Sources/RSocketCore/Extensions/RoutingEncoder.swift index 985ac6dd..7dad0ea8 100644 --- a/Sources/RSocketCore/Extensions/RoutingEncoder.swift +++ b/Sources/RSocketCore/Extensions/RoutingEncoder.swift @@ -18,6 +18,9 @@ import NIO public struct RoutingEncoder: MetadataEncoder { public typealias Metadata = RouteMetadata + + @inlinable + public init() {} @inlinable public var mimeType: MIMEType { .messageXRSocketRoutingV0 } diff --git a/Sources/RSocketCore/RSocket.swift b/Sources/RSocketCore/RSocket.swift index bb8d297a..dd004eeb 100644 --- a/Sources/RSocketCore/RSocket.swift +++ b/Sources/RSocketCore/RSocket.swift @@ -17,6 +17,7 @@ import Foundation public protocol RSocket { + var encoding: ConnectionEncoding { get } func metadataPush(metadata: Data) func fireAndForget(payload: Payload) diff --git a/Sources/RSocketCore/RequestDescription.swift b/Sources/RSocketCore/RequestDescription.swift index a4b1df5c..6de34f93 100644 --- a/Sources/RSocketCore/RequestDescription.swift +++ b/Sources/RSocketCore/RequestDescription.swift @@ -14,11 +14,17 @@ * limitations under the License. */ +import Foundation + public struct MetadataPush { public let encoder: AnyMetadataEncoder } extension MetadataPush { + @inlinable + public init() where Metadata == Data? { + self.init(using: OctetStreamMetadataEncoder()) + } @inlinable public init( using metadataEncoder: Encoder @@ -30,7 +36,7 @@ extension MetadataPush { public init( @CompositeMetadataEncoderBuilder _ makeEncoder: () -> Encoder ) where - Encoder: MetadataEncoder, + Encoder: MetadataEncoder, Encoder.Metadata == Metadata { encoder = makeEncoder().eraseToAnyMetadataEncoder() @@ -42,6 +48,12 @@ public struct FireAndForget { } extension FireAndForget { + @inlinable + public init() where Request == Data { + self.init { Encoder() } + } + + @inlinable public init( @EncoderBuilder _ makeEncoder: () -> Encoder ) where @@ -59,6 +71,11 @@ public struct RequestResponse { } extension RequestResponse { + @inlinable + public init() where Request == Data, Response == Data { + self.init { Coder() } + } + @inlinable public init( @CoderBuilder _ makeCoder: () -> Coder @@ -78,6 +95,11 @@ public struct RequestStream { } extension RequestStream { + @inlinable + public init() where Request == Data, Response == Data { + self.init { Coder() } + } + @inlinable public init( @CoderBuilder _ makeCoder: () -> Coder @@ -97,6 +119,11 @@ public struct RequestChannel { } extension RequestChannel { + @inlinable + public init() where Request == Data, Response == Data { + self.init { Coder() } + } + @inlinable public init( @CoderBuilder _ makeCoder: () -> Coder diff --git a/Sources/RSocketCore/Streams/Requester.swift b/Sources/RSocketCore/Streams/Requester.swift index 51cfe3ce..73bb0a4d 100644 --- a/Sources/RSocketCore/Streams/Requester.swift +++ b/Sources/RSocketCore/Streams/Requester.swift @@ -18,6 +18,7 @@ import Foundation import NIO internal final class Requester { + internal let encoding: ConnectionEncoding private let sendFrame: (Frame) -> Void private let eventLoop: EventLoop private var streamIdGenerator: StreamIDGenerator @@ -26,11 +27,13 @@ internal final class Requester { internal init( streamIdGenerator: StreamIDGenerator, + encoding: ConnectionEncoding, eventLoop: EventLoop, sendFrame: @escaping (Frame) -> Void, lateFrameHandler: ((Frame) -> ())? = nil ) { self.streamIdGenerator = streamIdGenerator + self.encoding = encoding self.eventLoop = eventLoop self.sendFrame = sendFrame self.lateFrameHandler = lateFrameHandler diff --git a/Sources/RSocketCore/Streams/Responder.swift b/Sources/RSocketCore/Streams/Responder.swift index 371263ee..b215a1e6 100644 --- a/Sources/RSocketCore/Streams/Responder.swift +++ b/Sources/RSocketCore/Streams/Responder.swift @@ -23,12 +23,12 @@ internal final class Responder { private let eventLoop: EventLoop private let lateFrameHandler: ((Frame) -> ())? internal init( - responderSocket: RSocket? = nil, + responderSocket: RSocket, eventLoop: EventLoop, sendFrame: @escaping (Frame) -> Void, lateFrameHandler: ((Frame) -> ())? = nil ) { - self.responderSocket = responderSocket ?? DefaultRSocket() + self.responderSocket = responderSocket self.sendFrame = sendFrame self.eventLoop = eventLoop self.lateFrameHandler = lateFrameHandler diff --git a/Sources/RSocketNIOChannel/ClientBootstrap.swift b/Sources/RSocketNIOChannel/ClientBootstrap.swift index 46a52e6e..d23cabb0 100644 --- a/Sources/RSocketNIOChannel/ClientBootstrap.swift +++ b/Sources/RSocketNIOChannel/ClientBootstrap.swift @@ -21,7 +21,7 @@ import RSocketCore final public class ClientBootstrap { private let group: EventLoopGroup private let bootstrap: NIO.ClientBootstrap - private let config: ClientConfiguration + public let config: ClientConfiguration private let transport: Transport private let sslContext: NIOSSLContext? diff --git a/Sources/RSocketReactiveSwift/Client/ReactiveSwiftClient.swift b/Sources/RSocketReactiveSwift/Client/ReactiveSwiftClient.swift index a70a8e75..b66a42a2 100644 --- a/Sources/RSocketReactiveSwift/Client/ReactiveSwiftClient.swift +++ b/Sources/RSocketReactiveSwift/Client/ReactiveSwiftClient.swift @@ -20,7 +20,7 @@ import ReactiveSwift public struct ReactiveSwiftClient: Client { private let coreClient: CoreClient - public var requester: RSocketReactiveSwift.RSocket { RequesterAdapter(requester: coreClient.requester) } + public var requester: RSocketReactiveSwift.RequesterRSocket { RequesterAdapter(requester: coreClient.requester) } public init(_ coreClient: CoreClient) { self.coreClient = coreClient @@ -31,10 +31,11 @@ extension ClientBootstrap where Client == CoreClient, Responder == RSocketCore.R public func connect( to endpoint: Transport.Endpoint, payload: Payload = .empty, - responder: RSocketReactiveSwift.RSocket? = nil + responder: RSocketReactiveSwift.ResponderRSocket? = nil ) -> SignalProducer { SignalProducer { observer, lifetime in - let future = connect(to: endpoint, payload: payload, responder: responder?.coreAdapter) + let responder = responder.map { ResponderAdapter(responder: $0, encoding: config.encoding) } + let future = connect(to: endpoint, payload: payload, responder: responder) .map(ReactiveSwiftClient.init) future.whenComplete { result in switch result { @@ -48,9 +49,3 @@ extension ClientBootstrap where Client == CoreClient, Responder == RSocketCore.R } } } - -private extension RSocketReactiveSwift.RSocket { - var coreAdapter: RSocketCore.RSocket { - ResponderAdapter(responder: self) - } -} diff --git a/Sources/RSocketReactiveSwift/Requester.swift b/Sources/RSocketReactiveSwift/Requester.swift index 0e47577c..80de6e53 100644 --- a/Sources/RSocketReactiveSwift/Requester.swift +++ b/Sources/RSocketReactiveSwift/Requester.swift @@ -18,12 +18,78 @@ import ReactiveSwift import RSocketCore import Foundation -internal struct RequesterAdapter: RSocket { +internal struct RequesterAdapter { internal let requester: RSocketCore.RSocket internal init(requester: RSocketCore.RSocket) { self.requester = requester } +} + +extension RequesterAdapter: RequesterRSocket { + func callAsFunction(_ metadataPush: MetadataPush, metadata: Metadata) throws { + let metadata = try metadataPush.encoder.encode(metadata) + self.metadataPush(metadata: metadata) + } + + func callAsFunction(_ fireAndForget: FireAndForget, request: Request) throws { + var encoder = fireAndForget.encoder + let payload = try encoder.encode(request, encoding: encoding) + self.fireAndForget(payload: payload) + } + + func callAsFunction( + _ requestResponse: RequestResponse, + request: Request + ) -> SignalProducer { + SignalProducer { () throws -> SignalProducer in + var encoder = requestResponse.encoder + var decoder = requestResponse.decoder + let payload = try encoder.encode(request, encoding: encoding) + return self.requestResponse(payload: payload).attemptMap { response in + try decoder.decode(response, encoding: encoding) + } + }.flatten(.latest) + } + + func callAsFunction( + _ requestStream: RequestStream, + request: Request + ) -> SignalProducer { + SignalProducer { () throws -> SignalProducer in + var encoder = requestStream.encoder + var decoder = requestStream.decoder + let payload = try encoder.encode(request, encoding: encoding) + return self.requestStream(payload: payload).attemptMap { response in + try decoder.decode(response, encoding: encoding) + } + }.flatten(.latest) + } + + func callAsFunction( + _ requestChannel: RequestChannel, + initialRequest: Request, + producer: SignalProducer? + ) -> SignalProducer { + SignalProducer { () throws -> SignalProducer in + var encoder = requestChannel.encoder + var decoder = requestChannel.decoder + let payload = try encoder.encode(initialRequest, encoding: encoding) + let payloadProducer = producer?.attemptMap { data in + try encoder.encode(data, encoding: encoding) + } + return self.requestChannel(payload: payload, payloadProducer: payloadProducer).attemptMap{ response in + try decoder.decode(response, encoding: encoding) + } + }.flatten(.latest) + } +} + + +extension RequesterAdapter { + internal var encoding: ConnectionEncoding { + requester.encoding + } internal func metadataPush(metadata: Data) { requester.metadataPush(metadata: metadata) diff --git a/Sources/RSocketReactiveSwift/RequesterRSocket.swift b/Sources/RSocketReactiveSwift/RequesterRSocket.swift new file mode 100644 index 00000000..3fb5c49e --- /dev/null +++ b/Sources/RSocketReactiveSwift/RequesterRSocket.swift @@ -0,0 +1,40 @@ +/* + * Copyright 2015-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import ReactiveSwift +import RSocketCore + +public protocol RequesterRSocket { + func callAsFunction(_ metadataPush: MetadataPush, metadata: Metadata) throws + + func callAsFunction(_ fireAndForget: FireAndForget, request: Request) throws + + func callAsFunction( + _ requestResponse: RequestResponse, + request: Request + ) -> SignalProducer + + func callAsFunction( + _ requestStream: RequestStream, + request: Request + ) -> SignalProducer + + func callAsFunction( + _ requestChannel: RequestChannel, + initialRequest: Request, + producer: SignalProducer? + ) -> SignalProducer +} diff --git a/Sources/RSocketReactiveSwift/Responder.swift b/Sources/RSocketReactiveSwift/Responder.swift index 769f58dc..c201532b 100644 --- a/Sources/RSocketReactiveSwift/Responder.swift +++ b/Sources/RSocketReactiveSwift/Responder.swift @@ -18,11 +18,14 @@ import ReactiveSwift import RSocketCore import Foundation -internal struct ResponderAdapter: RSocketCore.RSocket { - private let responder: RSocket +internal struct ResponderAdapter: RSocketCore.RSocket { + private let responder: ResponderRSocket + + internal var encoding: ConnectionEncoding - internal init(responder: RSocket) { + internal init(responder: ResponderRSocket, encoding: ConnectionEncoding) { self.responder = responder + self.encoding = encoding } func metadataPush(metadata: Data) { diff --git a/Sources/RSocketReactiveSwift/RSocket.swift b/Sources/RSocketReactiveSwift/ResponderRSocket.swift similarity index 97% rename from Sources/RSocketReactiveSwift/RSocket.swift rename to Sources/RSocketReactiveSwift/ResponderRSocket.swift index f570d5de..200fe7f9 100644 --- a/Sources/RSocketReactiveSwift/RSocket.swift +++ b/Sources/RSocketReactiveSwift/ResponderRSocket.swift @@ -19,7 +19,7 @@ import ReactiveSwift import RSocketCore import Foundation -public protocol RSocket { +public protocol ResponderRSocket { func metadataPush(metadata: Data) func fireAndForget(payload: Payload) func requestResponse(payload: Payload) -> SignalProducer diff --git a/Sources/RSocketTSChannel/ClientBootstrap.swift b/Sources/RSocketTSChannel/ClientBootstrap.swift index b4b036c9..b47c424c 100644 --- a/Sources/RSocketTSChannel/ClientBootstrap.swift +++ b/Sources/RSocketTSChannel/ClientBootstrap.swift @@ -25,7 +25,7 @@ import RSocketCore final public class ClientBootstrap { private let group = NIOTSEventLoopGroup() private let bootstrap: NIOTSConnectionBootstrap - private let config: ClientConfiguration + public let config: ClientConfiguration private let transport: Transport private let tlsOptions: NWProtocolTLS.Options? public init( diff --git a/Sources/RSocketTestUtilities/TestRSocket.swift b/Sources/RSocketTestUtilities/TestRSocket.swift index 8591ec2f..38e5d4f9 100644 --- a/Sources/RSocketTestUtilities/TestRSocket.swift +++ b/Sources/RSocketTestUtilities/TestRSocket.swift @@ -19,6 +19,7 @@ import NIO import RSocketCore public final class TestRSocket: RSocket { + public var encoding: ConnectionEncoding public var metadataPush: ((Data) -> ())? = nil public var fireAndForget: ((_ payload: Payload) -> ())? = nil public var requestResponse: ((_ payload: Payload, _ responderOutput: UnidirectionalStream) -> Cancellable)? = nil @@ -34,9 +35,11 @@ public final class TestRSocket: RSocket { requestResponse: ((Payload, UnidirectionalStream) -> Cancellable)? = nil, stream: ((Payload, Int32, UnidirectionalStream) -> Subscription)? = nil, channel: ((Payload, Int32, Bool, UnidirectionalStream) -> UnidirectionalStream)? = nil, + encoding: ConnectionEncoding = .default, file: StaticString = #file, line: UInt = #line ) { + self.encoding = encoding self.metadataPush = metadataPush self.fireAndForget = fireAndForget self.requestResponse = requestResponse diff --git a/Tests/RSocketCoreTests/EndToEndTests.swift b/Tests/RSocketCoreTests/EndToEndTests.swift index ae45e5b8..c3ac16ee 100644 --- a/Tests/RSocketCoreTests/EndToEndTests.swift +++ b/Tests/RSocketCoreTests/EndToEndTests.swift @@ -157,8 +157,8 @@ class EndToEndTests: XCTestCase { let server = makeServerBootstrap(shouldAcceptClient: { clientInfo in XCTAssertEqual(clientInfo.timeBetweenKeepaliveFrames, Int32(setup.timeout.timeBetweenKeepaliveFrames)) XCTAssertEqual(clientInfo.maxLifetime, Int32(setup.timeout.maxLifetime)) - XCTAssertEqual(clientInfo.metadataEncodingMimeType, setup.encoding.metadata.rawValue) - XCTAssertEqual(clientInfo.dataEncodingMimeType, setup.encoding.data.rawValue) + XCTAssertEqual(clientInfo.encoding.metadata, setup.encoding.metadata) + XCTAssertEqual(clientInfo.encoding.data, setup.encoding.data) clientDidConnect.fulfill() return .accept }) diff --git a/Tests/RSocketReactiveSwiftTests/RSocketReactiveSwiftTests.swift b/Tests/RSocketReactiveSwiftTests/RSocketReactiveSwiftTests.swift index 94d3a7db..3f9dc9f7 100644 --- a/Tests/RSocketReactiveSwiftTests/RSocketReactiveSwiftTests.swift +++ b/Tests/RSocketReactiveSwiftTests/RSocketReactiveSwiftTests.swift @@ -20,13 +20,20 @@ import RSocketCore import RSocketTestUtilities @testable import RSocketReactiveSwift +extension Data: ExpressibleByStringLiteral { + public init(stringLiteral value: String) { + self.init(value.utf8) + } +} + func setup( - server: RSocketReactiveSwift.RSocket? = nil, - client: RSocketReactiveSwift.RSocket? = nil + server: RSocketReactiveSwift.ResponderRSocket? = nil, + client: RSocketReactiveSwift.ResponderRSocket? = nil ) -> (server: ReactiveSwiftClient, client: ReactiveSwiftClient) { let (server, client) = TestDemultiplexer.pipe( - serverResponder: server.map(ResponderAdapter.init(responder:)), - clientResponder: client.map(ResponderAdapter.init(responder:))) + serverResponder: server.map { ResponderAdapter(responder:$0, encoding: .default) }, + clientResponder: client.map { ResponderAdapter(responder:$0, encoding: .default) } + ) return ( ReactiveSwiftClient(CoreClient(requester: server.requester)), ReactiveSwiftClient(CoreClient(requester: client.requester)) @@ -34,25 +41,25 @@ func setup( } final class RSocketReactiveSwiftTests: XCTestCase { - func testMetadataPush() { - let metadata = Data(String("Hello World").utf8) + func testMetadataPush() throws { + let metadata: Data = "Hello World" let didReceiveRequest = expectation(description: "did receive request") let serverResponder = TestRSocket(metadataPush: { data in didReceiveRequest.fulfill() XCTAssertEqual(data, metadata) }) let (_, client) = setup(server: serverResponder) - client.requester.metadataPush(metadata: metadata) + try client.requester(MetadataPush(), metadata: metadata) self.wait(for: [didReceiveRequest], timeout: 0.1) } - func testFireAndForget() { + func testFireAndForget() throws { let didReceiveRequest = expectation(description: "did receive request") let serverResponder = TestRSocket(fireAndForget: { payload in didReceiveRequest.fulfill() XCTAssertEqual(payload, "Hello World") }) let (_, client) = setup(server: serverResponder) - client.requester.fireAndForget(payload: "Hello World") + try client.requester(FireAndForget(), request: "Hello World") self.wait(for: [didReceiveRequest], timeout: 0.1) } func testRequestResponse() { @@ -69,7 +76,10 @@ final class RSocketReactiveSwiftTests: XCTestCase { } }) let (_, client) = setup(server: serverResponder) - let disposable = client.requester.requestResponse(payload: "Hello World").startWithSignal { signal, _ in + let disposable = client.requester( + RequestResponse(), + request: "Hello World" + ).startWithSignal { signal, _ in signal.flatMapError({ error in XCTFail("\(error)") return .empty @@ -95,7 +105,10 @@ final class RSocketReactiveSwiftTests: XCTestCase { } }) let (_, client) = setup(server: serverResponder) - let disposable = client.requester.requestResponse(payload: "Hello World").startWithSignal { signal, _ in + let disposable = client.requester( + RequestResponse(), + request: "Hello World" + ).startWithSignal { signal, _ in signal.flatMapError({ error in XCTFail("\(error)") return .empty @@ -126,7 +139,7 @@ final class RSocketReactiveSwiftTests: XCTestCase { } }) let (_, client) = setup(server: serverResponder) - let disposable = client.requester.requestStream(payload: "Hello World").startWithSignal { signal, _ in + let disposable = client.requester(RequestStream(), request: "Hello World").startWithSignal { signal, _ in signal.flatMapError({ error in XCTFail("\(error)") return .empty @@ -172,7 +185,7 @@ final class RSocketReactiveSwiftTests: XCTestCase { } }) let (_, client) = setup(server: serverResponder) - let disposable = client.requester.requestChannel(payload: "Hello Responder", payloadProducer: .init({ observer, _ in + let disposable = client.requester(RequestChannel(), initialRequest: "Hello Responder", producer: .init({ observer, _ in requesterDidSendChannelMessages.fulfill() observer.send(value: "Hello") observer.send(value: "from") @@ -216,9 +229,9 @@ final class RSocketReactiveSwiftTests: XCTestCase { } }) let (_, client) = setup(server: serverResponder) - let disposable = client.requester.requestResponse(payload: "Hello World").startWithSignal { signal, _ -> Disposable? in + let disposable = client.requester(RequestResponse(), request: "Hello World").startWithSignal { signal, _ -> Disposable? in didStartRequestSignal.fulfill() - return signal.flatMapError({ error -> Signal in + return signal.flatMapError({ error -> Signal in XCTFail("\(error)") return .empty }).materialize().collect().observeValues { values in @@ -245,9 +258,9 @@ final class RSocketReactiveSwiftTests: XCTestCase { } }) let (_, client) = setup(server: serverResponder) - let disposable = client.requester.requestStream(payload: "Hello World").startWithSignal { signal, _ -> Disposable? in + let disposable = client.requester(RequestStream(), request: "Hello World").startWithSignal { signal, _ -> Disposable? in didStartRequestSignal.fulfill() - return signal.flatMapError({ error -> Signal in + return signal.flatMapError({ error -> Signal in XCTFail("\(error)") return .empty }).materialize().collect().observeValues { values in @@ -290,7 +303,7 @@ final class RSocketReactiveSwiftTests: XCTestCase { let requesterDidStartListeningChannelMessages = expectation(description: "responder did start listening to channel messages") let payloadProducerLifetimeEnded = expectation(description: "payload producer lifetime ended") let requesterDidStartPayloadProducer = expectation(description: "requester did start payload producer") - let disposable = client.requester.requestChannel(payload: "Hello", payloadProducer: .init({ observer, lifetime in + let disposable = client.requester(RequestChannel(), initialRequest: "Hello", producer: .init({ observer, lifetime in requesterDidStartPayloadProducer.fulfill() lifetime.observeEnded { _ = observer @@ -298,7 +311,7 @@ final class RSocketReactiveSwiftTests: XCTestCase { } })).startWithSignal { signal, _ -> Disposable? in requesterDidStartListeningChannelMessages.fulfill() - return signal.flatMapError({ error -> Signal in + return signal.flatMapError({ error -> Signal in XCTFail("\(error)") return .empty }).materialize().collect().observeValues { values in diff --git a/Tests/RSocketReactiveSwiftTests/TestDemultiplexer.swift b/Tests/RSocketReactiveSwiftTests/TestDemultiplexer.swift index a9718ae6..f6b7b007 100644 --- a/Tests/RSocketReactiveSwiftTests/TestDemultiplexer.swift +++ b/Tests/RSocketReactiveSwiftTests/TestDemultiplexer.swift @@ -48,11 +48,13 @@ extension TestDemultiplexer { serverResponder: RSocketCore.RSocket?, clientResponder: RSocketCore.RSocket? ) -> (server: TestDemultiplexer, client: TestDemultiplexer) { + let serverResponder = serverResponder ?? DefaultRSocket(encoding: .default) + let clientResponder = clientResponder ?? DefaultRSocket(encoding: .default) var client: TestDemultiplexer! let eventLoop = EmbeddedEventLoop() let server = TestDemultiplexer( connectionSide: .server, - requester: .init(streamIdGenerator: .server, eventLoop: eventLoop, sendFrame: { frame in + requester: .init(streamIdGenerator: .server, encoding: .default, eventLoop: eventLoop, sendFrame: { frame in client.receiveFrame(frame: frame) }), responder: .init(responderSocket: serverResponder, eventLoop: eventLoop, sendFrame: { frame in @@ -60,7 +62,7 @@ extension TestDemultiplexer { })) client = TestDemultiplexer( connectionSide: .client, - requester: .init(streamIdGenerator: .client, eventLoop: eventLoop, sendFrame: { frame in + requester: .init(streamIdGenerator: .client, encoding: .default, eventLoop: eventLoop, sendFrame: { frame in server.receiveFrame(frame: frame) }), responder: .init(responderSocket: clientResponder, eventLoop: eventLoop, sendFrame: { frame in diff --git a/Tests/RSocketReactiveSwiftTests/TestRSocket.swift b/Tests/RSocketReactiveSwiftTests/TestRSocket.swift index dffb9edc..cff46bba 100644 --- a/Tests/RSocketReactiveSwiftTests/TestRSocket.swift +++ b/Tests/RSocketReactiveSwiftTests/TestRSocket.swift @@ -19,7 +19,7 @@ import ReactiveSwift import Foundation import RSocketReactiveSwift -final class TestRSocket: RSocketReactiveSwift.RSocket { +final class TestRSocket: RSocketReactiveSwift.ResponderRSocket { var metadataPushCallback: (Data) -> () var fireAndForgetCallback: (Payload) -> () var requestResponseCallback: (Payload) -> SignalProducer