diff --git a/WireGuard/WireGuard/Tunnel/TunnelsManager.swift b/WireGuard/WireGuard/Tunnel/TunnelsManager.swift index 5640e6c..3efadb5 100644 --- a/WireGuard/WireGuard/Tunnel/TunnelsManager.swift +++ b/WireGuard/WireGuard/Tunnel/TunnelsManager.swift @@ -25,10 +25,12 @@ class TunnelsManager { weak var activationDelegate: TunnelsManagerActivationDelegate? private var statusObservationToken: AnyObject? private var waiteeObservationToken: AnyObject? + private var configurationsObservationToken: AnyObject? init(tunnelProviders: [NETunnelProviderManager]) { tunnels = tunnelProviders.map { TunnelContainer(tunnel: $0) }.sorted { $0.name < $1.name } startObservingTunnelStatuses() + startObservingTunnelConfigurations() } static func create(completionHandler: @escaping (WireGuardResult) -> Void) { @@ -53,26 +55,33 @@ class TunnelsManager { #endif } - func reload(completionHandler: @escaping (Bool) -> Void) { - #if targetEnvironment(simulator) - completionHandler(false) - #else - NETunnelProviderManager.loadAllFromPreferences { managers, _ in - guard let managers = managers else { - completionHandler(false) - return - } + func reload() { + NETunnelProviderManager.loadAllFromPreferences { [weak self] managers, _ in + guard let self = self else { return } - let newTunnels = managers.map { TunnelContainer(tunnel: $0) }.sorted { $0.name < $1.name } - let hasChanges = self.tunnels.map { $0.tunnelConfiguration } != newTunnels.map { $0.tunnelConfiguration } - if hasChanges { - self.tunnels = newTunnels - completionHandler(true) - } else { - completionHandler(false) + let loadedTunnelProviders = managers ?? [] + + var numberOfRemovedTunnels = 0 + for (index, currentTunnel) in self.tunnels.enumerated() { + if !loadedTunnelProviders.contains(where: { $0.tunnelConfiguration == currentTunnel.tunnelConfiguration }) { + // Tunnel was deleted outside the app + self.tunnels.remove(at: index - numberOfRemovedTunnels) + self.tunnelsListDelegate?.tunnelRemoved(at: index - numberOfRemovedTunnels) + numberOfRemovedTunnels += 1 + } + } + for loadedTunnelProvider in loadedTunnelProviders { + if let matchingTunnel = self.tunnels.first(where: { $0.tunnelConfiguration == loadedTunnelProvider.tunnelConfiguration }) { + matchingTunnel.tunnelProvider = loadedTunnelProvider + } else { + // Tunnel was added outside the app + let tunnel = TunnelContainer(tunnel: loadedTunnelProvider) + self.tunnels.append(tunnel) + self.tunnels.sort { $0.name < $1.name } + self.tunnelsListDelegate?.tunnelAdded(at: self.tunnels.firstIndex(of: tunnel)!) + } } } - #endif } func add(tunnelConfiguration: TunnelConfiguration, activateOnDemandSetting: ActivateOnDemandSetting = ActivateOnDemandSetting.defaultSetting, completionHandler: @escaping (WireGuardResult) -> Void) { @@ -319,6 +328,12 @@ class TunnelsManager { } } + func startObservingTunnelConfigurations() { + configurationsObservationToken = NotificationCenter.default.addObserver(forName: .NEVPNConfigurationChange, object: nil, queue: OperationQueue.main) { [weak self] _ in + self?.reload() + } + } + } private func lastErrorTextFromNetworkExtension(for tunnel: TunnelContainer) -> (title: String, message: String)? { @@ -367,7 +382,7 @@ class TunnelContainer: NSObject { fileprivate var tunnelProvider: NETunnelProviderManager var tunnelConfiguration: TunnelConfiguration? { - return (tunnelProvider.protocolConfiguration as? NETunnelProviderProtocol)?.asTunnelConfiguration(called: tunnelProvider.localizedDescription) + return tunnelProvider.tunnelConfiguration } var activateOnDemandSetting: ActivateOnDemandSetting { @@ -461,3 +476,9 @@ class TunnelContainer: NSObject { (tunnelProvider.connection as? NETunnelProviderSession)?.stopTunnel() } } + +extension NETunnelProviderManager { + var tunnelConfiguration: TunnelConfiguration? { + return (protocolConfiguration as? NETunnelProviderProtocol)?.asTunnelConfiguration(called: localizedDescription) + } +}