Replace hostname/endpointProtocols with remotes

Like official OpenVPN options.
This commit is contained in:
Davide De Rosa 2022-03-03 15:34:57 +01:00
parent 2bcd11fd7e
commit 133b4b2337
13 changed files with 278 additions and 249 deletions

View File

@ -9,11 +9,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Parse OpenVPN authentication requirement from `--auth-user-pass`.
- OpenVPN: Parse authentication requirement from `--auth-user-pass`.
- OpenVPN: Handle multiple `--remote` options correctly.
### Changed
- Support multiple peers in WireGuard.
- WireGuard: Support multiple peers.
## 4.1.0 (2022-02-09)

View File

@ -178,8 +178,7 @@ M69t86apMrAxkUxVJAWLRBd9fbYyzJgTW61tFqXWTZpiz6bhuWApSEzaHcL3/f5l
sessionBuilder.digest = .sha1
sessionBuilder.compressionFraming = .compLZO
sessionBuilder.renegotiatesAfter = nil
sessionBuilder.hostname = hostname
sessionBuilder.endpointProtocols = [EndpointProtocol(socketType, port)]
sessionBuilder.remotes = [Endpoint(hostname, EndpointProtocol(socketType, port))]
sessionBuilder.clientCertificate = clientCertificate
sessionBuilder.clientKey = clientKey
sessionBuilder.mtu = 1350

View File

