NE: Handle bad domain names and Activate On Demand

This combination causes iOS to keep trying to bring up the tunnel,
leading to a lot of displayMessage() alerts.

In this fix, if we get a DNS resolution error in an Activate On Demand
enabled tunnel, we silently retry 9 times (with a 4-second delay before
each retry) and then show the displayMessage() alert.
This commit is contained in:
Roopesh Chander 2018-12-19 15:38:00 +05:30
parent 82ca9f7c5a
commit b8c331c72d
4 changed files with 59 additions and 19 deletions

View File

@ -4,7 +4,7 @@
import NetworkExtension import NetworkExtension
extension NETunnelProviderProtocol { extension NETunnelProviderProtocol {
convenience init?(tunnelConfiguration: TunnelConfiguration) { convenience init?(tunnelConfiguration: TunnelConfiguration, isActivateOnDemandEnabled: Bool) {
assert(!tunnelConfiguration.interface.name.isEmpty) assert(!tunnelConfiguration.interface.name.isEmpty)
guard let serializedTunnelConfiguration = try? JSONEncoder().encode(tunnelConfiguration) else { return nil } guard let serializedTunnelConfiguration = try? JSONEncoder().encode(tunnelConfiguration) else { return nil }
@ -14,7 +14,8 @@ extension NETunnelProviderProtocol {
providerBundleIdentifier = "\(appId).network-extension" providerBundleIdentifier = "\(appId).network-extension"
providerConfiguration = [ providerConfiguration = [
"tunnelConfiguration": serializedTunnelConfiguration, "tunnelConfiguration": serializedTunnelConfiguration,
"tunnelConfigurationVersion": 1 "tunnelConfigurationVersion": 1,
"isActivateOnDemandEnabled": isActivateOnDemandEnabled
] ]
let endpoints = tunnelConfiguration.peers.compactMap {$0.endpoint} let endpoints = tunnelConfiguration.peers.compactMap {$0.endpoint}
@ -32,4 +33,8 @@ extension NETunnelProviderProtocol {
guard let serializedTunnelConfiguration = providerConfiguration?["tunnelConfiguration"] as? Data else { return nil } guard let serializedTunnelConfiguration = providerConfiguration?["tunnelConfiguration"] as? Data else { return nil }
return try? JSONDecoder().decode(TunnelConfiguration.self, from: serializedTunnelConfiguration) return try? JSONDecoder().decode(TunnelConfiguration.self, from: serializedTunnelConfiguration)
} }
var isActivateOnDemandEnabled: Bool {
return (providerConfiguration?["isActivateOnDemandEnabled"] as? Bool) ?? false
}
} }

View File

@ -59,7 +59,7 @@ class TunnelsManager {
} }
let tunnelProviderManager = NETunnelProviderManager() let tunnelProviderManager = NETunnelProviderManager()
tunnelProviderManager.protocolConfiguration = NETunnelProviderProtocol(tunnelConfiguration: tunnelConfiguration) tunnelProviderManager.protocolConfiguration = NETunnelProviderProtocol(tunnelConfiguration: tunnelConfiguration, isActivateOnDemandEnabled: activateOnDemandSetting.isActivateOnDemandEnabled)
tunnelProviderManager.localizedDescription = tunnelName tunnelProviderManager.localizedDescription = tunnelName
tunnelProviderManager.isEnabled = true tunnelProviderManager.isEnabled = true
@ -115,7 +115,7 @@ class TunnelsManager {
} }
tunnel.name = tunnelName tunnel.name = tunnelName
} }
tunnelProviderManager.protocolConfiguration = NETunnelProviderProtocol(tunnelConfiguration: tunnelConfiguration) tunnelProviderManager.protocolConfiguration = NETunnelProviderProtocol(tunnelConfiguration: tunnelConfiguration, isActivateOnDemandEnabled: activateOnDemandSetting.isActivateOnDemandEnabled)
tunnelProviderManager.localizedDescription = tunnelName tunnelProviderManager.localizedDescription = tunnelName
tunnelProviderManager.isEnabled = true tunnelProviderManager.isEnabled = true

View File

