Make hostname optional in ConnectionStrategy

Assume preferring resolved addresses.
This commit is contained in:
Davide De Rosa 2019-04-09 19:40:17 +02:00
parent f4683bd337
commit 3fe9c6de6d
4 changed files with 32 additions and 12 deletions

View File

@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## Unreleased
### Changed
- Make `hostname` optional and pick `resolvedAddresses` if nil.
## 1.6.1 (2019-04-07) ## 1.6.1 (2019-04-07)
### Fixed ### Fixed

View File

@ -42,7 +42,7 @@ import SwiftyBeaver
private let log = SwiftyBeaver.self private let log = SwiftyBeaver.self
class ConnectionStrategy { class ConnectionStrategy {
private let hostname: String private let hostname: String?
private let prefersResolvedAddresses: Bool private let prefersResolvedAddresses: Bool
@ -52,15 +52,17 @@ class ConnectionStrategy {
private var currentProtocolIndex = 0 private var currentProtocolIndex = 0
init(hostname: String, configuration: TunnelKitProvider.Configuration) { init(configuration: TunnelKitProvider.Configuration) {
precondition(!configuration.prefersResolvedAddresses || !(configuration.resolvedAddresses?.isEmpty ?? true)) hostname = configuration.sessionConfiguration.hostname
prefersResolvedAddresses = (hostname == nil) || configuration.prefersResolvedAddresses
self.hostname = hostname
prefersResolvedAddresses = configuration.prefersResolvedAddresses
resolvedAddresses = configuration.resolvedAddresses resolvedAddresses = configuration.resolvedAddresses
if prefersResolvedAddresses {
guard !(resolvedAddresses?.isEmpty ?? true) else {
fatalError("Either hostname or resolved addresses provided")
}
}
guard var endpointProtocols = configuration.sessionConfiguration.endpointProtocols else { guard var endpointProtocols = configuration.sessionConfiguration.endpointProtocols else {
fatalError("No endpoints defined") fatalError("No endpoints provided")
} }
if configuration.sessionConfiguration.randomizeEndpoint ?? false { if configuration.sessionConfiguration.randomizeEndpoint ?? false {
endpointProtocols.shuffle() endpointProtocols.shuffle()
@ -92,6 +94,11 @@ class ConnectionStrategy {
} }
// fall back to DNS // fall back to DNS
guard let hostname = hostname else {
log.error("DNS resolution unavailable: no hostname provided!")
completionHandler(nil, TunnelKitProvider.ProviderError.dnsFailure)
return
}
log.debug("DNS resolve hostname: \(hostname.maskedDescription)") log.debug("DNS resolve hostname: \(hostname.maskedDescription)")
DNSResolver.resolve(hostname, timeout: timeout, queue: queue) { (addresses, error) in DNSResolver.resolve(hostname, timeout: timeout, queue: queue) { (addresses, error) in

View File

@ -466,7 +466,7 @@ extension TunnelKitProvider {
- Returns: The generated `NETunnelProviderProtocol` object. - Returns: The generated `NETunnelProviderProtocol` object.
- Throws: `ProviderError.credentials` if unable to store `credentials.password` to the `appGroup` keychain. - Throws: `ProviderError.credentials` if unable to store `credentials.password` to the `appGroup` keychain.
*/ */
public func generatedTunnelProtocol(withBundleIdentifier bundleIdentifier: String, appGroup: String, hostname: String, credentials: SessionProxy.Credentials? = nil) throws -> NETunnelProviderProtocol { public func generatedTunnelProtocol(withBundleIdentifier bundleIdentifier: String, appGroup: String, hostname: String?, credentials: SessionProxy.Credentials? = nil) throws -> NETunnelProviderProtocol {
let protocolConfiguration = NETunnelProviderProtocol() let protocolConfiguration = NETunnelProviderProtocol()
protocolConfiguration.providerBundleIdentifier = bundleIdentifier protocolConfiguration.providerBundleIdentifier = bundleIdentifier

View File

@ -123,7 +123,6 @@ open class TunnelKitProvider: NEPacketTunnelProvider {
open override func startTunnel(options: [String : NSObject]? = nil, completionHandler: @escaping (Error?) -> Void) { open override func startTunnel(options: [String : NSObject]? = nil, completionHandler: @escaping (Error?) -> Void) {
// required configuration // required configuration
let hostname: String
do { do {
guard let tunnelProtocol = protocolConfiguration as? NETunnelProviderProtocol else { guard let tunnelProtocol = protocolConfiguration as? NETunnelProviderProtocol else {
throw ProviderConfigurationError.parameter(name: "protocolConfiguration") throw ProviderConfigurationError.parameter(name: "protocolConfiguration")
@ -134,9 +133,17 @@ open class TunnelKitProvider: NEPacketTunnelProvider {
guard let providerConfiguration = tunnelProtocol.providerConfiguration else { guard let providerConfiguration = tunnelProtocol.providerConfiguration else {
throw ProviderConfigurationError.parameter(name: "protocolConfiguration.providerConfiguration") throw ProviderConfigurationError.parameter(name: "protocolConfiguration.providerConfiguration")
} }
hostname = serverAddress
try appGroup = Configuration.appGroup(from: providerConfiguration) try appGroup = Configuration.appGroup(from: providerConfiguration)
try cfg = Configuration.parsed(from: providerConfiguration) try cfg = Configuration.parsed(from: providerConfiguration)
// inject serverAddress into sessionConfiguration.hostname
if !serverAddress.isEmpty {
var sessionBuilder = cfg.sessionConfiguration.builder()
sessionBuilder.hostname = serverAddress
var cfgBuilder = cfg.builder()
cfgBuilder.sessionConfiguration = sessionBuilder.build()
cfg = cfgBuilder.build()
}
} catch let e { } catch let e {
var message: String? var message: String?
if let te = e as? ProviderConfigurationError { if let te = e as? ProviderConfigurationError {
@ -162,7 +169,7 @@ open class TunnelKitProvider: NEPacketTunnelProvider {
credentials = nil credentials = nil
} }
strategy = ConnectionStrategy(hostname: hostname, configuration: cfg) strategy = ConnectionStrategy(configuration: cfg)
if let content = cfg.existingLog(in: appGroup) { if let content = cfg.existingLog(in: appGroup) {
var existingLog = content.components(separatedBy: "\n") var existingLog = content.components(separatedBy: "\n")