@ -51,8 +51,16 @@ public struct DNSRecord {
}
}
/// Errors coming from `DNSResolver`.
public enum DNSError: Error {
case failure
case timeout
}
/// Convenient methods for DNS resolution.
public class DNSResolver {
private static let queue = DispatchQueue(label: "DNSResolver")
/**
@ -63,17 +71,17 @@ public class DNSResolver {
- Parameter queue: The queue to execute the `completionHandler` in.
- Parameter completionHandler: The completion handler with the resolved addresses and an optional error.
*/
public static func resolve(_ hostname: String, timeout: Int, queue: DispatchQueue, completionHandler: @escaping ([DNSRecord]?, Error?) -> Void) {
var pendingHandler: (([DNSRecord]?, Error?) -> Void)? = completionHandler
public static func resolve(_ hostname: String, timeout: Int, queue: DispatchQueue, completionHandler: @escaping (Result<[DNSRecord], DNSError>) -> Void) {
var pendingHandler: ((Result<[DNSRecord], DNSError>) -> Void)? = completionHandler
let host = CFHostCreateWithName(nil, hostname as CFString).takeRetainedValue()
DNSResolver.queue.async {
CFHostStartInfoResolution(host, .addresses, nil)
guard let handler = pendingHandler else {
return
}
DNSResolver.didResolve(host: host) { (records, error) in
DNSResolver.didResolve(host: host) { result in
queue.async {
handler(records, error)
handler(result)
pendingHandler = nil
}
}
@ -83,15 +91,15 @@ public class DNSResolver {
return
}
CFHostCancelInfoResolution(host, .addresses)
handler(nil, nil)
handler(.failure(.timeout))
pendingHandler = nil
}
}
private static func didResolve(host: CFHost, completionHandler: @escaping ([DNSRecord]?, Error?) -> Void) {
private static func didResolve(host: CFHost, completionHandler: @escaping (Result<[DNSRecord], DNSError>) -> Void) {
var success: DarwinBoolean = false
guard let rawAddresses = CFHostGetAddressing(host, &success)?.takeUnretainedValue() as Array? else {
completionHandler(nil, nil)
completionHandler(.failure(.failure))
return
}
@ -120,7 +128,11 @@ public class DNSResolver {
records.append(DNSRecord(address: address, isIPv6: true))
}
}
completionHandler(records, nil)
guard !records.isEmpty else {
completionHandler(.failure(.failure))
return
}
completionHandler(.success(records))
}
/**

View File

@ -1,5 +1,5 @@
//
// EndpointProtocol.swift
// Endpoint.swift
// TunnelKit
//
// Created by Davide De Rosa on 11/10/18.
@ -25,6 +25,31 @@
import Foundation
/// Represents an endpoint.
public struct Endpoint: Codable, Equatable, CustomStringConvertible {
public let address: String
public let proto: EndpointProtocol
/// :nodoc:
public init(_ address: String, _ proto: EndpointProtocol) {
self.address = address
self.proto = proto
}
// MARK: Equatable
public static func ==(lhs: Endpoint, rhs: Endpoint) -> Bool {
return lhs.address == rhs.address && lhs.proto == rhs.proto
}
// MARK: CustomStringConvertible
public var description: String {
return "\(address.maskedDescription):\(proto)"
}
}
/// Defines the communication protocol of an endpoint.
public struct EndpointProtocol: RawRepresentable, Equatable, CustomStringConvertible {
@ -34,6 +59,7 @@ public struct EndpointProtocol: RawRepresentable, Equatable, CustomStringConvert
/// The remote port.
public let port: UInt16
/// :nodoc:
public init(_ socketType: SocketType, _ port: UInt16) {
self.socketType = socketType
self.port = port

View File

@ -39,118 +39,56 @@ import NetworkExtension
import SwiftyBeaver
import TunnelKitCore
import TunnelKitAppExtension
import TunnelKitOpenVPNManager
import TunnelKitOpenVPN
private let log = SwiftyBeaver.self
class ConnectionStrategy {
struct Endpoint: CustomStringConvertible {
let record: DNSRecord
private var remotes: [ResolvedRemote]
let proto: EndpointProtocol
private var currentRemoteIndex: Int
var isValid: Bool {
if record.isIPv6 {
return proto.socketType != .udp4 && proto.socketType != .tcp4
} else {
return proto.socketType != .udp6 && proto.socketType != .tcp6
}
}
// MARK: CustomStringConvertible
var description: String {
return "\(record.address.maskedDescription):\(proto)"
var currentRemote: ResolvedRemote? {
guard currentRemoteIndex < remotes.count else {
return nil
}
return remotes[currentRemoteIndex]
}
private let hostname: String?
private let endpointProtocols: [EndpointProtocol]
private var endpoints: [Endpoint]
private var currentEndpointIndex: Int
private let resolvedAddresses: [String]
init(configuration: OpenVPNProvider.Configuration) {
hostname = configuration.sessionConfiguration.hostname
guard var endpointProtocols = configuration.sessionConfiguration.endpointProtocols else {
fatalError("No endpoints provided")
init(configuration: OpenVPN.Configuration) {
guard var remotes = configuration.remotes, !remotes.isEmpty else {
fatalError("No remotes provided")
}
if configuration.sessionConfiguration.randomizeEndpoint ?? false {
endpointProtocols.shuffle()
}
self.endpointProtocols = endpointProtocols
currentEndpointIndex = 0
if let resolvedAddresses = configuration.resolvedAddresses {
if configuration.prefersResolvedAddresses {
log.debug("Will use pre-resolved addresses only")
endpoints = ConnectionStrategy.unrolledEndpoints(
records: resolvedAddresses.map { DNSRecord(address: $0, isIPv6: false) },
protos: endpointProtocols
)
} else {
log.debug("Will use DNS resolution with fallback to pre-resolved addresses")
endpoints = []
}
self.resolvedAddresses = resolvedAddresses
} else {
log.debug("Will use DNS resolution")
guard hostname != nil else {
fatalError("Either configuration.sessionConfiguration.hostname or configuration.resolvedAddresses required")
}
endpoints = []
resolvedAddresses = []
if configuration.randomizeEndpoint ?? false {
remotes.shuffle()
}
self.remotes = remotes.map(ResolvedRemote.init)
currentRemoteIndex = 0
}
private static func unrolledEndpoints(ipv4Addresses: [String], protos: [EndpointProtocol]) -> [Endpoint] {
return unrolledEndpoints(records: ipv4Addresses.map { DNSRecord(address: $0, isIPv6: false) }, protos: protos)
}
private static func unrolledEndpoints(records: [DNSRecord], protos: [EndpointProtocol]) -> [Endpoint] {
guard !records.isEmpty else {
return []
func hasEndpoints() -> Bool {
guard let remote = currentRemote else {
return false
}
var endpoints: [Endpoint] = []
for r in records {
for p in protos {
let endpoint = Endpoint(record: r, proto: p)
guard endpoint.isValid else {
continue
}
endpoints.append(endpoint)
}
}
log.debug("Unrolled endpoints: \(endpoints.maskedDescription)")
return endpoints
}
func hasEndpoint() -> Bool {
return currentEndpointIndex < endpoints.count
}
func currentEndpoint() -> Endpoint {
guard hasEndpoint() else {
fatalError("Endpoint index out of bounds (\(currentEndpointIndex) >= \(endpoints.count))")
}
return endpoints[currentEndpointIndex]
return !remote.isResolved || remote.currentEndpoint != nil
}
@discardableResult
func tryNextEndpoint() -> Bool {
guard hasEndpoint() else {
guard let remote = currentRemote else {
return false
}
currentEndpointIndex += 1
guard currentEndpointIndex < endpoints.count else {
log.debug("Exhausted endpoints")
log.debug("Try next endpoint in current remote: \(remote.maskedDescription)")
if remote.nextEndpoint() {
return true
}
log.debug("Exhausted endpoints, try next remote")
currentRemoteIndex += 1
guard let _ = currentRemote else {
log.debug("Exhausted remotes, giving up")
return false
}
log.debug("Try next endpoint: \(currentEndpoint().maskedDescription)")
return true
}
@ -158,51 +96,38 @@ class ConnectionStrategy {
from provider: NEProvider,
timeout: Int,
queue: DispatchQueue,
completionHandler: @escaping (GenericSocket?, Error?) -> Void) {
if hasEndpoint() {
let endpoint = currentEndpoint()
completionHandler: @escaping (Result<GenericSocket, OpenVPNProviderError>) -> Void)
{
guard let remote = currentRemote else {
completionHandler(.failure(.exhaustedEndpoints))
return
}
if remote.isResolved, let endpoint = remote.currentEndpoint {
log.debug("Pick current endpoint: \(endpoint.maskedDescription)")
let socket = provider.createSocket(to: endpoint)
completionHandler(socket, nil)
completionHandler(.success(socket))
return
}
log.debug("No endpoints available, will resort to DNS resolution")
guard let hostname = hostname else {
log.error("DNS resolution unavailable: no hostname provided!")
completionHandler(nil, OpenVPNProviderError.dnsFailure)
return
}
log.debug("DNS resolve hostname: \(hostname.maskedDescription)")
DNSResolver.resolve(hostname, timeout: timeout, queue: queue) { (records, error) in
self.currentEndpointIndex = 0
if let records = records, !records.isEmpty {
log.debug("DNS resolved addresses: \(records.map { $0.address }.maskedDescription)")
self.endpoints = ConnectionStrategy.unrolledEndpoints(records: records, protos: self.endpointProtocols)
} else {
log.error("DNS resolution failed!")
log.debug("Fall back to resolved addresses: \(self.resolvedAddresses.maskedDescription)")
self.endpoints = ConnectionStrategy.unrolledEndpoints(ipv4Addresses: self.resolvedAddresses, protos: self.endpointProtocols)
}
log.debug("No resolved endpoints, will resort to DNS resolution")
log.debug("DNS resolve address: \(remote.maskedDescription)")
guard self.hasEndpoint() else {
remote.resolve(timeout: timeout, queue: queue) {
guard let endpoint = remote.currentEndpoint else {
log.error("No endpoints available")
completionHandler(nil, OpenVPNProviderError.dnsFailure)
completionHandler(.failure(.dnsFailure))
return
}
let targetEndpoint = self.currentEndpoint()
log.debug("Pick current endpoint: \(targetEndpoint.maskedDescription)")
let socket = provider.createSocket(to: targetEndpoint)
completionHandler(socket, nil)
log.debug("Pick current endpoint: \(endpoint.maskedDescription)")
let socket = provider.createSocket(to: endpoint)
completionHandler(.success(socket))
}
}
}
private extension NEProvider {
func createSocket(to endpoint: ConnectionStrategy.Endpoint) -> GenericSocket {
let ep = NWHostEndpoint(hostname: endpoint.record.address, port: "\(endpoint.proto.port)")
func createSocket(to endpoint: Endpoint) -> GenericSocket {
let ep = NWHostEndpoint(hostname: endpoint.address, port: "\(endpoint.proto.port)")
switch endpoint.proto.socketType {
case .udp, .udp4, .udp6:
let impl = createUDPSession(to: ep, from: nil)

View File

@ -154,7 +154,7 @@ open class OpenVPNTunnelProvider: NEPacketTunnelProvider {
guard let tunnelProtocol = protocolConfiguration as? NETunnelProviderProtocol else {
throw OpenVPNProviderConfigurationError.parameter(name: "protocolConfiguration")
}
guard let serverAddress = tunnelProtocol.serverAddress else {
guard let _ = tunnelProtocol.serverAddress else {
throw OpenVPNProviderConfigurationError.parameter(name: "protocolConfiguration.serverAddress")
}
guard let providerConfiguration = tunnelProtocol.providerConfiguration else {
@ -162,15 +162,6 @@ open class OpenVPNTunnelProvider: NEPacketTunnelProvider {
}
try appGroup = OpenVPNProvider.Configuration.appGroup(from: providerConfiguration)
try cfg = OpenVPNProvider.Configuration.parsed(from: providerConfiguration)
// inject serverAddress into sessionConfiguration.hostname
if !serverAddress.isEmpty {
var sessionBuilder = cfg.sessionConfiguration.builder()
sessionBuilder.hostname = serverAddress
var cfgBuilder = cfg.builder()
cfgBuilder.sessionConfiguration = sessionBuilder.build()
cfg = cfgBuilder.build()
}
} catch let e {
var message: String?
if let te = e as? OpenVPNProviderConfigurationError {
@ -237,7 +228,7 @@ open class OpenVPNTunnelProvider: NEPacketTunnelProvider {
cfg.print(appVersion: appVersion)
// prepare to pick endpoints
strategy = ConnectionStrategy(configuration: cfg)
strategy = ConnectionStrategy(configuration: cfg.sessionConfiguration)
let session: OpenVPNSession
do {
@ -337,12 +328,21 @@ open class OpenVPNTunnelProvider: NEPacketTunnelProvider {
return
}
strategy.createSocket(from: self, timeout: dnsTimeout, queue: tunnelQueue) { (socket, error) in
guard let socket = socket else {
strategy.createSocket(from: self, timeout: dnsTimeout, queue: tunnelQueue) {
switch $0 {
case .success(let socket):
self.connectTunnel(via: socket)
case .failure(let error):
if case .dnsFailure = error {
self.tunnelQueue.async {
self.strategy.tryNextEndpoint()
self.connectTunnel()
}
return
}
self.disposeTunnel(error: error)
return
}
self.connectTunnel(via: socket)
}
}
@ -839,7 +839,7 @@ extension OpenVPNTunnelProvider: OpenVPNSessionDelegate {
extension OpenVPNTunnelProvider {
private func tryNextEndpoint() -> Bool {
guard strategy.tryNextEndpoint() else {
disposeTunnel(error: OpenVPNProviderError.exhaustedProtocols)
disposeTunnel(error: OpenVPNProviderError.exhaustedEndpoints)
return false
}
return true

View File

@ -0,0 +1,106 @@
//
// ResolvedRemote.swift
// TunnelKit
//
// Created by Davide De Rosa on 3/3/22.
// Copyright (c) 2022 Davide De Rosa. All rights reserved.
//
// https://github.com/passepartoutvpn
//
// This file is part of TunnelKit.
//
// TunnelKit is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// TunnelKit is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with TunnelKit. If not, see <http://www.gnu.org/licenses/>.
//
import Foundation
import TunnelKitCore
import SwiftyBeaver
private let log = SwiftyBeaver.self
class ResolvedRemote: CustomStringConvertible {
let originalEndpoint: Endpoint
private(set) var isResolved: Bool
private(set) var resolvedEndpoints: [Endpoint]
private var currentEndpointIndex: Int
var currentEndpoint: Endpoint? {
guard currentEndpointIndex < resolvedEndpoints.count else {
return nil
}
return resolvedEndpoints[currentEndpointIndex]
}
init(_ originalEndpoint: Endpoint) {
self.originalEndpoint = originalEndpoint
isResolved = false
resolvedEndpoints = []
currentEndpointIndex = 0
}
func nextEndpoint() -> Bool {
currentEndpointIndex += 1
return currentEndpointIndex < resolvedEndpoints.count
}
func resolve(timeout: Int, queue: DispatchQueue, completionHandler: @escaping () -> Void) {
DNSResolver.resolve(originalEndpoint.address, timeout: timeout, queue: queue) { [weak self] in
self?.handleResult($0)
completionHandler()
}
}
private func handleResult(_ result: Result<[DNSRecord], DNSError>) {
switch result {
case .success(let records):
log.debug("DNS resolved addresses: \(records.map { $0.address }.maskedDescription)")
isResolved = true
resolvedEndpoints = unrolledEndpoints(records: records)
case .failure:
log.error("DNS resolution failed!")
isResolved = false
resolvedEndpoints = []
}
}
private func unrolledEndpoints(records: [DNSRecord]) -> [Endpoint] {
let endpoints = records.filter {
$0.isCompatible(withProtocol: originalEndpoint.proto)
}.map {
Endpoint($0.address, originalEndpoint.proto)
}
log.debug("Unrolled endpoints: \(endpoints.maskedDescription)")
return endpoints
}
// MARK: CustomStringConvertible
var description: String {
return "{\(originalEndpoint.maskedDescription), resolved: \(resolvedEndpoints.maskedDescription)}"
}
}
private extension DNSRecord {
func isCompatible(withProtocol proto: EndpointProtocol) -> Bool {
if isIPv6 {
return proto.socketType != .udp4 && proto.socketType != .tcp4
} else {
return proto.socketType != .udp6 && proto.socketType != .tcp6
}
}
}

View File

@ -218,11 +218,8 @@ extension OpenVPN {
// MARK: Client
/// The server hostname (picked from first remote).
public var hostname: String?
/// The list of server endpoints.
public var endpointProtocols: [EndpointProtocol]?
public var remotes: [Endpoint]?
/// If true, checks EKU of server certificate.
public var checksEKU: Bool?
@ -338,8 +335,7 @@ extension OpenVPN {
keepAliveTimeout: keepAliveTimeout,
renegotiatesAfter: renegotiatesAfter,
xorMask: xorMask,
hostname: hostname,
endpointProtocols: endpointProtocols,
remotes: remotes,
checksEKU: checksEKU,
checksSANHost: checksSANHost,
sanHost: sanHost,
@ -410,11 +406,8 @@ extension OpenVPN {
/// - Seealso: `ConfigurationBuilder.xorMask`
public let xorMask: UInt8?
/// - Seealso: `ConfigurationBuilder.hostname`
public let hostname: String?
/// - Seealso: `ConfigurationBuilder.endpointProtocols`
public let endpointProtocols: [EndpointProtocol]?
/// - Seealso: `ConfigurationBuilder.remotes`
public let remotes: [Endpoint]?
/// - Seealso: `ConfigurationBuilder.checksEKU`
public let checksEKU: Bool?
@ -520,8 +513,7 @@ extension OpenVPN.Configuration {
builder.keepAliveInterval = keepAliveInterval
builder.keepAliveTimeout = keepAliveTimeout
builder.renegotiatesAfter = renegotiatesAfter
builder.hostname = hostname
builder.endpointProtocols = endpointProtocols
builder.remotes = remotes
builder.checksEKU = checksEKU
builder.checksSANHost = checksSANHost
builder.sanHost = sanHost
@ -552,10 +544,10 @@ extension OpenVPN.Configuration {
extension OpenVPN.Configuration {
public func print() {
guard let endpointProtocols = endpointProtocols else {
fatalError("No sessionConfiguration.endpointProtocols set")
guard let remotes = remotes else {
fatalError("No sessionConfiguration.remotes set")
}
log.info("\tProtocols: \(endpointProtocols)")
log.info("\tRemotes: \(remotes)")
log.info("\tCipher: \(fallbackCipher)")
log.info("\tDigest: \(fallbackDigest)")
log.info("\tCompression framing: \(fallbackCompressionFraming)")

View File

@ -770,14 +770,9 @@ extension OpenVPN {
optDefaultProto = optDefaultProto ?? .udp
optDefaultPort = optDefaultPort ?? 1194
if !optRemotes.isEmpty {
sessionBuilder.hostname = optRemotes[0].0
var fullRemotes: [(String, UInt16, SocketType)] = []
let hostname = optRemotes[0].0
optRemotes.forEach {
guard $0.0 == hostname else {
return
}
let hostname = $0.0
guard let port = $0.1 ?? optDefaultPort else {
return
}
@ -786,9 +781,9 @@ extension OpenVPN {
}
fullRemotes.append((hostname, port, socketType))
}
sessionBuilder.endpointProtocols = fullRemotes.map { EndpointProtocol($0.2, $0.1) }
} else {
sessionBuilder.hostname = nil
sessionBuilder.remotes = fullRemotes.map {
Endpoint($0.0, .init($0.2, $0.1))
}
}
sessionBuilder.authUserPass = authUserPass

View File

@ -55,8 +55,6 @@ extension OpenVPNProvider {
public static let defaults = Configuration(
sessionConfiguration: OpenVPN.ConfigurationBuilder().build(),
prefersResolvedAddresses: false,
resolvedAddresses: nil,
shouldDebug: false,
debugLogFormat: nil,
masksPrivateData: true,
@ -66,17 +64,6 @@ extension OpenVPNProvider {
/// The session configuration.
public var sessionConfiguration: OpenVPN.Configuration
/// Prefers resolved addresses over DNS resolution. `resolvedAddresses` must be set and non-empty. Default is `false`.
///
/// - Seealso: `fallbackServerAddresses`
public var prefersResolvedAddresses: Bool
/// Resolved addresses in case DNS fails or `prefersResolvedAddresses` is `true` (IPv4 only).
public var resolvedAddresses: [String]?
/// Optional version identifier about the client pushed to server in peer-info as `IV_UI_VER`.
public var versionIdentifier: String?
// MARK: Debugging
/// Enables debugging.
@ -88,6 +75,9 @@ extension OpenVPNProvider {
/// Mask private data in debug log (default is `true`).
public var masksPrivateData: Bool?
/// Optional version identifier about the client pushed to server in peer-info as `IV_UI_VER`.
public var versionIdentifier: String?
// MARK: Building
/**
@ -97,8 +87,6 @@ extension OpenVPNProvider {
*/
public init(sessionConfiguration: OpenVPN.Configuration) {
self.sessionConfiguration = sessionConfiguration
prefersResolvedAddresses = ConfigurationBuilder.defaults.prefersResolvedAddresses
resolvedAddresses = nil
shouldDebug = ConfigurationBuilder.defaults.shouldDebug
debugLogFormat = ConfigurationBuilder.defaults.debugLogFormat
masksPrivateData = ConfigurationBuilder.defaults.masksPrivateData
@ -113,8 +101,6 @@ extension OpenVPNProvider {
public func build() -> Configuration {
return Configuration(
sessionConfiguration: sessionConfiguration,
prefersResolvedAddresses: prefersResolvedAddresses,
resolvedAddresses: resolvedAddresses,
shouldDebug: shouldDebug,
debugLogFormat: shouldDebug ? debugLogFormat : nil,
masksPrivateData: masksPrivateData,
@ -129,12 +115,6 @@ extension OpenVPNProvider {
/// - Seealso: `OpenVPNProvider.ConfigurationBuilder.sessionConfiguration`
public let sessionConfiguration: OpenVPN.Configuration
/// - Seealso: `OpenVPNProvider.ConfigurationBuilder.prefersResolvedAddresses`
public let prefersResolvedAddresses: Bool
/// - Seealso: `OpenVPNProvider.ConfigurationBuilder.resolvedAddresses`
public let resolvedAddresses: [String]?
/// - Seealso: `OpenVPNProvider.ConfigurationBuilder.shouldDebug`
public let shouldDebug: Bool
@ -246,11 +226,7 @@ extension OpenVPNProvider {
- Throws: `OpenVPNProviderError.configuration` if `providerConfiguration` is incomplete.
*/
public static func parsed(from providerConfiguration: [String: Any]) throws -> Configuration {
let cfg = try fromDictionary(OpenVPNProvider.Configuration.self, providerConfiguration)
guard !cfg.prefersResolvedAddresses || !(cfg.resolvedAddresses?.isEmpty ?? true) else {
throw OpenVPNProviderConfigurationError.parameter(name: "protocolConfiguration.providerConfiguration[prefersResolvedAddresses] is true but no [resolvedAddresses]")
}
return cfg
return try fromDictionary(OpenVPNProvider.Configuration.self, providerConfiguration)
}
/**
@ -284,13 +260,16 @@ extension OpenVPNProvider {
withBundleIdentifier bundleIdentifier: String,
appGroup: String,
context: String,
credentials: OpenVPN.Credentials?) throws -> NETunnelProviderProtocol {
credentials: OpenVPN.Credentials?) throws -> NETunnelProviderProtocol
{
let protocolConfiguration = NETunnelProviderProtocol()
let keychain = Keychain(group: appGroup)
protocolConfiguration.providerBundleIdentifier = bundleIdentifier
protocolConfiguration.serverAddress = sessionConfiguration.hostname ?? resolvedAddresses?.first
guard let firstRemote = sessionConfiguration.remotes?.first else {
fatalError("No remotes set")
}
protocolConfiguration.serverAddress = "\(firstRemote.address):\(firstRemote.proto.port)"
if let username = credentials?.username {
protocolConfiguration.username = username
if let password = credentials?.password {
@ -324,8 +303,6 @@ extension OpenVPNProvider.Configuration {
*/
public func builder() -> OpenVPNProvider.ConfigurationBuilder {
var builder = OpenVPNProvider.ConfigurationBuilder(sessionConfiguration: sessionConfiguration)
builder.prefersResolvedAddresses = prefersResolvedAddresses
builder.resolvedAddresses = resolvedAddresses
builder.shouldDebug = shouldDebug
builder.debugLogFormat = debugLogFormat
builder.masksPrivateData = masksPrivateData

View File

@ -58,8 +58,8 @@ public enum OpenVPNProviderError: String, Error {
/// Socket endpoint could not be resolved.
case dnsFailure
/// No more protocols available to try.
case exhaustedProtocols
/// No more endpoints available to try.
case exhaustedEndpoints
/// Socket failed to reach active state.
case socketActivity

View File

@ -62,6 +62,8 @@ class AppExtensionTests: XCTestCase {
let identifier = "com.example.Provider"
let appGroup = "group.com.algoritmico.TunnelKit"
let hostname = "example.com"
let port: UInt16 = 1234
let serverAddress = "\(hostname):\(port)"
let context = "foobar"
let credentials = OpenVPN.Credentials("foo", "bar")
@ -69,8 +71,7 @@ class AppExtensionTests: XCTestCase {
sessionBuilder.ca = OpenVPN.CryptoContainer(pem: "abcdef")
sessionBuilder.cipher = .aes128cbc
sessionBuilder.digest = .sha256
sessionBuilder.hostname = hostname
sessionBuilder.endpointProtocols = []
sessionBuilder.remotes = [.init(hostname, .init(.udp, port))]
sessionBuilder.mtu = 1230
builder = OpenVPNProvider.ConfigurationBuilder(sessionConfiguration: sessionBuilder.build())
XCTAssertNotNil(builder)
@ -86,7 +87,7 @@ class AppExtensionTests: XCTestCase {
XCTAssertNotNil(proto)
XCTAssertEqual(proto?.providerBundleIdentifier, identifier)
XCTAssertEqual(proto?.serverAddress, hostname)
XCTAssertEqual(proto?.serverAddress, serverAddress)
XCTAssertEqual(proto?.username, credentials.username)
XCTAssertEqual(proto?.passwordReference, try? Keychain(group: appGroup).passwordReference(for: credentials.username, context: context))
@ -107,15 +108,17 @@ class AppExtensionTests: XCTestCase {
func testDNSResolver() {
let exp = expectation(description: "DNS")
DNSResolver.resolve("www.google.com", timeout: 1000, queue: .main) { (addrs, error) in
DNSResolver.resolve("www.google.com", timeout: 1000, queue: .main) {
defer {
exp.fulfill()
}
guard let addrs = addrs else {
switch $0 {
case .success(let records):
print("\(records)")
case .failure:
print("Can't resolve")
return
}
print("\(addrs)")
}
waitForExpectations(timeout: 5.0, handler: nil)
}
@ -143,37 +146,31 @@ class AppExtensionTests: XCTestCase {
func testEndpointCycling() {
CoreConfiguration.masksPrivateData = false
var builder1 = OpenVPN.ConfigurationBuilder()
builder1.hostname = "italy.privateinternetaccess.com"
builder1.endpointProtocols = [
EndpointProtocol(.tcp6, 2222),
EndpointProtocol(.udp, 1111),
EndpointProtocol(.udp4, 3333)
var builder = OpenVPN.ConfigurationBuilder()
let hostname = "italy.privateinternetaccess.com"
builder.remotes = [
.init(hostname, .init(.tcp6, 2222)),
.init(hostname, .init(.udp, 1111)),
.init(hostname, .init(.udp4, 3333))
]
var builder2 = OpenVPNProvider.ConfigurationBuilder(sessionConfiguration: builder1.build())
builder2.prefersResolvedAddresses = true
builder2.resolvedAddresses = [
"82.102.21.218",
"82.102.21.214",
"82.102.21.213",
]
let strategy = ConnectionStrategy(configuration: builder2.build())
let strategy = ConnectionStrategy(configuration: builder.build())
let expected = [
"82.102.21.218:UDP:1111",
"82.102.21.218:UDP4:3333",
"82.102.21.214:UDP:1111",
"82.102.21.214:UDP4:3333",
"82.102.21.213:UDP:1111",
"82.102.21.213:UDP4:3333",
"italy.privateinternetaccess.com:TCP6:2222",
"italy.privateinternetaccess.com:UDP:1111",
"italy.privateinternetaccess.com:UDP4:3333"
]
var i = 0
while strategy.hasEndpoint() {
let endpoint = strategy.currentEndpoint()
print("\(endpoint)")
XCTAssertEqual(endpoint.description, expected[i])
while strategy.hasEndpoints() {
guard let remote = strategy.currentRemote else {
break
}
print("\(i): \(remote)")
XCTAssertEqual(remote.originalEndpoint.description, expected[i])
i += 1
strategy.tryNextEndpoint()
guard strategy.tryNextEndpoint() else {
break
}
}
}

View File

@ -105,13 +105,12 @@ class ConfigurationParserTests: XCTestCase {
func testPIA() throws {
let file = try OpenVPN.ConfigurationParser.parsed(fromURL: url(withName: "pia-hungary"))
XCTAssertEqual(file.configuration.hostname, "hungary.privateinternetaccess.com")
XCTAssertEqual(file.configuration.remotes, [
.init("hungary.privateinternetaccess.com", .init(.udp, 1198)),
.init("hungary.privateinternetaccess.com", .init(.tcp, 502)),
])
XCTAssertEqual(file.configuration.cipher, .aes128cbc)
XCTAssertEqual(file.configuration.digest, .sha1)
XCTAssertEqual(file.configuration.endpointProtocols, [
EndpointProtocol(.udp, 1198),
EndpointProtocol(.tcp, 502)
])
}
func testStripped() throws {