c4b0964c3c
*Socket + *Link
127 lines
4.3 KiB
Swift
127 lines
4.3 KiB
Swift
//
|
|
// ConnectionStrategy.swift
|
|
// TunnelKit
|
|
//
|
|
// Created by Davide De Rosa on 6/18/18.
|
|
// Copyright © 2018 London Trust Media. All rights reserved.
|
|
//
|
|
|
|
import Foundation
|
|
import NetworkExtension
|
|
import SwiftyBeaver
|
|
|
|
private let log = SwiftyBeaver.self
|
|
|
|
class ConnectionStrategy {
|
|
private let hostname: String
|
|
|
|
private let prefersResolvedAddresses: Bool
|
|
|
|
private var resolvedAddresses: [String]?
|
|
|
|
private let endpointProtocols: [TunnelKitProvider.EndpointProtocol]
|
|
|
|
private var currentProtocolIndex = 0
|
|
|
|
init(hostname: String, configuration: TunnelKitProvider.Configuration) {
|
|
precondition(!configuration.prefersResolvedAddresses || !(configuration.resolvedAddresses?.isEmpty ?? true))
|
|
|
|
self.hostname = hostname
|
|
prefersResolvedAddresses = configuration.prefersResolvedAddresses
|
|
resolvedAddresses = configuration.resolvedAddresses
|
|
endpointProtocols = configuration.endpointProtocols
|
|
}
|
|
|
|
func createSocket(
|
|
from provider: NEProvider,
|
|
timeout: Int,
|
|
preferredAddress: String? = nil,
|
|
queue: DispatchQueue,
|
|
completionHandler: @escaping (GenericSocket?, Error?) -> Void) {
|
|
|
|
// reuse preferred address
|
|
if let preferredAddress = preferredAddress {
|
|
log.debug("Pick preferred address: \(preferredAddress)")
|
|
let socket = provider.createSocket(to: preferredAddress, protocol: currentProtocol())
|
|
completionHandler(socket, nil)
|
|
return
|
|
}
|
|
|
|
// use any resolved address
|
|
if prefersResolvedAddresses, let resolvedAddress = anyResolvedAddress() {
|
|
log.debug("Pick resolved address: \(resolvedAddress)")
|
|
let socket = provider.createSocket(to: resolvedAddress, protocol: currentProtocol())
|
|
completionHandler(socket, nil)
|
|
return
|
|
}
|
|
|
|
// fall back to DNS
|
|
log.debug("DNS resolve hostname: \(hostname)")
|
|
DNSResolver.resolve(hostname, timeout: timeout, queue: queue) { (addresses, error) in
|
|
|
|
// refresh resolved addresses
|
|
if let resolved = addresses, !resolved.isEmpty {
|
|
self.resolvedAddresses = resolved
|
|
|
|
log.debug("DNS resolved addresses: \(resolved)")
|
|
} else {
|
|
log.error("DNS resolution failed!")
|
|
}
|
|
|
|
guard let targetAddress = self.resolvedAddress(from: addresses) else {
|
|
log.error("No resolved or fallback address available")
|
|
completionHandler(nil, TunnelKitProvider.ProviderError.dnsFailure)
|
|
return
|
|
}
|
|
|
|
let socket = provider.createSocket(to: targetAddress, protocol: self.currentProtocol())
|
|
completionHandler(socket, nil)
|
|
}
|
|
}
|
|
|
|
func tryNextProtocol() -> Bool {
|
|
let next = currentProtocolIndex + 1
|
|
guard next < endpointProtocols.count else {
|
|
log.debug("No more protocols available")
|
|
return false
|
|
}
|
|
currentProtocolIndex = next
|
|
log.debug("Fall back to next protocol: \(currentProtocol())")
|
|
return true
|
|
}
|
|
|
|
private func currentProtocol() -> TunnelKitProvider.EndpointProtocol {
|
|
return endpointProtocols[currentProtocolIndex]
|
|
}
|
|
|
|
private func resolvedAddress(from addresses: [String]?) -> String? {
|
|
guard let resolved = addresses, !resolved.isEmpty else {
|
|
return anyResolvedAddress()
|
|
}
|
|
return resolved[0]
|
|
}
|
|
|
|
private func anyResolvedAddress() -> String? {
|
|
guard let addresses = resolvedAddresses, !addresses.isEmpty else {
|
|
return nil
|
|
}
|
|
let n = Int(arc4random() % UInt32(addresses.count))
|
|
return addresses[n]
|
|
}
|
|
}
|
|
|
|
private extension NEProvider {
|
|
func createSocket(to address: String, protocol endpointProtocol: TunnelKitProvider.EndpointProtocol) -> GenericSocket {
|
|
let endpoint = NWHostEndpoint(hostname: address, port: "\(endpointProtocol.port)")
|
|
switch endpointProtocol.socketType {
|
|
case .udp:
|
|
let impl = createUDPSession(to: endpoint, from: nil)
|
|
return NEUDPSocket(impl: impl)
|
|
|
|
case .tcp:
|
|
let impl = createTCPConnection(to: endpoint, enableTLS: false, tlsParameters: nil, delegate: nil)
|
|
return NETCPSocket(impl: impl)
|
|
}
|
|
}
|
|
}
|