VPN: DNSResolver: Resolve multiple endpoints in parallel

Signed-off-by: Roopesh Chander <roop@roopc.net>
This commit is contained in:
Roopesh Chander 2018-10-28 15:34:07 +05:30
parent dfbdcf3c28
commit 493166bd70
1 changed files with 32 additions and 21 deletions

View File

@ -6,40 +6,51 @@ import Foundation
class DNSResolver {
let endpoints: [Endpoint?]
let dispatchGroup: DispatchGroup
var dispatchWorkItems: [DispatchWorkItem]
init(endpoints: [Endpoint?]) {
self.endpoints = endpoints
self.dispatchWorkItems = []
self.dispatchGroup = DispatchGroup()
}
func resolve(completionHandler: @escaping ([Endpoint?]?) -> Void) {
let endpoints = self.endpoints
DispatchQueue.global(qos: .userInitiated).async {
var resolvedEndpoints: [Endpoint?] = []
var isError = false
for endpoint in endpoints {
if let endpoint = endpoint {
if let resolvedEndpoint = DNSResolver.resolveSync(endpoint: endpoint) {
resolvedEndpoints.append(resolvedEndpoint)
} else {
isError = true
break
let dispatchGroup = self.dispatchGroup
dispatchWorkItems = []
var resolvedEndpoints: [Endpoint?] = Array<Endpoint?>(repeating: nil, count: endpoints.count)
let numberOfEndpointsToResolve = endpoints.compactMap { $0 }.count
for (i, endpoint) in self.endpoints.enumerated() {
guard let endpoint = endpoint else { return }
let workItem = DispatchWorkItem {
resolvedEndpoints[i] = DNSResolver.resolveSync(endpoint: endpoint)
}
} else {
resolvedEndpoints.append(nil)
dispatchWorkItems.append(workItem)
DispatchQueue.global(qos: .userInitiated).async(group: dispatchGroup, execute: workItem)
}
}
if (isError) {
DispatchQueue.main.async {
dispatchGroup.notify(queue: .main) {
let numberOfResolvedEndpoints = resolvedEndpoints.compactMap { $0 }.count
if (numberOfResolvedEndpoints < numberOfEndpointsToResolve) {
completionHandler(nil)
}
return
}
DispatchQueue.main.async {
} else {
completionHandler(resolvedEndpoints)
}
}
}
func cancel() {
for workItem in dispatchWorkItems {
workItem.cancel()
}
}
deinit {
cancel()
}
}
extension DNSResolver {
// Based on DNS resolution code by Jason Donenfeld <jason@zx2c4.com>
// in parse_endpoint() in src/tools/config.c in the WireGuard codebase
private static func resolveSync(endpoint: Endpoint) -> Endpoint? {