@ -18,8 +18,12 @@ class ErrorNotifier {
switch error { switch error {
case .savedProtocolConfigurationIsInvalid: case .savedProtocolConfigurationIsInvalid:
return ("Activation failure", "Could not retrieve tunnel information from the saved configuration") return ("Activation failure", "Could not retrieve tunnel information from the saved configuration")
case .dnsResolutionFailure: case .dnsResolutionFailure(let tunnelName, let isActivateOnDemandEnabled):
return ("DNS resolution failure", "One or more endpoint domains could not be resolved") if isActivateOnDemandEnabled {
return ("DNS resolution failure", "This tunnel has Activate On Demand enabled, so activation might be retried. You may turn off Activate On Demand in the WireGuard app by navigating to: '\(tunnelName)' > Edit")
} else {
return ("DNS resolution failure", "One or more endpoint domains could not be resolved")
}
case .couldNotStartWireGuard: case .couldNotStartWireGuard:
return ("Activation failure", "WireGuard backend could not be started") return ("Activation failure", "WireGuard backend could not be started")
case .coultNotSetNetworkSettings: case .coultNotSetNetworkSettings:

View File

@ -8,7 +8,7 @@ import os.log
enum PacketTunnelProviderError: Error { enum PacketTunnelProviderError: Error {
case savedProtocolConfigurationIsInvalid case savedProtocolConfigurationIsInvalid
case dnsResolutionFailure(hostnames: [String]) case dnsResolutionFailure(tunnelName: String, isActivateOnDemandEnabled: Bool)
case couldNotStartWireGuard case couldNotStartWireGuard
case coultNotSetNetworkSettings case coultNotSetNetworkSettings
} }
@ -38,21 +38,22 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
configureLogger() configureLogger()
wg_log(.info, message: "Starting tunnel '\(tunnelConfiguration.interface.name)'") let tunnelName = tunnelConfiguration.interface.name
wg_log(.info, message: "Starting tunnel '\(tunnelName)'")
let isActivateOnDemandEnabled = tunnelProviderProtocol.isActivateOnDemandEnabled
if isActivateOnDemandEnabled {
wg_log(.info, staticMessage: "Tunnel has Activate On Demand enabled")
} else {
wg_log(.info, staticMessage: "Tunnel has Activate On Demand disabled")
}
let endpoints = tunnelConfiguration.peers.map { $0.endpoint } let endpoints = tunnelConfiguration.peers.map { $0.endpoint }
var resolvedEndpoints = [Endpoint?]() guard let resolvedEndpoints = resolveDomainNames(endpoints: endpoints, isActivateOnDemandEnabled: isActivateOnDemandEnabled) else {
do { let dnsError = PacketTunnelProviderError.dnsResolutionFailure(tunnelName: tunnelName, isActivateOnDemandEnabled: isActivateOnDemandEnabled)
resolvedEndpoints = try DNSResolver.resolveSync(endpoints: endpoints) errorNotifier.notify(dnsError)
} catch DNSResolverError.dnsResolutionFailed(let hostnames) { startTunnelCompletionHandler(dnsError)
wg_log(.error, staticMessage: "Starting tunnel failed: DNS resolution failure")
wg_log(.error, message: "Hostnames for which DNS resolution failed: \(hostnames.joined(separator: ", "))")
errorNotifier.notify(PacketTunnelProviderError.dnsResolutionFailure(hostnames: hostnames))
startTunnelCompletionHandler(PacketTunnelProviderError.dnsResolutionFailure(hostnames: hostnames))
return return
} catch {
// There can be no other errors from DNSResolver.resolveSync()
fatalError()
} }
assert(endpoints.count == resolvedEndpoints.count) assert(endpoints.count == resolvedEndpoints.count)
@ -143,6 +144,36 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
} }
} }
private func resolveDomainNames(endpoints: [Endpoint?], isActivateOnDemandEnabled: Bool) -> [Endpoint?]? {
var resolvedEndpoints = [Endpoint?]()
let dnsResolutionAttemptsCount = isActivateOnDemandEnabled ? 10 : 1
var isDNSResolved = false
for attemptIndex in 0 ..< dnsResolutionAttemptsCount {
do {
resolvedEndpoints = try DNSResolver.resolveSync(endpoints: endpoints)
isDNSResolved = true
} catch DNSResolverError.dnsResolutionFailed(let hostnames) {
wg_log(.error, staticMessage: "Starting tunnel failed: DNS resolution failure")
wg_log(.error, message: "Hostnames for which DNS resolution failed: \(hostnames.joined(separator: ", "))")
} catch {
// There can be no other errors from DNSResolver.resolveSync()
fatalError()
}
if isDNSResolved {
break
} else {
let isLastAttempt = attemptIndex == dnsResolutionAttemptsCount - 1
if !isLastAttempt {
Thread.sleep(forTimeInterval: 4 /* seconds */)
wg_log(.error, message: "Retrying DNS resolution (Attempt \(attemptIndex + 2))")
}
}
}
return isDNSResolved ? resolvedEndpoints : nil
}
private func connect(interfaceName: String, settings: String, fileDescriptor: Int32) -> Int32 { private func connect(interfaceName: String, settings: String, fileDescriptor: Int32) -> Int32 {
return withStringsAsGoStrings(interfaceName, settings) { return wgTurnOn($0.0, $0.1, fileDescriptor) } return withStringsAsGoStrings(interfaceName, settings) { return wgTurnOn($0.0, $0.1, fileDescriptor) }
} }