diff --git a/Sources/WireGuardKit/DNSResolver.swift b/Sources/WireGuardKit/DNSResolver.swift index 5315c94..7a0f2e9 100644 --- a/Sources/WireGuardKit/DNSResolver.swift +++ b/Sources/WireGuardKit/DNSResolver.swift @@ -109,7 +109,7 @@ extension Endpoint { hints.ai_family = AF_UNSPEC hints.ai_socktype = SOCK_DGRAM hints.ai_protocol = IPPROTO_UDP - hints.ai_flags = AI_DEFAULT + hints.ai_flags = 0 // We set this to zero so that we actually resolve this using DNS64 var result: UnsafeMutablePointer? defer { diff --git a/Sources/WireGuardKit/PacketTunnelSettingsGenerator.swift b/Sources/WireGuardKit/PacketTunnelSettingsGenerator.swift index 9efe1fa..0ddc1b7 100644 --- a/Sources/WireGuardKit/PacketTunnelSettingsGenerator.swift +++ b/Sources/WireGuardKit/PacketTunnelSettingsGenerator.swift @@ -9,6 +9,9 @@ import NetworkExtension import WireGuardKitC #endif +/// A type alias for `Result` type that holds a tuple with source and resolved endpoint. +typealias EndpointResolutionResult = Result<(Endpoint, Endpoint), DNSResolutionError> + class PacketTunnelSettingsGenerator { let tunnelConfiguration: TunnelConfiguration let resolvedEndpoints: [Endpoint?] @@ -18,31 +21,27 @@ class PacketTunnelSettingsGenerator { self.resolvedEndpoints = resolvedEndpoints } - func endpointUapiConfiguration() -> (String, [DNSResolutionError]) { - var resolutionErrors = [DNSResolutionError]() + func endpointUapiConfiguration() -> (String, [EndpointResolutionResult?]) { + var resolutionResults = [EndpointResolutionResult?]() var wgSettings = "" - for (index, peer) in tunnelConfiguration.peers.enumerated() { - wgSettings.append("public_key=\(peer.publicKey.hexKey)\n") - let result = Result { try resolvedEndpoints[index]?.withReresolvedIP() } - .mapError { error -> DNSResolutionError in - // swiftlint:disable:next force_cast - return error as! DNSResolutionError - } - switch result { - case .success(.some(let endpoint)): - if case .name = endpoint.host { assert(false, "Endpoint is not resolved") } - wgSettings.append("endpoint=\(endpoint.stringRepresentation)\n") - case .success(.none): - break - case .failure(let error): - resolutionErrors.append(error) + assert(tunnelConfiguration.peers.count == resolvedEndpoints.count) + for (peer, resolvedEndpoint) in zip(self.tunnelConfiguration.peers, self.resolvedEndpoints) { + wgSettings.append("public_key=\(peer.publicKey.hexKey)\n") + + let result = resolvedEndpoint.map(Self.reresolveEndpoint) + if case .success((_, let resolvedEndpoint)) = result { + if case .name = resolvedEndpoint.host { assert(false, "Endpoint is not resolved") } + wgSettings.append("endpoint=\(resolvedEndpoint.stringRepresentation)\n") } + resolutionResults.append(result) } - return (wgSettings, resolutionErrors) + + return (wgSettings, resolutionResults) } - func uapiConfiguration() -> String { + func uapiConfiguration() -> (String, [EndpointResolutionResult?]) { + var resolutionResults = [EndpointResolutionResult?]() var wgSettings = "" wgSettings.append("private_key=\(tunnelConfiguration.interface.privateKey.hexKey)\n") if let listenPort = tunnelConfiguration.interface.listenPort { @@ -52,15 +51,19 @@ class PacketTunnelSettingsGenerator { wgSettings.append("replace_peers=true\n") } assert(tunnelConfiguration.peers.count == resolvedEndpoints.count) - for (index, peer) in tunnelConfiguration.peers.enumerated() { + for (peer, resolvedEndpoint) in zip(self.tunnelConfiguration.peers, self.resolvedEndpoints) { wgSettings.append("public_key=\(peer.publicKey.hexKey)\n") if let preSharedKey = peer.preSharedKey?.hexKey { wgSettings.append("preshared_key=\(preSharedKey)\n") } - if let endpoint = try? resolvedEndpoints[index]?.withReresolvedIP() { - if case .name = endpoint.host { assert(false, "Endpoint is not resolved") } - wgSettings.append("endpoint=\(endpoint.stringRepresentation)\n") + + let result = resolvedEndpoint.map(Self.reresolveEndpoint) + if case .success((_, let resolvedEndpoint)) = result { + if case .name = resolvedEndpoint.host { assert(false, "Endpoint is not resolved") } + wgSettings.append("endpoint=\(resolvedEndpoint.stringRepresentation)\n") } + resolutionResults.append(result) + let persistentKeepAlive = peer.persistentKeepAlive ?? 0 wgSettings.append("persistent_keepalive_interval=\(persistentKeepAlive)\n") if !peer.allowedIPs.isEmpty { @@ -68,7 +71,7 @@ class PacketTunnelSettingsGenerator { peer.allowedIPs.forEach { wgSettings.append("allowed_ip=\($0.stringRepresentation)\n") } } } - return wgSettings + return (wgSettings, resolutionResults) } func generateNetworkSettings() -> NEPacketTunnelNetworkSettings { @@ -163,4 +166,12 @@ class PacketTunnelSettingsGenerator { } return (ipv4IncludedRoutes, ipv6IncludedRoutes) } + + private class func reresolveEndpoint(endpoint: Endpoint) -> EndpointResolutionResult { + return Result { (endpoint, try endpoint.withReresolvedIP()) } + .mapError { error -> DNSResolutionError in + // swiftlint:disable:next force_cast + return error as! DNSResolutionError + } + } } diff --git a/Sources/WireGuardKit/WireGuardAdapter.swift b/Sources/WireGuardKit/WireGuardAdapter.swift index 113c06f..bf885c2 100644 --- a/Sources/WireGuardKit/WireGuardAdapter.swift +++ b/Sources/WireGuardKit/WireGuardAdapter.swift @@ -28,6 +28,18 @@ public enum WireGuardAdapterError: Error { case startWireGuardBackend(Int32) } +/// Enum representing internal state of the `WireGuardAdapter` +private enum State { + /// The tunnel is stopped + case stopped + + /// The tunnel is up and running + case started(_ handle: Int32, _ settingsGenerator: PacketTunnelSettingsGenerator) + + /// The tunnel is temporarily shutdown due to device going offline + case temporaryShutdown(_ settingsGenerator: PacketTunnelSettingsGenerator) +} + public class WireGuardAdapter { public typealias LogHandler = (WireGuardLogLevel, String) -> Void @@ -40,15 +52,11 @@ public class WireGuardAdapter { /// Log handler closure. private let logHandler: LogHandler - /// WireGuard internal handle returned by `wgTurnOn` that's used to associate the calls - /// with the specific WireGuard tunnel. - private var wireguardHandle: Int32? - /// Private queue used to synchronize access to `WireGuardAdapter` members. private let workQueue = DispatchQueue(label: "WireGuardAdapterWorkQueue") - /// Packet tunnel settings generator. - private var settingsGenerator: PacketTunnelSettingsGenerator? + /// Adapter state. + private var state: State = .stopped /// Tunnel device file descriptor. private var tunnelFileDescriptor: Int32? { @@ -108,7 +116,7 @@ public class WireGuardAdapter { networkMonitor?.cancel() // Shutdown the tunnel - if let handle = self.wireguardHandle { + if case .started(let handle, _) = self.state { wgTurnOff(handle) } } @@ -119,7 +127,7 @@ public class WireGuardAdapter { /// - Parameter completionHandler: completion handler. public func getRuntimeConfiguration(completionHandler: @escaping (String?) -> Void) { workQueue.async { - guard let handle = self.wireguardHandle else { + guard case .started(let handle, _) = self.state else { completionHandler(nil) return } @@ -139,16 +147,11 @@ public class WireGuardAdapter { /// - completionHandler: completion handler. public func start(tunnelConfiguration: TunnelConfiguration, completionHandler: @escaping (WireGuardAdapterError?) -> Void) { workQueue.async { - guard self.wireguardHandle == nil else { + guard case .stopped = self.state else { completionHandler(.invalidState) return } - guard let tunnelFileDescriptor = self.tunnelFileDescriptor else { - completionHandler(.cannotLocateTunnelFileDescriptor) - return - } - #if os(macOS) wgEnableRoaming(true) #endif @@ -157,25 +160,26 @@ public class WireGuardAdapter { networkMonitor.pathUpdateHandler = { [weak self] path in self?.didReceivePathUpdate(path: path) } - networkMonitor.start(queue: self.workQueue) - self.networkMonitor = networkMonitor - self.updateNetworkSettings(tunnelConfiguration: tunnelConfiguration) { settingsGenerator, error in - if let error = error { - completionHandler(error) - } else { - var returnError: WireGuardAdapterError? - let handle = wgTurnOn(settingsGenerator!.uapiConfiguration(), tunnelFileDescriptor) + do { + let settingsGenerator = try self.makeSettingsGenerator(with: tunnelConfiguration) + try self.setNetworkSettings(settingsGenerator.generateNetworkSettings()) - if handle >= 0 { - self.wireguardHandle = handle - } else { - returnError = .startWireGuardBackend(handle) - } + let (wgConfig, resolutionResults) = settingsGenerator.uapiConfiguration() + self.logEndpointResolutionResults(resolutionResults) - completionHandler(returnError) - } + self.state = .started( + try self.startWireGuardBackend(wgConfig: wgConfig), + settingsGenerator + ) + self.networkMonitor = networkMonitor + completionHandler(nil) + } catch let error as WireGuardAdapterError { + networkMonitor.cancel() + completionHandler(error) + } catch { + fatalError() } } } @@ -184,7 +188,14 @@ public class WireGuardAdapter { /// - Parameter completionHandler: completion handler. public func stop(completionHandler: @escaping (WireGuardAdapterError?) -> Void) { workQueue.async { - guard let handle = self.wireguardHandle else { + switch self.state { + case .started(let handle, _): + wgTurnOff(handle) + + case .temporaryShutdown: + break + + case .stopped: completionHandler(.invalidState) return } @@ -192,8 +203,7 @@ public class WireGuardAdapter { self.networkMonitor?.cancel() self.networkMonitor = nil - wgTurnOff(handle) - self.wireguardHandle = nil + self.state = .stopped completionHandler(nil) } @@ -205,7 +215,7 @@ public class WireGuardAdapter { /// - completionHandler: completion handler. public func update(tunnelConfiguration: TunnelConfiguration, completionHandler: @escaping (WireGuardAdapterError?) -> Void) { workQueue.async { - guard let handle = self.wireguardHandle else { + if case .stopped = self.state { completionHandler(.invalidState) return } @@ -214,16 +224,35 @@ public class WireGuardAdapter { // configuration. // This will broadcast the `NEVPNStatusDidChange` notification to the GUI process. self.packetTunnelProvider?.reasserting = true + defer { + self.packetTunnelProvider?.reasserting = false + } - self.updateNetworkSettings(tunnelConfiguration: tunnelConfiguration) { settingsGenerator, error in - if let error = error { - completionHandler(error) - } else { - wgSetConfig(handle, settingsGenerator!.uapiConfiguration()) - completionHandler(nil) + do { + let settingsGenerator = try self.makeSettingsGenerator(with: tunnelConfiguration) + try self.setNetworkSettings(settingsGenerator.generateNetworkSettings()) + + switch self.state { + case .started(let handle, _): + let (wgConfig, resolutionResults) = settingsGenerator.uapiConfiguration() + self.logEndpointResolutionResults(resolutionResults) + + wgSetConfig(handle, wgConfig) + + self.state = .started(handle, settingsGenerator) + + case .temporaryShutdown: + self.state = .temporaryShutdown(settingsGenerator) + + case .stopped: + fatalError() } - self.packetTunnelProvider?.reasserting = false + completionHandler(nil) + } catch let error as WireGuardAdapterError { + completionHandler(error) + } catch { + fatalError() } } } @@ -246,30 +275,15 @@ public class WireGuardAdapter { } } - /// Resolve endpoints and update network configuration. + /// Set network tunnel configuration. + /// This method ensures that the call to `setTunnelNetworkSettings` does not time out, as in + /// certain scenarios the completion handler given to it may not be invoked by the system. + /// /// - Parameters: - /// - tunnelConfiguration: tunnel configuration - /// - completionHandler: completion handler - private func updateNetworkSettings(tunnelConfiguration: TunnelConfiguration, completionHandler: @escaping (PacketTunnelSettingsGenerator?, WireGuardAdapterError?) -> Void) { - let resolvedEndpoints: [Endpoint?] - - let resolvePeersResult = Result { try self.resolvePeers(for: tunnelConfiguration) } - .mapError { error -> WireGuardAdapterError in - // swiftlint:disable:next force_cast - return error as! WireGuardAdapterError - } - - switch resolvePeersResult { - case .success(let endpoints): - resolvedEndpoints = endpoints - case .failure(let error): - completionHandler(nil, error) - return - } - - let settingsGenerator = PacketTunnelSettingsGenerator(tunnelConfiguration: tunnelConfiguration, resolvedEndpoints: resolvedEndpoints) - let networkSettings = settingsGenerator.generateNetworkSettings() - + /// - networkSettings: an instance of type `NEPacketTunnelNetworkSettings`. + /// - Throws: an error of type `WireGuardAdapterError`. + /// - Returns: `PacketTunnelSettingsGenerator`. + private func setNetworkSettings(_ networkSettings: NEPacketTunnelNetworkSettings) throws { var systemError: Error? let condition = NSCondition() @@ -287,16 +301,11 @@ public class WireGuardAdapter { let setTunnelNetworkSettingsTimeout: TimeInterval = 5 // seconds if condition.wait(until: Date().addingTimeInterval(setTunnelNetworkSettingsTimeout)) { - let returnError = systemError.map { WireGuardAdapterError.setNetworkSettings($0) } - - // Only assign `settingsGenerator` when `setTunnelNetworkSettings` succeeded. - if returnError == nil { - self.settingsGenerator = settingsGenerator + if let systemError = systemError { + throw WireGuardAdapterError.setNetworkSettings(systemError) } - - completionHandler(settingsGenerator, returnError) } else { - completionHandler(nil, .setNetworkSettingsTimeout) + throw WireGuardAdapterError.setNetworkSettingsTimeout } } @@ -327,24 +336,97 @@ public class WireGuardAdapter { return resolvedEndpoints } + /// Start WireGuard backend. + /// - Parameter wgConfig: WireGuard configuration + /// - Throws: an error of type `WireGuardAdapterError` + /// - Returns: tunnel handle + private func startWireGuardBackend(wgConfig: String) throws -> Int32 { + guard let tunnelFileDescriptor = self.tunnelFileDescriptor else { + throw WireGuardAdapterError.cannotLocateTunnelFileDescriptor + } + + let handle = wgTurnOn(wgConfig, tunnelFileDescriptor) + if handle >= 0 { + return handle + } else { + throw WireGuardAdapterError.startWireGuardBackend(handle) + } + } + + /// Resolves the hostnames in the given tunnel configuration and return settings generator. + /// - Parameter tunnelConfiguration: an instance of type `TunnelConfiguration`. + /// - Throws: an error of type `WireGuardAdapterError`. + /// - Returns: an instance of type `PacketTunnelSettingsGenerator`. + private func makeSettingsGenerator(with tunnelConfiguration: TunnelConfiguration) throws -> PacketTunnelSettingsGenerator { + return PacketTunnelSettingsGenerator( + tunnelConfiguration: tunnelConfiguration, + resolvedEndpoints: try self.resolvePeers(for: tunnelConfiguration) + ) + } + + /// Log DNS resolution results. + /// - Parameter resolutionErrors: an array of type `[DNSResolutionError]`. + private func logEndpointResolutionResults(_ resolutionResults: [EndpointResolutionResult?]) { + for case .some(let result) in resolutionResults { + switch result { + case .success((let sourceEndpoint, let resolvedEndpoint)): + if sourceEndpoint.host == resolvedEndpoint.host { + self.logHandler(.debug, "DNS64: mapped \(sourceEndpoint.host) to itself.") + } else { + self.logHandler(.debug, "DNS64: mapped \(sourceEndpoint.host) to \(resolvedEndpoint.host)") + } + case .failure(let resolutionError): + self.logHandler(.error, "Failed to resolve endpoint \(resolutionError.address): \(resolutionError.errorDescription ?? "(nil)")") + } + } + } + /// Helper method used by network path monitor. /// - Parameter path: new network path private func didReceivePathUpdate(path: Network.NWPath) { - guard let handle = self.wireguardHandle else { return } - self.logHandler(.debug, "Network change detected with \(path.status) route and interface order \(path.availableInterfaces)") - #if os(iOS) - if let settingsGenerator = self.settingsGenerator { - let (wgSettings, resolutionErrors) = settingsGenerator.endpointUapiConfiguration() - for error in resolutionErrors { - self.logHandler(.error, "Failed to re-resolve \(error.address): \(error.errorDescription ?? "(nil)")") - } - wgSetConfig(handle, wgSettings) - } - #endif + switch self.state { + case .started(let handle, let settingsGenerator): + if path.status.isSatisfiable { + #if os(iOS) + let (wgConfig, resolutionResults) = settingsGenerator.endpointUapiConfiguration() + self.logEndpointResolutionResults(resolutionResults) - wgBumpSockets(handle) + wgSetConfig(handle, wgConfig) + #endif + + wgBumpSockets(handle) + } else { + self.logHandler(.info, "Connectivity offline, pausing backend.") + + self.state = .temporaryShutdown(settingsGenerator) + wgTurnOff(handle) + } + + case .temporaryShutdown(let settingsGenerator): + guard path.status.isSatisfiable else { return } + + self.logHandler(.info, "Connectivity online, resuming backend.") + + do { + try self.setNetworkSettings(settingsGenerator.generateNetworkSettings()) + + let (wgConfig, resolutionResults) = settingsGenerator.uapiConfiguration() + self.logEndpointResolutionResults(resolutionResults) + + self.state = .started( + try self.startWireGuardBackend(wgConfig: wgConfig), + settingsGenerator + ) + } catch { + self.logHandler(.error, "Failed to restart backend: \(error.localizedDescription)") + } + + case .stopped: + // no-op + break + } } } @@ -354,3 +436,17 @@ public enum WireGuardLogLevel: Int32 { case info = 1 case error = 2 } + +private extension Network.NWPath.Status { + /// Returns `true` if the path is potentially satisfiable. + var isSatisfiable: Bool { + switch self { + case .requiresConnection, .satisfied: + return true + case .unsatisfied: + return false + @unknown default: + return true + } + } +}