From 4bdf6b7006905e9ae1c2833d40b3215cc4f093ee Mon Sep 17 00:00:00 2001 From: Davide De Rosa Date: Tue, 14 Apr 2020 22:53:15 +0200 Subject: [PATCH] Redefine endpoint strategy according to IPv4/6 --- .../AppExtension/ConnectionStrategy.swift | 190 +++++++++++------- .../OpenVPNTunnelProvider+Configuration.swift | 2 +- .../AppExtension/OpenVPNTunnelProvider.swift | 14 +- .../OpenVPN/AppExtensionTests.swift | 97 ++++++++- 4 files changed, 220 insertions(+), 83 deletions(-) diff --git a/TunnelKit/Sources/Protocols/OpenVPN/AppExtension/ConnectionStrategy.swift b/TunnelKit/Sources/Protocols/OpenVPN/AppExtension/ConnectionStrategy.swift index e00b4cb..7ec91a6 100644 --- a/TunnelKit/Sources/Protocols/OpenVPN/AppExtension/ConnectionStrategy.swift +++ b/TunnelKit/Sources/Protocols/OpenVPN/AppExtension/ConnectionStrategy.swift @@ -41,25 +41,38 @@ import SwiftyBeaver private let log = SwiftyBeaver.self class ConnectionStrategy { + struct Endpoint: CustomStringConvertible { + let record: DNSRecord + + let proto: EndpointProtocol + + var isValid: Bool { + if record.isIPv6 { + return proto.socketType != .udp4 && proto.socketType != .tcp4 + } else { + return proto.socketType != .udp6 && proto.socketType != .tcp6 + } + } + + // MARK: CustomStringConvertible + + var description: String { + return "\(record.address.maskedDescription):\(proto)" + } + } + private let hostname: String? - private let prefersResolvedAddresses: Bool - - private var resolvedAddresses: [String]? - private let endpointProtocols: [EndpointProtocol] - private var currentProtocolIndex = 0 + private var endpoints: [Endpoint] + + private var currentEndpointIndex: Int + + private let resolvedAddresses: [String] init(configuration: OpenVPNTunnelProvider.Configuration) { hostname = configuration.sessionConfiguration.hostname - prefersResolvedAddresses = (hostname == nil) || configuration.prefersResolvedAddresses - resolvedAddresses = configuration.resolvedAddresses - if prefersResolvedAddresses { - guard !(resolvedAddresses?.isEmpty ?? true) else { - fatalError("Either hostname or resolved addresses provided") - } - } guard var endpointProtocols = configuration.sessionConfiguration.endpointProtocols else { fatalError("No endpoints provided") } @@ -67,101 +80,130 @@ class ConnectionStrategy { endpointProtocols.shuffle() } self.endpointProtocols = endpointProtocols + + currentEndpointIndex = 0 + if let resolvedAddresses = configuration.resolvedAddresses { + if configuration.prefersResolvedAddresses { + endpoints = ConnectionStrategy.unrolledEndpoints( + records: resolvedAddresses.map { DNSRecord(address: $0, isIPv6: false) }, + protos: endpointProtocols + ) + } else { + endpoints = [] + } + self.resolvedAddresses = resolvedAddresses + } else { + guard hostname != nil else { + fatalError("Either configuration.hostname or resolvedRecords required") + } + endpoints = [] + resolvedAddresses = [] + } + } + + private static func unrolledEndpoints(ipv4Addresses: [String], protos: [EndpointProtocol]) -> [Endpoint] { + return unrolledEndpoints(records: ipv4Addresses.map { DNSRecord(address: $0, isIPv6: false) }, protos: protos) } + private static func unrolledEndpoints(records: [DNSRecord], protos: [EndpointProtocol]) -> [Endpoint] { + guard !records.isEmpty else { + return [] + } + var endpoints: [Endpoint] = [] + for r in records { + for p in protos { + let endpoint = Endpoint(record: r, proto: p) + guard endpoint.isValid else { + continue + } + endpoints.append(endpoint) + } + } + log.debug("Unrolled endpoints: \(endpoints.maskedDescription)") + return endpoints + } + + func hasEndpoint() -> Bool { + return currentEndpointIndex < endpoints.count + } + + func currentEndpoint() -> Endpoint { + guard hasEndpoint() else { + fatalError("Endpoint index out of bounds (\(currentEndpointIndex) >= \(endpoints.count))") + } + return endpoints[currentEndpointIndex] + } + + @discardableResult + func tryNextEndpoint() -> Bool { + guard hasEndpoint() else { + return false + } + currentEndpointIndex += 1 + guard currentEndpointIndex < endpoints.count else { + log.debug("Exhausted endpoints") + return false + } + log.debug("Try next endpoint: \(currentEndpoint().maskedDescription)") + return true + } + func createSocket( from provider: NEProvider, timeout: Int, - preferredAddress: String? = nil, queue: DispatchQueue, completionHandler: @escaping (GenericSocket?, Error?) -> Void) { - - // reuse preferred address - if let preferredAddress = preferredAddress { - log.debug("Pick preferred address: \(preferredAddress.maskedDescription)") - let socket = provider.createSocket(to: preferredAddress, protocol: currentProtocol()) + + if hasEndpoint() { + let endpoint = currentEndpoint() + log.debug("Pick current endpoint: \(endpoint.maskedDescription)") + let socket = provider.createSocket(to: endpoint) completionHandler(socket, nil) return } - - // use any resolved address - if prefersResolvedAddresses, let resolvedAddress = anyResolvedAddress() { - log.debug("Pick resolved address: \(resolvedAddress.maskedDescription)") - let socket = provider.createSocket(to: resolvedAddress, protocol: currentProtocol()) - completionHandler(socket, nil) - return - } - - // fall back to DNS + log.debug("No endpoints available, will resort to DNS resolution") + guard let hostname = hostname else { log.error("DNS resolution unavailable: no hostname provided!") completionHandler(nil, OpenVPNTunnelProvider.ProviderError.dnsFailure) return } log.debug("DNS resolve hostname: \(hostname.maskedDescription)") - DNSResolver.resolve(hostname, timeout: timeout, queue: queue) { (addresses, error) in - - // refresh resolved addresses - if let resolved = addresses, !resolved.isEmpty { - self.resolvedAddresses = resolved - - log.debug("DNS resolved addresses: \(resolved.map { $0.maskedDescription })") + DNSResolver.resolve(hostname, timeout: timeout, queue: queue) { (records, error) in + self.currentEndpointIndex = 0 + if let records = records, !records.isEmpty { + log.debug("DNS resolved addresses: \(records.map { $0.address }.maskedDescription)") + self.endpoints = ConnectionStrategy.unrolledEndpoints(records: records, protos: self.endpointProtocols) } else { log.error("DNS resolution failed!") + log.debug("Fall back to resolved addresses: \(self.resolvedAddresses.maskedDescription)") + self.endpoints = ConnectionStrategy.unrolledEndpoints(ipv4Addresses: self.resolvedAddresses, protos: self.endpointProtocols) } - - guard let targetAddress = self.resolvedAddress(from: addresses) else { - log.error("No resolved or fallback address available") + + guard self.hasEndpoint() else { + log.error("No endpoints available") completionHandler(nil, OpenVPNTunnelProvider.ProviderError.dnsFailure) return } - let socket = provider.createSocket(to: targetAddress, protocol: self.currentProtocol()) + let targetEndpoint = self.currentEndpoint() + log.debug("Pick current endpoint: \(targetEndpoint.maskedDescription)") + let socket = provider.createSocket(to: targetEndpoint) completionHandler(socket, nil) } } - - func tryNextProtocol() -> Bool { - let next = currentProtocolIndex + 1 - guard next < endpointProtocols.count else { - log.debug("No more protocols available") - return false - } - currentProtocolIndex = next - log.debug("Fall back to next protocol: \(currentProtocol())") - return true - } - - private func currentProtocol() -> EndpointProtocol { - return endpointProtocols[currentProtocolIndex] - } - - private func resolvedAddress(from addresses: [String]?) -> String? { - guard let resolved = addresses, !resolved.isEmpty else { - return anyResolvedAddress() - } - return resolved[0] - } - - private func anyResolvedAddress() -> String? { - guard let addresses = resolvedAddresses, !addresses.isEmpty else { - return nil - } - let n = Int(arc4random() % UInt32(addresses.count)) - return addresses[n] - } } private extension NEProvider { - func createSocket(to address: String, protocol endpointProtocol: EndpointProtocol) -> GenericSocket { - let endpoint = NWHostEndpoint(hostname: address, port: "\(endpointProtocol.port)") - switch endpointProtocol.socketType { + func createSocket(to endpoint: ConnectionStrategy.Endpoint) -> GenericSocket { + let ep = NWHostEndpoint(hostname: endpoint.record.address, port: "\(endpoint.proto.port)") + switch endpoint.proto.socketType { case .udp, .udp4, .udp6: - let impl = createUDPSession(to: endpoint, from: nil) + let impl = createUDPSession(to: ep, from: nil) return NEUDPSocket(impl: impl) case .tcp, .tcp4, .tcp6: - let impl = createTCPConnection(to: endpoint, enableTLS: false, tlsParameters: nil, delegate: nil) + let impl = createTCPConnection(to: ep, enableTLS: false, tlsParameters: nil, delegate: nil) return NETCPSocket(impl: impl) } } diff --git a/TunnelKit/Sources/Protocols/OpenVPN/AppExtension/OpenVPNTunnelProvider+Configuration.swift b/TunnelKit/Sources/Protocols/OpenVPN/AppExtension/OpenVPNTunnelProvider+Configuration.swift index 6dcfd6a..69aa44d 100644 --- a/TunnelKit/Sources/Protocols/OpenVPN/AppExtension/OpenVPNTunnelProvider+Configuration.swift +++ b/TunnelKit/Sources/Protocols/OpenVPN/AppExtension/OpenVPNTunnelProvider+Configuration.swift @@ -66,7 +66,7 @@ extension OpenVPNTunnelProvider { /// - Seealso: `fallbackServerAddresses` public var prefersResolvedAddresses: Bool - /// Resolved addresses in case DNS fails or `prefersResolvedAddresses` is `true`. + /// Resolved addresses in case DNS fails or `prefersResolvedAddresses` is `true` (IPv4 only). public var resolvedAddresses: [String]? /// The MTU of the link. diff --git a/TunnelKit/Sources/Protocols/OpenVPN/AppExtension/OpenVPNTunnelProvider.swift b/TunnelKit/Sources/Protocols/OpenVPN/AppExtension/OpenVPNTunnelProvider.swift index 9c1f038..7b80a13 100644 --- a/TunnelKit/Sources/Protocols/OpenVPN/AppExtension/OpenVPNTunnelProvider.swift +++ b/TunnelKit/Sources/Protocols/OpenVPN/AppExtension/OpenVPNTunnelProvider.swift @@ -311,7 +311,7 @@ open class OpenVPNTunnelProvider: NEPacketTunnelProvider { // MARK: Connection (tunnel queue) - private func connectTunnel(upgradedSocket: GenericSocket? = nil, preferredAddress: String? = nil) { + private func connectTunnel(upgradedSocket: GenericSocket? = nil) { log.info("Creating link session") // reuse upgraded socket @@ -321,7 +321,7 @@ open class OpenVPNTunnelProvider: NEPacketTunnelProvider { return } - strategy.createSocket(from: self, timeout: dnsTimeout, preferredAddress: preferredAddress, queue: tunnelQueue) { (socket, error) in + strategy.createSocket(from: self, timeout: dnsTimeout, queue: tunnelQueue) { (socket, error) in guard let socket = socket else { self.disposeTunnel(error: error) return @@ -424,7 +424,7 @@ extension OpenVPNTunnelProvider: GenericSocketDelegate { // fallback: TCP connection timeout suggests falling back if let _ = socket as? NETCPSocket { - guard tryNextProtocol() else { + guard tryNextEndpoint() else { // disposeTunnel return } @@ -471,7 +471,7 @@ extension OpenVPNTunnelProvider: GenericSocketDelegate { // fallback: UDP is connection-less, treat negotiation timeout as socket timeout if didTimeoutNegotiation { - guard tryNextProtocol() else { + guard tryNextEndpoint() else { // disposeTunnel return } @@ -489,7 +489,7 @@ extension OpenVPNTunnelProvider: GenericSocketDelegate { return } - self.connectTunnel(upgradedSocket: upgradedSocket, preferredAddress: socket.remoteAddress) + self.connectTunnel(upgradedSocket: upgradedSocket) } return } @@ -775,8 +775,8 @@ extension OpenVPNTunnelProvider: OpenVPNSessionDelegate { } extension OpenVPNTunnelProvider { - private func tryNextProtocol() -> Bool { - guard strategy.tryNextProtocol() else { + private func tryNextEndpoint() -> Bool { + guard strategy.tryNextEndpoint() else { disposeTunnel(error: ProviderError.exhaustedProtocols) return false } diff --git a/TunnelKitTests/OpenVPN/AppExtensionTests.swift b/TunnelKitTests/OpenVPN/AppExtensionTests.swift index cb0c84d..c9484f1 100644 --- a/TunnelKitTests/OpenVPN/AppExtensionTests.swift +++ b/TunnelKitTests/OpenVPN/AppExtensionTests.swift @@ -94,7 +94,7 @@ class AppExtensionTests: XCTestCase { func testDNSResolver() { let exp = expectation(description: "DNS") - DNSResolver.resolve("djsbjhcbjzhbxjnvsd.com", timeout: 1000, queue: DispatchQueue.main) { (addrs, error) in + DNSResolver.resolve("www.google.com", timeout: 1000, queue: .main) { (addrs, error) in defer { exp.fulfill() } @@ -126,4 +126,99 @@ class AppExtensionTests: XCTestCase { XCTAssertEqual(string, expString) } } + + func testEndpointCycling() { + CoreConfiguration.masksPrivateData = false + + var builder1 = OpenVPN.ConfigurationBuilder() + builder1.hostname = "italy.privateinternetaccess.com" + builder1.endpointProtocols = [ + EndpointProtocol(.tcp6, 2222), + EndpointProtocol(.udp, 1111), + EndpointProtocol(.udp4, 3333) + ] + var builder2 = OpenVPNTunnelProvider.ConfigurationBuilder(sessionConfiguration: builder1.build()) + builder2.prefersResolvedAddresses = true + builder2.resolvedAddresses = [ + "82.102.21.218", + "82.102.21.214", + "82.102.21.213", + ] + let strategy = ConnectionStrategy(configuration: builder2.build()) + + let expected = [ + "82.102.21.218:UDP:1111", + "82.102.21.218:UDP4:3333", + "82.102.21.214:UDP:1111", + "82.102.21.214:UDP4:3333", + "82.102.21.213:UDP:1111", + "82.102.21.213:UDP4:3333", + ] + var i = 0 + while strategy.hasEndpoint() { + let endpoint = strategy.currentEndpoint() + print("\(endpoint)") + XCTAssertEqual(endpoint.description, expected[i]) + i += 1 + strategy.tryNextEndpoint() + } + } + +// func testEndpointCycling4() { +// CoreConfiguration.masksPrivateData = false +// +// var builder = OpenVPN.ConfigurationBuilder() +// builder.hostname = "italy.privateinternetaccess.com" +// builder.endpointProtocols = [ +// EndpointProtocol(.tcp4, 2222), +// ] +// let strategy = ConnectionStrategy( +// configuration: builder.build(), +// resolvedRecords: [ +// DNSRecord(address: "111:bbbb:ffff::eeee", isIPv6: true), +// DNSRecord(address: "11.22.33.44", isIPv6: false), +// ] +// ) +// +// let expected = [ +// "11.22.33.44:TCP4:2222" +// ] +// var i = 0 +// while strategy.hasEndpoint() { +// let endpoint = strategy.currentEndpoint() +// print("\(endpoint)") +// XCTAssertEqual(endpoint.description, expected[i]) +// i += 1 +// strategy.tryNextEndpoint() +// } +// } +// +// func testEndpointCycling6() { +// CoreConfiguration.masksPrivateData = false +// +// var builder = OpenVPN.ConfigurationBuilder() +// builder.hostname = "italy.privateinternetaccess.com" +// builder.endpointProtocols = [ +// EndpointProtocol(.udp6, 2222), +// ] +// let strategy = ConnectionStrategy( +// configuration: builder.build(), +// resolvedRecords: [ +// DNSRecord(address: "111:bbbb:ffff::eeee", isIPv6: true), +// DNSRecord(address: "11.22.33.44", isIPv6: false), +// ] +// ) +// +// let expected = [ +// "111:bbbb:ffff::eeee:UDP6:2222" +// ] +// var i = 0 +// while strategy.hasEndpoint() { +// let endpoint = strategy.currentEndpoint() +// print("\(endpoint)") +// XCTAssertEqual(endpoint.description, expected[i]) +// i += 1 +// strategy.tryNextEndpoint() +// } +// } }