diff --git a/WireGuard/WireGuardNetworkExtension/DNSResolver.swift b/WireGuard/WireGuardNetworkExtension/DNSResolver.swift index 6b7b73b..4181f75 100644 --- a/WireGuard/WireGuardNetworkExtension/DNSResolver.swift +++ b/WireGuard/WireGuardNetworkExtension/DNSResolver.swift @@ -4,6 +4,10 @@ import Network import Foundation +enum DNSResolverError: Error { + case dnsResolutionFailed(hostnames: [String]) +} + class DNSResolver { let endpoints: [Endpoint?] let dispatchGroup: DispatchGroup @@ -32,10 +36,11 @@ class DNSResolver { return resolvedEndpoints } - func resolve(completionHandler: @escaping ([Endpoint?]?) -> Void) { + func resolveSync() throws -> [Endpoint?] { let endpoints = self.endpoints let dispatchGroup = self.dispatchGroup dispatchWorkItems = [] + var resolvedEndpoints: [Endpoint?] = Array(repeating: nil, count: endpoints.count) var isResolvedByDNSRequest: [Bool] = Array(repeating: false, count: endpoints.count) for (i, endpoint) in self.endpoints.enumerated() { @@ -54,26 +59,26 @@ class DNSResolver { DispatchQueue.global(qos: .userInitiated).async(group: dispatchGroup, execute: workItem) } } - dispatchGroup.notify(queue: .main) { - assert(endpoints.count == resolvedEndpoints.count) - for (i, endpoint) in endpoints.enumerated() { - guard let endpoint = endpoint, let resolvedEndpoint = resolvedEndpoints[i] else { - completionHandler(nil) - return + + dispatchGroup.wait() // TODO: Timeout? + + var hostnamesWithDnsResolutionFailure: [String] = [] + assert(endpoints.count == resolvedEndpoints.count) + for tuple in zip(endpoints, resolvedEndpoints) { + let endpoint = tuple.0 + let resolvedEndpoint = tuple.1 + if let endpoint = endpoint { + if (resolvedEndpoint == nil) { + // DNS resolution failed + guard let hostname = endpoint.hostname() else { fatalError() } + hostnamesWithDnsResolutionFailure.append(hostname) } - if (isResolvedByDNSRequest[i]) { - DNSResolver.cache.setObject(resolvedEndpoint.stringRepresentation() as NSString, - forKey: endpoint.stringRepresentation() as NSString) - } - } - let numberOfEndpointsToResolve = endpoints.compactMap { $0 }.count - let numberOfResolvedEndpoints = resolvedEndpoints.compactMap { $0 }.count - if (numberOfResolvedEndpoints < numberOfEndpointsToResolve) { - completionHandler(nil) - } else { - completionHandler(resolvedEndpoints) } } + if (!hostnamesWithDnsResolutionFailure.isEmpty) { + throw DNSResolverError.dnsResolutionFailed(hostnames: hostnamesWithDnsResolutionFailure) + } + return resolvedEndpoints } func cancel() {