Merge pull request #163 from passepartoutvpn/enforce-ipv4-ipv6-resolution

Enforce IPv4/6 endpoints
This commit is contained in:
Davide De Rosa 2020-04-15 11:13:31 +02:00 committed by GitHub
commit a35636b1b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 261 additions and 100 deletions

View File

@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Index out of range during negotiation (Grivus). [#143](https://github.com/passepartoutvpn/tunnelkit/pull/143) - Index out of range during negotiation (Grivus). [#143](https://github.com/passepartoutvpn/tunnelkit/pull/143)
- Handle server shutdown/restart (remote `--explicit-exit-notify`). [#131](https://github.com/passepartoutvpn/tunnelkit/issues/131) - Handle server shutdown/restart (remote `--explicit-exit-notify`). [#131](https://github.com/passepartoutvpn/tunnelkit/issues/131)
- Abrupt disconnection upon unknown packet key id (johankool). [#161](https://github.com/passepartoutvpn/tunnelkit/pull/161) - Abrupt disconnection upon unknown packet key id (johankool). [#161](https://github.com/passepartoutvpn/tunnelkit/pull/161)
- Handle explicit IPv4/IPv6 protocols (`4` or `6` suffix in `--proto`). [#153](https://github.com/passepartoutvpn/tunnelkit/issues/153)
- Pointer warnings from Xcode 11.4 upgrade. - Pointer warnings from Xcode 11.4 upgrade.
## 2.2.1 (2019-12-14) ## 2.2.1 (2019-12-14)

View File

@ -36,6 +36,16 @@
import Foundation import Foundation
/// Result of `DNSResolver`.
public struct DNSRecord {
/// Address string.
public let address: String
/// `true` if IPv6.
public let isIPv6: Bool
}
/// Convenient methods for DNS resolution. /// Convenient methods for DNS resolution.
public class DNSResolver { public class DNSResolver {
private static let queue = DispatchQueue(label: "DNSResolver") private static let queue = DispatchQueue(label: "DNSResolver")
@ -48,17 +58,17 @@ public class DNSResolver {
- Parameter queue: The queue to execute the `completionHandler` in. - Parameter queue: The queue to execute the `completionHandler` in.
- Parameter completionHandler: The completion handler with the resolved addresses and an optional error. - Parameter completionHandler: The completion handler with the resolved addresses and an optional error.
*/ */
public static func resolve(_ hostname: String, timeout: Int, queue: DispatchQueue, completionHandler: @escaping ([String]?, Error?) -> Void) { public static func resolve(_ hostname: String, timeout: Int, queue: DispatchQueue, completionHandler: @escaping ([DNSRecord]?, Error?) -> Void) {
var pendingHandler: (([String]?, Error?) -> Void)? = completionHandler var pendingHandler: (([DNSRecord]?, Error?) -> Void)? = completionHandler
let host = CFHostCreateWithName(nil, hostname as CFString).takeRetainedValue() let host = CFHostCreateWithName(nil, hostname as CFString).takeRetainedValue()
DNSResolver.queue.async { DNSResolver.queue.async {
CFHostStartInfoResolution(host, .addresses, nil) CFHostStartInfoResolution(host, .addresses, nil)
guard let handler = pendingHandler else { guard let handler = pendingHandler else {
return return
} }
DNSResolver.didResolve(host: host) { (addrs, error) in DNSResolver.didResolve(host: host) { (records, error) in
queue.async { queue.async {
handler(addrs, error) handler(records, error)
pendingHandler = nil pendingHandler = nil
} }
} }
@ -73,14 +83,14 @@ public class DNSResolver {
} }
} }
private static func didResolve(host: CFHost, completionHandler: @escaping ([String]?, Error?) -> Void) { private static func didResolve(host: CFHost, completionHandler: @escaping ([DNSRecord]?, Error?) -> Void) {
var success: DarwinBoolean = false var success: DarwinBoolean = false
guard let rawAddresses = CFHostGetAddressing(host, &success)?.takeUnretainedValue() as Array? else { guard let rawAddresses = CFHostGetAddressing(host, &success)?.takeUnretainedValue() as Array? else {
completionHandler(nil, nil) completionHandler(nil, nil)
return return
} }
var ipAddresses: [String] = [] var records: [DNSRecord] = []
for case let rawAddress as Data in rawAddresses { for case let rawAddress as Data in rawAddresses {
var ipAddress = [CChar](repeating: 0, count: Int(NI_MAXHOST)) var ipAddress = [CChar](repeating: 0, count: Int(NI_MAXHOST))
let result: Int32 = rawAddress.withUnsafeBytes { let result: Int32 = rawAddress.withUnsafeBytes {
@ -98,9 +108,14 @@ public class DNSResolver {
guard result == 0 else { guard result == 0 else {
continue continue
} }
ipAddresses.append(String(cString: ipAddress)) let address = String(cString: ipAddress)
if rawAddress.count == 16 {
records.append(DNSRecord(address: address, isIPv6: false))
} else {
records.append(DNSRecord(address: address, isIPv6: true))
}
} }
completionHandler(ipAddresses, nil) completionHandler(records, nil)
} }
/** /**

View File

@ -33,4 +33,16 @@ public enum SocketType: String {
/// TCP socket type. /// TCP socket type.
case tcp = "TCP" case tcp = "TCP"
/// UDP socket type (IPv4).
case udp4 = "UDP4"
/// TCP socket type (IPv4).
case tcp4 = "TCP4"
/// UDP socket type (IPv6).
case udp6 = "UDP6"
/// TCP socket type (IPv6).
case tcp6 = "TCP6"
} }

View File

@ -41,25 +41,38 @@ import SwiftyBeaver
private let log = SwiftyBeaver.self private let log = SwiftyBeaver.self
class ConnectionStrategy { class ConnectionStrategy {
struct Endpoint: CustomStringConvertible {
let record: DNSRecord
let proto: EndpointProtocol
var isValid: Bool {
if record.isIPv6 {
return proto.socketType != .udp4 && proto.socketType != .tcp4
} else {
return proto.socketType != .udp6 && proto.socketType != .tcp6
}
}
// MARK: CustomStringConvertible
var description: String {
return "\(record.address.maskedDescription):\(proto)"
}
}
private let hostname: String? private let hostname: String?
private let prefersResolvedAddresses: Bool
private var resolvedAddresses: [String]?
private let endpointProtocols: [EndpointProtocol] private let endpointProtocols: [EndpointProtocol]
private var currentProtocolIndex = 0 private var endpoints: [Endpoint]
private var currentEndpointIndex: Int
private let resolvedAddresses: [String]
init(configuration: OpenVPNTunnelProvider.Configuration) { init(configuration: OpenVPNTunnelProvider.Configuration) {
hostname = configuration.sessionConfiguration.hostname hostname = configuration.sessionConfiguration.hostname
prefersResolvedAddresses = (hostname == nil) || configuration.prefersResolvedAddresses
resolvedAddresses = configuration.resolvedAddresses
if prefersResolvedAddresses {
guard !(resolvedAddresses?.isEmpty ?? true) else {
fatalError("Either hostname or resolved addresses provided")
}
}
guard var endpointProtocols = configuration.sessionConfiguration.endpointProtocols else { guard var endpointProtocols = configuration.sessionConfiguration.endpointProtocols else {
fatalError("No endpoints provided") fatalError("No endpoints provided")
} }
@ -67,101 +80,130 @@ class ConnectionStrategy {
endpointProtocols.shuffle() endpointProtocols.shuffle()
} }
self.endpointProtocols = endpointProtocols self.endpointProtocols = endpointProtocols
currentEndpointIndex = 0
if let resolvedAddresses = configuration.resolvedAddresses {
if configuration.prefersResolvedAddresses {
endpoints = ConnectionStrategy.unrolledEndpoints(
records: resolvedAddresses.map { DNSRecord(address: $0, isIPv6: false) },
protos: endpointProtocols
)
} else {
endpoints = []
}
self.resolvedAddresses = resolvedAddresses
} else {
guard hostname != nil else {
fatalError("Either configuration.hostname or resolvedRecords required")
}
endpoints = []
resolvedAddresses = []
}
}
private static func unrolledEndpoints(ipv4Addresses: [String], protos: [EndpointProtocol]) -> [Endpoint] {
return unrolledEndpoints(records: ipv4Addresses.map { DNSRecord(address: $0, isIPv6: false) }, protos: protos)
} }
private static func unrolledEndpoints(records: [DNSRecord], protos: [EndpointProtocol]) -> [Endpoint] {
guard !records.isEmpty else {
return []
}
var endpoints: [Endpoint] = []
for r in records {
for p in protos {
let endpoint = Endpoint(record: r, proto: p)
guard endpoint.isValid else {
continue
}
endpoints.append(endpoint)
}
}
log.debug("Unrolled endpoints: \(endpoints.maskedDescription)")
return endpoints
}
func hasEndpoint() -> Bool {
return currentEndpointIndex < endpoints.count
}
func currentEndpoint() -> Endpoint {
guard hasEndpoint() else {
fatalError("Endpoint index out of bounds (\(currentEndpointIndex) >= \(endpoints.count))")
}
return endpoints[currentEndpointIndex]
}
@discardableResult
func tryNextEndpoint() -> Bool {
guard hasEndpoint() else {
return false
}
currentEndpointIndex += 1
guard currentEndpointIndex < endpoints.count else {
log.debug("Exhausted endpoints")
return false
}
log.debug("Try next endpoint: \(currentEndpoint().maskedDescription)")
return true
}
func createSocket( func createSocket(
from provider: NEProvider, from provider: NEProvider,
timeout: Int, timeout: Int,
preferredAddress: String? = nil,
queue: DispatchQueue, queue: DispatchQueue,
completionHandler: @escaping (GenericSocket?, Error?) -> Void) { completionHandler: @escaping (GenericSocket?, Error?) -> Void) {
// reuse preferred address if hasEndpoint() {
if let preferredAddress = preferredAddress { let endpoint = currentEndpoint()
log.debug("Pick preferred address: \(preferredAddress.maskedDescription)") log.debug("Pick current endpoint: \(endpoint.maskedDescription)")
let socket = provider.createSocket(to: preferredAddress, protocol: currentProtocol()) let socket = provider.createSocket(to: endpoint)
completionHandler(socket, nil) completionHandler(socket, nil)
return return
} }
log.debug("No endpoints available, will resort to DNS resolution")
// use any resolved address
if prefersResolvedAddresses, let resolvedAddress = anyResolvedAddress() {
log.debug("Pick resolved address: \(resolvedAddress.maskedDescription)")
let socket = provider.createSocket(to: resolvedAddress, protocol: currentProtocol())
completionHandler(socket, nil)
return
}
// fall back to DNS
guard let hostname = hostname else { guard let hostname = hostname else {
log.error("DNS resolution unavailable: no hostname provided!") log.error("DNS resolution unavailable: no hostname provided!")
completionHandler(nil, OpenVPNTunnelProvider.ProviderError.dnsFailure) completionHandler(nil, OpenVPNTunnelProvider.ProviderError.dnsFailure)
return return
} }
log.debug("DNS resolve hostname: \(hostname.maskedDescription)") log.debug("DNS resolve hostname: \(hostname.maskedDescription)")
DNSResolver.resolve(hostname, timeout: timeout, queue: queue) { (addresses, error) in DNSResolver.resolve(hostname, timeout: timeout, queue: queue) { (records, error) in
self.currentEndpointIndex = 0
// refresh resolved addresses if let records = records, !records.isEmpty {
if let resolved = addresses, !resolved.isEmpty { log.debug("DNS resolved addresses: \(records.map { $0.address }.maskedDescription)")
self.resolvedAddresses = resolved self.endpoints = ConnectionStrategy.unrolledEndpoints(records: records, protos: self.endpointProtocols)
log.debug("DNS resolved addresses: \(resolved.map { $0.maskedDescription })")
} else { } else {
log.error("DNS resolution failed!") log.error("DNS resolution failed!")
log.debug("Fall back to resolved addresses: \(self.resolvedAddresses.maskedDescription)")
self.endpoints = ConnectionStrategy.unrolledEndpoints(ipv4Addresses: self.resolvedAddresses, protos: self.endpointProtocols)
} }
guard let targetAddress = self.resolvedAddress(from: addresses) else { guard self.hasEndpoint() else {
log.error("No resolved or fallback address available") log.error("No endpoints available")
completionHandler(nil, OpenVPNTunnelProvider.ProviderError.dnsFailure) completionHandler(nil, OpenVPNTunnelProvider.ProviderError.dnsFailure)
return return
} }
let socket = provider.createSocket(to: targetAddress, protocol: self.currentProtocol()) let targetEndpoint = self.currentEndpoint()
log.debug("Pick current endpoint: \(targetEndpoint.maskedDescription)")
let socket = provider.createSocket(to: targetEndpoint)
completionHandler(socket, nil) completionHandler(socket, nil)
} }
} }
func tryNextProtocol() -> Bool {
let next = currentProtocolIndex + 1
guard next < endpointProtocols.count else {
log.debug("No more protocols available")
return false
}
currentProtocolIndex = next
log.debug("Fall back to next protocol: \(currentProtocol())")
return true
}
private func currentProtocol() -> EndpointProtocol {
return endpointProtocols[currentProtocolIndex]
}
private func resolvedAddress(from addresses: [String]?) -> String? {
guard let resolved = addresses, !resolved.isEmpty else {
return anyResolvedAddress()
}
return resolved[0]
}
private func anyResolvedAddress() -> String? {
guard let addresses = resolvedAddresses, !addresses.isEmpty else {
return nil
}
let n = Int(arc4random() % UInt32(addresses.count))
return addresses[n]
}
} }
private extension NEProvider { private extension NEProvider {
func createSocket(to address: String, protocol endpointProtocol: EndpointProtocol) -> GenericSocket { func createSocket(to endpoint: ConnectionStrategy.Endpoint) -> GenericSocket {
let endpoint = NWHostEndpoint(hostname: address, port: "\(endpointProtocol.port)") let ep = NWHostEndpoint(hostname: endpoint.record.address, port: "\(endpoint.proto.port)")
switch endpointProtocol.socketType { switch endpoint.proto.socketType {
case .udp: case .udp, .udp4, .udp6:
let impl = createUDPSession(to: endpoint, from: nil) let impl = createUDPSession(to: ep, from: nil)
return NEUDPSocket(impl: impl) return NEUDPSocket(impl: impl)
case .tcp: case .tcp, .tcp4, .tcp6:
let impl = createTCPConnection(to: endpoint, enableTLS: false, tlsParameters: nil, delegate: nil) let impl = createTCPConnection(to: ep, enableTLS: false, tlsParameters: nil, delegate: nil)
return NETCPSocket(impl: impl) return NETCPSocket(impl: impl)
} }
} }

View File

@ -66,7 +66,7 @@ extension OpenVPNTunnelProvider {
/// - Seealso: `fallbackServerAddresses` /// - Seealso: `fallbackServerAddresses`
public var prefersResolvedAddresses: Bool public var prefersResolvedAddresses: Bool
/// Resolved addresses in case DNS fails or `prefersResolvedAddresses` is `true`. /// Resolved addresses in case DNS fails or `prefersResolvedAddresses` is `true` (IPv4 only).
public var resolvedAddresses: [String]? public var resolvedAddresses: [String]?
/// The MTU of the link. /// The MTU of the link.

View File

@ -311,7 +311,7 @@ open class OpenVPNTunnelProvider: NEPacketTunnelProvider {
// MARK: Connection (tunnel queue) // MARK: Connection (tunnel queue)
private func connectTunnel(upgradedSocket: GenericSocket? = nil, preferredAddress: String? = nil) { private func connectTunnel(upgradedSocket: GenericSocket? = nil) {
log.info("Creating link session") log.info("Creating link session")
// reuse upgraded socket // reuse upgraded socket
@ -321,7 +321,7 @@ open class OpenVPNTunnelProvider: NEPacketTunnelProvider {
return return
} }
strategy.createSocket(from: self, timeout: dnsTimeout, preferredAddress: preferredAddress, queue: tunnelQueue) { (socket, error) in strategy.createSocket(from: self, timeout: dnsTimeout, queue: tunnelQueue) { (socket, error) in
guard let socket = socket else { guard let socket = socket else {
self.disposeTunnel(error: error) self.disposeTunnel(error: error)
return return
@ -424,7 +424,7 @@ extension OpenVPNTunnelProvider: GenericSocketDelegate {
// fallback: TCP connection timeout suggests falling back // fallback: TCP connection timeout suggests falling back
if let _ = socket as? NETCPSocket { if let _ = socket as? NETCPSocket {
guard tryNextProtocol() else { guard tryNextEndpoint() else {
// disposeTunnel // disposeTunnel
return return
} }
@ -471,7 +471,7 @@ extension OpenVPNTunnelProvider: GenericSocketDelegate {
// fallback: UDP is connection-less, treat negotiation timeout as socket timeout // fallback: UDP is connection-less, treat negotiation timeout as socket timeout
if didTimeoutNegotiation { if didTimeoutNegotiation {
guard tryNextProtocol() else { guard tryNextEndpoint() else {
// disposeTunnel // disposeTunnel
return return
} }
@ -489,7 +489,7 @@ extension OpenVPNTunnelProvider: GenericSocketDelegate {
return return
} }
self.connectTunnel(upgradedSocket: upgradedSocket, preferredAddress: socket.remoteAddress) self.connectTunnel(upgradedSocket: upgradedSocket)
} }
return return
} }
@ -775,8 +775,8 @@ extension OpenVPNTunnelProvider: OpenVPNSessionDelegate {
} }
extension OpenVPNTunnelProvider { extension OpenVPNTunnelProvider {
private func tryNextProtocol() -> Bool { private func tryNextEndpoint() -> Bool {
guard strategy.tryNextProtocol() else { guard strategy.tryNextEndpoint() else {
disposeTunnel(error: ProviderError.exhaustedProtocols) disposeTunnel(error: ProviderError.exhaustedProtocols)
return false return false
} }

View File

@ -62,11 +62,11 @@ extension OpenVPN {
// MARK: Client // MARK: Client
static let proto = NSRegularExpression("^proto +(udp6?|tcp6?)") static let proto = NSRegularExpression("^proto +(udp[46]?|tcp[46]?)")
static let port = NSRegularExpression("^port +\\d+") static let port = NSRegularExpression("^port +\\d+")
static let remote = NSRegularExpression("^remote +[^ ]+( +\\d+)?( +(udp6?|tcp6?))?") static let remote = NSRegularExpression("^remote +[^ ]+( +\\d+)?( +(udp[46]?|tcp[46]?))?")
static let eku = NSRegularExpression("^remote-cert-tls +server") static let eku = NSRegularExpression("^remote-cert-tls +server")
@ -817,10 +817,6 @@ private extension String {
private extension SocketType { private extension SocketType {
init?(protoString: String) { init?(protoString: String) {
var str = protoString self.init(rawValue: protoString.uppercased())
if str.hasSuffix("6") {
str.removeLast()
}
self.init(rawValue: str.uppercased())
} }
} }

View File

@ -94,7 +94,7 @@ class AppExtensionTests: XCTestCase {
func testDNSResolver() { func testDNSResolver() {
let exp = expectation(description: "DNS") let exp = expectation(description: "DNS")
DNSResolver.resolve("djsbjhcbjzhbxjnvsd.com", timeout: 1000, queue: DispatchQueue.main) { (addrs, error) in DNSResolver.resolve("www.google.com", timeout: 1000, queue: .main) { (addrs, error) in
defer { defer {
exp.fulfill() exp.fulfill()
} }
@ -126,4 +126,99 @@ class AppExtensionTests: XCTestCase {
XCTAssertEqual(string, expString) XCTAssertEqual(string, expString)
} }
} }
func testEndpointCycling() {
CoreConfiguration.masksPrivateData = false
var builder1 = OpenVPN.ConfigurationBuilder()
builder1.hostname = "italy.privateinternetaccess.com"
builder1.endpointProtocols = [
EndpointProtocol(.tcp6, 2222),
EndpointProtocol(.udp, 1111),
EndpointProtocol(.udp4, 3333)
]
var builder2 = OpenVPNTunnelProvider.ConfigurationBuilder(sessionConfiguration: builder1.build())
builder2.prefersResolvedAddresses = true
builder2.resolvedAddresses = [
"82.102.21.218",
"82.102.21.214",
"82.102.21.213",
]
let strategy = ConnectionStrategy(configuration: builder2.build())
let expected = [
"82.102.21.218:UDP:1111",
"82.102.21.218:UDP4:3333",
"82.102.21.214:UDP:1111",
"82.102.21.214:UDP4:3333",
"82.102.21.213:UDP:1111",
"82.102.21.213:UDP4:3333",
]
var i = 0
while strategy.hasEndpoint() {
let endpoint = strategy.currentEndpoint()
print("\(endpoint)")
XCTAssertEqual(endpoint.description, expected[i])
i += 1
strategy.tryNextEndpoint()
}
}
// func testEndpointCycling4() {
// CoreConfiguration.masksPrivateData = false
//
// var builder = OpenVPN.ConfigurationBuilder()
// builder.hostname = "italy.privateinternetaccess.com"
// builder.endpointProtocols = [
// EndpointProtocol(.tcp4, 2222),
// ]
// let strategy = ConnectionStrategy(
// configuration: builder.build(),
// resolvedRecords: [
// DNSRecord(address: "111:bbbb:ffff::eeee", isIPv6: true),
// DNSRecord(address: "11.22.33.44", isIPv6: false),
// ]
// )
//
// let expected = [
// "11.22.33.44:TCP4:2222"
// ]
// var i = 0
// while strategy.hasEndpoint() {
// let endpoint = strategy.currentEndpoint()
// print("\(endpoint)")
// XCTAssertEqual(endpoint.description, expected[i])
// i += 1
// strategy.tryNextEndpoint()
// }
// }
//
// func testEndpointCycling6() {
// CoreConfiguration.masksPrivateData = false
//
// var builder = OpenVPN.ConfigurationBuilder()
// builder.hostname = "italy.privateinternetaccess.com"
// builder.endpointProtocols = [
// EndpointProtocol(.udp6, 2222),
// ]
// let strategy = ConnectionStrategy(
// configuration: builder.build(),
// resolvedRecords: [
// DNSRecord(address: "111:bbbb:ffff::eeee", isIPv6: true),
// DNSRecord(address: "11.22.33.44", isIPv6: false),
// ]
// )
//
// let expected = [
// "111:bbbb:ffff::eeee:UDP6:2222"
// ]
// var i = 0
// while strategy.hasEndpoint() {
// let endpoint = strategy.currentEndpoint()
// print("\(endpoint)")
// XCTAssertEqual(endpoint.description, expected[i])
// i += 1
// strategy.tryNextEndpoint()
// }
// }
} }