WireGuardKit: Conditionally turn on/off wireguard-go

Signed-off-by: Andrej Mihajlov <and@mullvad.net>
This commit is contained in:
Andrej Mihajlov 2020-12-01 11:18:31 +01:00
parent 3de7c99301
commit 9f8d0e24df
3 changed files with 215 additions and 108 deletions

View File

@ -109,7 +109,7 @@ extension Endpoint {
hints.ai_family = AF_UNSPEC hints.ai_family = AF_UNSPEC
hints.ai_socktype = SOCK_DGRAM hints.ai_socktype = SOCK_DGRAM
hints.ai_protocol = IPPROTO_UDP hints.ai_protocol = IPPROTO_UDP
hints.ai_flags = AI_DEFAULT hints.ai_flags = 0 // We set this to zero so that we actually resolve this using DNS64
var result: UnsafeMutablePointer<addrinfo>? var result: UnsafeMutablePointer<addrinfo>?
defer { defer {

View File

@ -9,6 +9,9 @@ import NetworkExtension
import WireGuardKitC import WireGuardKitC
#endif #endif
/// A type alias for `Result` type that holds a tuple with source and resolved endpoint.
typealias EndpointResolutionResult = Result<(Endpoint, Endpoint), DNSResolutionError>
class PacketTunnelSettingsGenerator { class PacketTunnelSettingsGenerator {
let tunnelConfiguration: TunnelConfiguration let tunnelConfiguration: TunnelConfiguration
let resolvedEndpoints: [Endpoint?] let resolvedEndpoints: [Endpoint?]
@ -18,31 +21,27 @@ class PacketTunnelSettingsGenerator {
self.resolvedEndpoints = resolvedEndpoints self.resolvedEndpoints = resolvedEndpoints
} }
func endpointUapiConfiguration() -> (String, [DNSResolutionError]) { func endpointUapiConfiguration() -> (String, [EndpointResolutionResult?]) {
var resolutionErrors = [DNSResolutionError]() var resolutionResults = [EndpointResolutionResult?]()
var wgSettings = "" var wgSettings = ""
for (index, peer) in tunnelConfiguration.peers.enumerated() {
wgSettings.append("public_key=\(peer.publicKey.hexKey)\n")
let result = Result { try resolvedEndpoints[index]?.withReresolvedIP() }
.mapError { error -> DNSResolutionError in
// swiftlint:disable:next force_cast
return error as! DNSResolutionError
}
switch result { assert(tunnelConfiguration.peers.count == resolvedEndpoints.count)
case .success(.some(let endpoint)): for (peer, resolvedEndpoint) in zip(self.tunnelConfiguration.peers, self.resolvedEndpoints) {
if case .name = endpoint.host { assert(false, "Endpoint is not resolved") } wgSettings.append("public_key=\(peer.publicKey.hexKey)\n")
wgSettings.append("endpoint=\(endpoint.stringRepresentation)\n")
case .success(.none): let result = resolvedEndpoint.map(Self.reresolveEndpoint)
break if case .success((_, let resolvedEndpoint)) = result {
case .failure(let error): if case .name = resolvedEndpoint.host { assert(false, "Endpoint is not resolved") }
resolutionErrors.append(error) wgSettings.append("endpoint=\(resolvedEndpoint.stringRepresentation)\n")
} }
resolutionResults.append(result)
} }
return (wgSettings, resolutionErrors)
return (wgSettings, resolutionResults)
} }
func uapiConfiguration() -> String { func uapiConfiguration() -> (String, [EndpointResolutionResult?]) {
var resolutionResults = [EndpointResolutionResult?]()
var wgSettings = "" var wgSettings = ""
wgSettings.append("private_key=\(tunnelConfiguration.interface.privateKey.hexKey)\n") wgSettings.append("private_key=\(tunnelConfiguration.interface.privateKey.hexKey)\n")
if let listenPort = tunnelConfiguration.interface.listenPort { if let listenPort = tunnelConfiguration.interface.listenPort {
@ -52,15 +51,19 @@ class PacketTunnelSettingsGenerator {
wgSettings.append("replace_peers=true\n") wgSettings.append("replace_peers=true\n")
} }
assert(tunnelConfiguration.peers.count == resolvedEndpoints.count) assert(tunnelConfiguration.peers.count == resolvedEndpoints.count)
for (index, peer) in tunnelConfiguration.peers.enumerated() { for (peer, resolvedEndpoint) in zip(self.tunnelConfiguration.peers, self.resolvedEndpoints) {
wgSettings.append("public_key=\(peer.publicKey.hexKey)\n") wgSettings.append("public_key=\(peer.publicKey.hexKey)\n")
if let preSharedKey = peer.preSharedKey?.hexKey { if let preSharedKey = peer.preSharedKey?.hexKey {
wgSettings.append("preshared_key=\(preSharedKey)\n") wgSettings.append("preshared_key=\(preSharedKey)\n")
} }
if let endpoint = try? resolvedEndpoints[index]?.withReresolvedIP() {
if case .name = endpoint.host { assert(false, "Endpoint is not resolved") } let result = resolvedEndpoint.map(Self.reresolveEndpoint)
wgSettings.append("endpoint=\(endpoint.stringRepresentation)\n") if case .success((_, let resolvedEndpoint)) = result {
if case .name = resolvedEndpoint.host { assert(false, "Endpoint is not resolved") }
wgSettings.append("endpoint=\(resolvedEndpoint.stringRepresentation)\n")
} }
resolutionResults.append(result)
let persistentKeepAlive = peer.persistentKeepAlive ?? 0 let persistentKeepAlive = peer.persistentKeepAlive ?? 0
wgSettings.append("persistent_keepalive_interval=\(persistentKeepAlive)\n") wgSettings.append("persistent_keepalive_interval=\(persistentKeepAlive)\n")
if !peer.allowedIPs.isEmpty { if !peer.allowedIPs.isEmpty {
@ -68,7 +71,7 @@ class PacketTunnelSettingsGenerator {
peer.allowedIPs.forEach { wgSettings.append("allowed_ip=\($0.stringRepresentation)\n") } peer.allowedIPs.forEach { wgSettings.append("allowed_ip=\($0.stringRepresentation)\n") }
} }
} }
return wgSettings return (wgSettings, resolutionResults)
} }
func generateNetworkSettings() -> NEPacketTunnelNetworkSettings { func generateNetworkSettings() -> NEPacketTunnelNetworkSettings {
@ -163,4 +166,12 @@ class PacketTunnelSettingsGenerator {
} }
return (ipv4IncludedRoutes, ipv6IncludedRoutes) return (ipv4IncludedRoutes, ipv6IncludedRoutes)
} }
private class func reresolveEndpoint(endpoint: Endpoint) -> EndpointResolutionResult {
return Result { (endpoint, try endpoint.withReresolvedIP()) }
.mapError { error -> DNSResolutionError in
// swiftlint:disable:next force_cast
return error as! DNSResolutionError
}
}
} }

View File

@ -28,6 +28,18 @@ public enum WireGuardAdapterError: Error {
case startWireGuardBackend(Int32) case startWireGuardBackend(Int32)
} }
/// Enum representing internal state of the `WireGuardAdapter`
private enum State {
/// The tunnel is stopped
case stopped
/// The tunnel is up and running
case started(_ handle: Int32, _ settingsGenerator: PacketTunnelSettingsGenerator)
/// The tunnel is temporarily shutdown due to device going offline
case temporaryShutdown(_ settingsGenerator: PacketTunnelSettingsGenerator)
}
public class WireGuardAdapter { public class WireGuardAdapter {
public typealias LogHandler = (WireGuardLogLevel, String) -> Void public typealias LogHandler = (WireGuardLogLevel, String) -> Void
@ -40,15 +52,11 @@ public class WireGuardAdapter {
/// Log handler closure. /// Log handler closure.
private let logHandler: LogHandler private let logHandler: LogHandler
/// WireGuard internal handle returned by `wgTurnOn` that's used to associate the calls
/// with the specific WireGuard tunnel.
private var wireguardHandle: Int32?
/// Private queue used to synchronize access to `WireGuardAdapter` members. /// Private queue used to synchronize access to `WireGuardAdapter` members.
private let workQueue = DispatchQueue(label: "WireGuardAdapterWorkQueue") private let workQueue = DispatchQueue(label: "WireGuardAdapterWorkQueue")
/// Packet tunnel settings generator. /// Adapter state.
private var settingsGenerator: PacketTunnelSettingsGenerator? private var state: State = .stopped
/// Tunnel device file descriptor. /// Tunnel device file descriptor.
private var tunnelFileDescriptor: Int32? { private var tunnelFileDescriptor: Int32? {
@ -108,7 +116,7 @@ public class WireGuardAdapter {
networkMonitor?.cancel() networkMonitor?.cancel()
// Shutdown the tunnel // Shutdown the tunnel
if let handle = self.wireguardHandle { if case .started(let handle, _) = self.state {
wgTurnOff(handle) wgTurnOff(handle)
} }
} }
@ -119,7 +127,7 @@ public class WireGuardAdapter {
/// - Parameter completionHandler: completion handler. /// - Parameter completionHandler: completion handler.
public func getRuntimeConfiguration(completionHandler: @escaping (String?) -> Void) { public func getRuntimeConfiguration(completionHandler: @escaping (String?) -> Void) {
workQueue.async { workQueue.async {
guard let handle = self.wireguardHandle else { guard case .started(let handle, _) = self.state else {
completionHandler(nil) completionHandler(nil)
return return
} }
@ -139,16 +147,11 @@ public class WireGuardAdapter {
/// - completionHandler: completion handler. /// - completionHandler: completion handler.
public func start(tunnelConfiguration: TunnelConfiguration, completionHandler: @escaping (WireGuardAdapterError?) -> Void) { public func start(tunnelConfiguration: TunnelConfiguration, completionHandler: @escaping (WireGuardAdapterError?) -> Void) {
workQueue.async { workQueue.async {
guard self.wireguardHandle == nil else { guard case .stopped = self.state else {
completionHandler(.invalidState) completionHandler(.invalidState)
return return
} }
guard let tunnelFileDescriptor = self.tunnelFileDescriptor else {
completionHandler(.cannotLocateTunnelFileDescriptor)
return
}
#if os(macOS) #if os(macOS)
wgEnableRoaming(true) wgEnableRoaming(true)
#endif #endif
@ -157,25 +160,26 @@ public class WireGuardAdapter {
networkMonitor.pathUpdateHandler = { [weak self] path in networkMonitor.pathUpdateHandler = { [weak self] path in
self?.didReceivePathUpdate(path: path) self?.didReceivePathUpdate(path: path)
} }
networkMonitor.start(queue: self.workQueue) networkMonitor.start(queue: self.workQueue)
self.networkMonitor = networkMonitor
self.updateNetworkSettings(tunnelConfiguration: tunnelConfiguration) { settingsGenerator, error in do {
if let error = error { let settingsGenerator = try self.makeSettingsGenerator(with: tunnelConfiguration)
completionHandler(error) try self.setNetworkSettings(settingsGenerator.generateNetworkSettings())
} else {
var returnError: WireGuardAdapterError?
let handle = wgTurnOn(settingsGenerator!.uapiConfiguration(), tunnelFileDescriptor)
if handle >= 0 { let (wgConfig, resolutionResults) = settingsGenerator.uapiConfiguration()
self.wireguardHandle = handle self.logEndpointResolutionResults(resolutionResults)
} else {
returnError = .startWireGuardBackend(handle)
}
completionHandler(returnError) self.state = .started(
} try self.startWireGuardBackend(wgConfig: wgConfig),
settingsGenerator
)
self.networkMonitor = networkMonitor
completionHandler(nil)
} catch let error as WireGuardAdapterError {
networkMonitor.cancel()
completionHandler(error)
} catch {
fatalError()
} }
} }
} }
@ -184,7 +188,14 @@ public class WireGuardAdapter {
/// - Parameter completionHandler: completion handler. /// - Parameter completionHandler: completion handler.
public func stop(completionHandler: @escaping (WireGuardAdapterError?) -> Void) { public func stop(completionHandler: @escaping (WireGuardAdapterError?) -> Void) {
workQueue.async { workQueue.async {
guard let handle = self.wireguardHandle else { switch self.state {
case .started(let handle, _):
wgTurnOff(handle)
case .temporaryShutdown:
break
case .stopped:
completionHandler(.invalidState) completionHandler(.invalidState)
return return
} }
@ -192,8 +203,7 @@ public class WireGuardAdapter {
self.networkMonitor?.cancel() self.networkMonitor?.cancel()
self.networkMonitor = nil self.networkMonitor = nil
wgTurnOff(handle) self.state = .stopped
self.wireguardHandle = nil
completionHandler(nil) completionHandler(nil)
} }
@ -205,7 +215,7 @@ public class WireGuardAdapter {
/// - completionHandler: completion handler. /// - completionHandler: completion handler.
public func update(tunnelConfiguration: TunnelConfiguration, completionHandler: @escaping (WireGuardAdapterError?) -> Void) { public func update(tunnelConfiguration: TunnelConfiguration, completionHandler: @escaping (WireGuardAdapterError?) -> Void) {
workQueue.async { workQueue.async {
guard let handle = self.wireguardHandle else { if case .stopped = self.state {
completionHandler(.invalidState) completionHandler(.invalidState)
return return
} }
@ -214,16 +224,35 @@ public class WireGuardAdapter {
// configuration. // configuration.
// This will broadcast the `NEVPNStatusDidChange` notification to the GUI process. // This will broadcast the `NEVPNStatusDidChange` notification to the GUI process.
self.packetTunnelProvider?.reasserting = true self.packetTunnelProvider?.reasserting = true
defer {
self.packetTunnelProvider?.reasserting = false
}
self.updateNetworkSettings(tunnelConfiguration: tunnelConfiguration) { settingsGenerator, error in do {
if let error = error { let settingsGenerator = try self.makeSettingsGenerator(with: tunnelConfiguration)
completionHandler(error) try self.setNetworkSettings(settingsGenerator.generateNetworkSettings())
} else {
wgSetConfig(handle, settingsGenerator!.uapiConfiguration()) switch self.state {
completionHandler(nil) case .started(let handle, _):
let (wgConfig, resolutionResults) = settingsGenerator.uapiConfiguration()
self.logEndpointResolutionResults(resolutionResults)
wgSetConfig(handle, wgConfig)
self.state = .started(handle, settingsGenerator)
case .temporaryShutdown:
self.state = .temporaryShutdown(settingsGenerator)
case .stopped:
fatalError()
} }
self.packetTunnelProvider?.reasserting = false completionHandler(nil)
} catch let error as WireGuardAdapterError {
completionHandler(error)
} catch {
fatalError()
} }
} }
} }
@ -246,30 +275,15 @@ public class WireGuardAdapter {
} }
} }
/// Resolve endpoints and update network configuration. /// Set network tunnel configuration.
/// This method ensures that the call to `setTunnelNetworkSettings` does not time out, as in
/// certain scenarios the completion handler given to it may not be invoked by the system.
///
/// - Parameters: /// - Parameters:
/// - tunnelConfiguration: tunnel configuration /// - networkSettings: an instance of type `NEPacketTunnelNetworkSettings`.
/// - completionHandler: completion handler /// - Throws: an error of type `WireGuardAdapterError`.
private func updateNetworkSettings(tunnelConfiguration: TunnelConfiguration, completionHandler: @escaping (PacketTunnelSettingsGenerator?, WireGuardAdapterError?) -> Void) { /// - Returns: `PacketTunnelSettingsGenerator`.
let resolvedEndpoints: [Endpoint?] private func setNetworkSettings(_ networkSettings: NEPacketTunnelNetworkSettings) throws {
let resolvePeersResult = Result { try self.resolvePeers(for: tunnelConfiguration) }
.mapError { error -> WireGuardAdapterError in
// swiftlint:disable:next force_cast
return error as! WireGuardAdapterError
}
switch resolvePeersResult {
case .success(let endpoints):
resolvedEndpoints = endpoints
case .failure(let error):
completionHandler(nil, error)
return
}
let settingsGenerator = PacketTunnelSettingsGenerator(tunnelConfiguration: tunnelConfiguration, resolvedEndpoints: resolvedEndpoints)
let networkSettings = settingsGenerator.generateNetworkSettings()
var systemError: Error? var systemError: Error?
let condition = NSCondition() let condition = NSCondition()
@ -287,16 +301,11 @@ public class WireGuardAdapter {
let setTunnelNetworkSettingsTimeout: TimeInterval = 5 // seconds let setTunnelNetworkSettingsTimeout: TimeInterval = 5 // seconds
if condition.wait(until: Date().addingTimeInterval(setTunnelNetworkSettingsTimeout)) { if condition.wait(until: Date().addingTimeInterval(setTunnelNetworkSettingsTimeout)) {
let returnError = systemError.map { WireGuardAdapterError.setNetworkSettings($0) } if let systemError = systemError {
throw WireGuardAdapterError.setNetworkSettings(systemError)
// Only assign `settingsGenerator` when `setTunnelNetworkSettings` succeeded.
if returnError == nil {
self.settingsGenerator = settingsGenerator
} }
completionHandler(settingsGenerator, returnError)
} else { } else {
completionHandler(nil, .setNetworkSettingsTimeout) throw WireGuardAdapterError.setNetworkSettingsTimeout
} }
} }
@ -327,24 +336,97 @@ public class WireGuardAdapter {
return resolvedEndpoints return resolvedEndpoints
} }
/// Start WireGuard backend.
/// - Parameter wgConfig: WireGuard configuration
/// - Throws: an error of type `WireGuardAdapterError`
/// - Returns: tunnel handle
private func startWireGuardBackend(wgConfig: String) throws -> Int32 {
guard let tunnelFileDescriptor = self.tunnelFileDescriptor else {
throw WireGuardAdapterError.cannotLocateTunnelFileDescriptor
}
let handle = wgTurnOn(wgConfig, tunnelFileDescriptor)
if handle >= 0 {
return handle
} else {
throw WireGuardAdapterError.startWireGuardBackend(handle)
}
}
/// Resolves the hostnames in the given tunnel configuration and return settings generator.
/// - Parameter tunnelConfiguration: an instance of type `TunnelConfiguration`.
/// - Throws: an error of type `WireGuardAdapterError`.
/// - Returns: an instance of type `PacketTunnelSettingsGenerator`.
private func makeSettingsGenerator(with tunnelConfiguration: TunnelConfiguration) throws -> PacketTunnelSettingsGenerator {
return PacketTunnelSettingsGenerator(
tunnelConfiguration: tunnelConfiguration,
resolvedEndpoints: try self.resolvePeers(for: tunnelConfiguration)
)
}
/// Log DNS resolution results.
/// - Parameter resolutionErrors: an array of type `[DNSResolutionError]`.
private func logEndpointResolutionResults(_ resolutionResults: [EndpointResolutionResult?]) {
for case .some(let result) in resolutionResults {
switch result {
case .success((let sourceEndpoint, let resolvedEndpoint)):
if sourceEndpoint.host == resolvedEndpoint.host {
self.logHandler(.debug, "DNS64: mapped \(sourceEndpoint.host) to itself.")
} else {
self.logHandler(.debug, "DNS64: mapped \(sourceEndpoint.host) to \(resolvedEndpoint.host)")
}
case .failure(let resolutionError):
self.logHandler(.error, "Failed to resolve endpoint \(resolutionError.address): \(resolutionError.errorDescription ?? "(nil)")")
}
}
}
/// Helper method used by network path monitor. /// Helper method used by network path monitor.
/// - Parameter path: new network path /// - Parameter path: new network path
private func didReceivePathUpdate(path: Network.NWPath) { private func didReceivePathUpdate(path: Network.NWPath) {
guard let handle = self.wireguardHandle else { return }
self.logHandler(.debug, "Network change detected with \(path.status) route and interface order \(path.availableInterfaces)") self.logHandler(.debug, "Network change detected with \(path.status) route and interface order \(path.availableInterfaces)")
#if os(iOS) switch self.state {
if let settingsGenerator = self.settingsGenerator { case .started(let handle, let settingsGenerator):
let (wgSettings, resolutionErrors) = settingsGenerator.endpointUapiConfiguration() if path.status.isSatisfiable {
for error in resolutionErrors { #if os(iOS)
self.logHandler(.error, "Failed to re-resolve \(error.address): \(error.errorDescription ?? "(nil)")") let (wgConfig, resolutionResults) = settingsGenerator.endpointUapiConfiguration()
} self.logEndpointResolutionResults(resolutionResults)
wgSetConfig(handle, wgSettings)
}
#endif
wgBumpSockets(handle) wgSetConfig(handle, wgConfig)
#endif
wgBumpSockets(handle)
} else {
self.logHandler(.info, "Connectivity offline, pausing backend.")
self.state = .temporaryShutdown(settingsGenerator)
wgTurnOff(handle)
}
case .temporaryShutdown(let settingsGenerator):
guard path.status.isSatisfiable else { return }
self.logHandler(.info, "Connectivity online, resuming backend.")
do {
try self.setNetworkSettings(settingsGenerator.generateNetworkSettings())
let (wgConfig, resolutionResults) = settingsGenerator.uapiConfiguration()
self.logEndpointResolutionResults(resolutionResults)
self.state = .started(
try self.startWireGuardBackend(wgConfig: wgConfig),
settingsGenerator
)
} catch {
self.logHandler(.error, "Failed to restart backend: \(error.localizedDescription)")
}
case .stopped:
// no-op
break
}
} }
} }
@ -354,3 +436,17 @@ public enum WireGuardLogLevel: Int32 {
case info = 1 case info = 1
case error = 2 case error = 2
} }
private extension Network.NWPath.Status {
/// Returns `true` if the path is potentially satisfiable.
var isSatisfiable: Bool {
switch self {
case .requiresConnection, .satisfied:
return true
case .unsatisfied:
return false
@unknown default:
return true
}
}
}