Redefine endpoint strategy according to IPv4/6
This commit is contained in:
parent
40eb98fd72
commit
4bdf6b7006
@ -41,25 +41,38 @@ import SwiftyBeaver
|
|||||||
private let log = SwiftyBeaver.self
|
private let log = SwiftyBeaver.self
|
||||||
|
|
||||||
class ConnectionStrategy {
|
class ConnectionStrategy {
|
||||||
|
struct Endpoint: CustomStringConvertible {
|
||||||
|
let record: DNSRecord
|
||||||
|
|
||||||
|
let proto: EndpointProtocol
|
||||||
|
|
||||||
|
var isValid: Bool {
|
||||||
|
if record.isIPv6 {
|
||||||
|
return proto.socketType != .udp4 && proto.socketType != .tcp4
|
||||||
|
} else {
|
||||||
|
return proto.socketType != .udp6 && proto.socketType != .tcp6
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: CustomStringConvertible
|
||||||
|
|
||||||
|
var description: String {
|
||||||
|
return "\(record.address.maskedDescription):\(proto)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private let hostname: String?
|
private let hostname: String?
|
||||||
|
|
||||||
private let prefersResolvedAddresses: Bool
|
|
||||||
|
|
||||||
private var resolvedAddresses: [String]?
|
|
||||||
|
|
||||||
private let endpointProtocols: [EndpointProtocol]
|
private let endpointProtocols: [EndpointProtocol]
|
||||||
|
|
||||||
private var currentProtocolIndex = 0
|
private var endpoints: [Endpoint]
|
||||||
|
|
||||||
|
private var currentEndpointIndex: Int
|
||||||
|
|
||||||
|
private let resolvedAddresses: [String]
|
||||||
|
|
||||||
init(configuration: OpenVPNTunnelProvider.Configuration) {
|
init(configuration: OpenVPNTunnelProvider.Configuration) {
|
||||||
hostname = configuration.sessionConfiguration.hostname
|
hostname = configuration.sessionConfiguration.hostname
|
||||||
prefersResolvedAddresses = (hostname == nil) || configuration.prefersResolvedAddresses
|
|
||||||
resolvedAddresses = configuration.resolvedAddresses
|
|
||||||
if prefersResolvedAddresses {
|
|
||||||
guard !(resolvedAddresses?.isEmpty ?? true) else {
|
|
||||||
fatalError("Either hostname or resolved addresses provided")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
guard var endpointProtocols = configuration.sessionConfiguration.endpointProtocols else {
|
guard var endpointProtocols = configuration.sessionConfiguration.endpointProtocols else {
|
||||||
fatalError("No endpoints provided")
|
fatalError("No endpoints provided")
|
||||||
}
|
}
|
||||||
@ -67,101 +80,130 @@ class ConnectionStrategy {
|
|||||||
endpointProtocols.shuffle()
|
endpointProtocols.shuffle()
|
||||||
}
|
}
|
||||||
self.endpointProtocols = endpointProtocols
|
self.endpointProtocols = endpointProtocols
|
||||||
|
|
||||||
|
currentEndpointIndex = 0
|
||||||
|
if let resolvedAddresses = configuration.resolvedAddresses {
|
||||||
|
if configuration.prefersResolvedAddresses {
|
||||||
|
endpoints = ConnectionStrategy.unrolledEndpoints(
|
||||||
|
records: resolvedAddresses.map { DNSRecord(address: $0, isIPv6: false) },
|
||||||
|
protos: endpointProtocols
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
endpoints = []
|
||||||
|
}
|
||||||
|
self.resolvedAddresses = resolvedAddresses
|
||||||
|
} else {
|
||||||
|
guard hostname != nil else {
|
||||||
|
fatalError("Either configuration.hostname or resolvedRecords required")
|
||||||
|
}
|
||||||
|
endpoints = []
|
||||||
|
resolvedAddresses = []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static func unrolledEndpoints(ipv4Addresses: [String], protos: [EndpointProtocol]) -> [Endpoint] {
|
||||||
|
return unrolledEndpoints(records: ipv4Addresses.map { DNSRecord(address: $0, isIPv6: false) }, protos: protos)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static func unrolledEndpoints(records: [DNSRecord], protos: [EndpointProtocol]) -> [Endpoint] {
|
||||||
|
guard !records.isEmpty else {
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
var endpoints: [Endpoint] = []
|
||||||
|
for r in records {
|
||||||
|
for p in protos {
|
||||||
|
let endpoint = Endpoint(record: r, proto: p)
|
||||||
|
guard endpoint.isValid else {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
endpoints.append(endpoint)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.debug("Unrolled endpoints: \(endpoints.maskedDescription)")
|
||||||
|
return endpoints
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasEndpoint() -> Bool {
|
||||||
|
return currentEndpointIndex < endpoints.count
|
||||||
|
}
|
||||||
|
|
||||||
|
func currentEndpoint() -> Endpoint {
|
||||||
|
guard hasEndpoint() else {
|
||||||
|
fatalError("Endpoint index out of bounds (\(currentEndpointIndex) >= \(endpoints.count))")
|
||||||
|
}
|
||||||
|
return endpoints[currentEndpointIndex]
|
||||||
|
}
|
||||||
|
|
||||||
|
@discardableResult
|
||||||
|
func tryNextEndpoint() -> Bool {
|
||||||
|
guard hasEndpoint() else {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
currentEndpointIndex += 1
|
||||||
|
guard currentEndpointIndex < endpoints.count else {
|
||||||
|
log.debug("Exhausted endpoints")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
log.debug("Try next endpoint: \(currentEndpoint().maskedDescription)")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func createSocket(
|
func createSocket(
|
||||||
from provider: NEProvider,
|
from provider: NEProvider,
|
||||||
timeout: Int,
|
timeout: Int,
|
||||||
preferredAddress: String? = nil,
|
|
||||||
queue: DispatchQueue,
|
queue: DispatchQueue,
|
||||||
completionHandler: @escaping (GenericSocket?, Error?) -> Void) {
|
completionHandler: @escaping (GenericSocket?, Error?) -> Void) {
|
||||||
|
|
||||||
// reuse preferred address
|
if hasEndpoint() {
|
||||||
if let preferredAddress = preferredAddress {
|
let endpoint = currentEndpoint()
|
||||||
log.debug("Pick preferred address: \(preferredAddress.maskedDescription)")
|
log.debug("Pick current endpoint: \(endpoint.maskedDescription)")
|
||||||
let socket = provider.createSocket(to: preferredAddress, protocol: currentProtocol())
|
let socket = provider.createSocket(to: endpoint)
|
||||||
completionHandler(socket, nil)
|
completionHandler(socket, nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
log.debug("No endpoints available, will resort to DNS resolution")
|
||||||
// use any resolved address
|
|
||||||
if prefersResolvedAddresses, let resolvedAddress = anyResolvedAddress() {
|
|
||||||
log.debug("Pick resolved address: \(resolvedAddress.maskedDescription)")
|
|
||||||
let socket = provider.createSocket(to: resolvedAddress, protocol: currentProtocol())
|
|
||||||
completionHandler(socket, nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// fall back to DNS
|
|
||||||
guard let hostname = hostname else {
|
guard let hostname = hostname else {
|
||||||
log.error("DNS resolution unavailable: no hostname provided!")
|
log.error("DNS resolution unavailable: no hostname provided!")
|
||||||
completionHandler(nil, OpenVPNTunnelProvider.ProviderError.dnsFailure)
|
completionHandler(nil, OpenVPNTunnelProvider.ProviderError.dnsFailure)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.debug("DNS resolve hostname: \(hostname.maskedDescription)")
|
log.debug("DNS resolve hostname: \(hostname.maskedDescription)")
|
||||||
DNSResolver.resolve(hostname, timeout: timeout, queue: queue) { (addresses, error) in
|
DNSResolver.resolve(hostname, timeout: timeout, queue: queue) { (records, error) in
|
||||||
|
self.currentEndpointIndex = 0
|
||||||
// refresh resolved addresses
|
if let records = records, !records.isEmpty {
|
||||||
if let resolved = addresses, !resolved.isEmpty {
|
log.debug("DNS resolved addresses: \(records.map { $0.address }.maskedDescription)")
|
||||||
self.resolvedAddresses = resolved
|
self.endpoints = ConnectionStrategy.unrolledEndpoints(records: records, protos: self.endpointProtocols)
|
||||||
|
|
||||||
log.debug("DNS resolved addresses: \(resolved.map { $0.maskedDescription })")
|
|
||||||
} else {
|
} else {
|
||||||
log.error("DNS resolution failed!")
|
log.error("DNS resolution failed!")
|
||||||
|
log.debug("Fall back to resolved addresses: \(self.resolvedAddresses.maskedDescription)")
|
||||||
|
self.endpoints = ConnectionStrategy.unrolledEndpoints(ipv4Addresses: self.resolvedAddresses, protos: self.endpointProtocols)
|
||||||
}
|
}
|
||||||
|
|
||||||
guard let targetAddress = self.resolvedAddress(from: addresses) else {
|
guard self.hasEndpoint() else {
|
||||||
log.error("No resolved or fallback address available")
|
log.error("No endpoints available")
|
||||||
completionHandler(nil, OpenVPNTunnelProvider.ProviderError.dnsFailure)
|
completionHandler(nil, OpenVPNTunnelProvider.ProviderError.dnsFailure)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
let socket = provider.createSocket(to: targetAddress, protocol: self.currentProtocol())
|
let targetEndpoint = self.currentEndpoint()
|
||||||
|
log.debug("Pick current endpoint: \(targetEndpoint.maskedDescription)")
|
||||||
|
let socket = provider.createSocket(to: targetEndpoint)
|
||||||
completionHandler(socket, nil)
|
completionHandler(socket, nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func tryNextProtocol() -> Bool {
|
|
||||||
let next = currentProtocolIndex + 1
|
|
||||||
guard next < endpointProtocols.count else {
|
|
||||||
log.debug("No more protocols available")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
currentProtocolIndex = next
|
|
||||||
log.debug("Fall back to next protocol: \(currentProtocol())")
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
private func currentProtocol() -> EndpointProtocol {
|
|
||||||
return endpointProtocols[currentProtocolIndex]
|
|
||||||
}
|
|
||||||
|
|
||||||
private func resolvedAddress(from addresses: [String]?) -> String? {
|
|
||||||
guard let resolved = addresses, !resolved.isEmpty else {
|
|
||||||
return anyResolvedAddress()
|
|
||||||
}
|
|
||||||
return resolved[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
private func anyResolvedAddress() -> String? {
|
|
||||||
guard let addresses = resolvedAddresses, !addresses.isEmpty else {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
let n = Int(arc4random() % UInt32(addresses.count))
|
|
||||||
return addresses[n]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private extension NEProvider {
|
private extension NEProvider {
|
||||||
func createSocket(to address: String, protocol endpointProtocol: EndpointProtocol) -> GenericSocket {
|
func createSocket(to endpoint: ConnectionStrategy.Endpoint) -> GenericSocket {
|
||||||
let endpoint = NWHostEndpoint(hostname: address, port: "\(endpointProtocol.port)")
|
let ep = NWHostEndpoint(hostname: endpoint.record.address, port: "\(endpoint.proto.port)")
|
||||||
switch endpointProtocol.socketType {
|
switch endpoint.proto.socketType {
|
||||||
case .udp, .udp4, .udp6:
|
case .udp, .udp4, .udp6:
|
||||||
let impl = createUDPSession(to: endpoint, from: nil)
|
let impl = createUDPSession(to: ep, from: nil)
|
||||||
return NEUDPSocket(impl: impl)
|
return NEUDPSocket(impl: impl)
|
||||||
|
|
||||||
case .tcp, .tcp4, .tcp6:
|
case .tcp, .tcp4, .tcp6:
|
||||||
let impl = createTCPConnection(to: endpoint, enableTLS: false, tlsParameters: nil, delegate: nil)
|
let impl = createTCPConnection(to: ep, enableTLS: false, tlsParameters: nil, delegate: nil)
|
||||||
return NETCPSocket(impl: impl)
|
return NETCPSocket(impl: impl)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -66,7 +66,7 @@ extension OpenVPNTunnelProvider {
|
|||||||
/// - Seealso: `fallbackServerAddresses`
|
/// - Seealso: `fallbackServerAddresses`
|
||||||
public var prefersResolvedAddresses: Bool
|
public var prefersResolvedAddresses: Bool
|
||||||
|
|
||||||
/// Resolved addresses in case DNS fails or `prefersResolvedAddresses` is `true`.
|
/// Resolved addresses in case DNS fails or `prefersResolvedAddresses` is `true` (IPv4 only).
|
||||||
public var resolvedAddresses: [String]?
|
public var resolvedAddresses: [String]?
|
||||||
|
|
||||||
/// The MTU of the link.
|
/// The MTU of the link.
|
||||||
|
@ -311,7 +311,7 @@ open class OpenVPNTunnelProvider: NEPacketTunnelProvider {
|
|||||||
|
|
||||||
// MARK: Connection (tunnel queue)
|
// MARK: Connection (tunnel queue)
|
||||||
|
|
||||||
private func connectTunnel(upgradedSocket: GenericSocket? = nil, preferredAddress: String? = nil) {
|
private func connectTunnel(upgradedSocket: GenericSocket? = nil) {
|
||||||
log.info("Creating link session")
|
log.info("Creating link session")
|
||||||
|
|
||||||
// reuse upgraded socket
|
// reuse upgraded socket
|
||||||
@ -321,7 +321,7 @@ open class OpenVPNTunnelProvider: NEPacketTunnelProvider {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
strategy.createSocket(from: self, timeout: dnsTimeout, preferredAddress: preferredAddress, queue: tunnelQueue) { (socket, error) in
|
strategy.createSocket(from: self, timeout: dnsTimeout, queue: tunnelQueue) { (socket, error) in
|
||||||
guard let socket = socket else {
|
guard let socket = socket else {
|
||||||
self.disposeTunnel(error: error)
|
self.disposeTunnel(error: error)
|
||||||
return
|
return
|
||||||
@ -424,7 +424,7 @@ extension OpenVPNTunnelProvider: GenericSocketDelegate {
|
|||||||
|
|
||||||
// fallback: TCP connection timeout suggests falling back
|
// fallback: TCP connection timeout suggests falling back
|
||||||
if let _ = socket as? NETCPSocket {
|
if let _ = socket as? NETCPSocket {
|
||||||
guard tryNextProtocol() else {
|
guard tryNextEndpoint() else {
|
||||||
// disposeTunnel
|
// disposeTunnel
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -471,7 +471,7 @@ extension OpenVPNTunnelProvider: GenericSocketDelegate {
|
|||||||
|
|
||||||
// fallback: UDP is connection-less, treat negotiation timeout as socket timeout
|
// fallback: UDP is connection-less, treat negotiation timeout as socket timeout
|
||||||
if didTimeoutNegotiation {
|
if didTimeoutNegotiation {
|
||||||
guard tryNextProtocol() else {
|
guard tryNextEndpoint() else {
|
||||||
// disposeTunnel
|
// disposeTunnel
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -489,7 +489,7 @@ extension OpenVPNTunnelProvider: GenericSocketDelegate {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
self.connectTunnel(upgradedSocket: upgradedSocket, preferredAddress: socket.remoteAddress)
|
self.connectTunnel(upgradedSocket: upgradedSocket)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -775,8 +775,8 @@ extension OpenVPNTunnelProvider: OpenVPNSessionDelegate {
|
|||||||
}
|
}
|
||||||
|
|
||||||
extension OpenVPNTunnelProvider {
|
extension OpenVPNTunnelProvider {
|
||||||
private func tryNextProtocol() -> Bool {
|
private func tryNextEndpoint() -> Bool {
|
||||||
guard strategy.tryNextProtocol() else {
|
guard strategy.tryNextEndpoint() else {
|
||||||
disposeTunnel(error: ProviderError.exhaustedProtocols)
|
disposeTunnel(error: ProviderError.exhaustedProtocols)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -94,7 +94,7 @@ class AppExtensionTests: XCTestCase {
|
|||||||
|
|
||||||
func testDNSResolver() {
|
func testDNSResolver() {
|
||||||
let exp = expectation(description: "DNS")
|
let exp = expectation(description: "DNS")
|
||||||
DNSResolver.resolve("djsbjhcbjzhbxjnvsd.com", timeout: 1000, queue: DispatchQueue.main) { (addrs, error) in
|
DNSResolver.resolve("www.google.com", timeout: 1000, queue: .main) { (addrs, error) in
|
||||||
defer {
|
defer {
|
||||||
exp.fulfill()
|
exp.fulfill()
|
||||||
}
|
}
|
||||||
@ -126,4 +126,99 @@ class AppExtensionTests: XCTestCase {
|
|||||||
XCTAssertEqual(string, expString)
|
XCTAssertEqual(string, expString)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func testEndpointCycling() {
|
||||||
|
CoreConfiguration.masksPrivateData = false
|
||||||
|
|
||||||
|
var builder1 = OpenVPN.ConfigurationBuilder()
|
||||||
|
builder1.hostname = "italy.privateinternetaccess.com"
|
||||||
|
builder1.endpointProtocols = [
|
||||||
|
EndpointProtocol(.tcp6, 2222),
|
||||||
|
EndpointProtocol(.udp, 1111),
|
||||||
|
EndpointProtocol(.udp4, 3333)
|
||||||
|
]
|
||||||
|
var builder2 = OpenVPNTunnelProvider.ConfigurationBuilder(sessionConfiguration: builder1.build())
|
||||||
|
builder2.prefersResolvedAddresses = true
|
||||||
|
builder2.resolvedAddresses = [
|
||||||
|
"82.102.21.218",
|
||||||
|
"82.102.21.214",
|
||||||
|
"82.102.21.213",
|
||||||
|
]
|
||||||
|
let strategy = ConnectionStrategy(configuration: builder2.build())
|
||||||
|
|
||||||
|
let expected = [
|
||||||
|
"82.102.21.218:UDP:1111",
|
||||||
|
"82.102.21.218:UDP4:3333",
|
||||||
|
"82.102.21.214:UDP:1111",
|
||||||
|
"82.102.21.214:UDP4:3333",
|
||||||
|
"82.102.21.213:UDP:1111",
|
||||||
|
"82.102.21.213:UDP4:3333",
|
||||||
|
]
|
||||||
|
var i = 0
|
||||||
|
while strategy.hasEndpoint() {
|
||||||
|
let endpoint = strategy.currentEndpoint()
|
||||||
|
print("\(endpoint)")
|
||||||
|
XCTAssertEqual(endpoint.description, expected[i])
|
||||||
|
i += 1
|
||||||
|
strategy.tryNextEndpoint()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// func testEndpointCycling4() {
|
||||||
|
// CoreConfiguration.masksPrivateData = false
|
||||||
|
//
|
||||||
|
// var builder = OpenVPN.ConfigurationBuilder()
|
||||||
|
// builder.hostname = "italy.privateinternetaccess.com"
|
||||||
|
// builder.endpointProtocols = [
|
||||||
|
// EndpointProtocol(.tcp4, 2222),
|
||||||
|
// ]
|
||||||
|
// let strategy = ConnectionStrategy(
|
||||||
|
// configuration: builder.build(),
|
||||||
|
// resolvedRecords: [
|
||||||
|
// DNSRecord(address: "111:bbbb:ffff::eeee", isIPv6: true),
|
||||||
|
// DNSRecord(address: "11.22.33.44", isIPv6: false),
|
||||||
|
// ]
|
||||||
|
// )
|
||||||
|
//
|
||||||
|
// let expected = [
|
||||||
|
// "11.22.33.44:TCP4:2222"
|
||||||
|
// ]
|
||||||
|
// var i = 0
|
||||||
|
// while strategy.hasEndpoint() {
|
||||||
|
// let endpoint = strategy.currentEndpoint()
|
||||||
|
// print("\(endpoint)")
|
||||||
|
// XCTAssertEqual(endpoint.description, expected[i])
|
||||||
|
// i += 1
|
||||||
|
// strategy.tryNextEndpoint()
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// func testEndpointCycling6() {
|
||||||
|
// CoreConfiguration.masksPrivateData = false
|
||||||
|
//
|
||||||
|
// var builder = OpenVPN.ConfigurationBuilder()
|
||||||
|
// builder.hostname = "italy.privateinternetaccess.com"
|
||||||
|
// builder.endpointProtocols = [
|
||||||
|
// EndpointProtocol(.udp6, 2222),
|
||||||
|
// ]
|
||||||
|
// let strategy = ConnectionStrategy(
|
||||||
|
// configuration: builder.build(),
|
||||||
|
// resolvedRecords: [
|
||||||
|
// DNSRecord(address: "111:bbbb:ffff::eeee", isIPv6: true),
|
||||||
|
// DNSRecord(address: "11.22.33.44", isIPv6: false),
|
||||||
|
// ]
|
||||||
|
// )
|
||||||
|
//
|
||||||
|
// let expected = [
|
||||||
|
// "111:bbbb:ffff::eeee:UDP6:2222"
|
||||||
|
// ]
|
||||||
|
// var i = 0
|
||||||
|
// while strategy.hasEndpoint() {
|
||||||
|
// let endpoint = strategy.currentEndpoint()
|
||||||
|
// print("\(endpoint)")
|
||||||
|
// XCTAssertEqual(endpoint.description, expected[i])
|
||||||
|
// i += 1
|
||||||
|
// strategy.tryNextEndpoint()
|
||||||
|
// }
|
||||||
|
// }
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user