Redefine endpoint strategy according to IPv4/6

This commit is contained in:
Davide De Rosa 2020-04-14 22:53:15 +02:00
parent 40eb98fd72
commit 4bdf6b7006
4 changed files with 220 additions and 83 deletions

View File

@ -41,25 +41,38 @@ import SwiftyBeaver
private let log = SwiftyBeaver.self private let log = SwiftyBeaver.self
class ConnectionStrategy { class ConnectionStrategy {
struct Endpoint: CustomStringConvertible {
let record: DNSRecord
let proto: EndpointProtocol
var isValid: Bool {
if record.isIPv6 {
return proto.socketType != .udp4 && proto.socketType != .tcp4
} else {
return proto.socketType != .udp6 && proto.socketType != .tcp6
}
}
// MARK: CustomStringConvertible
var description: String {
return "\(record.address.maskedDescription):\(proto)"
}
}
private let hostname: String? private let hostname: String?
private let prefersResolvedAddresses: Bool
private var resolvedAddresses: [String]?
private let endpointProtocols: [EndpointProtocol] private let endpointProtocols: [EndpointProtocol]
private var currentProtocolIndex = 0 private var endpoints: [Endpoint]
private var currentEndpointIndex: Int
private let resolvedAddresses: [String]
init(configuration: OpenVPNTunnelProvider.Configuration) { init(configuration: OpenVPNTunnelProvider.Configuration) {
hostname = configuration.sessionConfiguration.hostname hostname = configuration.sessionConfiguration.hostname
prefersResolvedAddresses = (hostname == nil) || configuration.prefersResolvedAddresses
resolvedAddresses = configuration.resolvedAddresses
if prefersResolvedAddresses {
guard !(resolvedAddresses?.isEmpty ?? true) else {
fatalError("Either hostname or resolved addresses provided")
}
}
guard var endpointProtocols = configuration.sessionConfiguration.endpointProtocols else { guard var endpointProtocols = configuration.sessionConfiguration.endpointProtocols else {
fatalError("No endpoints provided") fatalError("No endpoints provided")
} }
@ -67,101 +80,130 @@ class ConnectionStrategy {
endpointProtocols.shuffle() endpointProtocols.shuffle()
} }
self.endpointProtocols = endpointProtocols self.endpointProtocols = endpointProtocols
currentEndpointIndex = 0
if let resolvedAddresses = configuration.resolvedAddresses {
if configuration.prefersResolvedAddresses {
endpoints = ConnectionStrategy.unrolledEndpoints(
records: resolvedAddresses.map { DNSRecord(address: $0, isIPv6: false) },
protos: endpointProtocols
)
} else {
endpoints = []
}
self.resolvedAddresses = resolvedAddresses
} else {
guard hostname != nil else {
fatalError("Either configuration.hostname or resolvedRecords required")
}
endpoints = []
resolvedAddresses = []
}
}
private static func unrolledEndpoints(ipv4Addresses: [String], protos: [EndpointProtocol]) -> [Endpoint] {
return unrolledEndpoints(records: ipv4Addresses.map { DNSRecord(address: $0, isIPv6: false) }, protos: protos)
} }
private static func unrolledEndpoints(records: [DNSRecord], protos: [EndpointProtocol]) -> [Endpoint] {
guard !records.isEmpty else {
return []
}
var endpoints: [Endpoint] = []
for r in records {
for p in protos {
let endpoint = Endpoint(record: r, proto: p)
guard endpoint.isValid else {
continue
}
endpoints.append(endpoint)
}
}
log.debug("Unrolled endpoints: \(endpoints.maskedDescription)")
return endpoints
}
func hasEndpoint() -> Bool {
return currentEndpointIndex < endpoints.count
}
func currentEndpoint() -> Endpoint {
guard hasEndpoint() else {
fatalError("Endpoint index out of bounds (\(currentEndpointIndex) >= \(endpoints.count))")
}
return endpoints[currentEndpointIndex]
}
@discardableResult
func tryNextEndpoint() -> Bool {
guard hasEndpoint() else {
return false
}
currentEndpointIndex += 1
guard currentEndpointIndex < endpoints.count else {
log.debug("Exhausted endpoints")
return false
}
log.debug("Try next endpoint: \(currentEndpoint().maskedDescription)")
return true
}
func createSocket( func createSocket(
from provider: NEProvider, from provider: NEProvider,
timeout: Int, timeout: Int,
preferredAddress: String? = nil,
queue: DispatchQueue, queue: DispatchQueue,
completionHandler: @escaping (GenericSocket?, Error?) -> Void) { completionHandler: @escaping (GenericSocket?, Error?) -> Void) {
// reuse preferred address if hasEndpoint() {
if let preferredAddress = preferredAddress { let endpoint = currentEndpoint()
log.debug("Pick preferred address: \(preferredAddress.maskedDescription)") log.debug("Pick current endpoint: \(endpoint.maskedDescription)")
let socket = provider.createSocket(to: preferredAddress, protocol: currentProtocol()) let socket = provider.createSocket(to: endpoint)
completionHandler(socket, nil) completionHandler(socket, nil)
return return
} }
log.debug("No endpoints available, will resort to DNS resolution")
// use any resolved address
if prefersResolvedAddresses, let resolvedAddress = anyResolvedAddress() {
log.debug("Pick resolved address: \(resolvedAddress.maskedDescription)")
let socket = provider.createSocket(to: resolvedAddress, protocol: currentProtocol())
completionHandler(socket, nil)
return
}
// fall back to DNS
guard let hostname = hostname else { guard let hostname = hostname else {
log.error("DNS resolution unavailable: no hostname provided!") log.error("DNS resolution unavailable: no hostname provided!")
completionHandler(nil, OpenVPNTunnelProvider.ProviderError.dnsFailure) completionHandler(nil, OpenVPNTunnelProvider.ProviderError.dnsFailure)
return return
} }
log.debug("DNS resolve hostname: \(hostname.maskedDescription)") log.debug("DNS resolve hostname: \(hostname.maskedDescription)")
DNSResolver.resolve(hostname, timeout: timeout, queue: queue) { (addresses, error) in DNSResolver.resolve(hostname, timeout: timeout, queue: queue) { (records, error) in
self.currentEndpointIndex = 0
// refresh resolved addresses if let records = records, !records.isEmpty {
if let resolved = addresses, !resolved.isEmpty { log.debug("DNS resolved addresses: \(records.map { $0.address }.maskedDescription)")
self.resolvedAddresses = resolved self.endpoints = ConnectionStrategy.unrolledEndpoints(records: records, protos: self.endpointProtocols)
log.debug("DNS resolved addresses: \(resolved.map { $0.maskedDescription })")
} else { } else {
log.error("DNS resolution failed!") log.error("DNS resolution failed!")
log.debug("Fall back to resolved addresses: \(self.resolvedAddresses.maskedDescription)")
self.endpoints = ConnectionStrategy.unrolledEndpoints(ipv4Addresses: self.resolvedAddresses, protos: self.endpointProtocols)
} }
guard let targetAddress = self.resolvedAddress(from: addresses) else { guard self.hasEndpoint() else {
log.error("No resolved or fallback address available") log.error("No endpoints available")
completionHandler(nil, OpenVPNTunnelProvider.ProviderError.dnsFailure) completionHandler(nil, OpenVPNTunnelProvider.ProviderError.dnsFailure)
return return
} }
let socket = provider.createSocket(to: targetAddress, protocol: self.currentProtocol()) let targetEndpoint = self.currentEndpoint()
log.debug("Pick current endpoint: \(targetEndpoint.maskedDescription)")
let socket = provider.createSocket(to: targetEndpoint)
completionHandler(socket, nil) completionHandler(socket, nil)
} }
} }
func tryNextProtocol() -> Bool {
let next = currentProtocolIndex + 1
guard next < endpointProtocols.count else {
log.debug("No more protocols available")
return false
}
currentProtocolIndex = next
log.debug("Fall back to next protocol: \(currentProtocol())")
return true
}
private func currentProtocol() -> EndpointProtocol {
return endpointProtocols[currentProtocolIndex]
}
private func resolvedAddress(from addresses: [String]?) -> String? {
guard let resolved = addresses, !resolved.isEmpty else {
return anyResolvedAddress()
}
return resolved[0]
}
private func anyResolvedAddress() -> String? {
guard let addresses = resolvedAddresses, !addresses.isEmpty else {
return nil
}
let n = Int(arc4random() % UInt32(addresses.count))
return addresses[n]
}
} }
private extension NEProvider { private extension NEProvider {
func createSocket(to address: String, protocol endpointProtocol: EndpointProtocol) -> GenericSocket { func createSocket(to endpoint: ConnectionStrategy.Endpoint) -> GenericSocket {
let endpoint = NWHostEndpoint(hostname: address, port: "\(endpointProtocol.port)") let ep = NWHostEndpoint(hostname: endpoint.record.address, port: "\(endpoint.proto.port)")
switch endpointProtocol.socketType { switch endpoint.proto.socketType {
case .udp, .udp4, .udp6: case .udp, .udp4, .udp6:
let impl = createUDPSession(to: endpoint, from: nil) let impl = createUDPSession(to: ep, from: nil)
return NEUDPSocket(impl: impl) return NEUDPSocket(impl: impl)
case .tcp, .tcp4, .tcp6: case .tcp, .tcp4, .tcp6:
let impl = createTCPConnection(to: endpoint, enableTLS: false, tlsParameters: nil, delegate: nil) let impl = createTCPConnection(to: ep, enableTLS: false, tlsParameters: nil, delegate: nil)
return NETCPSocket(impl: impl) return NETCPSocket(impl: impl)
} }
} }

View File

@ -66,7 +66,7 @@ extension OpenVPNTunnelProvider {
/// - Seealso: `fallbackServerAddresses` /// - Seealso: `fallbackServerAddresses`
public var prefersResolvedAddresses: Bool public var prefersResolvedAddresses: Bool
/// Resolved addresses in case DNS fails or `prefersResolvedAddresses` is `true`. /// Resolved addresses in case DNS fails or `prefersResolvedAddresses` is `true` (IPv4 only).
public var resolvedAddresses: [String]? public var resolvedAddresses: [String]?
/// The MTU of the link. /// The MTU of the link.

View File

@ -311,7 +311,7 @@ open class OpenVPNTunnelProvider: NEPacketTunnelProvider {
// MARK: Connection (tunnel queue) // MARK: Connection (tunnel queue)
private func connectTunnel(upgradedSocket: GenericSocket? = nil, preferredAddress: String? = nil) { private func connectTunnel(upgradedSocket: GenericSocket? = nil) {
log.info("Creating link session") log.info("Creating link session")
// reuse upgraded socket // reuse upgraded socket
@ -321,7 +321,7 @@ open class OpenVPNTunnelProvider: NEPacketTunnelProvider {
return return
} }
strategy.createSocket(from: self, timeout: dnsTimeout, preferredAddress: preferredAddress, queue: tunnelQueue) { (socket, error) in strategy.createSocket(from: self, timeout: dnsTimeout, queue: tunnelQueue) { (socket, error) in
guard let socket = socket else { guard let socket = socket else {
self.disposeTunnel(error: error) self.disposeTunnel(error: error)
return return
@ -424,7 +424,7 @@ extension OpenVPNTunnelProvider: GenericSocketDelegate {
// fallback: TCP connection timeout suggests falling back // fallback: TCP connection timeout suggests falling back
if let _ = socket as? NETCPSocket { if let _ = socket as? NETCPSocket {
guard tryNextProtocol() else { guard tryNextEndpoint() else {
// disposeTunnel // disposeTunnel
return return
} }
@ -471,7 +471,7 @@ extension OpenVPNTunnelProvider: GenericSocketDelegate {
// fallback: UDP is connection-less, treat negotiation timeout as socket timeout // fallback: UDP is connection-less, treat negotiation timeout as socket timeout
if didTimeoutNegotiation { if didTimeoutNegotiation {
guard tryNextProtocol() else { guard tryNextEndpoint() else {
// disposeTunnel // disposeTunnel
return return
} }
@ -489,7 +489,7 @@ extension OpenVPNTunnelProvider: GenericSocketDelegate {
return return
} }
self.connectTunnel(upgradedSocket: upgradedSocket, preferredAddress: socket.remoteAddress) self.connectTunnel(upgradedSocket: upgradedSocket)
} }
return return
} }
@ -775,8 +775,8 @@ extension OpenVPNTunnelProvider: OpenVPNSessionDelegate {
} }
extension OpenVPNTunnelProvider { extension OpenVPNTunnelProvider {
private func tryNextProtocol() -> Bool { private func tryNextEndpoint() -> Bool {
guard strategy.tryNextProtocol() else { guard strategy.tryNextEndpoint() else {
disposeTunnel(error: ProviderError.exhaustedProtocols) disposeTunnel(error: ProviderError.exhaustedProtocols)
return false return false
} }

View File

@ -94,7 +94,7 @@ class AppExtensionTests: XCTestCase {
func testDNSResolver() { func testDNSResolver() {
let exp = expectation(description: "DNS") let exp = expectation(description: "DNS")
DNSResolver.resolve("djsbjhcbjzhbxjnvsd.com", timeout: 1000, queue: DispatchQueue.main) { (addrs, error) in DNSResolver.resolve("www.google.com", timeout: 1000, queue: .main) { (addrs, error) in
defer { defer {
exp.fulfill() exp.fulfill()
} }
@ -126,4 +126,99 @@ class AppExtensionTests: XCTestCase {
XCTAssertEqual(string, expString) XCTAssertEqual(string, expString)
} }
} }
func testEndpointCycling() {
CoreConfiguration.masksPrivateData = false
var builder1 = OpenVPN.ConfigurationBuilder()
builder1.hostname = "italy.privateinternetaccess.com"
builder1.endpointProtocols = [
EndpointProtocol(.tcp6, 2222),
EndpointProtocol(.udp, 1111),
EndpointProtocol(.udp4, 3333)
]
var builder2 = OpenVPNTunnelProvider.ConfigurationBuilder(sessionConfiguration: builder1.build())
builder2.prefersResolvedAddresses = true
builder2.resolvedAddresses = [
"82.102.21.218",
"82.102.21.214",
"82.102.21.213",
]
let strategy = ConnectionStrategy(configuration: builder2.build())
let expected = [
"82.102.21.218:UDP:1111",
"82.102.21.218:UDP4:3333",
"82.102.21.214:UDP:1111",
"82.102.21.214:UDP4:3333",
"82.102.21.213:UDP:1111",
"82.102.21.213:UDP4:3333",
]
var i = 0
while strategy.hasEndpoint() {
let endpoint = strategy.currentEndpoint()
print("\(endpoint)")
XCTAssertEqual(endpoint.description, expected[i])
i += 1
strategy.tryNextEndpoint()
}
}
// func testEndpointCycling4() {
// CoreConfiguration.masksPrivateData = false
//
// var builder = OpenVPN.ConfigurationBuilder()
// builder.hostname = "italy.privateinternetaccess.com"
// builder.endpointProtocols = [
// EndpointProtocol(.tcp4, 2222),
// ]
// let strategy = ConnectionStrategy(
// configuration: builder.build(),
// resolvedRecords: [
// DNSRecord(address: "111:bbbb:ffff::eeee", isIPv6: true),
// DNSRecord(address: "11.22.33.44", isIPv6: false),
// ]
// )
//
// let expected = [
// "11.22.33.44:TCP4:2222"
// ]
// var i = 0
// while strategy.hasEndpoint() {
// let endpoint = strategy.currentEndpoint()
// print("\(endpoint)")
// XCTAssertEqual(endpoint.description, expected[i])
// i += 1
// strategy.tryNextEndpoint()
// }
// }
//
// func testEndpointCycling6() {
// CoreConfiguration.masksPrivateData = false
//
// var builder = OpenVPN.ConfigurationBuilder()
// builder.hostname = "italy.privateinternetaccess.com"
// builder.endpointProtocols = [
// EndpointProtocol(.udp6, 2222),
// ]
// let strategy = ConnectionStrategy(
// configuration: builder.build(),
// resolvedRecords: [
// DNSRecord(address: "111:bbbb:ffff::eeee", isIPv6: true),
// DNSRecord(address: "11.22.33.44", isIPv6: false),
// ]
// )
//
// let expected = [
// "111:bbbb:ffff::eeee:UDP6:2222"
// ]
// var i = 0
// while strategy.hasEndpoint() {
// let endpoint = strategy.currentEndpoint()
// print("\(endpoint)")
// XCTAssertEqual(endpoint.description, expected[i])
// i += 1
// strategy.tryNextEndpoint()
// }
// }
} }