Redefine endpoint strategy according to IPv4/6

This commit is contained in:
Davide De Rosa 2020-04-14 22:53:15 +02:00
parent 40eb98fd72
commit 4bdf6b7006
4 changed files with 220 additions and 83 deletions

View File

@ -41,25 +41,38 @@ import SwiftyBeaver
private let log = SwiftyBeaver.self
class ConnectionStrategy {
struct Endpoint: CustomStringConvertible {
let record: DNSRecord
let proto: EndpointProtocol
var isValid: Bool {
if record.isIPv6 {
return proto.socketType != .udp4 && proto.socketType != .tcp4
} else {
return proto.socketType != .udp6 && proto.socketType != .tcp6
}
}
// MARK: CustomStringConvertible
var description: String {
return "\(record.address.maskedDescription):\(proto)"
}
}
private let hostname: String?
private let prefersResolvedAddresses: Bool
private var resolvedAddresses: [String]?
private let endpointProtocols: [EndpointProtocol]
private var currentProtocolIndex = 0
private var endpoints: [Endpoint]
private var currentEndpointIndex: Int
private let resolvedAddresses: [String]
init(configuration: OpenVPNTunnelProvider.Configuration) {
hostname = configuration.sessionConfiguration.hostname
prefersResolvedAddresses = (hostname == nil) || configuration.prefersResolvedAddresses
resolvedAddresses = configuration.resolvedAddresses
if prefersResolvedAddresses {
guard !(resolvedAddresses?.isEmpty ?? true) else {
fatalError("Either hostname or resolved addresses provided")
}
}
guard var endpointProtocols = configuration.sessionConfiguration.endpointProtocols else {
fatalError("No endpoints provided")
}
@ -67,101 +80,130 @@ class ConnectionStrategy {
endpointProtocols.shuffle()
}
self.endpointProtocols = endpointProtocols
currentEndpointIndex = 0
if let resolvedAddresses = configuration.resolvedAddresses {
if configuration.prefersResolvedAddresses {
endpoints = ConnectionStrategy.unrolledEndpoints(
records: resolvedAddresses.map { DNSRecord(address: $0, isIPv6: false) },
protos: endpointProtocols
)
} else {
endpoints = []
}
self.resolvedAddresses = resolvedAddresses
} else {
guard hostname != nil else {
fatalError("Either configuration.hostname or resolvedRecords required")
}
endpoints = []
resolvedAddresses = []
}
}
private static func unrolledEndpoints(ipv4Addresses: [String], protos: [EndpointProtocol]) -> [Endpoint] {
return unrolledEndpoints(records: ipv4Addresses.map { DNSRecord(address: $0, isIPv6: false) }, protos: protos)
}
private static func unrolledEndpoints(records: [DNSRecord], protos: [EndpointProtocol]) -> [Endpoint] {
guard !records.isEmpty else {
return []
}
var endpoints: [Endpoint] = []
for r in records {
for p in protos {
let endpoint = Endpoint(record: r, proto: p)
guard endpoint.isValid else {
continue
}
endpoints.append(endpoint)
}
}
log.debug("Unrolled endpoints: \(endpoints.maskedDescription)")
return endpoints
}
func hasEndpoint() -> Bool {
return currentEndpointIndex < endpoints.count
}
func currentEndpoint() -> Endpoint {
guard hasEndpoint() else {
fatalError("Endpoint index out of bounds (\(currentEndpointIndex) >= \(endpoints.count))")
}
return endpoints[currentEndpointIndex]
}
@discardableResult
func tryNextEndpoint() -> Bool {
guard hasEndpoint() else {
return false
}
currentEndpointIndex += 1
guard currentEndpointIndex < endpoints.count else {
log.debug("Exhausted endpoints")
return false
}
log.debug("Try next endpoint: \(currentEndpoint().maskedDescription)")
return true
}
func createSocket(
from provider: NEProvider,
timeout: Int,
preferredAddress: String? = nil,
queue: DispatchQueue,
completionHandler: @escaping (GenericSocket?, Error?) -> Void) {
// reuse preferred address
if let preferredAddress = preferredAddress {
log.debug("Pick preferred address: \(preferredAddress.maskedDescription)")
let socket = provider.createSocket(to: preferredAddress, protocol: currentProtocol())
if hasEndpoint() {
let endpoint = currentEndpoint()
log.debug("Pick current endpoint: \(endpoint.maskedDescription)")
let socket = provider.createSocket(to: endpoint)
completionHandler(socket, nil)
return
}
// use any resolved address
if prefersResolvedAddresses, let resolvedAddress = anyResolvedAddress() {
log.debug("Pick resolved address: \(resolvedAddress.maskedDescription)")
let socket = provider.createSocket(to: resolvedAddress, protocol: currentProtocol())
completionHandler(socket, nil)
return
}
// fall back to DNS
log.debug("No endpoints available, will resort to DNS resolution")
guard let hostname = hostname else {
log.error("DNS resolution unavailable: no hostname provided!")
completionHandler(nil, OpenVPNTunnelProvider.ProviderError.dnsFailure)
return
}
log.debug("DNS resolve hostname: \(hostname.maskedDescription)")
DNSResolver.resolve(hostname, timeout: timeout, queue: queue) { (addresses, error) in
// refresh resolved addresses
if let resolved = addresses, !resolved.isEmpty {
self.resolvedAddresses = resolved
log.debug("DNS resolved addresses: \(resolved.map { $0.maskedDescription })")
DNSResolver.resolve(hostname, timeout: timeout, queue: queue) { (records, error) in
self.currentEndpointIndex = 0
if let records = records, !records.isEmpty {
log.debug("DNS resolved addresses: \(records.map { $0.address }.maskedDescription)")
self.endpoints = ConnectionStrategy.unrolledEndpoints(records: records, protos: self.endpointProtocols)
} else {
log.error("DNS resolution failed!")
log.debug("Fall back to resolved addresses: \(self.resolvedAddresses.maskedDescription)")
self.endpoints = ConnectionStrategy.unrolledEndpoints(ipv4Addresses: self.resolvedAddresses, protos: self.endpointProtocols)
}
guard let targetAddress = self.resolvedAddress(from: addresses) else {
log.error("No resolved or fallback address available")
guard self.hasEndpoint() else {
log.error("No endpoints available")
completionHandler(nil, OpenVPNTunnelProvider.ProviderError.dnsFailure)
return
}
let socket = provider.createSocket(to: targetAddress, protocol: self.currentProtocol())
let targetEndpoint = self.currentEndpoint()
log.debug("Pick current endpoint: \(targetEndpoint.maskedDescription)")
let socket = provider.createSocket(to: targetEndpoint)
completionHandler(socket, nil)
}
}
func tryNextProtocol() -> Bool {
let next = currentProtocolIndex + 1
guard next < endpointProtocols.count else {
log.debug("No more protocols available")
return false
}
currentProtocolIndex = next
log.debug("Fall back to next protocol: \(currentProtocol())")
return true
}
private func currentProtocol() -> EndpointProtocol {
return endpointProtocols[currentProtocolIndex]
}
private func resolvedAddress(from addresses: [String]?) -> String? {
guard let resolved = addresses, !resolved.isEmpty else {
return anyResolvedAddress()
}
return resolved[0]
}
private func anyResolvedAddress() -> String? {
guard let addresses = resolvedAddresses, !addresses.isEmpty else {
return nil
}
let n = Int(arc4random() % UInt32(addresses.count))
return addresses[n]
}
}
private extension NEProvider {
func createSocket(to address: String, protocol endpointProtocol: EndpointProtocol) -> GenericSocket {
let endpoint = NWHostEndpoint(hostname: address, port: "\(endpointProtocol.port)")
switch endpointProtocol.socketType {
func createSocket(to endpoint: ConnectionStrategy.Endpoint) -> GenericSocket {
let ep = NWHostEndpoint(hostname: endpoint.record.address, port: "\(endpoint.proto.port)")
switch endpoint.proto.socketType {
case .udp, .udp4, .udp6:
let impl = createUDPSession(to: endpoint, from: nil)
let impl = createUDPSession(to: ep, from: nil)
return NEUDPSocket(impl: impl)
case .tcp, .tcp4, .tcp6:
let impl = createTCPConnection(to: endpoint, enableTLS: false, tlsParameters: nil, delegate: nil)
let impl = createTCPConnection(to: ep, enableTLS: false, tlsParameters: nil, delegate: nil)
return NETCPSocket(impl: impl)
}
}

View File

@ -66,7 +66,7 @@ extension OpenVPNTunnelProvider {
/// - Seealso: `fallbackServerAddresses`
public var prefersResolvedAddresses: Bool
/// Resolved addresses in case DNS fails or `prefersResolvedAddresses` is `true`.
/// Resolved addresses in case DNS fails or `prefersResolvedAddresses` is `true` (IPv4 only).
public var resolvedAddresses: [String]?
/// The MTU of the link.

View File

@ -311,7 +311,7 @@ open class OpenVPNTunnelProvider: NEPacketTunnelProvider {
// MARK: Connection (tunnel queue)
private func connectTunnel(upgradedSocket: GenericSocket? = nil, preferredAddress: String? = nil) {
private func connectTunnel(upgradedSocket: GenericSocket? = nil) {
log.info("Creating link session")
// reuse upgraded socket
@ -321,7 +321,7 @@ open class OpenVPNTunnelProvider: NEPacketTunnelProvider {
return
}
strategy.createSocket(from: self, timeout: dnsTimeout, preferredAddress: preferredAddress, queue: tunnelQueue) { (socket, error) in
strategy.createSocket(from: self, timeout: dnsTimeout, queue: tunnelQueue) { (socket, error) in
guard let socket = socket else {
self.disposeTunnel(error: error)
return
@ -424,7 +424,7 @@ extension OpenVPNTunnelProvider: GenericSocketDelegate {
// fallback: TCP connection timeout suggests falling back
if let _ = socket as? NETCPSocket {
guard tryNextProtocol() else {
guard tryNextEndpoint() else {
// disposeTunnel
return
}
@ -471,7 +471,7 @@ extension OpenVPNTunnelProvider: GenericSocketDelegate {
// fallback: UDP is connection-less, treat negotiation timeout as socket timeout
if didTimeoutNegotiation {
guard tryNextProtocol() else {
guard tryNextEndpoint() else {
// disposeTunnel
return
}
@ -489,7 +489,7 @@ extension OpenVPNTunnelProvider: GenericSocketDelegate {
return
}
self.connectTunnel(upgradedSocket: upgradedSocket, preferredAddress: socket.remoteAddress)
self.connectTunnel(upgradedSocket: upgradedSocket)
}
return
}
@ -775,8 +775,8 @@ extension OpenVPNTunnelProvider: OpenVPNSessionDelegate {
}
extension OpenVPNTunnelProvider {
private func tryNextProtocol() -> Bool {
guard strategy.tryNextProtocol() else {
private func tryNextEndpoint() -> Bool {
guard strategy.tryNextEndpoint() else {
disposeTunnel(error: ProviderError.exhaustedProtocols)
return false
}

View File

@ -94,7 +94,7 @@ class AppExtensionTests: XCTestCase {
func testDNSResolver() {
let exp = expectation(description: "DNS")
DNSResolver.resolve("djsbjhcbjzhbxjnvsd.com", timeout: 1000, queue: DispatchQueue.main) { (addrs, error) in
DNSResolver.resolve("www.google.com", timeout: 1000, queue: .main) { (addrs, error) in
defer {
exp.fulfill()
}
@ -126,4 +126,99 @@ class AppExtensionTests: XCTestCase {
XCTAssertEqual(string, expString)
}
}
func testEndpointCycling() {
CoreConfiguration.masksPrivateData = false
var builder1 = OpenVPN.ConfigurationBuilder()
builder1.hostname = "italy.privateinternetaccess.com"
builder1.endpointProtocols = [
EndpointProtocol(.tcp6, 2222),
EndpointProtocol(.udp, 1111),
EndpointProtocol(.udp4, 3333)
]
var builder2 = OpenVPNTunnelProvider.ConfigurationBuilder(sessionConfiguration: builder1.build())
builder2.prefersResolvedAddresses = true
builder2.resolvedAddresses = [
"82.102.21.218",
"82.102.21.214",
"82.102.21.213",
]
let strategy = ConnectionStrategy(configuration: builder2.build())
let expected = [
"82.102.21.218:UDP:1111",
"82.102.21.218:UDP4:3333",
"82.102.21.214:UDP:1111",
"82.102.21.214:UDP4:3333",
"82.102.21.213:UDP:1111",
"82.102.21.213:UDP4:3333",
]
var i = 0
while strategy.hasEndpoint() {
let endpoint = strategy.currentEndpoint()
print("\(endpoint)")
XCTAssertEqual(endpoint.description, expected[i])
i += 1
strategy.tryNextEndpoint()
}
}
// func testEndpointCycling4() {
// CoreConfiguration.masksPrivateData = false
//
// var builder = OpenVPN.ConfigurationBuilder()
// builder.hostname = "italy.privateinternetaccess.com"
// builder.endpointProtocols = [
// EndpointProtocol(.tcp4, 2222),
// ]
// let strategy = ConnectionStrategy(
// configuration: builder.build(),
// resolvedRecords: [
// DNSRecord(address: "111:bbbb:ffff::eeee", isIPv6: true),
// DNSRecord(address: "11.22.33.44", isIPv6: false),
// ]
// )
//
// let expected = [
// "11.22.33.44:TCP4:2222"
// ]
// var i = 0
// while strategy.hasEndpoint() {
// let endpoint = strategy.currentEndpoint()
// print("\(endpoint)")
// XCTAssertEqual(endpoint.description, expected[i])
// i += 1
// strategy.tryNextEndpoint()
// }
// }
//
// func testEndpointCycling6() {
// CoreConfiguration.masksPrivateData = false
//
// var builder = OpenVPN.ConfigurationBuilder()
// builder.hostname = "italy.privateinternetaccess.com"
// builder.endpointProtocols = [
// EndpointProtocol(.udp6, 2222),
// ]
// let strategy = ConnectionStrategy(
// configuration: builder.build(),
// resolvedRecords: [
// DNSRecord(address: "111:bbbb:ffff::eeee", isIPv6: true),
// DNSRecord(address: "11.22.33.44", isIPv6: false),
// ]
// )
//
// let expected = [
// "111:bbbb:ffff::eeee:UDP6:2222"
// ]
// var i = 0
// while strategy.hasEndpoint() {
// let endpoint = strategy.currentEndpoint()
// print("\(endpoint)")
// XCTAssertEqual(endpoint.description, expected[i])
// i += 1
// strategy.tryNextEndpoint()
// }
// }
}