From cab80f8fc0622a78932e92b86ec020ac8e3d590e Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Fri, 21 Dec 2018 15:56:03 +0100 Subject: [PATCH] NE: simplify logic --- .../PacketTunnelProvider.swift | 63 +++++++------------ .../PacketTunnelSettingsGenerator.swift | 10 +-- 2 files changed, 27 insertions(+), 46 deletions(-) diff --git a/WireGuard/WireGuardNetworkExtension/PacketTunnelProvider.swift b/WireGuard/WireGuardNetworkExtension/PacketTunnelProvider.swift index c418ebc..3a9066d 100644 --- a/WireGuard/WireGuardNetworkExtension/PacketTunnelProvider.swift +++ b/WireGuard/WireGuardNetworkExtension/PacketTunnelProvider.swift @@ -16,8 +16,9 @@ enum PacketTunnelProviderError: Error { class PacketTunnelProvider: NEPacketTunnelProvider { private var wgHandle: Int32? - private var networkMonitor: NWPathMonitor? + private var lastFirstInterface: NWInterface? + private var packetTunnelSettingsGenerator: PacketTunnelSettingsGenerator? deinit { networkMonitor?.cancel() @@ -65,7 +66,7 @@ class PacketTunnelProvider: NEPacketTunnelProvider { } assert(endpoints.count == resolvedEndpoints.count) - let packetTunnelSettingsGenerator = PacketTunnelSettingsGenerator(tunnelConfiguration: tunnelConfiguration, resolvedEndpoints: resolvedEndpoints) + packetTunnelSettingsGenerator = PacketTunnelSettingsGenerator(tunnelConfiguration: tunnelConfiguration, resolvedEndpoints: resolvedEndpoints) let fileDescriptor = packetFlow.value(forKeyPath: "socket.fileDescriptor") as! Int32 //swiftlint:disable:this force_cast if fileDescriptor < 0 { @@ -75,52 +76,23 @@ class PacketTunnelProvider: NEPacketTunnelProvider { return } - let wireguardSettings = packetTunnelSettingsGenerator.uapiConfiguration() - - var handle: Int32 = -1 - - func interfaceDescription(_ interface: NWInterface?) -> String { - if let interface = interface { - return "\(interface.name) (\(interface.type))" - } else { - return "None" - } - } + let wireguardSettings = packetTunnelSettingsGenerator!.uapiConfiguration() networkMonitor = NWPathMonitor() - var previousPrimaryNetworkPathInterface = networkMonitor?.currentPath.availableInterfaces.first - wg_log(.debug, message: "Network path primary interface: \(interfaceDescription(previousPrimaryNetworkPathInterface))") - networkMonitor?.pathUpdateHandler = { path in - guard handle >= 0 else { return } - if path.status == .satisfied { - wg_log(.debug, message: "Network change detected, re-establishing sockets and IPs: \(path.availableInterfaces)") - let primaryNetworkPathInterface = path.availableInterfaces.first - wg_log(.debug, message: "Network path primary interface: \(interfaceDescription(primaryNetworkPathInterface))") - let shouldIncludeListenPort = previousPrimaryNetworkPathInterface != primaryNetworkPathInterface - let endpointString = packetTunnelSettingsGenerator.endpointUapiConfiguration(shouldIncludeListenPort: shouldIncludeListenPort, currentListenPort: wgGetListenPort(handle)) - let err = withStringsAsGoStrings(endpointString, call: { return wgSetConfig(handle, $0.0) }) - if err == -EADDRINUSE { - // We expect this to happen only if shouldIncludeListenPort is true - let endpointString = packetTunnelSettingsGenerator.endpointUapiConfiguration(shouldIncludeListenPort: shouldIncludeListenPort, currentListenPort: 0) - _ = withStringsAsGoStrings(endpointString, call: { return wgSetConfig(handle, $0.0) }) - } - previousPrimaryNetworkPathInterface = primaryNetworkPathInterface - } - } - networkMonitor?.start(queue: DispatchQueue(label: "NetworkMonitor")) - - handle = connect(interfaceName: tunnelConfiguration.interface.name, settings: wireguardSettings, fileDescriptor: fileDescriptor) + lastFirstInterface = networkMonitor!.currentPath.availableInterfaces.first + networkMonitor!.pathUpdateHandler = pathUpdate + networkMonitor!.start(queue: DispatchQueue(label: "NetworkMonitor")) + let handle = withStringsAsGoStrings(tunnelConfiguration.interface.name, wireguardSettings) { return wgTurnOn($0.0, $0.1, fileDescriptor) } if handle < 0 { wg_log(.error, staticMessage: "Starting tunnel failed: Could not start WireGuard") errorNotifier.notify(PacketTunnelProviderError.couldNotStartWireGuard) startTunnelCompletionHandler(PacketTunnelProviderError.couldNotStartWireGuard) return } - wgHandle = handle - let networkSettings: NEPacketTunnelNetworkSettings = packetTunnelSettingsGenerator.generateNetworkSettings() + let networkSettings: NEPacketTunnelNetworkSettings = packetTunnelSettingsGenerator!.generateNetworkSettings() setTunnelNetworkSettings(networkSettings) { error in if let error = error { wg_log(.error, staticMessage: "Starting tunnel failed: Error setting network settings.") @@ -165,8 +137,21 @@ class PacketTunnelProvider: NEPacketTunnelProvider { } } - private func connect(interfaceName: String, settings: String, fileDescriptor: Int32) -> Int32 { - return withStringsAsGoStrings(interfaceName, settings) { return wgTurnOn($0.0, $0.1, fileDescriptor) } + private func pathUpdate(path: Network.NWPath) { + guard let handle = wgHandle, let packetTunnelSettingsGenerator = packetTunnelSettingsGenerator else { return } + var listenPort: UInt16? + if path.availableInterfaces.isEmpty || lastFirstInterface != path.availableInterfaces.first { + listenPort = wgGetListenPort(handle) + lastFirstInterface = path.availableInterfaces.first + } + guard path.status == .satisfied else { return } + wg_log(.debug, message: "Network change detected, re-establishing sockets and IPs: \(path.availableInterfaces)") + let endpointString = packetTunnelSettingsGenerator.endpointUapiConfiguration(currentListenPort: listenPort) + let err = withStringsAsGoStrings(endpointString, call: { return wgSetConfig(handle, $0.0) }) + if err == -EADDRINUSE && listenPort != nil { + let endpointString = packetTunnelSettingsGenerator.endpointUapiConfiguration(currentListenPort: 0) + _ = withStringsAsGoStrings(endpointString, call: { return wgSetConfig(handle, $0.0) }) + } } } diff --git a/WireGuard/WireGuardNetworkExtension/PacketTunnelSettingsGenerator.swift b/WireGuard/WireGuardNetworkExtension/PacketTunnelSettingsGenerator.swift index 888769d..fd706d9 100644 --- a/WireGuard/WireGuardNetworkExtension/PacketTunnelSettingsGenerator.swift +++ b/WireGuard/WireGuardNetworkExtension/PacketTunnelSettingsGenerator.swift @@ -15,15 +15,11 @@ class PacketTunnelSettingsGenerator { self.resolvedEndpoints = resolvedEndpoints } - func endpointUapiConfiguration(shouldIncludeListenPort: Bool, currentListenPort: UInt16?) -> String { + func endpointUapiConfiguration(currentListenPort: UInt16?) -> String { var wgSettings = "" - if shouldIncludeListenPort { - if let tunnelListenPort = tunnelConfiguration.interface.listenPort { - wgSettings.append("listen_port=\(tunnelListenPort)\n") - } else if let currentListenPort = currentListenPort { - wgSettings.append("listen_port=\(currentListenPort)\n") - } + if let currentListenPort = currentListenPort { + wgSettings.append("listen_port=\(tunnelConfiguration.interface.listenPort ?? currentListenPort)\n") } for (index, peer) in tunnelConfiguration.peers.enumerated() {