tunnelkit/TunnelKit/Sources/AppExtension/ConnectionStrategy.swift

127 lines
4.3 KiB
Swift

//
// ConnectionStrategy.swift
// TunnelKit
//
// Created by Davide De Rosa on 6/18/18.
// Copyright © 2018 London Trust Media. All rights reserved.
//
import Foundation
import NetworkExtension
import SwiftyBeaver
private let log = SwiftyBeaver.self
class ConnectionStrategy {
private let hostname: String
private let prefersResolvedAddresses: Bool
private var resolvedAddresses: [String]?
private let endpointProtocols: [TunnelKitProvider.EndpointProtocol]
private var currentProtocolIndex = 0
init(hostname: String, configuration: TunnelKitProvider.Configuration) {
precondition(!configuration.prefersResolvedAddresses || !(configuration.resolvedAddresses?.isEmpty ?? true))
self.hostname = hostname
prefersResolvedAddresses = configuration.prefersResolvedAddresses
resolvedAddresses = configuration.resolvedAddresses
endpointProtocols = configuration.endpointProtocols
}
func createSocket(
from provider: NEProvider,
timeout: Int,
preferredAddress: String? = nil,
queue: DispatchQueue,
completionHandler: @escaping (GenericSocket?, Error?) -> Void) {
// reuse preferred address
if let preferredAddress = preferredAddress {
log.debug("Pick preferred address: \(preferredAddress)")
let socket = provider.createSocket(to: preferredAddress, protocol: currentProtocol())
completionHandler(socket, nil)
return
}
// use any resolved address
if prefersResolvedAddresses, let resolvedAddress = anyResolvedAddress() {
log.debug("Pick resolved address: \(resolvedAddress)")
let socket = provider.createSocket(to: resolvedAddress, protocol: currentProtocol())
completionHandler(socket, nil)
return
}
// fall back to DNS
log.debug("DNS resolve hostname: \(hostname)")
DNSResolver.resolve(hostname, timeout: timeout, queue: queue) { (addresses, error) in
// refresh resolved addresses
if let resolved = addresses, !resolved.isEmpty {
self.resolvedAddresses = resolved
log.debug("DNS resolved addresses: \(resolved)")
} else {
log.error("DNS resolution failed!")
}
guard let targetAddress = self.resolvedAddress(from: addresses) else {
log.error("No resolved or fallback address available")
completionHandler(nil, TunnelKitProvider.ProviderError.dnsFailure)
return
}
let socket = provider.createSocket(to: targetAddress, protocol: self.currentProtocol())
completionHandler(socket, nil)
}
}
func tryNextProtocol() -> Bool {
let next = currentProtocolIndex + 1
guard next < endpointProtocols.count else {
log.debug("No more protocols available")
return false
}
currentProtocolIndex = next
log.debug("Fall back to next protocol: \(currentProtocol())")
return true
}
private func currentProtocol() -> TunnelKitProvider.EndpointProtocol {
return endpointProtocols[currentProtocolIndex]
}
private func resolvedAddress(from addresses: [String]?) -> String? {
guard let resolved = addresses, !resolved.isEmpty else {
return anyResolvedAddress()
}
return resolved[0]
}
private func anyResolvedAddress() -> String? {
guard let addresses = resolvedAddresses, !addresses.isEmpty else {
return nil
}
let n = Int(arc4random() % UInt32(addresses.count))
return addresses[n]
}
}
private extension NEProvider {
func createSocket(to address: String, protocol endpointProtocol: TunnelKitProvider.EndpointProtocol) -> GenericSocket {
let endpoint = NWHostEndpoint(hostname: address, port: "\(endpointProtocol.port)")
switch endpointProtocol.socketType {
case .udp:
let impl = createUDPSession(to: endpoint, from: nil)
return NEUDPInterface(impl: impl)
case .tcp:
let impl = createTCPConnection(to: endpoint, enableTLS: false, tlsParameters: nil, delegate: nil)
return NETCPInterface(impl: impl)
}
}
}