From c8fba951ad3b6379baba6cb4468aa709d9e8e478 Mon Sep 17 00:00:00 2001 From: Roopesh Chander Date: Thu, 25 Oct 2018 15:50:27 +0530 Subject: [PATCH] VPN: Adding, modifying and deleting actual NETunnelProviderManager instances Signed-off-by: Roopesh Chander --- .../iOS/TunnelDetailTableViewController.swift | 4 +- .../iOS/TunnelEditTableViewController.swift | 6 +- .../iOS/TunnelsListTableViewController.swift | 12 +- WireGuard/WireGuard/VPN/TunnelsManager.swift | 181 ++++++++++++++---- 4 files changed, 157 insertions(+), 46 deletions(-) diff --git a/WireGuard/WireGuard/UI/iOS/TunnelDetailTableViewController.swift b/WireGuard/WireGuard/UI/iOS/TunnelDetailTableViewController.swift index 769c348..061a8b2 100644 --- a/WireGuard/WireGuard/UI/iOS/TunnelDetailTableViewController.swift +++ b/WireGuard/WireGuard/UI/iOS/TunnelDetailTableViewController.swift @@ -25,7 +25,7 @@ class TunnelDetailTableViewController: UITableViewController { init(tunnelsManager tm: TunnelsManager, tunnel t: TunnelContainer) { tunnelsManager = tm tunnel = t - tunnelViewModel = TunnelViewModel(tunnelConfiguration: t.tunnelConfiguration) + tunnelViewModel = TunnelViewModel(tunnelConfiguration: t.tunnelConfiguration()) super.init(style: .grouped) } @@ -56,7 +56,7 @@ class TunnelDetailTableViewController: UITableViewController { extension TunnelDetailTableViewController: TunnelEditTableViewControllerDelegate { func tunnelSaved(tunnel: TunnelContainer) { - tunnelViewModel = TunnelViewModel(tunnelConfiguration: tunnel.tunnelConfiguration) + tunnelViewModel = TunnelViewModel(tunnelConfiguration: tunnel.tunnelConfiguration()) self.tableView.reloadData() } func tunnelEditingCancelled() { diff --git a/WireGuard/WireGuard/UI/iOS/TunnelEditTableViewController.swift b/WireGuard/WireGuard/UI/iOS/TunnelEditTableViewController.swift index e155a61..fdd42d4 100644 --- a/WireGuard/WireGuard/UI/iOS/TunnelEditTableViewController.swift +++ b/WireGuard/WireGuard/UI/iOS/TunnelEditTableViewController.swift @@ -34,7 +34,7 @@ class TunnelEditTableViewController: UITableViewController { // Use this initializer to edit an existing tunnel. tunnelsManager = tm tunnel = t - tunnelViewModel = TunnelViewModel(tunnelConfiguration: t.tunnelConfiguration) + tunnelViewModel = TunnelViewModel(tunnelConfiguration: t.tunnelConfiguration()) super.init(style: .grouped) } @@ -92,7 +92,9 @@ class TunnelEditTableViewController: UITableViewController { self?.showErrorAlert(title: "Could not save", message: "Internal error") } else { self?.dismiss(animated: true, completion: nil) - self?.delegate?.tunnelSaved(tunnel: tunnel) + if let tunnel = tunnel { + self?.delegate?.tunnelSaved(tunnel: tunnel) + } } } } diff --git a/WireGuard/WireGuard/UI/iOS/TunnelsListTableViewController.swift b/WireGuard/WireGuard/UI/iOS/TunnelsListTableViewController.swift index 26e70ab..45e7280 100644 --- a/WireGuard/WireGuard/UI/iOS/TunnelsListTableViewController.swift +++ b/WireGuard/WireGuard/UI/iOS/TunnelsListTableViewController.swift @@ -165,8 +165,16 @@ extension TunnelsListTableViewController { // MARK: TunnelsManagerDelegate extension TunnelsListTableViewController: TunnelsManagerDelegate { - func tunnelsAdded(atIndex index: Int, numberOfTunnels: Int) { - self.tableView.insertRows(at: [IndexPath(row: index, section: 0)], with: .automatic) + func tunnelAdded(at index: Int) { + tableView.insertRows(at: [IndexPath(row: index, section: 0)], with: .automatic) + } + + func tunnelModified(at index: Int) { + tableView.reloadRows(at: [IndexPath(row: index, section: 0)], with: .automatic) + } + + func tunnelsChanged() { + tableView.reloadData() } } diff --git a/WireGuard/WireGuard/VPN/TunnelsManager.swift b/WireGuard/WireGuard/VPN/TunnelsManager.swift index 48b2767..e6c9df0 100644 --- a/WireGuard/WireGuard/VPN/TunnelsManager.swift +++ b/WireGuard/WireGuard/VPN/TunnelsManager.swift @@ -2,32 +2,26 @@ // Copyright © 2018 WireGuard LLC. All rights reserved. import Foundation - -class TunnelProviderManager { - // Mock of NETunnelProviderManager - var name: String - fileprivate var tunnelConfiguration: TunnelConfiguration - init(tunnelConfiguration: TunnelConfiguration) { - self.name = tunnelConfiguration.interface.name - self.tunnelConfiguration = tunnelConfiguration - } -} +import NetworkExtension +import os.log class TunnelContainer { - var name: String { return tunnelProvider.name } - let tunnelProvider: TunnelProviderManager - var tunnelConfiguration: TunnelConfiguration { - get { return tunnelProvider.tunnelConfiguration } - } - var index: Int - init(tunnel: TunnelProviderManager, index: Int) { + var name: String { return tunnelProvider.localizedDescription ?? "" } + fileprivate let tunnelProvider: NETunnelProviderManager + fileprivate var index: Int + init(tunnel: NETunnelProviderManager, index: Int) { self.tunnelProvider = tunnel self.index = index } + func tunnelConfiguration() -> TunnelConfiguration? { + return (tunnelProvider.protocolConfiguration as! NETunnelProviderProtocol).tunnelConfiguration() + } } protocol TunnelsManagerDelegate: class { - func tunnelsAdded(atIndex: Int, numberOfTunnels: Int) + func tunnelAdded(at: Int) + func tunnelModified(at: Int) + func tunnelsChanged() } class TunnelsManager { @@ -35,46 +29,129 @@ class TunnelsManager { var tunnels: [TunnelContainer] weak var delegate: TunnelsManagerDelegate? = nil + private var isAddingTunnel: Bool = false + private var isModifyingTunnel: Bool = false + private var isDeletingTunnel: Bool = false + enum TunnelsManagerError: Error { case tunnelsUninitialized } - init(tunnelProviders: [TunnelProviderManager]) { - var tunnels: [TunnelContainer] = [] - for (i, tunnelProvider) in tunnelProviders.enumerated() { - let tunnel = TunnelContainer(tunnel: tunnelProvider, index: i) - tunnels.append(tunnel) + init(tunnelProviders: [NETunnelProviderManager]) { + var tunnels = tunnelProviders.map { TunnelContainer(tunnel: $0, index: 0) } + tunnels.sort { $0.name < $1.name } + for i in 0 ..< tunnels.count { + tunnels[i].index = i } self.tunnels = tunnels } - static func create(completionHandler: (TunnelsManager?) -> Void) { - completionHandler(TunnelsManager(tunnelProviders: [])) + static func create(completionHandler: @escaping (TunnelsManager?) -> Void) { + NETunnelProviderManager.loadAllFromPreferences { (managers, error) in + if let error = error { + os_log("Failed to load tunnel provider managers %{public}@", log: OSLog.default, type: .debug, "\(error)") + return + } + completionHandler(TunnelsManager(tunnelProviders: managers ?? [])) + } } - func add(tunnelConfiguration: TunnelConfiguration, completionHandler: @escaping (TunnelContainer, Error?) -> Void) { - assert(!tunnelConfiguration.interface.name.isEmpty) - let tunnelProvider = TunnelProviderManager(tunnelConfiguration: tunnelConfiguration) - for tunnel in tunnels { - tunnel.index = tunnel.index + 1 + private func insertionIndexFor(tunnelName: String) -> Int { + // Wishlist: Use binary search instead + for i in 0 ..< tunnels.count { + if (tunnelName.lexicographicallyPrecedes(tunnels[i].name)) { return i } + } + return tunnels.count + } + + func add(tunnelConfiguration: TunnelConfiguration, completionHandler: @escaping (TunnelContainer?, Error?) -> Void) { + let tunnelName = tunnelConfiguration.interface.name + assert(!tunnelName.isEmpty) + + isAddingTunnel = true + let tunnelProviderManager = NETunnelProviderManager() + tunnelProviderManager.protocolConfiguration = NETunnelProviderProtocol(tunnelConfiguration: tunnelConfiguration) + tunnelProviderManager.localizedDescription = tunnelName + tunnelProviderManager.isEnabled = true + + tunnelProviderManager.saveToPreferences { [weak self] (error) in + defer { self?.isAddingTunnel = false } + guard (error == nil) else { + completionHandler(nil, error) + return + } + if let s = self { + let index = s.insertionIndexFor(tunnelName: tunnelName) + let tunnel = TunnelContainer(tunnel: tunnelProviderManager, index: index) + for i in index ..< s.tunnels.count { + s.tunnels[i].index = s.tunnels[i].index + 1 + } + s.tunnels.insert(tunnel, at: index) + s.delegate?.tunnelAdded(at: index) + completionHandler(tunnel, nil) + } } - let tunnel = TunnelContainer(tunnel: tunnelProvider, index: 0) - tunnels.insert(tunnel, at: 0) - delegate?.tunnelsAdded(atIndex: 0, numberOfTunnels: 1) - completionHandler(tunnel, nil) } func modify(tunnel: TunnelContainer, with tunnelConfiguration: TunnelConfiguration, completionHandler: @escaping (Error?) -> Void) { - tunnel.tunnelProvider.tunnelConfiguration = tunnelConfiguration - completionHandler(nil) + let tunnelName = tunnelConfiguration.interface.name + assert(!tunnelName.isEmpty) + + isModifyingTunnel = true + + let tunnelProviderManager = tunnel.tunnelProvider + let isNameChanged = (tunnelName != tunnelProviderManager.localizedDescription) + tunnelProviderManager.protocolConfiguration = NETunnelProviderProtocol(tunnelConfiguration: tunnelConfiguration) + tunnelProviderManager.localizedDescription = tunnelName + tunnelProviderManager.isEnabled = true + + tunnelProviderManager.saveToPreferences { [weak self] (error) in + defer { self?.isModifyingTunnel = false } + guard (error != nil) else { + completionHandler(error) + return + } + if let s = self { + if (isNameChanged) { + s.tunnels.remove(at: tunnel.index) + for i in tunnel.index ..< s.tunnels.count { + s.tunnels[i].index = s.tunnels[i].index - 1 + } + let index = s.insertionIndexFor(tunnelName: tunnelName) + tunnel.index = index + for i in index ..< s.tunnels.count { + s.tunnels[i].index = s.tunnels[i].index + 1 + } + s.tunnels.insert(tunnel, at: index) + s.delegate?.tunnelsChanged() + } else { + s.delegate?.tunnelModified(at: tunnel.index) + } + completionHandler(nil) + } + } } func remove(tunnel: TunnelContainer, completionHandler: @escaping (Error?) -> Void) { - for i in ((tunnel.index + 1) ..< tunnels.count) { - tunnels[i].index = tunnels[i].index + 1 + let tunnelProviderManager = tunnel.tunnelProvider + let tunnelIndex = tunnel.index + + isDeletingTunnel = true + + tunnelProviderManager.removeFromPreferences { [weak self] (error) in + defer { self?.isDeletingTunnel = false } + guard (error != nil) else { + completionHandler(error) + return + } + if let s = self { + for i in ((tunnelIndex + 1) ..< s.tunnels.count) { + s.tunnels[i].index = s.tunnels[i].index + 1 + } + s.tunnels.remove(at: tunnelIndex) + } + completionHandler(nil) } - tunnels.remove(at: tunnel.index) - completionHandler(nil) } func numberOfTunnels() -> Int { @@ -85,3 +162,27 @@ class TunnelsManager { return tunnels[index] } } + +extension NETunnelProviderProtocol { + convenience init?(tunnelConfiguration: TunnelConfiguration) { + assert(!tunnelConfiguration.interface.name.isEmpty) + guard let serializedTunnelConfiguration = try? JSONEncoder().encode(tunnelConfiguration) else { return nil } + + self.init() + + let appId = Bundle.main.bundleIdentifier! + let firstValidEndpoint = tunnelConfiguration.peers.first(where: { $0.endpoint != nil })?.endpoint + + providerBundleIdentifier = "\(appId).WireGuardNetworkExtension" + providerConfiguration = [ + "tunnelConfiguration": serializedTunnelConfiguration + ] + serverAddress = firstValidEndpoint?.stringRepresentation() ?? "Unspecified" + username = tunnelConfiguration.interface.name + } + + func tunnelConfiguration() -> TunnelConfiguration? { + guard let serializedTunnelConfiguration = providerConfiguration?["tunnelConfiguration"] as? Data else { return nil } + return try? JSONDecoder().decode(TunnelConfiguration.self, from: serializedTunnelConfiguration) + } +}