diff --git a/CHANGELOG.md b/CHANGELOG.md index 8eee030..e14e529 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,11 +9,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- Parse OpenVPN authentication requirement from `--auth-user-pass`. +- OpenVPN: Parse authentication requirement from `--auth-user-pass`. +- OpenVPN: Handle multiple `--remote` options correctly. ### Changed -- Support multiple peers in WireGuard. +- WireGuard: Support multiple peers. ## 4.1.0 (2022-02-09) diff --git a/Demo/Demo/Configuration.swift b/Demo/Demo/Configuration.swift index ed1a2f1..ad1b858 100644 --- a/Demo/Demo/Configuration.swift +++ b/Demo/Demo/Configuration.swift @@ -178,8 +178,7 @@ M69t86apMrAxkUxVJAWLRBd9fbYyzJgTW61tFqXWTZpiz6bhuWApSEzaHcL3/f5l sessionBuilder.digest = .sha1 sessionBuilder.compressionFraming = .compLZO sessionBuilder.renegotiatesAfter = nil - sessionBuilder.hostname = hostname - sessionBuilder.endpointProtocols = [EndpointProtocol(socketType, port)] + sessionBuilder.remotes = [Endpoint(hostname, EndpointProtocol(socketType, port))] sessionBuilder.clientCertificate = clientCertificate sessionBuilder.clientKey = clientKey sessionBuilder.mtu = 1350 diff --git a/Sources/TunnelKitCore/DNSResolver.swift b/Sources/TunnelKitCore/DNSResolver.swift index ec560fe..32acac1 100644 --- a/Sources/TunnelKitCore/DNSResolver.swift +++ b/Sources/TunnelKitCore/DNSResolver.swift @@ -51,8 +51,16 @@ public struct DNSRecord { } } +/// Errors coming from `DNSResolver`. +public enum DNSError: Error { + case failure + + case timeout +} + /// Convenient methods for DNS resolution. public class DNSResolver { + private static let queue = DispatchQueue(label: "DNSResolver") /** @@ -63,17 +71,17 @@ public class DNSResolver { - Parameter queue: The queue to execute the `completionHandler` in. - Parameter completionHandler: The completion handler with the resolved addresses and an optional error. */ - public static func resolve(_ hostname: String, timeout: Int, queue: DispatchQueue, completionHandler: @escaping ([DNSRecord]?, Error?) -> Void) { - var pendingHandler: (([DNSRecord]?, Error?) -> Void)? = completionHandler + public static func resolve(_ hostname: String, timeout: Int, queue: DispatchQueue, completionHandler: @escaping (Result<[DNSRecord], DNSError>) -> Void) { + var pendingHandler: ((Result<[DNSRecord], DNSError>) -> Void)? = completionHandler let host = CFHostCreateWithName(nil, hostname as CFString).takeRetainedValue() DNSResolver.queue.async { CFHostStartInfoResolution(host, .addresses, nil) guard let handler = pendingHandler else { return } - DNSResolver.didResolve(host: host) { (records, error) in + DNSResolver.didResolve(host: host) { result in queue.async { - handler(records, error) + handler(result) pendingHandler = nil } } @@ -83,15 +91,15 @@ public class DNSResolver { return } CFHostCancelInfoResolution(host, .addresses) - handler(nil, nil) + handler(.failure(.timeout)) pendingHandler = nil } } - private static func didResolve(host: CFHost, completionHandler: @escaping ([DNSRecord]?, Error?) -> Void) { + private static func didResolve(host: CFHost, completionHandler: @escaping (Result<[DNSRecord], DNSError>) -> Void) { var success: DarwinBoolean = false guard let rawAddresses = CFHostGetAddressing(host, &success)?.takeUnretainedValue() as Array? else { - completionHandler(nil, nil) + completionHandler(.failure(.failure)) return } @@ -120,7 +128,11 @@ public class DNSResolver { records.append(DNSRecord(address: address, isIPv6: true)) } } - completionHandler(records, nil) + guard !records.isEmpty else { + completionHandler(.failure(.failure)) + return + } + completionHandler(.success(records)) } /** diff --git a/Sources/TunnelKitCore/EndpointProtocol.swift b/Sources/TunnelKitCore/Endpoint.swift similarity index 79% rename from Sources/TunnelKitCore/EndpointProtocol.swift rename to Sources/TunnelKitCore/Endpoint.swift index 209770c..0f2302b 100644 --- a/Sources/TunnelKitCore/EndpointProtocol.swift +++ b/Sources/TunnelKitCore/Endpoint.swift @@ -1,5 +1,5 @@ // -// EndpointProtocol.swift +// Endpoint.swift // TunnelKit // // Created by Davide De Rosa on 11/10/18. @@ -25,6 +25,31 @@ import Foundation +/// Represents an endpoint. +public struct Endpoint: Codable, Equatable, CustomStringConvertible { + public let address: String + + public let proto: EndpointProtocol + + /// :nodoc: + public init(_ address: String, _ proto: EndpointProtocol) { + self.address = address + self.proto = proto + } + + // MARK: Equatable + + public static func ==(lhs: Endpoint, rhs: Endpoint) -> Bool { + return lhs.address == rhs.address && lhs.proto == rhs.proto + } + + // MARK: CustomStringConvertible + + public var description: String { + return "\(address.maskedDescription):\(proto)" + } +} + /// Defines the communication protocol of an endpoint. public struct EndpointProtocol: RawRepresentable, Equatable, CustomStringConvertible { @@ -33,7 +58,8 @@ public struct EndpointProtocol: RawRepresentable, Equatable, CustomStringConvert /// The remote port. public let port: UInt16 - + + /// :nodoc: public init(_ socketType: SocketType, _ port: UInt16) { self.socketType = socketType self.port = port diff --git a/Sources/TunnelKitOpenVPNAppExtension/ConnectionStrategy.swift b/Sources/TunnelKitOpenVPNAppExtension/ConnectionStrategy.swift index 3aa76d5..46772d3 100644 --- a/Sources/TunnelKitOpenVPNAppExtension/ConnectionStrategy.swift +++ b/Sources/TunnelKitOpenVPNAppExtension/ConnectionStrategy.swift @@ -39,118 +39,56 @@ import NetworkExtension import SwiftyBeaver import TunnelKitCore import TunnelKitAppExtension -import TunnelKitOpenVPNManager +import TunnelKitOpenVPN private let log = SwiftyBeaver.self class ConnectionStrategy { - struct Endpoint: CustomStringConvertible { - let record: DNSRecord - - let proto: EndpointProtocol - - var isValid: Bool { - if record.isIPv6 { - return proto.socketType != .udp4 && proto.socketType != .tcp4 - } else { - return proto.socketType != .udp6 && proto.socketType != .tcp6 - } - } - - // MARK: CustomStringConvertible + private var remotes: [ResolvedRemote] - var description: String { - return "\(record.address.maskedDescription):\(proto)" - } - } - - private let hostname: String? - - private let endpointProtocols: [EndpointProtocol] + private var currentRemoteIndex: Int - private var endpoints: [Endpoint] - - private var currentEndpointIndex: Int - - private let resolvedAddresses: [String] - - init(configuration: OpenVPNProvider.Configuration) { - hostname = configuration.sessionConfiguration.hostname - guard var endpointProtocols = configuration.sessionConfiguration.endpointProtocols else { - fatalError("No endpoints provided") + var currentRemote: ResolvedRemote? { + guard currentRemoteIndex < remotes.count else { + return nil } - if configuration.sessionConfiguration.randomizeEndpoint ?? false { - endpointProtocols.shuffle() - } - self.endpointProtocols = endpointProtocols - - currentEndpointIndex = 0 - if let resolvedAddresses = configuration.resolvedAddresses { - if configuration.prefersResolvedAddresses { - log.debug("Will use pre-resolved addresses only") - endpoints = ConnectionStrategy.unrolledEndpoints( - records: resolvedAddresses.map { DNSRecord(address: $0, isIPv6: false) }, - protos: endpointProtocols - ) - } else { - log.debug("Will use DNS resolution with fallback to pre-resolved addresses") - endpoints = [] - } - self.resolvedAddresses = resolvedAddresses - } else { - log.debug("Will use DNS resolution") - guard hostname != nil else { - fatalError("Either configuration.sessionConfiguration.hostname or configuration.resolvedAddresses required") - } - endpoints = [] - resolvedAddresses = [] - } - } - - private static func unrolledEndpoints(ipv4Addresses: [String], protos: [EndpointProtocol]) -> [Endpoint] { - return unrolledEndpoints(records: ipv4Addresses.map { DNSRecord(address: $0, isIPv6: false) }, protos: protos) + return remotes[currentRemoteIndex] } - private static func unrolledEndpoints(records: [DNSRecord], protos: [EndpointProtocol]) -> [Endpoint] { - guard !records.isEmpty else { - return [] + init(configuration: OpenVPN.Configuration) { + guard var remotes = configuration.remotes, !remotes.isEmpty else { + fatalError("No remotes provided") } - var endpoints: [Endpoint] = [] - for r in records { - for p in protos { - let endpoint = Endpoint(record: r, proto: p) - guard endpoint.isValid else { - continue - } - endpoints.append(endpoint) - } + if configuration.randomizeEndpoint ?? false { + remotes.shuffle() } - log.debug("Unrolled endpoints: \(endpoints.maskedDescription)") - return endpoints - } - - func hasEndpoint() -> Bool { - return currentEndpointIndex < endpoints.count + self.remotes = remotes.map(ResolvedRemote.init) + currentRemoteIndex = 0 } - func currentEndpoint() -> Endpoint { - guard hasEndpoint() else { - fatalError("Endpoint index out of bounds (\(currentEndpointIndex) >= \(endpoints.count))") + func hasEndpoints() -> Bool { + guard let remote = currentRemote else { + return false } - return endpoints[currentEndpointIndex] + return !remote.isResolved || remote.currentEndpoint != nil } @discardableResult func tryNextEndpoint() -> Bool { - guard hasEndpoint() else { + guard let remote = currentRemote else { return false } - currentEndpointIndex += 1 - guard currentEndpointIndex < endpoints.count else { - log.debug("Exhausted endpoints") + log.debug("Try next endpoint in current remote: \(remote.maskedDescription)") + if remote.nextEndpoint() { + return true + } + + log.debug("Exhausted endpoints, try next remote") + currentRemoteIndex += 1 + guard let _ = currentRemote else { + log.debug("Exhausted remotes, giving up") return false } - log.debug("Try next endpoint: \(currentEndpoint().maskedDescription)") return true } @@ -158,51 +96,38 @@ class ConnectionStrategy { from provider: NEProvider, timeout: Int, queue: DispatchQueue, - completionHandler: @escaping (GenericSocket?, Error?) -> Void) { - - if hasEndpoint() { - let endpoint = currentEndpoint() + completionHandler: @escaping (Result) -> Void) + { + guard let remote = currentRemote else { + completionHandler(.failure(.exhaustedEndpoints)) + return + } + if remote.isResolved, let endpoint = remote.currentEndpoint { log.debug("Pick current endpoint: \(endpoint.maskedDescription)") let socket = provider.createSocket(to: endpoint) - completionHandler(socket, nil) + completionHandler(.success(socket)) return } - log.debug("No endpoints available, will resort to DNS resolution") - guard let hostname = hostname else { - log.error("DNS resolution unavailable: no hostname provided!") - completionHandler(nil, OpenVPNProviderError.dnsFailure) - return - } - log.debug("DNS resolve hostname: \(hostname.maskedDescription)") - DNSResolver.resolve(hostname, timeout: timeout, queue: queue) { (records, error) in - self.currentEndpointIndex = 0 - if let records = records, !records.isEmpty { - log.debug("DNS resolved addresses: \(records.map { $0.address }.maskedDescription)") - self.endpoints = ConnectionStrategy.unrolledEndpoints(records: records, protos: self.endpointProtocols) - } else { - log.error("DNS resolution failed!") - log.debug("Fall back to resolved addresses: \(self.resolvedAddresses.maskedDescription)") - self.endpoints = ConnectionStrategy.unrolledEndpoints(ipv4Addresses: self.resolvedAddresses, protos: self.endpointProtocols) - } - - guard self.hasEndpoint() else { + log.debug("No resolved endpoints, will resort to DNS resolution") + log.debug("DNS resolve address: \(remote.maskedDescription)") + + remote.resolve(timeout: timeout, queue: queue) { + guard let endpoint = remote.currentEndpoint else { log.error("No endpoints available") - completionHandler(nil, OpenVPNProviderError.dnsFailure) + completionHandler(.failure(.dnsFailure)) return } - - let targetEndpoint = self.currentEndpoint() - log.debug("Pick current endpoint: \(targetEndpoint.maskedDescription)") - let socket = provider.createSocket(to: targetEndpoint) - completionHandler(socket, nil) + log.debug("Pick current endpoint: \(endpoint.maskedDescription)") + let socket = provider.createSocket(to: endpoint) + completionHandler(.success(socket)) } } } private extension NEProvider { - func createSocket(to endpoint: ConnectionStrategy.Endpoint) -> GenericSocket { - let ep = NWHostEndpoint(hostname: endpoint.record.address, port: "\(endpoint.proto.port)") + func createSocket(to endpoint: Endpoint) -> GenericSocket { + let ep = NWHostEndpoint(hostname: endpoint.address, port: "\(endpoint.proto.port)") switch endpoint.proto.socketType { case .udp, .udp4, .udp6: let impl = createUDPSession(to: ep, from: nil) diff --git a/Sources/TunnelKitOpenVPNAppExtension/OpenVPNTunnelProvider.swift b/Sources/TunnelKitOpenVPNAppExtension/OpenVPNTunnelProvider.swift index 0e925eb..927e20f 100644 --- a/Sources/TunnelKitOpenVPNAppExtension/OpenVPNTunnelProvider.swift +++ b/Sources/TunnelKitOpenVPNAppExtension/OpenVPNTunnelProvider.swift @@ -154,7 +154,7 @@ open class OpenVPNTunnelProvider: NEPacketTunnelProvider { guard let tunnelProtocol = protocolConfiguration as? NETunnelProviderProtocol else { throw OpenVPNProviderConfigurationError.parameter(name: "protocolConfiguration") } - guard let serverAddress = tunnelProtocol.serverAddress else { + guard let _ = tunnelProtocol.serverAddress else { throw OpenVPNProviderConfigurationError.parameter(name: "protocolConfiguration.serverAddress") } guard let providerConfiguration = tunnelProtocol.providerConfiguration else { @@ -162,15 +162,6 @@ open class OpenVPNTunnelProvider: NEPacketTunnelProvider { } try appGroup = OpenVPNProvider.Configuration.appGroup(from: providerConfiguration) try cfg = OpenVPNProvider.Configuration.parsed(from: providerConfiguration) - - // inject serverAddress into sessionConfiguration.hostname - if !serverAddress.isEmpty { - var sessionBuilder = cfg.sessionConfiguration.builder() - sessionBuilder.hostname = serverAddress - var cfgBuilder = cfg.builder() - cfgBuilder.sessionConfiguration = sessionBuilder.build() - cfg = cfgBuilder.build() - } } catch let e { var message: String? if let te = e as? OpenVPNProviderConfigurationError { @@ -237,7 +228,7 @@ open class OpenVPNTunnelProvider: NEPacketTunnelProvider { cfg.print(appVersion: appVersion) // prepare to pick endpoints - strategy = ConnectionStrategy(configuration: cfg) + strategy = ConnectionStrategy(configuration: cfg.sessionConfiguration) let session: OpenVPNSession do { @@ -337,12 +328,21 @@ open class OpenVPNTunnelProvider: NEPacketTunnelProvider { return } - strategy.createSocket(from: self, timeout: dnsTimeout, queue: tunnelQueue) { (socket, error) in - guard let socket = socket else { + strategy.createSocket(from: self, timeout: dnsTimeout, queue: tunnelQueue) { + switch $0 { + case .success(let socket): + self.connectTunnel(via: socket) + + case .failure(let error): + if case .dnsFailure = error { + self.tunnelQueue.async { + self.strategy.tryNextEndpoint() + self.connectTunnel() + } + return + } self.disposeTunnel(error: error) - return } - self.connectTunnel(via: socket) } } @@ -839,7 +839,7 @@ extension OpenVPNTunnelProvider: OpenVPNSessionDelegate { extension OpenVPNTunnelProvider { private func tryNextEndpoint() -> Bool { guard strategy.tryNextEndpoint() else { - disposeTunnel(error: OpenVPNProviderError.exhaustedProtocols) + disposeTunnel(error: OpenVPNProviderError.exhaustedEndpoints) return false } return true diff --git a/Sources/TunnelKitOpenVPNAppExtension/ResolvedRemote.swift b/Sources/TunnelKitOpenVPNAppExtension/ResolvedRemote.swift new file mode 100644 index 0000000..d315059 --- /dev/null +++ b/Sources/TunnelKitOpenVPNAppExtension/ResolvedRemote.swift @@ -0,0 +1,106 @@ +// +// ResolvedRemote.swift +// TunnelKit +// +// Created by Davide De Rosa on 3/3/22. +// Copyright (c) 2022 Davide De Rosa. All rights reserved. +// +// https://github.com/passepartoutvpn +// +// This file is part of TunnelKit. +// +// TunnelKit is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// TunnelKit is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with TunnelKit. If not, see . +// + +import Foundation +import TunnelKitCore +import SwiftyBeaver + +private let log = SwiftyBeaver.self + +class ResolvedRemote: CustomStringConvertible { + let originalEndpoint: Endpoint + + private(set) var isResolved: Bool + + private(set) var resolvedEndpoints: [Endpoint] + + private var currentEndpointIndex: Int + + var currentEndpoint: Endpoint? { + guard currentEndpointIndex < resolvedEndpoints.count else { + return nil + } + return resolvedEndpoints[currentEndpointIndex] + } + + init(_ originalEndpoint: Endpoint) { + self.originalEndpoint = originalEndpoint + isResolved = false + resolvedEndpoints = [] + currentEndpointIndex = 0 + } + + func nextEndpoint() -> Bool { + currentEndpointIndex += 1 + return currentEndpointIndex < resolvedEndpoints.count + } + + func resolve(timeout: Int, queue: DispatchQueue, completionHandler: @escaping () -> Void) { + DNSResolver.resolve(originalEndpoint.address, timeout: timeout, queue: queue) { [weak self] in + self?.handleResult($0) + completionHandler() + } + } + + private func handleResult(_ result: Result<[DNSRecord], DNSError>) { + switch result { + case .success(let records): + log.debug("DNS resolved addresses: \(records.map { $0.address }.maskedDescription)") + isResolved = true + resolvedEndpoints = unrolledEndpoints(records: records) + + case .failure: + log.error("DNS resolution failed!") + isResolved = false + resolvedEndpoints = [] + } + } + + private func unrolledEndpoints(records: [DNSRecord]) -> [Endpoint] { + let endpoints = records.filter { + $0.isCompatible(withProtocol: originalEndpoint.proto) + }.map { + Endpoint($0.address, originalEndpoint.proto) + } + log.debug("Unrolled endpoints: \(endpoints.maskedDescription)") + return endpoints + } + + // MARK: CustomStringConvertible + + var description: String { + return "{\(originalEndpoint.maskedDescription), resolved: \(resolvedEndpoints.maskedDescription)}" + } +} + +private extension DNSRecord { + func isCompatible(withProtocol proto: EndpointProtocol) -> Bool { + if isIPv6 { + return proto.socketType != .udp4 && proto.socketType != .tcp4 + } else { + return proto.socketType != .udp6 && proto.socketType != .tcp6 + } + } +} diff --git a/Sources/TunnelKitOpenVPNCore/Configuration.swift b/Sources/TunnelKitOpenVPNCore/Configuration.swift index c13bde3..563c025 100644 --- a/Sources/TunnelKitOpenVPNCore/Configuration.swift +++ b/Sources/TunnelKitOpenVPNCore/Configuration.swift @@ -218,11 +218,8 @@ extension OpenVPN { // MARK: Client - /// The server hostname (picked from first remote). - public var hostname: String? - /// The list of server endpoints. - public var endpointProtocols: [EndpointProtocol]? + public var remotes: [Endpoint]? /// If true, checks EKU of server certificate. public var checksEKU: Bool? @@ -338,8 +335,7 @@ extension OpenVPN { keepAliveTimeout: keepAliveTimeout, renegotiatesAfter: renegotiatesAfter, xorMask: xorMask, - hostname: hostname, - endpointProtocols: endpointProtocols, + remotes: remotes, checksEKU: checksEKU, checksSANHost: checksSANHost, sanHost: sanHost, @@ -410,11 +406,8 @@ extension OpenVPN { /// - Seealso: `ConfigurationBuilder.xorMask` public let xorMask: UInt8? - /// - Seealso: `ConfigurationBuilder.hostname` - public let hostname: String? - - /// - Seealso: `ConfigurationBuilder.endpointProtocols` - public let endpointProtocols: [EndpointProtocol]? + /// - Seealso: `ConfigurationBuilder.remotes` + public let remotes: [Endpoint]? /// - Seealso: `ConfigurationBuilder.checksEKU` public let checksEKU: Bool? @@ -520,8 +513,7 @@ extension OpenVPN.Configuration { builder.keepAliveInterval = keepAliveInterval builder.keepAliveTimeout = keepAliveTimeout builder.renegotiatesAfter = renegotiatesAfter - builder.hostname = hostname - builder.endpointProtocols = endpointProtocols + builder.remotes = remotes builder.checksEKU = checksEKU builder.checksSANHost = checksSANHost builder.sanHost = sanHost @@ -552,10 +544,10 @@ extension OpenVPN.Configuration { extension OpenVPN.Configuration { public func print() { - guard let endpointProtocols = endpointProtocols else { - fatalError("No sessionConfiguration.endpointProtocols set") + guard let remotes = remotes else { + fatalError("No sessionConfiguration.remotes set") } - log.info("\tProtocols: \(endpointProtocols)") + log.info("\tRemotes: \(remotes)") log.info("\tCipher: \(fallbackCipher)") log.info("\tDigest: \(fallbackDigest)") log.info("\tCompression framing: \(fallbackCompressionFraming)") diff --git a/Sources/TunnelKitOpenVPNCore/ConfigurationParser.swift b/Sources/TunnelKitOpenVPNCore/ConfigurationParser.swift index 7970131..20d1dad 100644 --- a/Sources/TunnelKitOpenVPNCore/ConfigurationParser.swift +++ b/Sources/TunnelKitOpenVPNCore/ConfigurationParser.swift @@ -770,14 +770,9 @@ extension OpenVPN { optDefaultProto = optDefaultProto ?? .udp optDefaultPort = optDefaultPort ?? 1194 if !optRemotes.isEmpty { - sessionBuilder.hostname = optRemotes[0].0 - var fullRemotes: [(String, UInt16, SocketType)] = [] - let hostname = optRemotes[0].0 optRemotes.forEach { - guard $0.0 == hostname else { - return - } + let hostname = $0.0 guard let port = $0.1 ?? optDefaultPort else { return } @@ -786,9 +781,9 @@ extension OpenVPN { } fullRemotes.append((hostname, port, socketType)) } - sessionBuilder.endpointProtocols = fullRemotes.map { EndpointProtocol($0.2, $0.1) } - } else { - sessionBuilder.hostname = nil + sessionBuilder.remotes = fullRemotes.map { + Endpoint($0.0, .init($0.2, $0.1)) + } } sessionBuilder.authUserPass = authUserPass diff --git a/Sources/TunnelKitOpenVPNManager/OpenVPNProvider+Configuration.swift b/Sources/TunnelKitOpenVPNManager/OpenVPNProvider+Configuration.swift index 946943a..b70e720 100644 --- a/Sources/TunnelKitOpenVPNManager/OpenVPNProvider+Configuration.swift +++ b/Sources/TunnelKitOpenVPNManager/OpenVPNProvider+Configuration.swift @@ -55,8 +55,6 @@ extension OpenVPNProvider { public static let defaults = Configuration( sessionConfiguration: OpenVPN.ConfigurationBuilder().build(), - prefersResolvedAddresses: false, - resolvedAddresses: nil, shouldDebug: false, debugLogFormat: nil, masksPrivateData: true, @@ -66,17 +64,6 @@ extension OpenVPNProvider { /// The session configuration. public var sessionConfiguration: OpenVPN.Configuration - /// Prefers resolved addresses over DNS resolution. `resolvedAddresses` must be set and non-empty. Default is `false`. - /// - /// - Seealso: `fallbackServerAddresses` - public var prefersResolvedAddresses: Bool - - /// Resolved addresses in case DNS fails or `prefersResolvedAddresses` is `true` (IPv4 only). - public var resolvedAddresses: [String]? - - /// Optional version identifier about the client pushed to server in peer-info as `IV_UI_VER`. - public var versionIdentifier: String? - // MARK: Debugging /// Enables debugging. @@ -88,6 +75,9 @@ extension OpenVPNProvider { /// Mask private data in debug log (default is `true`). public var masksPrivateData: Bool? + /// Optional version identifier about the client pushed to server in peer-info as `IV_UI_VER`. + public var versionIdentifier: String? + // MARK: Building /** @@ -97,8 +87,6 @@ extension OpenVPNProvider { */ public init(sessionConfiguration: OpenVPN.Configuration) { self.sessionConfiguration = sessionConfiguration - prefersResolvedAddresses = ConfigurationBuilder.defaults.prefersResolvedAddresses - resolvedAddresses = nil shouldDebug = ConfigurationBuilder.defaults.shouldDebug debugLogFormat = ConfigurationBuilder.defaults.debugLogFormat masksPrivateData = ConfigurationBuilder.defaults.masksPrivateData @@ -113,8 +101,6 @@ extension OpenVPNProvider { public func build() -> Configuration { return Configuration( sessionConfiguration: sessionConfiguration, - prefersResolvedAddresses: prefersResolvedAddresses, - resolvedAddresses: resolvedAddresses, shouldDebug: shouldDebug, debugLogFormat: shouldDebug ? debugLogFormat : nil, masksPrivateData: masksPrivateData, @@ -129,12 +115,6 @@ extension OpenVPNProvider { /// - Seealso: `OpenVPNProvider.ConfigurationBuilder.sessionConfiguration` public let sessionConfiguration: OpenVPN.Configuration - /// - Seealso: `OpenVPNProvider.ConfigurationBuilder.prefersResolvedAddresses` - public let prefersResolvedAddresses: Bool - - /// - Seealso: `OpenVPNProvider.ConfigurationBuilder.resolvedAddresses` - public let resolvedAddresses: [String]? - /// - Seealso: `OpenVPNProvider.ConfigurationBuilder.shouldDebug` public let shouldDebug: Bool @@ -246,11 +226,7 @@ extension OpenVPNProvider { - Throws: `OpenVPNProviderError.configuration` if `providerConfiguration` is incomplete. */ public static func parsed(from providerConfiguration: [String: Any]) throws -> Configuration { - let cfg = try fromDictionary(OpenVPNProvider.Configuration.self, providerConfiguration) - guard !cfg.prefersResolvedAddresses || !(cfg.resolvedAddresses?.isEmpty ?? true) else { - throw OpenVPNProviderConfigurationError.parameter(name: "protocolConfiguration.providerConfiguration[prefersResolvedAddresses] is true but no [resolvedAddresses]") - } - return cfg + return try fromDictionary(OpenVPNProvider.Configuration.self, providerConfiguration) } /** @@ -284,13 +260,16 @@ extension OpenVPNProvider { withBundleIdentifier bundleIdentifier: String, appGroup: String, context: String, - credentials: OpenVPN.Credentials?) throws -> NETunnelProviderProtocol { - + credentials: OpenVPN.Credentials?) throws -> NETunnelProviderProtocol + { let protocolConfiguration = NETunnelProviderProtocol() let keychain = Keychain(group: appGroup) protocolConfiguration.providerBundleIdentifier = bundleIdentifier - protocolConfiguration.serverAddress = sessionConfiguration.hostname ?? resolvedAddresses?.first + guard let firstRemote = sessionConfiguration.remotes?.first else { + fatalError("No remotes set") + } + protocolConfiguration.serverAddress = "\(firstRemote.address):\(firstRemote.proto.port)" if let username = credentials?.username { protocolConfiguration.username = username if let password = credentials?.password { @@ -324,8 +303,6 @@ extension OpenVPNProvider.Configuration { */ public func builder() -> OpenVPNProvider.ConfigurationBuilder { var builder = OpenVPNProvider.ConfigurationBuilder(sessionConfiguration: sessionConfiguration) - builder.prefersResolvedAddresses = prefersResolvedAddresses - builder.resolvedAddresses = resolvedAddresses builder.shouldDebug = shouldDebug builder.debugLogFormat = debugLogFormat builder.masksPrivateData = masksPrivateData diff --git a/Sources/TunnelKitOpenVPNManager/OpenVPNProviderError.swift b/Sources/TunnelKitOpenVPNManager/OpenVPNProviderError.swift index d8e6891..e111110 100644 --- a/Sources/TunnelKitOpenVPNManager/OpenVPNProviderError.swift +++ b/Sources/TunnelKitOpenVPNManager/OpenVPNProviderError.swift @@ -58,8 +58,8 @@ public enum OpenVPNProviderError: String, Error { /// Socket endpoint could not be resolved. case dnsFailure - /// No more protocols available to try. - case exhaustedProtocols + /// No more endpoints available to try. + case exhaustedEndpoints /// Socket failed to reach active state. case socketActivity diff --git a/Tests/TunnelKitOpenVPNTests/AppExtensionTests.swift b/Tests/TunnelKitOpenVPNTests/AppExtensionTests.swift index ce1585e..a61d8b7 100644 --- a/Tests/TunnelKitOpenVPNTests/AppExtensionTests.swift +++ b/Tests/TunnelKitOpenVPNTests/AppExtensionTests.swift @@ -62,6 +62,8 @@ class AppExtensionTests: XCTestCase { let identifier = "com.example.Provider" let appGroup = "group.com.algoritmico.TunnelKit" let hostname = "example.com" + let port: UInt16 = 1234 + let serverAddress = "\(hostname):\(port)" let context = "foobar" let credentials = OpenVPN.Credentials("foo", "bar") @@ -69,8 +71,7 @@ class AppExtensionTests: XCTestCase { sessionBuilder.ca = OpenVPN.CryptoContainer(pem: "abcdef") sessionBuilder.cipher = .aes128cbc sessionBuilder.digest = .sha256 - sessionBuilder.hostname = hostname - sessionBuilder.endpointProtocols = [] + sessionBuilder.remotes = [.init(hostname, .init(.udp, port))] sessionBuilder.mtu = 1230 builder = OpenVPNProvider.ConfigurationBuilder(sessionConfiguration: sessionBuilder.build()) XCTAssertNotNil(builder) @@ -86,7 +87,7 @@ class AppExtensionTests: XCTestCase { XCTAssertNotNil(proto) XCTAssertEqual(proto?.providerBundleIdentifier, identifier) - XCTAssertEqual(proto?.serverAddress, hostname) + XCTAssertEqual(proto?.serverAddress, serverAddress) XCTAssertEqual(proto?.username, credentials.username) XCTAssertEqual(proto?.passwordReference, try? Keychain(group: appGroup).passwordReference(for: credentials.username, context: context)) @@ -107,15 +108,17 @@ class AppExtensionTests: XCTestCase { func testDNSResolver() { let exp = expectation(description: "DNS") - DNSResolver.resolve("www.google.com", timeout: 1000, queue: .main) { (addrs, error) in + DNSResolver.resolve("www.google.com", timeout: 1000, queue: .main) { defer { exp.fulfill() } - guard let addrs = addrs else { + switch $0 { + case .success(let records): + print("\(records)") + + case .failure: print("Can't resolve") - return } - print("\(addrs)") } waitForExpectations(timeout: 5.0, handler: nil) } @@ -143,37 +146,31 @@ class AppExtensionTests: XCTestCase { func testEndpointCycling() { CoreConfiguration.masksPrivateData = false - var builder1 = OpenVPN.ConfigurationBuilder() - builder1.hostname = "italy.privateinternetaccess.com" - builder1.endpointProtocols = [ - EndpointProtocol(.tcp6, 2222), - EndpointProtocol(.udp, 1111), - EndpointProtocol(.udp4, 3333) + var builder = OpenVPN.ConfigurationBuilder() + let hostname = "italy.privateinternetaccess.com" + builder.remotes = [ + .init(hostname, .init(.tcp6, 2222)), + .init(hostname, .init(.udp, 1111)), + .init(hostname, .init(.udp4, 3333)) ] - var builder2 = OpenVPNProvider.ConfigurationBuilder(sessionConfiguration: builder1.build()) - builder2.prefersResolvedAddresses = true - builder2.resolvedAddresses = [ - "82.102.21.218", - "82.102.21.214", - "82.102.21.213", - ] - let strategy = ConnectionStrategy(configuration: builder2.build()) + let strategy = ConnectionStrategy(configuration: builder.build()) let expected = [ - "82.102.21.218:UDP:1111", - "82.102.21.218:UDP4:3333", - "82.102.21.214:UDP:1111", - "82.102.21.214:UDP4:3333", - "82.102.21.213:UDP:1111", - "82.102.21.213:UDP4:3333", + "italy.privateinternetaccess.com:TCP6:2222", + "italy.privateinternetaccess.com:UDP:1111", + "italy.privateinternetaccess.com:UDP4:3333" ] var i = 0 - while strategy.hasEndpoint() { - let endpoint = strategy.currentEndpoint() - print("\(endpoint)") - XCTAssertEqual(endpoint.description, expected[i]) + while strategy.hasEndpoints() { + guard let remote = strategy.currentRemote else { + break + } + print("\(i): \(remote)") + XCTAssertEqual(remote.originalEndpoint.description, expected[i]) i += 1 - strategy.tryNextEndpoint() + guard strategy.tryNextEndpoint() else { + break + } } } diff --git a/Tests/TunnelKitOpenVPNTests/ConfigurationParserTests.swift b/Tests/TunnelKitOpenVPNTests/ConfigurationParserTests.swift index 699281c..df4bfcd 100644 --- a/Tests/TunnelKitOpenVPNTests/ConfigurationParserTests.swift +++ b/Tests/TunnelKitOpenVPNTests/ConfigurationParserTests.swift @@ -105,13 +105,12 @@ class ConfigurationParserTests: XCTestCase { func testPIA() throws { let file = try OpenVPN.ConfigurationParser.parsed(fromURL: url(withName: "pia-hungary")) - XCTAssertEqual(file.configuration.hostname, "hungary.privateinternetaccess.com") + XCTAssertEqual(file.configuration.remotes, [ + .init("hungary.privateinternetaccess.com", .init(.udp, 1198)), + .init("hungary.privateinternetaccess.com", .init(.tcp, 502)), + ]) XCTAssertEqual(file.configuration.cipher, .aes128cbc) XCTAssertEqual(file.configuration.digest, .sha1) - XCTAssertEqual(file.configuration.endpointProtocols, [ - EndpointProtocol(.udp, 1198), - EndpointProtocol(.tcp, 502) - ]) } func testStripped() throws {