diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Factory.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Factory.swift index c896791cf..3dc47c5ae 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Factory.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Factory.swift @@ -68,13 +68,8 @@ extension HTTPConnectionPool.ConnectionFactory { var logger = logger logger[metadataKey: "ahc-connection-id"] = "\(connectionID)" - self.makeChannel( - requester: requester, - connectionID: connectionID, - deadline: deadline, - eventLoop: eventLoop, - logger: logger - ).whenComplete { [logger] result in + let promise = eventLoop.makePromise(of: NegotiatedProtocol.self) + promise.futureResult.whenComplete { [logger] result in switch result { case .success(.http1_1(let channel)): do { @@ -143,10 +138,26 @@ extension HTTPConnectionPool.ConnectionFactory { } } - case .failure(let error): + case .failure(var error): + // let's map `ChannelError.connectTimeout` into a `HTTPClientError.connectTimeout` + switch error { + case ChannelError.connectTimeout: + error = HTTPClientError.connectTimeout + default: + () + } requester.failedToCreateHTTPConnection(connectionID, error: error) } } + + self.makeChannel( + requester: requester, + connectionID: connectionID, + deadline: deadline, + eventLoop: eventLoop, + logger: logger, + promise: promise + ) } enum NegotiatedProtocol { @@ -159,50 +170,42 @@ extension HTTPConnectionPool.ConnectionFactory { connectionID: HTTPConnectionPool.Connection.ID, deadline: NIODeadline, eventLoop: EventLoop, - logger: Logger - ) -> EventLoopFuture { - let channelFuture: EventLoopFuture - + logger: Logger, + promise: EventLoopPromise + ) { if self.key.scheme.isProxyable, let proxy = self.clientConfiguration.proxy { switch proxy.type { case .socks: - channelFuture = self.makeSOCKSProxyChannel( + self.makeSOCKSProxyChannel( proxy, requester: requester, connectionID: connectionID, deadline: deadline, eventLoop: eventLoop, - logger: logger + logger: logger, + promise: promise ) case .http: - channelFuture = self.makeHTTPProxyChannel( + self.makeHTTPProxyChannel( proxy, requester: requester, connectionID: connectionID, deadline: deadline, eventLoop: eventLoop, - logger: logger + logger: logger, + promise: promise ) } } else { - channelFuture = self.makeNonProxiedChannel( + self.makeNonProxiedChannel( requester: requester, connectionID: connectionID, deadline: deadline, eventLoop: eventLoop, - logger: logger + logger: logger, + promise: promise ) } - - // let's map `ChannelError.connectTimeout` into a `HTTPClientError.connectTimeout` - return channelFuture.flatMapErrorThrowing { error throws -> NegotiatedProtocol in - switch error { - case ChannelError.connectTimeout: - throw HTTPClientError.connectTimeout - default: - throw error - } - } } private func makeNonProxiedChannel( @@ -210,29 +213,27 @@ extension HTTPConnectionPool.ConnectionFactory { connectionID: HTTPConnectionPool.Connection.ID, deadline: NIODeadline, eventLoop: EventLoop, - logger: Logger - ) -> EventLoopFuture { + logger: Logger, + promise: EventLoopPromise + ) { switch self.key.scheme { case .http, .httpUnix, .unix: - return self.makePlainChannel( + self.makePlainChannel( requester: requester, connectionID: connectionID, deadline: deadline, - eventLoop: eventLoop - ).map { .http1_1($0) } + eventLoop: eventLoop, + promise: promise + ) case .https, .httpsUnix: - return self.makeTLSChannel( + self.makeTLSChannel( requester: requester, connectionID: connectionID, deadline: deadline, eventLoop: eventLoop, - logger: logger - ).flatMapThrowing { - channel, - negotiated in - - try self.matchALPNToHTTPVersion(negotiated, channel: channel) - } + logger: logger, + promise: promise + ) } } @@ -240,15 +241,18 @@ extension HTTPConnectionPool.ConnectionFactory { requester: Requester, connectionID: HTTPConnectionPool.Connection.ID, deadline: NIODeadline, - eventLoop: EventLoop - ) -> EventLoopFuture { + eventLoop: EventLoop, + promise: EventLoopPromise + ) { precondition(!self.key.scheme.usesTLS, "Unexpected scheme") return self.makePlainBootstrap( requester: requester, connectionID: connectionID, deadline: deadline, eventLoop: eventLoop - ).connect(target: self.key.connectionTarget) + ).connect(target: self.key.connectionTarget).map { + .http1_1($0) + }.cascade(to: promise) } private func makeHTTPProxyChannel( @@ -257,8 +261,9 @@ extension HTTPConnectionPool.ConnectionFactory { connectionID: HTTPConnectionPool.Connection.ID, deadline: NIODeadline, eventLoop: EventLoop, - logger: Logger - ) -> EventLoopFuture { + logger: Logger, + promise: EventLoopPromise + ) { // A proxy connection starts with a plain text connection to the proxy server. After // the connection has been established with the proxy server, the connection might be // upgraded to TLS before we send our first request. @@ -268,34 +273,39 @@ extension HTTPConnectionPool.ConnectionFactory { deadline: deadline, eventLoop: eventLoop ) - return bootstrap.connect(host: proxy.host, port: proxy.port).flatMap { channel in - let encoder = HTTPRequestEncoder() - let decoder = ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .dropBytes)) - let proxyHandler = HTTP1ProxyConnectHandler( - target: self.key.connectionTarget, - proxyAuthorization: proxy.authorization, - deadline: deadline - ) + bootstrap.connect(host: proxy.host, port: proxy.port).whenComplete { result in + switch result { + case .success(let channel): + let encoder = HTTPRequestEncoder() + let decoder = ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .dropBytes)) + let proxyHandler = HTTP1ProxyConnectHandler( + target: self.key.connectionTarget, + proxyAuthorization: proxy.authorization, + deadline: deadline + ) - do { - try channel.pipeline.syncOperations.addHandler(encoder) - try channel.pipeline.syncOperations.addHandler(decoder) - try channel.pipeline.syncOperations.addHandler(proxyHandler) - } catch { - return channel.eventLoop.makeFailedFuture(error) - } + do { + try channel.pipeline.syncOperations.addHandler(encoder) + try channel.pipeline.syncOperations.addHandler(decoder) + try channel.pipeline.syncOperations.addHandler(proxyHandler) + } catch { + return promise.fail(error) + } - // The proxyEstablishedFuture is set as soon as the HTTP1ProxyConnectHandler is in a - // pipeline. It is created in HTTP1ProxyConnectHandler's handlerAdded method. - return proxyHandler.proxyEstablishedFuture!.assumeIsolated().flatMap { - channel.pipeline.syncOperations.removeHandler(proxyHandler).assumeIsolated().flatMap { - channel.pipeline.syncOperations.removeHandler(decoder).assumeIsolated().flatMap { - channel.pipeline.syncOperations.removeHandler(encoder) + // The proxyEstablishedFuture is set as soon as the HTTP1ProxyConnectHandler is in a + // pipeline. It is created in HTTP1ProxyConnectHandler's handlerAdded method. + return proxyHandler.proxyEstablishedFuture!.assumeIsolated().flatMap { + channel.pipeline.syncOperations.removeHandler(proxyHandler).assumeIsolated().flatMap { + channel.pipeline.syncOperations.removeHandler(decoder).assumeIsolated().flatMap { + channel.pipeline.syncOperations.removeHandler(encoder) + }.nonisolated() }.nonisolated() - }.nonisolated() - }.flatMap { - self.setupTLSInProxyConnectionIfNeeded(channel, deadline: deadline, logger: logger) - }.nonisolated() + }.flatMap { + self.setupTLSInProxyConnectionIfNeeded(channel, deadline: deadline, logger: logger) + }.nonisolated().cascade(to: promise) + case .failure(let error): + promise.fail(error) + } } } @@ -305,8 +315,9 @@ extension HTTPConnectionPool.ConnectionFactory { connectionID: HTTPConnectionPool.Connection.ID, deadline: NIODeadline, eventLoop: EventLoop, - logger: Logger - ) -> EventLoopFuture { + logger: Logger, + promise: EventLoopPromise + ) { // A proxy connection starts with a plain text connection to the proxy server. After // the connection has been established with the proxy server, the connection might be // upgraded to TLS before we send our first request. @@ -316,26 +327,32 @@ extension HTTPConnectionPool.ConnectionFactory { deadline: deadline, eventLoop: eventLoop ) - return bootstrap.connect(host: proxy.host, port: proxy.port).flatMap { channel in - let socksConnectHandler = SOCKSClientHandler(targetAddress: SOCKSAddress(self.key.connectionTarget)) - let socksEventHandler = SOCKSEventsHandler(deadline: deadline) - - do { - try channel.pipeline.syncOperations.addHandler(socksConnectHandler) - try channel.pipeline.syncOperations.addHandler(socksEventHandler) - } catch { - return channel.eventLoop.makeFailedFuture(error) + bootstrap.connect(host: proxy.host, port: proxy.port).whenComplete { result in + switch result { + case .success(let channel): + let socksConnectHandler = SOCKSClientHandler(targetAddress: SOCKSAddress(self.key.connectionTarget)) + let socksEventHandler = SOCKSEventsHandler(deadline: deadline) + + do { + try channel.pipeline.syncOperations.addHandler(socksConnectHandler) + try channel.pipeline.syncOperations.addHandler(socksEventHandler) + } catch { + return promise.fail(error) + } + + // The socksEstablishedFuture is set as soon as the SOCKSEventsHandler is in a + // pipeline. It is created in SOCKSEventsHandler's handlerAdded method. + socksEventHandler.socksEstablishedFuture!.assumeIsolated().flatMap { + channel.pipeline.syncOperations.removeHandler(socksEventHandler).assumeIsolated().flatMap { + channel.pipeline.syncOperations.removeHandler(socksConnectHandler) + }.nonisolated() + }.flatMap { + self.setupTLSInProxyConnectionIfNeeded(channel, deadline: deadline, logger: logger) + }.nonisolated().cascade(to: promise) + case .failure(let error): + promise.fail(error) } - // The socksEstablishedFuture is set as soon as the SOCKSEventsHandler is in a - // pipeline. It is created in SOCKSEventsHandler's handlerAdded method. - return socksEventHandler.socksEstablishedFuture!.assumeIsolated().flatMap { - channel.pipeline.syncOperations.removeHandler(socksEventHandler).assumeIsolated().flatMap { - channel.pipeline.syncOperations.removeHandler(socksConnectHandler) - }.nonisolated() - }.flatMap { - self.setupTLSInProxyConnectionIfNeeded(channel, deadline: deadline, logger: logger) - }.nonisolated() } } @@ -390,7 +407,7 @@ extension HTTPConnectionPool.ConnectionFactory { let sync = channel.pipeline.syncOperations let context = try sync.context(handlerType: TLSEventsHandler.self) return sync.removeHandler(context: context).flatMapThrowing { - try self.matchALPNToHTTPVersion(negotiated, channel: channel) + try Self.matchALPNToHTTPVersion(negotiated, channel: channel) } } catch { return channel.eventLoop.makeFailedFuture(error) @@ -449,8 +466,9 @@ extension HTTPConnectionPool.ConnectionFactory { connectionID: HTTPConnectionPool.Connection.ID, deadline: NIODeadline, eventLoop: EventLoop, - logger: Logger - ) -> EventLoopFuture<(Channel, String?)> { + logger: Logger, + promise: EventLoopPromise + ) { precondition(self.key.scheme.usesTLS, "Unexpected scheme") let bootstrapFuture = self.makeTLSBootstrap( requester: requester, @@ -460,36 +478,42 @@ extension HTTPConnectionPool.ConnectionFactory { logger: logger ) - var channelFuture = bootstrapFuture.flatMap { bootstrap -> EventLoopFuture in - bootstrap.connect(target: self.key.connectionTarget) - }.flatMap { channel -> EventLoopFuture<(Channel, String?)> in - do { - // if the channel is closed before flatMap is executed, all ChannelHandler are removed - // and TLSEventsHandler is therefore not present either - let tlsEventHandler = try channel.pipeline.syncOperations.handler(type: TLSEventsHandler.self) - - // The tlsEstablishedFuture is set as soon as the TLSEventsHandler is in a - // pipeline. It is created in TLSEventsHandler's handlerAdded method. - return tlsEventHandler.tlsEstablishedFuture!.assumeIsolated().flatMap { negotiated in - channel.pipeline.syncOperations.removeHandler(tlsEventHandler).map { (channel, negotiated) } - }.nonisolated() - } catch { - assert( - channel.isActive == false, - "if the channel is still active then TLSEventsHandler must be present but got error \(error)" - ) - return channel.eventLoop.makeFailedFuture(HTTPClientError.remoteConnectionClosed) + bootstrapFuture.whenComplete { result in + switch result { + case .success(let bootstrap): + bootstrap.connect(target: self.key.connectionTarget).flatMap { + channel -> EventLoopFuture<(Channel, String?)> in + do { + // if the channel is closed before flatMap is executed, all ChannelHandler are removed + // and TLSEventsHandler is therefore not present either + let tlsEventHandler = try channel.pipeline.syncOperations.handler(type: TLSEventsHandler.self) + + // The tlsEstablishedFuture is set as soon as the TLSEventsHandler is in a + // pipeline. It is created in TLSEventsHandler's handlerAdded method. + return tlsEventHandler.tlsEstablishedFuture!.assumeIsolated().flatMap { negotiated in + channel.pipeline.syncOperations.removeHandler(tlsEventHandler).map { (channel, negotiated) } + }.nonisolated() + } catch { + assert( + channel.isActive == false, + "if the channel is still active then TLSEventsHandler must be present but got error \(error)" + ) + return channel.eventLoop.makeFailedFuture(HTTPClientError.remoteConnectionClosed) + } + }.flatMapThrowing { channel, alpn in + try Self.matchALPNToHTTPVersion(alpn, channel: channel) + }.flatMapErrorThrowing { error in + // If NIOTransportSecurity is used, we want to map NWErrors into NWPOsixErrors or NWTLSError. + #if canImport(Network) + throw HTTPClient.NWErrorHandler.translateError(error) + #else + throw error + #endif + }.cascade(to: promise) + case .failure(let error): + promise.fail(error) } } - - #if canImport(Network) - // If NIOTransportSecurity is used, we want to map NWErrors into NWPOsixErrors or NWTLSError. - channelFuture = channelFuture.flatMapErrorThrowing { error in - throw HTTPClient.NWErrorHandler.translateError(error) - } - #endif - - return channelFuture } private func makeTLSBootstrap( @@ -582,7 +606,7 @@ extension HTTPConnectionPool.ConnectionFactory { } } - private func matchALPNToHTTPVersion(_ negotiated: String?, channel: Channel) throws -> NegotiatedProtocol { + private static func matchALPNToHTTPVersion(_ negotiated: String?, channel: Channel) throws -> NegotiatedProtocol { switch negotiated { case .none, .some("http/1.1"): return .http1_1(channel) diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+FactoryTests.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+FactoryTests.swift index 15cc9e7e9..37ff3a1ef 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+FactoryTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+FactoryTests.swift @@ -57,7 +57,10 @@ class HTTPConnectionPool_FactoryTests: XCTestCase { logger: .init(label: "test") ).wait() ) { - XCTAssertEqual($0 as? HTTPClientError, .connectTimeout) + guard let error = $0 as? ChannelError, case .connectTimeout = error else { + XCTFail("Unexpected error: \($0)") + return + } } } @@ -210,3 +213,24 @@ final class ExplodingRequester: HTTPConnectionRequester { XCTFail("waitingForConnectivity called unexpectedly") } } + +extension HTTPConnectionPool.ConnectionFactory { + fileprivate func makeChannel( + requester: Requester, + connectionID: HTTPConnectionPool.Connection.ID, + deadline: NIODeadline, + eventLoop: EventLoop, + logger: Logger + ) -> EventLoopFuture { + let promise = eventLoop.makePromise(of: NegotiatedProtocol.self) + self.makeChannel( + requester: requester, + connectionID: connectionID, + deadline: deadline, + eventLoop: eventLoop, + logger: logger, + promise: promise + ) + return promise.futureResult + } +}