From 8df53b7709adc3c6727f3681fdaa9ae20efdc2dd Mon Sep 17 00:00:00 2001 From: Roopesh Chander Date: Sun, 28 Oct 2018 15:34:07 +0530 Subject: [PATCH] VPN: DNSResolver: Resolve multiple endpoints in parallel --- WireGuard/WireGuard/VPN/DNSResolver.swift | 53 ++++++++++++++--------- 1 file changed, 32 insertions(+), 21 deletions(-) diff --git a/WireGuard/WireGuard/VPN/DNSResolver.swift b/WireGuard/WireGuard/VPN/DNSResolver.swift index e027852..9184394 100644 --- a/WireGuard/WireGuard/VPN/DNSResolver.swift +++ b/WireGuard/WireGuard/VPN/DNSResolver.swift @@ -6,40 +6,51 @@ import Foundation class DNSResolver { let endpoints: [Endpoint?] + let dispatchGroup: DispatchGroup + var dispatchWorkItems: [DispatchWorkItem] init(endpoints: [Endpoint?]) { self.endpoints = endpoints + self.dispatchWorkItems = [] + self.dispatchGroup = DispatchGroup() } func resolve(completionHandler: @escaping ([Endpoint?]?) -> Void) { let endpoints = self.endpoints - DispatchQueue.global(qos: .userInitiated).async { - var resolvedEndpoints: [Endpoint?] = [] - var isError = false - for endpoint in endpoints { - if let endpoint = endpoint { - if let resolvedEndpoint = DNSResolver.resolveSync(endpoint: endpoint) { - resolvedEndpoints.append(resolvedEndpoint) - } else { - isError = true - break - } - } else { - resolvedEndpoints.append(nil) - } + let dispatchGroup = self.dispatchGroup + dispatchWorkItems = [] + var resolvedEndpoints: [Endpoint?] = Array(repeating: nil, count: endpoints.count) + let numberOfEndpointsToResolve = endpoints.compactMap { $0 }.count + for (i, endpoint) in self.endpoints.enumerated() { + guard let endpoint = endpoint else { return } + let workItem = DispatchWorkItem { + resolvedEndpoints[i] = DNSResolver.resolveSync(endpoint: endpoint) } - if (isError) { - DispatchQueue.main.async { - completionHandler(nil) - } - return - } - DispatchQueue.main.async { + dispatchWorkItems.append(workItem) + DispatchQueue.global(qos: .userInitiated).async(group: dispatchGroup, execute: workItem) + } + dispatchGroup.notify(queue: .main) { + let numberOfResolvedEndpoints = resolvedEndpoints.compactMap { $0 }.count + if (numberOfResolvedEndpoints < numberOfEndpointsToResolve) { + completionHandler(nil) + } else { completionHandler(resolvedEndpoints) } } } + func cancel() { + for workItem in dispatchWorkItems { + workItem.cancel() + } + } + + deinit { + cancel() + } +} + +extension DNSResolver { // Based on DNS resolution code by Jason Donenfeld // in parse_endpoint() in src/tools/config.c in the WireGuard codebase private static func resolveSync(endpoint: Endpoint) -> Endpoint? {