VPN: Adding, modifying and deleting actual NETunnelProviderManager instances

This commit is contained in:
Roopesh Chander 2018-10-25 15:50:27 +05:30
parent f6620fed9a
commit ac60a97dee
4 changed files with 157 additions and 46 deletions

View File

@ -25,7 +25,7 @@ class TunnelDetailTableViewController: UITableViewController {
init(tunnelsManager tm: TunnelsManager, tunnel t: TunnelContainer) { init(tunnelsManager tm: TunnelsManager, tunnel t: TunnelContainer) {
tunnelsManager = tm tunnelsManager = tm
tunnel = t tunnel = t
tunnelViewModel = TunnelViewModel(tunnelConfiguration: t.tunnelConfiguration) tunnelViewModel = TunnelViewModel(tunnelConfiguration: t.tunnelConfiguration())
super.init(style: .grouped) super.init(style: .grouped)
} }
@ -56,7 +56,7 @@ class TunnelDetailTableViewController: UITableViewController {
extension TunnelDetailTableViewController: TunnelEditTableViewControllerDelegate { extension TunnelDetailTableViewController: TunnelEditTableViewControllerDelegate {
func tunnelSaved(tunnel: TunnelContainer) { func tunnelSaved(tunnel: TunnelContainer) {
tunnelViewModel = TunnelViewModel(tunnelConfiguration: tunnel.tunnelConfiguration) tunnelViewModel = TunnelViewModel(tunnelConfiguration: tunnel.tunnelConfiguration())
self.tableView.reloadData() self.tableView.reloadData()
} }
func tunnelEditingCancelled() { func tunnelEditingCancelled() {

View File

@ -34,7 +34,7 @@ class TunnelEditTableViewController: UITableViewController {
// Use this initializer to edit an existing tunnel. // Use this initializer to edit an existing tunnel.
tunnelsManager = tm tunnelsManager = tm
tunnel = t tunnel = t
tunnelViewModel = TunnelViewModel(tunnelConfiguration: t.tunnelConfiguration) tunnelViewModel = TunnelViewModel(tunnelConfiguration: t.tunnelConfiguration())
super.init(style: .grouped) super.init(style: .grouped)
} }
@ -92,12 +92,14 @@ class TunnelEditTableViewController: UITableViewController {
self?.showErrorAlert(title: "Could not save", message: "Internal error") self?.showErrorAlert(title: "Could not save", message: "Internal error")
} else { } else {
self?.dismiss(animated: true, completion: nil) self?.dismiss(animated: true, completion: nil)
if let tunnel = tunnel {
self?.delegate?.tunnelSaved(tunnel: tunnel) self?.delegate?.tunnelSaved(tunnel: tunnel)
} }
} }
} }
} }
} }
}
@objc func cancelTapped() { @objc func cancelTapped() {
dismiss(animated: true, completion: nil) dismiss(animated: true, completion: nil)

View File

@ -165,8 +165,16 @@ extension TunnelsListTableViewController {
// MARK: TunnelsManagerDelegate // MARK: TunnelsManagerDelegate
extension TunnelsListTableViewController: TunnelsManagerDelegate { extension TunnelsListTableViewController: TunnelsManagerDelegate {
func tunnelsAdded(atIndex index: Int, numberOfTunnels: Int) { func tunnelAdded(at index: Int) {
self.tableView.insertRows(at: [IndexPath(row: index, section: 0)], with: .automatic) tableView.insertRows(at: [IndexPath(row: index, section: 0)], with: .automatic)
}
func tunnelModified(at index: Int) {
tableView.reloadRows(at: [IndexPath(row: index, section: 0)], with: .automatic)
}
func tunnelsChanged() {
tableView.reloadData()
} }
} }

View File

@ -2,32 +2,26 @@
// Copyright © 2018 WireGuard LLC. All rights reserved. // Copyright © 2018 WireGuard LLC. All rights reserved.
import Foundation import Foundation
import NetworkExtension
class TunnelProviderManager { import os.log
// Mock of NETunnelProviderManager
var name: String
fileprivate var tunnelConfiguration: TunnelConfiguration
init(tunnelConfiguration: TunnelConfiguration) {
self.name = tunnelConfiguration.interface.name
self.tunnelConfiguration = tunnelConfiguration
}
}
class TunnelContainer { class TunnelContainer {
var name: String { return tunnelProvider.name } var name: String { return tunnelProvider.localizedDescription ?? "" }
let tunnelProvider: TunnelProviderManager fileprivate let tunnelProvider: NETunnelProviderManager
var tunnelConfiguration: TunnelConfiguration { fileprivate var index: Int
get { return tunnelProvider.tunnelConfiguration } init(tunnel: NETunnelProviderManager, index: Int) {
}
var index: Int
init(tunnel: TunnelProviderManager, index: Int) {
self.tunnelProvider = tunnel self.tunnelProvider = tunnel
self.index = index self.index = index
} }
func tunnelConfiguration() -> TunnelConfiguration? {
return (tunnelProvider.protocolConfiguration as! NETunnelProviderProtocol).tunnelConfiguration()
}
} }
protocol TunnelsManagerDelegate: class { protocol TunnelsManagerDelegate: class {
func tunnelsAdded(atIndex: Int, numberOfTunnels: Int) func tunnelAdded(at: Int)
func tunnelModified(at: Int)
func tunnelsChanged()
} }
class TunnelsManager { class TunnelsManager {
@ -35,47 +29,130 @@ class TunnelsManager {
var tunnels: [TunnelContainer] var tunnels: [TunnelContainer]
weak var delegate: TunnelsManagerDelegate? = nil weak var delegate: TunnelsManagerDelegate? = nil
private var isAddingTunnel: Bool = false
private var isModifyingTunnel: Bool = false
private var isDeletingTunnel: Bool = false
enum TunnelsManagerError: Error { enum TunnelsManagerError: Error {
case tunnelsUninitialized case tunnelsUninitialized
} }
init(tunnelProviders: [TunnelProviderManager]) { init(tunnelProviders: [NETunnelProviderManager]) {
var tunnels: [TunnelContainer] = [] var tunnels = tunnelProviders.map { TunnelContainer(tunnel: $0, index: 0) }
for (i, tunnelProvider) in tunnelProviders.enumerated() { tunnels.sort { $0.name < $1.name }
let tunnel = TunnelContainer(tunnel: tunnelProvider, index: i) for i in 0 ..< tunnels.count {
tunnels.append(tunnel) tunnels[i].index = i
} }
self.tunnels = tunnels self.tunnels = tunnels
} }
static func create(completionHandler: (TunnelsManager?) -> Void) { static func create(completionHandler: @escaping (TunnelsManager?) -> Void) {
completionHandler(TunnelsManager(tunnelProviders: [])) NETunnelProviderManager.loadAllFromPreferences { (managers, error) in
if let error = error {
os_log("Failed to load tunnel provider managers %{public}@", log: OSLog.default, type: .debug, "\(error)")
return
}
completionHandler(TunnelsManager(tunnelProviders: managers ?? []))
}
} }
func add(tunnelConfiguration: TunnelConfiguration, completionHandler: @escaping (TunnelContainer, Error?) -> Void) { private func insertionIndexFor(tunnelName: String) -> Int {
assert(!tunnelConfiguration.interface.name.isEmpty) // Wishlist: Use binary search instead
let tunnelProvider = TunnelProviderManager(tunnelConfiguration: tunnelConfiguration) for i in 0 ..< tunnels.count {
for tunnel in tunnels { if (tunnelName.lexicographicallyPrecedes(tunnels[i].name)) { return i }
tunnel.index = tunnel.index + 1
} }
let tunnel = TunnelContainer(tunnel: tunnelProvider, index: 0) return tunnels.count
tunnels.insert(tunnel, at: 0) }
delegate?.tunnelsAdded(atIndex: 0, numberOfTunnels: 1)
func add(tunnelConfiguration: TunnelConfiguration, completionHandler: @escaping (TunnelContainer?, Error?) -> Void) {
let tunnelName = tunnelConfiguration.interface.name
assert(!tunnelName.isEmpty)
isAddingTunnel = true
let tunnelProviderManager = NETunnelProviderManager()
tunnelProviderManager.protocolConfiguration = NETunnelProviderProtocol(tunnelConfiguration: tunnelConfiguration)
tunnelProviderManager.localizedDescription = tunnelName
tunnelProviderManager.isEnabled = true
tunnelProviderManager.saveToPreferences { [weak self] (error) in
defer { self?.isAddingTunnel = false }
guard (error == nil) else {
completionHandler(nil, error)
return
}
if let s = self {
let index = s.insertionIndexFor(tunnelName: tunnelName)
let tunnel = TunnelContainer(tunnel: tunnelProviderManager, index: index)
for i in index ..< s.tunnels.count {
s.tunnels[i].index = s.tunnels[i].index + 1
}
s.tunnels.insert(tunnel, at: index)
s.delegate?.tunnelAdded(at: index)
completionHandler(tunnel, nil) completionHandler(tunnel, nil)
} }
}
}
func modify(tunnel: TunnelContainer, with tunnelConfiguration: TunnelConfiguration, completionHandler: @escaping (Error?) -> Void) { func modify(tunnel: TunnelContainer, with tunnelConfiguration: TunnelConfiguration, completionHandler: @escaping (Error?) -> Void) {
tunnel.tunnelProvider.tunnelConfiguration = tunnelConfiguration let tunnelName = tunnelConfiguration.interface.name
assert(!tunnelName.isEmpty)
isModifyingTunnel = true
let tunnelProviderManager = tunnel.tunnelProvider
let isNameChanged = (tunnelName != tunnelProviderManager.localizedDescription)
tunnelProviderManager.protocolConfiguration = NETunnelProviderProtocol(tunnelConfiguration: tunnelConfiguration)
tunnelProviderManager.localizedDescription = tunnelName
tunnelProviderManager.isEnabled = true
tunnelProviderManager.saveToPreferences { [weak self] (error) in
defer { self?.isModifyingTunnel = false }
guard (error != nil) else {
completionHandler(error)
return
}
if let s = self {
if (isNameChanged) {
s.tunnels.remove(at: tunnel.index)
for i in tunnel.index ..< s.tunnels.count {
s.tunnels[i].index = s.tunnels[i].index - 1
}
let index = s.insertionIndexFor(tunnelName: tunnelName)
tunnel.index = index
for i in index ..< s.tunnels.count {
s.tunnels[i].index = s.tunnels[i].index + 1
}
s.tunnels.insert(tunnel, at: index)
s.delegate?.tunnelsChanged()
} else {
s.delegate?.tunnelModified(at: tunnel.index)
}
completionHandler(nil) completionHandler(nil)
} }
}
}
func remove(tunnel: TunnelContainer, completionHandler: @escaping (Error?) -> Void) { func remove(tunnel: TunnelContainer, completionHandler: @escaping (Error?) -> Void) {
for i in ((tunnel.index + 1) ..< tunnels.count) { let tunnelProviderManager = tunnel.tunnelProvider
tunnels[i].index = tunnels[i].index + 1 let tunnelIndex = tunnel.index
isDeletingTunnel = true
tunnelProviderManager.removeFromPreferences { [weak self] (error) in
defer { self?.isDeletingTunnel = false }
guard (error != nil) else {
completionHandler(error)
return
}
if let s = self {
for i in ((tunnelIndex + 1) ..< s.tunnels.count) {
s.tunnels[i].index = s.tunnels[i].index + 1
}
s.tunnels.remove(at: tunnelIndex)
} }
tunnels.remove(at: tunnel.index)
completionHandler(nil) completionHandler(nil)
} }
}
func numberOfTunnels() -> Int { func numberOfTunnels() -> Int {
return tunnels.count return tunnels.count
@ -85,3 +162,27 @@ class TunnelsManager {
return tunnels[index] return tunnels[index]
} }
} }
extension NETunnelProviderProtocol {
convenience init?(tunnelConfiguration: TunnelConfiguration) {
assert(!tunnelConfiguration.interface.name.isEmpty)
guard let serializedTunnelConfiguration = try? JSONEncoder().encode(tunnelConfiguration) else { return nil }
self.init()
let appId = Bundle.main.bundleIdentifier!
let firstValidEndpoint = tunnelConfiguration.peers.first(where: { $0.endpoint != nil })?.endpoint
providerBundleIdentifier = "\(appId).WireGuardNetworkExtension"
providerConfiguration = [
"tunnelConfiguration": serializedTunnelConfiguration
]
serverAddress = firstValidEndpoint?.stringRepresentation() ?? "Unspecified"
username = tunnelConfiguration.interface.name
}
func tunnelConfiguration() -> TunnelConfiguration? {
guard let serializedTunnelConfiguration = providerConfiguration?["tunnelConfiguration"] as? Data else { return nil }
return try? JSONDecoder().decode(TunnelConfiguration.self, from: serializedTunnelConfiguration)
}
}