Hide errors behind façade TunnelKit*Error (#325)

This commit is contained in:
Davide De Rosa 2023-07-02 11:56:40 +02:00 committed by GitHub
parent d8563e7f15
commit 729e8973cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 292 additions and 246 deletions

View File

@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## Unreleased
### Changed
- Hide errors behind façade TunnelKit\*Error. [#325](https://github.com/passepartoutvpn/tunnelkit/pull/325)
## 6.0.0 (2023-04-02) ## 6.0.0 (2023-04-02)
### Added ### Added

View File

@ -53,8 +53,11 @@ public struct DNSRecord {
/// Errors coming from `DNSResolver`. /// Errors coming from `DNSResolver`.
public enum DNSError: Error { public enum DNSError: Error {
/// Resolution failed.
case failure case failure
/// Resolution timed out.
case timeout case timeout
} }
@ -71,8 +74,8 @@ public class DNSResolver {
- Parameter queue: The queue to execute the `completionHandler` in. - Parameter queue: The queue to execute the `completionHandler` in.
- Parameter completionHandler: The completion handler with the resolved addresses and an optional error. - Parameter completionHandler: The completion handler with the resolved addresses and an optional error.
*/ */
public static func resolve(_ hostname: String, timeout: Int, queue: DispatchQueue, completionHandler: @escaping (Result<[DNSRecord], DNSError>) -> Void) { public static func resolve(_ hostname: String, timeout: Int, queue: DispatchQueue, completionHandler: @escaping (Result<[DNSRecord], Error>) -> Void) {
var pendingHandler: ((Result<[DNSRecord], DNSError>) -> Void)? = completionHandler var pendingHandler: ((Result<[DNSRecord], Error>) -> Void)? = completionHandler
let host = CFHostCreateWithName(nil, hostname as CFString).takeRetainedValue() let host = CFHostCreateWithName(nil, hostname as CFString).takeRetainedValue()
DNSResolver.queue.async { DNSResolver.queue.async {
CFHostStartInfoResolution(host, .addresses, nil) CFHostStartInfoResolution(host, .addresses, nil)
@ -91,15 +94,15 @@ public class DNSResolver {
return return
} }
CFHostCancelInfoResolution(host, .addresses) CFHostCancelInfoResolution(host, .addresses)
handler(.failure(.timeout)) handler(.failure(TunnelKitCoreError.dnsResolver(.timeout)))
pendingHandler = nil pendingHandler = nil
} }
} }
private static func didResolve(host: CFHost, completionHandler: @escaping (Result<[DNSRecord], DNSError>) -> Void) { private static func didResolve(host: CFHost, completionHandler: @escaping (Result<[DNSRecord], Error>) -> Void) {
var success: DarwinBoolean = false var success: DarwinBoolean = false
guard let rawAddresses = CFHostGetAddressing(host, &success)?.takeUnretainedValue() as Array? else { guard let rawAddresses = CFHostGetAddressing(host, &success)?.takeUnretainedValue() as Array? else {
completionHandler(.failure(.failure)) completionHandler(.failure(TunnelKitCoreError.dnsResolver(.failure)))
return return
} }
@ -129,7 +132,7 @@ public class DNSResolver {
} }
} }
guard !records.isEmpty else { guard !records.isEmpty else {
completionHandler(.failure(.failure)) completionHandler(.failure(TunnelKitCoreError.dnsResolver(.failure)))
return return
} }
completionHandler(.success(records)) completionHandler(.success(records))

View File

@ -62,23 +62,10 @@ public struct IPv4Settings: Codable, Equatable, CustomStringConvertible {
/// The address of the default gateway. /// The address of the default gateway.
public let defaultGateway: String public let defaultGateway: String
/// The additional routes.
@available(*, deprecated, message: "Store routes separately")
public let routes: [Route]
public init(address: String, addressMask: String, defaultGateway: String) { public init(address: String, addressMask: String, defaultGateway: String) {
self.address = address self.address = address
self.addressMask = addressMask self.addressMask = addressMask
self.defaultGateway = defaultGateway self.defaultGateway = defaultGateway
self.routes = []
}
@available(*, deprecated, message: "Store routes separately")
public init(address: String, addressMask: String, defaultGateway: String, routes: [Route]) {
self.address = address
self.addressMask = addressMask
self.defaultGateway = defaultGateway
self.routes = routes
} }
// MARK: CustomStringConvertible // MARK: CustomStringConvertible

View File

@ -62,23 +62,10 @@ public struct IPv6Settings: Codable, Equatable, CustomStringConvertible {
/// The address of the default gateway. /// The address of the default gateway.
public let defaultGateway: String public let defaultGateway: String
/// The additional routes.
@available(*, deprecated, message: "Store routes separately")
public let routes: [Route]
public init(address: String, addressPrefixLength: UInt8, defaultGateway: String) { public init(address: String, addressPrefixLength: UInt8, defaultGateway: String) {
self.address = address self.address = address
self.addressPrefixLength = addressPrefixLength self.addressPrefixLength = addressPrefixLength
self.defaultGateway = defaultGateway self.defaultGateway = defaultGateway
self.routes = []
}
@available(*, deprecated, message: "Store routes separately")
public init(address: String, addressPrefixLength: UInt8, defaultGateway: String, routes: [Route]) {
self.address = address
self.addressPrefixLength = addressPrefixLength
self.defaultGateway = defaultGateway
self.routes = routes
} }
// MARK: CustomStringConvertible // MARK: CustomStringConvertible

View File

@ -54,7 +54,7 @@ public class SecureRandom {
var randomBuffer = [UInt8](repeating: 0, count: 4) var randomBuffer = [UInt8](repeating: 0, count: 4)
guard SecRandomCopyBytes(kSecRandomDefault, 4, &randomBuffer) == 0 else { guard SecRandomCopyBytes(kSecRandomDefault, 4, &randomBuffer) == 0 else {
throw SecureRandomError.randomGenerator throw TunnelKitCoreError.secureRandom(.randomGenerator)
} }
var randomNumber: UInt32 = 0 var randomNumber: UInt32 = 0
@ -71,7 +71,7 @@ public class SecureRandom {
try withUnsafeMutablePointer(to: &randomNumber) { try withUnsafeMutablePointer(to: &randomNumber) {
try $0.withMemoryRebound(to: UInt8.self, capacity: 4) { (randomBytes: UnsafeMutablePointer<UInt8>) -> Void in try $0.withMemoryRebound(to: UInt8.self, capacity: 4) { (randomBytes: UnsafeMutablePointer<UInt8>) -> Void in
guard SecRandomCopyBytes(kSecRandomDefault, 4, randomBytes) == 0 else { guard SecRandomCopyBytes(kSecRandomDefault, 4, randomBytes) == 0 else {
throw SecureRandomError.randomGenerator throw TunnelKitCoreError.secureRandom(.randomGenerator)
} }
} }
} }
@ -85,7 +85,7 @@ public class SecureRandom {
try randomData.withUnsafeMutableBytes { try randomData.withUnsafeMutableBytes {
let randomBytes = $0.bytePointer let randomBytes = $0.bytePointer
guard SecRandomCopyBytes(kSecRandomDefault, length, randomBytes) == 0 else { guard SecRandomCopyBytes(kSecRandomDefault, length, randomBytes) == 0 else {
throw SecureRandomError.randomGenerator throw TunnelKitCoreError.secureRandom(.randomGenerator)
} }
} }
@ -101,7 +101,7 @@ public class SecureRandom {
} }
guard SecRandomCopyBytes(kSecRandomDefault, length, randomBytes) == 0 else { guard SecRandomCopyBytes(kSecRandomDefault, length, randomBytes) == 0 else {
throw SecureRandomError.randomGenerator throw TunnelKitCoreError.secureRandom(.randomGenerator)
} }
return Z(bytes: randomBytes, count: length) return Z(bytes: randomBytes, count: length)

View File

@ -1,8 +1,8 @@
// //
// Errors.swift // TunnelKitCoreError.swift
// TunnelKit // TunnelKit
// //
// Created by Davide De Rosa on 5/19/19. // Created by Davide De Rosa on 6/16/23.
// Copyright (c) 2023 Davide De Rosa. All rights reserved. // Copyright (c) 2023 Davide De Rosa. All rights reserved.
// //
// https://github.com/passepartoutvpn // https://github.com/passepartoutvpn
@ -24,19 +24,10 @@
// //
import Foundation import Foundation
import CTunnelKitOpenVPNCore
extension Error { /// Errors returned by Core library.
public func isOpenVPNError() -> Bool { public enum TunnelKitCoreError: Error {
let te = self as NSError case secureRandom(_ error: SecureRandomError)
return te.domain == OpenVPNErrorDomain
}
public func openVPNErrorCode() -> OpenVPNErrorCode? { case dnsResolver(_ error: DNSError)
let te = self as NSError
guard te.domain == OpenVPNErrorDomain else {
return nil
}
return OpenVPNErrorCode(rawValue: te.code)
}
} }

View File

@ -92,11 +92,11 @@ public class Keychain {
return try passwordReference(for: username, context: context) return try passwordReference(for: username, context: context)
} }
removePassword(for: username, context: context) removePassword(for: username, context: context)
} catch let e as KeychainError { } catch let error as KeychainError {
// rethrow cancelation // rethrow cancelation
if e == .userCancelled { if error == .userCancelled {
throw e throw error
} }
// otherwise, no pre-existing password // otherwise, no pre-existing password
@ -114,7 +114,7 @@ public class Keychain {
var ref: CFTypeRef? var ref: CFTypeRef?
let status = SecItemAdd(query as CFDictionary, &ref) let status = SecItemAdd(query as CFDictionary, &ref)
guard status == errSecSuccess, let refData = ref as? Data else { guard status == errSecSuccess, let refData = ref as? Data else {
throw KeychainError.add throw TunnelKitManagerError.keychain(.add)
} }
return refData return refData
} }
@ -160,16 +160,16 @@ public class Keychain {
break break
case errSecUserCanceled: case errSecUserCanceled:
throw KeychainError.userCancelled throw TunnelKitManagerError.keychain(.userCancelled)
default: default:
throw KeychainError.notFound throw TunnelKitManagerError.keychain(.notFound)
} }
guard let data = result as? Data else { guard let data = result as? Data else {
throw KeychainError.notFound throw TunnelKitManagerError.keychain(.notFound)
} }
guard let password = String(data: data, encoding: .utf8) else { guard let password = String(data: data, encoding: .utf8) else {
throw KeychainError.notFound throw TunnelKitManagerError.keychain(.notFound)
} }
return password return password
} }
@ -197,13 +197,13 @@ public class Keychain {
break break
case errSecUserCanceled: case errSecUserCanceled:
throw KeychainError.userCancelled throw TunnelKitManagerError.keychain(.userCancelled)
default: default:
throw KeychainError.notFound throw TunnelKitManagerError.keychain(.notFound)
} }
guard let data = result as? Data else { guard let data = result as? Data else {
throw KeychainError.notFound throw TunnelKitManagerError.keychain(.notFound)
} }
return data return data
} }
@ -226,16 +226,16 @@ public class Keychain {
break break
case errSecUserCanceled: case errSecUserCanceled:
throw KeychainError.userCancelled throw TunnelKitManagerError.keychain(.userCancelled)
default: default:
throw KeychainError.notFound throw TunnelKitManagerError.keychain(.notFound)
} }
guard let data = result as? Data else { guard let data = result as? Data else {
throw KeychainError.notFound throw TunnelKitManagerError.keychain(.notFound)
} }
guard let password = String(data: data, encoding: .utf8) else { guard let password = String(data: data, encoding: .utf8) else {
throw KeychainError.notFound throw TunnelKitManagerError.keychain(.notFound)
} }
return password return password
} }
@ -265,7 +265,7 @@ public class Keychain {
let status = SecItemAdd(query as CFDictionary, nil) let status = SecItemAdd(query as CFDictionary, nil)
guard status == errSecSuccess else { guard status == errSecSuccess else {
throw KeychainError.add throw TunnelKitManagerError.keychain(.add)
} }
return try publicKey(withIdentifier: identifier) return try publicKey(withIdentifier: identifier)
} }
@ -294,13 +294,13 @@ public class Keychain {
break break
case errSecUserCanceled: case errSecUserCanceled:
throw KeychainError.userCancelled throw TunnelKitManagerError.keychain(.userCancelled)
default: default:
throw KeychainError.notFound throw TunnelKitManagerError.keychain(.notFound)
} }
// guard let key = result as? SecKey else { // guard let key = result as? SecKey else {
// throw KeychainError.typeMismatch // throw TunnelKitManagerError.keychain(.typeMismatch)
// } // }
// return key // return key
return result as! SecKey return result as! SecKey

View File

@ -0,0 +1,31 @@
//
// TunnelKitManagerError.swift
// TunnelKit
//
// Created by Davide De Rosa on 6/16/23.
// Copyright (c) 2023 Davide De Rosa. All rights reserved.
//
// https://github.com/passepartoutvpn
//
// This file is part of TunnelKit.
//
// TunnelKit is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// TunnelKit is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with TunnelKit. If not, see <http://www.gnu.org/licenses/>.
//
import Foundation
/// Errors returned by Core library.
public enum TunnelKitManagerError: Error {
case keychain(_ error: KeychainError)
}

View File

@ -94,7 +94,7 @@ class ConnectionStrategy {
from provider: NEProvider, from provider: NEProvider,
timeout: Int, timeout: Int,
queue: DispatchQueue, queue: DispatchQueue,
completionHandler: @escaping (Result<GenericSocket, OpenVPNProviderError>) -> Void) { completionHandler: @escaping (Result<GenericSocket, TunnelKitOpenVPNError>) -> Void) {
guard let remote = currentRemote else { guard let remote = currentRemote else {
completionHandler(.failure(.exhaustedEndpoints)) completionHandler(.failure(.exhaustedEndpoints))
return return

View File

@ -145,28 +145,28 @@ open class OpenVPNTunnelProvider: NEPacketTunnelProvider {
// required configuration // required configuration
do { do {
guard let tunnelProtocol = protocolConfiguration as? NETunnelProviderProtocol else { guard let tunnelProtocol = protocolConfiguration as? NETunnelProviderProtocol else {
throw OpenVPNProviderConfigurationError.parameter(name: "protocolConfiguration") throw ConfigurationError.parameter(name: "protocolConfiguration")
} }
guard let _ = tunnelProtocol.serverAddress else { guard let _ = tunnelProtocol.serverAddress else {
throw OpenVPNProviderConfigurationError.parameter(name: "protocolConfiguration.serverAddress") throw ConfigurationError.parameter(name: "protocolConfiguration.serverAddress")
} }
guard let providerConfiguration = tunnelProtocol.providerConfiguration else { guard let providerConfiguration = tunnelProtocol.providerConfiguration else {
throw OpenVPNProviderConfigurationError.parameter(name: "protocolConfiguration.providerConfiguration") throw ConfigurationError.parameter(name: "protocolConfiguration.providerConfiguration")
} }
cfg = try fromDictionary(OpenVPN.ProviderConfiguration.self, providerConfiguration) cfg = try fromDictionary(OpenVPN.ProviderConfiguration.self, providerConfiguration)
} catch let e { } catch let cfgError as ConfigurationError {
var message: String? switch cfgError {
if let te = e as? OpenVPNProviderConfigurationError { case .parameter(let name):
switch te { NSLog("Tunnel configuration incomplete: \(name)")
case .parameter(let name):
message = "Tunnel configuration incomplete: \(name)"
default: default:
break NSLog("Tunnel configuration error: \(cfgError)")
}
} }
NSLog(message ?? "Unexpected error in tunnel configuration: \(e)") completionHandler(cfgError)
completionHandler(e) return
} catch {
NSLog("Unexpected error in tunnel configuration: \(error)")
completionHandler(error)
return return
} }
@ -188,7 +188,7 @@ open class OpenVPNTunnelProvider: NEPacketTunnelProvider {
let credentials: OpenVPN.Credentials? let credentials: OpenVPN.Credentials?
if let username = protocolConfiguration.username, let passwordReference = protocolConfiguration.passwordReference { if let username = protocolConfiguration.username, let passwordReference = protocolConfiguration.passwordReference {
guard let password = try? Keychain.password(forReference: passwordReference) else { guard let password = try? Keychain.password(forReference: passwordReference) else {
completionHandler(OpenVPNProviderConfigurationError.credentials(details: "Keychain.password(forReference:)")) completionHandler(ConfigurationError.credentials(details: "Keychain.password(forReference:)"))
return return
} }
credentials = OpenVPN.Credentials(username, password) credentials = OpenVPN.Credentials(username, password)
@ -200,7 +200,7 @@ open class OpenVPNTunnelProvider: NEPacketTunnelProvider {
cfg._appexSetLastError(nil) cfg._appexSetLastError(nil)
guard OpenVPN.prepareRandomNumberGenerator(seedLength: prngSeedLength) else { guard OpenVPN.prepareRandomNumberGenerator(seedLength: prngSeedLength) else {
completionHandler(OpenVPNProviderConfigurationError.prngInitialization) completionHandler(ConfigurationError.prngInitialization)
return return
} }
@ -216,8 +216,8 @@ open class OpenVPNTunnelProvider: NEPacketTunnelProvider {
do { do {
session = try OpenVPNSession(queue: tunnelQueue, configuration: cfg.configuration, cachesURL: cachesURL) session = try OpenVPNSession(queue: tunnelQueue, configuration: cfg.configuration, cachesURL: cachesURL)
refreshDataCount() refreshDataCount()
} catch let e { } catch {
completionHandler(e) completionHandler(error)
return return
} }
session.credentials = credentials session.credentials = credentials
@ -359,7 +359,7 @@ open class OpenVPNTunnelProvider: NEPacketTunnelProvider {
// from stopTunnel(), in which case we don't need to feed an error parameter to // from stopTunnel(), in which case we don't need to feed an error parameter to
// the stop completion handler // the stop completion handler
// //
pendingStartHandler?(error ?? OpenVPNProviderError.socketActivity) pendingStartHandler?(error ?? TunnelKitOpenVPNError.socketActivity)
pendingStartHandler = nil pendingStartHandler = nil
} }
// stopped intentionally // stopped intentionally
@ -434,9 +434,13 @@ extension OpenVPNTunnelProvider: GenericSocketDelegate {
// look for error causing shutdown // look for error causing shutdown
shutdownError = session.stopError shutdownError = session.stopError
if failure && (shutdownError == nil) { if failure && (shutdownError == nil) {
shutdownError = OpenVPNProviderError.linkError shutdownError = TunnelKitOpenVPNError.linkError
}
if case .negotiationTimeout = shutdownError as? OpenVPNError {
didTimeoutNegotiation = true
} else {
didTimeoutNegotiation = false
} }
didTimeoutNegotiation = (shutdownError as? OpenVPNError == .negotiationTimeout)
// only try upgrade on network errors // only try upgrade on network errors
if shutdownError as? OpenVPNError == nil { if shutdownError as? OpenVPNError == nil {
@ -479,7 +483,7 @@ extension OpenVPNTunnelProvider: GenericSocketDelegate {
public func socketHasBetterPath(_ socket: GenericSocket) { public func socketHasBetterPath(_ socket: GenericSocket) {
log.debug("Stopping tunnel due to a new better path") log.debug("Stopping tunnel due to a new better path")
logCurrentSSID() logCurrentSSID()
session?.reconnect(error: OpenVPNProviderError.networkChanged) session?.reconnect(error: TunnelKitOpenVPNError.networkChanged)
} }
} }
@ -546,7 +550,7 @@ extension OpenVPNTunnelProvider: OpenVPNSessionDelegate {
let newSettings = NetworkSettingsBuilder(remoteAddress: remoteAddress, localOptions: localOptions, remoteOptions: remoteOptions) let newSettings = NetworkSettingsBuilder(remoteAddress: remoteAddress, localOptions: localOptions, remoteOptions: remoteOptions)
guard !newSettings.isGateway || newSettings.hasGateway else { guard !newSettings.isGateway || newSettings.hasGateway else {
session?.shutdown(error: OpenVPNProviderError.gatewayUnattainable) session?.shutdown(error: TunnelKitOpenVPNError.gatewayUnattainable)
return return
} }
@ -594,7 +598,7 @@ extension OpenVPNTunnelProvider: OpenVPNSessionDelegate {
extension OpenVPNTunnelProvider { extension OpenVPNTunnelProvider {
private func tryNextEndpoint() -> Bool { private func tryNextEndpoint() -> Bool {
guard strategy.tryNextEndpoint() else { guard strategy.tryNextEndpoint() else {
disposeTunnel(error: OpenVPNProviderError.exhaustedEndpoints) disposeTunnel(error: TunnelKitOpenVPNError.exhaustedEndpoints)
return false return false
} }
return true return true
@ -647,44 +651,41 @@ extension OpenVPNTunnelProvider {
// let anyObject = object as AnyObject // let anyObject = object as AnyObject
// return Unmanaged<AnyObject>.passUnretained(anyObject).toOpaque() // return Unmanaged<AnyObject>.passUnretained(anyObject).toOpaque()
// } // }
}
// MARK: Errors // MARK: Errors
private func setErrorStatus(with error: Error) { private extension OpenVPNTunnelProvider {
enum ConfigurationError: Error {
/// A field in the `OpenVPNProvider.Configuration` provided is incorrect or incomplete.
case parameter(name: String)
/// Credentials are missing or inaccessible.
case credentials(details: String)
/// The pseudo-random number generator could not be initialized.
case prngInitialization
/// The TLS certificate could not be serialized.
case certificateSerialization
}
func setErrorStatus(with error: Error) {
cfg._appexSetLastError(unifiedError(from: error)) cfg._appexSetLastError(unifiedError(from: error))
} }
private func unifiedError(from error: Error) -> OpenVPNProviderError { func unifiedError(from error: Error) -> TunnelKitOpenVPNError {
if let te = error.openVPNErrorCode() {
switch te {
case .cryptoRandomGenerator, .cryptoAlgorithm:
return .encryptionInitialization
case .cryptoEncryption, .cryptoHMAC: // XXX: error handling is limited by lastError serialization
return .encryptionData // requirement, cannot return a generic Error here
// openVPNError(from: error) ?? error
openVPNError(from: error) ?? .linkError
}
case .tlscaRead, .tlscaUse, .tlscaPeerVerification, func openVPNError(from error: Error) -> TunnelKitOpenVPNError? {
.tlsClientCertificateRead, .tlsClientCertificateUse, if let specificError = error as? OpenVPNError {
.tlsClientKeyRead, .tlsClientKeyUse: switch specificError.asNativeOpenVPNError ?? specificError {
return .tlsInitialization
case .tlsServerCertificate, .tlsServerEKU, .tlsServerHost:
return .tlsServerVerification
case .tlsHandshake:
return .tlsHandshake
case .dataPathOverflow, .dataPathPeerIdMismatch:
return .unexpectedReply
case .dataPathCompression:
return .serverCompression
default:
break
}
} else if let se = error as? OpenVPNError {
switch se {
case .negotiationTimeout, .pingTimeout, .staleSession: case .negotiationTimeout, .pingTimeout, .staleSession:
return .timeout return .timeout
@ -703,14 +704,45 @@ extension OpenVPNTunnelProvider {
case .serverShutdown: case .serverShutdown:
return .serverShutdown return .serverShutdown
case .native(let code):
switch code {
case .cryptoRandomGenerator, .cryptoAlgorithm:
return .encryptionInitialization
case .cryptoEncryption, .cryptoHMAC:
return .encryptionData
case .tlscaRead, .tlscaUse, .tlscaPeerVerification,
.tlsClientCertificateRead, .tlsClientCertificateUse,
.tlsClientKeyRead, .tlsClientKeyUse:
return .tlsInitialization
case .tlsServerCertificate, .tlsServerEKU, .tlsServerHost:
return .tlsServerVerification
case .tlsHandshake:
return .tlsHandshake
case .dataPathOverflow, .dataPathPeerIdMismatch:
return .unexpectedReply
case .dataPathCompression:
return .serverCompression
default:
break
}
default: default:
return .unexpectedReply return .unexpectedReply
} }
} }
return error as? OpenVPNProviderError ?? .linkError return nil
} }
} }
// MARK: Hacks
private extension NEPacketTunnelProvider { private extension NEPacketTunnelProvider {
func forceExitOnMac() { func forceExitOnMac() {
#if os(macOS) #if os(macOS)

View File

@ -64,7 +64,7 @@ class ResolvedRemote: CustomStringConvertible {
} }
} }
private func handleResult(_ result: Result<[DNSRecord], DNSError>) { private func handleResult(_ result: Result<[DNSRecord], Error>) {
switch result { switch result {
case .success(let records): case .success(let records):
log.debug("DNS resolved addresses: \(records.map { $0.address }.maskedDescription)") log.debug("DNS resolved addresses: \(records.map { $0.address }.maskedDescription)")

View File

@ -44,5 +44,8 @@ extension OpenVPN {
/// Encryption passphrase is incorrect or key is corrupt. /// Encryption passphrase is incorrect or key is corrupt.
case unableToDecrypt(error: Error) case unableToDecrypt(error: Error)
/// The PUSH_REPLY is multipart.
case continuationPushReply
} }
} }

View File

@ -339,7 +339,7 @@ extension OpenVPN {
isContinuation = ($0.first == "2") isContinuation = ($0.first == "2")
} }
guard !isContinuation else { guard !isContinuation else {
throw OpenVPNError.continuationPushReply throw ConfigurationError.continuationPushReply
} }
// MARK: Inline content // MARK: Inline content
@ -805,8 +805,8 @@ extension OpenVPN {
} }
do { do {
sessionBuilder.clientKey = try clientKey.decrypted(with: passphrase) sessionBuilder.clientKey = try clientKey.decrypted(with: passphrase)
} catch let e { } catch {
throw ConfigurationError.unableToDecrypt(error: e) throw ConfigurationError.unableToDecrypt(error: error)
} }
} else { } else {
sessionBuilder.clientKey = optClientKey sessionBuilder.clientKey = optClientKey

View File

@ -35,9 +35,10 @@
// //
import Foundation import Foundation
import CTunnelKitOpenVPNCore
/// The possible errors raised/thrown during `OpenVPNSession` operation. /// The possible errors raised/thrown during `OpenVPNSession` operation.
public enum OpenVPNError: String, Error { public enum OpenVPNError: Error {
/// The negotiation timed out. /// The negotiation timed out.
case negotiationTimeout case negotiationTimeout
@ -51,15 +52,15 @@ public enum OpenVPNError: String, Error {
/// The connection key is wrong or wasn't expected. /// The connection key is wrong or wasn't expected.
case badKey case badKey
/// Control channel failure.
case controlChannel(message: String)
/// The control packet has an incorrect prefix payload. /// The control packet has an incorrect prefix payload.
case wrongControlDataPrefix case wrongControlDataPrefix
/// The provided credentials failed authentication. /// The provided credentials failed authentication.
case badCredentials case badCredentials
/// The PUSH_REPLY is multipart.
case continuationPushReply
/// The reply to PUSH_REQUEST is malformed. /// The reply to PUSH_REQUEST is malformed.
case malformedPushReply case malformedPushReply
@ -80,4 +81,17 @@ public enum OpenVPNError: String, Error {
/// Remote server shut down (--explicit-exit-notify). /// Remote server shut down (--explicit-exit-notify).
case serverShutdown case serverShutdown
/// NSError from ObjC layer.
case native(code: OpenVPNErrorCode)
}
extension Error {
public var asNativeOpenVPNError: OpenVPNError? {
let nativeError = self as NSError
guard nativeError.domain == OpenVPNErrorDomain, let code = OpenVPNErrorCode(rawValue: nativeError.code) else {
return nil
}
return .native(code: code)
}
} }

View File

@ -139,7 +139,7 @@ extension OpenVPN.ProviderConfiguration {
/** /**
The last error reported by the tunnel, if any. The last error reported by the tunnel, if any.
*/ */
public var lastError: OpenVPNProviderError? { public var lastError: TunnelKitOpenVPNError? {
return defaults?.openVPNLastError return defaults?.openVPNLastError
} }
@ -164,7 +164,7 @@ extension OpenVPN.ProviderConfiguration {
defaults?.openVPNServerConfiguration = newValue defaults?.openVPNServerConfiguration = newValue
} }
public func _appexSetLastError(_ newValue: OpenVPNProviderError?) { public func _appexSetLastError(_ newValue: TunnelKitOpenVPNError?) {
defaults?.openVPNLastError = newValue defaults?.openVPNLastError = newValue
} }
@ -250,12 +250,12 @@ extension UserDefaults {
} }
} }
public fileprivate(set) var openVPNLastError: OpenVPNProviderError? { public fileprivate(set) var openVPNLastError: TunnelKitOpenVPNError? {
get { get {
guard let rawValue = string(forKey: OpenVPN.ProviderConfiguration.Keys.lastError.rawValue) else { guard let rawValue = string(forKey: OpenVPN.ProviderConfiguration.Keys.lastError.rawValue) else {
return nil return nil
} }
return OpenVPNProviderError(rawValue: rawValue) return TunnelKitOpenVPNError(rawValue: rawValue)
} }
set { set {
guard let newValue = newValue else { guard let newValue = newValue else {

View File

@ -1,5 +1,5 @@
// //
// OpenVPNProviderError.swift // TunnelKitOpenVPNError.swift
// TunnelKit // TunnelKit
// //
// Created by Davide De Rosa on 11/8/21. // Created by Davide De Rosa on 11/8/21.
@ -35,25 +35,10 @@
// //
import Foundation import Foundation
import TunnelKitOpenVPNCore
/// Mostly programming errors by host app.
public enum OpenVPNProviderConfigurationError: Error {
/// A field in the `OpenVPNProvider.Configuration` provided is incorrect or incomplete.
case parameter(name: String)
/// Credentials are missing or inaccessible.
case credentials(details: String)
/// The pseudo-random number generator could not be initialized.
case prngInitialization
/// The TLS certificate could not be serialized.
case certificateSerialization
}
/// The errors causing a tunnel disconnection. /// The errors causing a tunnel disconnection.
public enum OpenVPNProviderError: String, Error { public enum TunnelKitOpenVPNError: String, Error {
/// Socket endpoint could not be resolved. /// Socket endpoint could not be resolved.
case dnsFailure case dnsFailure

View File

@ -33,14 +33,6 @@ import CTunnelKitOpenVPNProtocol
private let log = SwiftyBeaver.self private let log = SwiftyBeaver.self
extension OpenVPN { extension OpenVPN {
class ControlChannelError: Error, CustomStringConvertible {
let description: String
init(_ message: String) {
description = "\(String(describing: ControlChannelError.self))(\(message))"
}
}
class ControlChannel { class ControlChannel {
private let serializer: ControlChannelSerializer private let serializer: ControlChannelSerializer
@ -101,12 +93,17 @@ extension OpenVPN {
} }
func readInboundPacket(withData data: Data, offset: Int) throws -> ControlPacket { func readInboundPacket(withData data: Data, offset: Int) throws -> ControlPacket {
let packet = try serializer.deserialize(data: data, start: offset, end: nil) do {
log.debug("Control: Read packet \(packet)") let packet = try serializer.deserialize(data: data, start: offset, end: nil)
if let ackIds = packet.ackIds as? [UInt32], let ackRemoteSessionId = packet.ackRemoteSessionId { log.debug("Control: Read packet \(packet)")
try readAcks(ackIds, acksRemoteSessionId: ackRemoteSessionId) if let ackIds = packet.ackIds as? [UInt32], let ackRemoteSessionId = packet.ackRemoteSessionId {
try readAcks(ackIds, acksRemoteSessionId: ackRemoteSessionId)
}
return packet
} catch {
log.error("Control: Channel failure \(error)")
throw error
} }
return packet
} }
func enqueueInboundPacket(packet: ControlPacket) -> [ControlPacket] { func enqueueInboundPacket(packet: ControlPacket) -> [ControlPacket] {

View File

@ -54,11 +54,11 @@ extension OpenVPN.ControlChannel {
let end = end ?? packet.count let end = end ?? packet.count
guard end >= offset + PacketOpcodeLength else { guard end >= offset + PacketOpcodeLength else {
throw OpenVPN.ControlChannelError("Missing opcode") throw OpenVPNError.controlChannel(message: "Missing opcode")
} }
let codeValue = packet[offset] >> 3 let codeValue = packet[offset] >> 3
guard let code = PacketCode(rawValue: codeValue) else { guard let code = PacketCode(rawValue: codeValue) else {
throw OpenVPN.ControlChannelError("Unknown code: \(codeValue))") throw OpenVPNError.controlChannel(message: "Unknown code: \(codeValue))")
} }
let key = packet[offset] & 0b111 let key = packet[offset] & 0b111
offset += PacketOpcodeLength offset += PacketOpcodeLength
@ -66,13 +66,13 @@ extension OpenVPN.ControlChannel {
log.debug("Control: Try read packet with code \(code) and key \(key)") log.debug("Control: Try read packet with code \(code) and key \(key)")
guard end >= offset + PacketSessionIdLength else { guard end >= offset + PacketSessionIdLength else {
throw OpenVPN.ControlChannelError("Missing sessionId") throw OpenVPNError.controlChannel(message: "Missing sessionId")
} }
let sessionId = packet.subdata(offset: offset, count: PacketSessionIdLength) let sessionId = packet.subdata(offset: offset, count: PacketSessionIdLength)
offset += PacketSessionIdLength offset += PacketSessionIdLength
guard end >= offset + 1 else { guard end >= offset + 1 else {
throw OpenVPN.ControlChannelError("Missing ackSize") throw OpenVPNError.controlChannel(message: "Missing ackSize")
} }
let ackSize = packet[offset] let ackSize = packet[offset]
offset += 1 offset += 1
@ -81,7 +81,7 @@ extension OpenVPN.ControlChannel {
var ackRemoteSessionId: Data? var ackRemoteSessionId: Data?
if ackSize > 0 { if ackSize > 0 {
guard end >= (offset + Int(ackSize) * PacketIdLength) else { guard end >= (offset + Int(ackSize) * PacketIdLength) else {
throw OpenVPN.ControlChannelError("Missing acks") throw OpenVPNError.controlChannel(message: "Missing acks")
} }
var ids: [UInt32] = [] var ids: [UInt32] = []
for _ in 0..<ackSize { for _ in 0..<ackSize {
@ -91,7 +91,7 @@ extension OpenVPN.ControlChannel {
} }
guard end >= offset + PacketSessionIdLength else { guard end >= offset + PacketSessionIdLength else {
throw OpenVPN.ControlChannelError("Missing remoteSessionId") throw OpenVPNError.controlChannel(message: "Missing remoteSessionId")
} }
let remoteSessionId = packet.subdata(offset: offset, count: PacketSessionIdLength) let remoteSessionId = packet.subdata(offset: offset, count: PacketSessionIdLength)
offset += PacketSessionIdLength offset += PacketSessionIdLength
@ -102,16 +102,16 @@ extension OpenVPN.ControlChannel {
if code == .ackV1 { if code == .ackV1 {
guard let ackIds = ackIds else { guard let ackIds = ackIds else {
throw OpenVPN.ControlChannelError("Ack packet without ids") throw OpenVPNError.controlChannel(message: "Ack packet without ids")
} }
guard let ackRemoteSessionId = ackRemoteSessionId else { guard let ackRemoteSessionId = ackRemoteSessionId else {
throw OpenVPN.ControlChannelError("Ack packet without remoteSessionId") throw OpenVPNError.controlChannel(message: "Ack packet without remoteSessionId")
} }
return ControlPacket(key: key, sessionId: sessionId, ackIds: ackIds as [NSNumber], ackRemoteSessionId: ackRemoteSessionId) return ControlPacket(key: key, sessionId: sessionId, ackIds: ackIds as [NSNumber], ackRemoteSessionId: ackRemoteSessionId)
} }
guard end >= offset + PacketIdLength else { guard end >= offset + PacketIdLength else {
throw OpenVPN.ControlChannelError("Missing packetId") throw OpenVPNError.controlChannel(message: "Missing packetId")
} }
let packetId = packet.networkUInt32Value(from: offset) let packetId = packet.networkUInt32Value(from: offset)
offset += PacketIdLength offset += PacketIdLength
@ -192,7 +192,7 @@ extension OpenVPN.ControlChannel {
// data starts with (prefix=(header + sessionId) + auth=(hmac + replayId)) // data starts with (prefix=(header + sessionId) + auth=(hmac + replayId))
guard end >= preambleLength else { guard end >= preambleLength else {
throw OpenVPN.ControlChannelError("Missing HMAC") throw OpenVPNError.controlChannel(message: "Missing HMAC")
} }
// needs a copy for swapping // needs a copy for swapping
@ -206,7 +206,12 @@ extension OpenVPN.ControlChannel {
// TODO: validate replay packet id // TODO: validate replay packet id
return try plain.deserialize(data: authPacket, start: authLength, end: nil) do {
return try plain.deserialize(data: authPacket, start: authLength, end: nil)
} catch {
log.error("Control: Channel failure \(error)")
throw error
}
} }
} }
} }
@ -269,7 +274,7 @@ extension OpenVPN.ControlChannel {
// data starts with (ad=(header + sessionId + replayId) + tag) // data starts with (ad=(header + sessionId + replayId) + tag)
guard end >= start + adLength + tagLength else { guard end >= start + adLength + tagLength else {
throw OpenVPN.ControlChannelError("Missing AD+TAG") throw OpenVPNError.controlChannel(message: "Missing AD+TAG")
} }
let encryptedCount = packet.count - adLength let encryptedCount = packet.count - adLength
@ -288,7 +293,12 @@ extension OpenVPN.ControlChannel {
// TODO: validate replay packet id // TODO: validate replay packet id
return try plain.deserialize(data: decryptedPacket, start: 0, end: nil) do {
return try plain.deserialize(data: decryptedPacket, start: 0, end: nil)
} catch {
log.error("Control: Channel failure \(error)")
throw error
}
} }
} }
} }

View File

@ -477,8 +477,8 @@ public class OpenVPNSession: Session {
continue continue
} }
controlPacket = parsedPacket controlPacket = parsedPacket
} catch let e { } catch {
log.warning("Dropped malformed packet: \(e)") log.warning("Dropped malformed packet: \(error)")
continue continue
// deferStop(.shutdown, e) // deferStop(.shutdown, e)
// return // return
@ -573,8 +573,8 @@ public class OpenVPNSession: Session {
authenticator = nil authenticator = nil
do { do {
try controlChannel.reset(forNewSession: forNewSession) try controlChannel.reset(forNewSession: forNewSession)
} catch let e { } catch {
deferStop(.shutdown, e) deferStop(.shutdown, error)
} }
} }
@ -658,18 +658,18 @@ public class OpenVPNSession: Session {
authenticator = try OpenVPN.Authenticator(credentials?.username, pushReply?.options.authToken ?? credentials?.password) authenticator = try OpenVPN.Authenticator(credentials?.username, pushReply?.options.authToken ?? credentials?.password)
authenticator?.withLocalOptions = withLocalOptions authenticator?.withLocalOptions = withLocalOptions
try authenticator?.putAuth(into: negotiationKey.tls, options: configuration) try authenticator?.putAuth(into: negotiationKey.tls, options: configuration)
} catch let e { } catch {
deferStop(.shutdown, e) deferStop(.shutdown, error)
return return
} }
let cipherTextOut: Data let cipherTextOut: Data
do { do {
cipherTextOut = try negotiationKey.tls.pullCipherText() cipherTextOut = try negotiationKey.tls.pullCipherText()
} catch let e { } catch {
if let _ = e.openVPNErrorCode() { if let nativeError = error.asNativeOpenVPNError {
log.error("TLS.auth: Failed pulling ciphertext (error: \(e))") log.error("TLS.auth: Failed pulling ciphertext (error: \(nativeError))")
shutdown(error: e) shutdown(error: nativeError)
return return
} }
log.verbose("TLS.auth: Still can't pull ciphertext") log.verbose("TLS.auth: Still can't pull ciphertext")
@ -695,10 +695,10 @@ public class OpenVPNSession: Session {
let cipherTextOut: Data let cipherTextOut: Data
do { do {
cipherTextOut = try negotiationKey.tls.pullCipherText() cipherTextOut = try negotiationKey.tls.pullCipherText()
} catch let e { } catch {
if let _ = e.openVPNErrorCode() { if let nativeError = error.asNativeOpenVPNError {
log.error("TLS.auth: Failed pulling ciphertext (error: \(e))") log.error("TLS.auth: Failed pulling ciphertext (error: \(nativeError))")
shutdown(error: e) shutdown(error: nativeError)
return return
} }
log.verbose("TLS.ifconfig: Still can't pull ciphertext") log.verbose("TLS.ifconfig: Still can't pull ciphertext")
@ -789,21 +789,21 @@ public class OpenVPNSession: Session {
negotiationKey.tlsOptional = tls negotiationKey.tlsOptional = tls
do { do {
try negotiationKey.tls.start() try negotiationKey.tls.start()
} catch let e { } catch {
deferStop(.shutdown, e) deferStop(.shutdown, error)
return return
} }
let cipherTextOut: Data let cipherTextOut: Data
do { do {
cipherTextOut = try negotiationKey.tls.pullCipherText() cipherTextOut = try negotiationKey.tls.pullCipherText()
} catch let e { } catch {
if let _ = e.openVPNErrorCode() { if let nativeError = error.asNativeOpenVPNError {
log.error("TLS.connect: Failed pulling ciphertext (error: \(e))") log.error("TLS.connect: Failed pulling ciphertext (error: \(nativeError))")
shutdown(error: e) shutdown(error: nativeError)
return return
} }
deferStop(.shutdown, e) deferStop(.shutdown, error)
return return
} }
@ -836,10 +836,10 @@ public class OpenVPNSession: Session {
cipherTextOut = try negotiationKey.tls.pullCipherText() cipherTextOut = try negotiationKey.tls.pullCipherText()
log.debug("TLS.connect: Send pulled ciphertext (\(cipherTextOut.count) bytes)") log.debug("TLS.connect: Send pulled ciphertext (\(cipherTextOut.count) bytes)")
enqueueControlPackets(code: .controlV1, key: negotiationKey.id, payload: cipherTextOut) enqueueControlPackets(code: .controlV1, key: negotiationKey.id, payload: cipherTextOut)
} catch let e { } catch {
if let _ = e.openVPNErrorCode() { if let nativeError = error.asNativeOpenVPNError {
log.error("TLS.connect: Failed pulling ciphertext (error: \(e))") log.error("TLS.connect: Failed pulling ciphertext (error: \(nativeError))")
shutdown(error: e) shutdown(error: nativeError)
return return
} }
log.verbose("TLS.connect: No available ciphertext to pull") log.verbose("TLS.connect: No available ciphertext to pull")
@ -878,8 +878,8 @@ public class OpenVPNSession: Session {
guard try auth.parseAuthReply() else { guard try auth.parseAuthReply() else {
return return
} }
} catch let e { } catch {
deferStop(.shutdown, e) deferStop(.shutdown, error)
return return
} }
@ -962,12 +962,12 @@ public class OpenVPNSession: Session {
throw OpenVPNError.serverCompression throw OpenVPNError.serverCompression
} }
} }
} catch OpenVPNError.continuationPushReply { } catch OpenVPN.ConfigurationError.continuationPushReply {
continuatedPushReplyMessage = completeMessage.replacingOccurrences(of: "push-continuation", with: "") continuatedPushReplyMessage = completeMessage.replacingOccurrences(of: "push-continuation", with: "")
// FIXME: strip "PUSH_REPLY" and "push-continuation 2" // FIXME: strip "PUSH_REPLY" and "push-continuation 2"
return return
} catch let e { } catch {
deferStop(.shutdown, e) deferStop(.shutdown, error)
return return
} }
@ -1025,9 +1025,9 @@ public class OpenVPNSession: Session {
let rawList: [Data] let rawList: [Data]
do { do {
rawList = try controlChannel.writeOutboundPackets() rawList = try controlChannel.writeOutboundPackets()
} catch let e { } catch {
log.warning("Failed control packet serialization: \(e)") log.warning("Failed control packet serialization: \(error)")
deferStop(.shutdown, e) deferStop(.shutdown, error)
return return
} }
for raw in rawList { for raw in rawList {
@ -1110,8 +1110,8 @@ public class OpenVPNSession: Session {
sessionId, sessionId,
remoteSessionId remoteSessionId
) )
} catch let e { } catch {
deferStop(.shutdown, e) deferStop(.shutdown, error)
return return
} }
@ -1141,12 +1141,12 @@ public class OpenVPNSession: Session {
} }
tunnel?.writePackets(decryptedPackets, completionHandler: nil) tunnel?.writePackets(decryptedPackets, completionHandler: nil)
} catch let e { } catch {
guard !e.isOpenVPNError() else { if let nativeError = error.asNativeOpenVPNError {
deferStop(.shutdown, e) deferStop(.shutdown, nativeError)
return return
} }
deferStop(.reconnect, e) deferStop(.reconnect, error)
} }
} }
@ -1181,12 +1181,12 @@ public class OpenVPNSession: Session {
// log.verbose("Data: \(encryptedPackets.count) packets successfully written to LINK") // log.verbose("Data: \(encryptedPackets.count) packets successfully written to LINK")
} }
} }
} catch let e { } catch {
guard !e.isOpenVPNError() else { if let nativeError = error.asNativeOpenVPNError {
deferStop(.shutdown, e) deferStop(.shutdown, nativeError)
return return
} }
deferStop(.reconnect, e) deferStop(.reconnect, error)
} }
} }
@ -1206,8 +1206,8 @@ public class OpenVPNSession: Session {
ackPacketIds: [controlPacket.packetId], ackPacketIds: [controlPacket.packetId],
ackRemoteSessionId: controlPacket.sessionId ackRemoteSessionId: controlPacket.sessionId
) )
} catch let e { } catch {
deferStop(.shutdown, e) deferStop(.shutdown, error)
return return
} }

View File

@ -36,7 +36,7 @@ open class WireGuardTunnelProvider: NEPacketTunnelProvider {
cfg = try fromDictionary(WireGuard.ProviderConfiguration.self, providerConfiguration) cfg = try fromDictionary(WireGuard.ProviderConfiguration.self, providerConfiguration)
tunnelConfiguration = cfg.configuration.tunnelConfiguration tunnelConfiguration = cfg.configuration.tunnelConfiguration
} catch { } catch {
completionHandler(WireGuardProviderError.savedProtocolConfigurationIsInvalid) completionHandler(TunnelKitWireGuardError.savedProtocolConfigurationIsInvalid)
return return
} }
@ -59,24 +59,24 @@ open class WireGuardTunnelProvider: NEPacketTunnelProvider {
case .cannotLocateTunnelFileDescriptor: case .cannotLocateTunnelFileDescriptor:
wg_log(.error, staticMessage: "Starting tunnel failed: could not determine file descriptor") wg_log(.error, staticMessage: "Starting tunnel failed: could not determine file descriptor")
self.cfg._appexSetLastError(.couldNotDetermineFileDescriptor) self.cfg._appexSetLastError(.couldNotDetermineFileDescriptor)
completionHandler(WireGuardProviderError.couldNotDetermineFileDescriptor) completionHandler(TunnelKitWireGuardError.couldNotDetermineFileDescriptor)
case .dnsResolution(let dnsErrors): case .dnsResolution(let dnsErrors):
let hostnamesWithDnsResolutionFailure = dnsErrors.map(\.address) let hostnamesWithDnsResolutionFailure = dnsErrors.map(\.address)
.joined(separator: ", ") .joined(separator: ", ")
wg_log(.error, message: "DNS resolution failed for the following hostnames: \(hostnamesWithDnsResolutionFailure)") wg_log(.error, message: "DNS resolution failed for the following hostnames: \(hostnamesWithDnsResolutionFailure)")
self.cfg._appexSetLastError(.dnsResolutionFailure) self.cfg._appexSetLastError(.dnsResolutionFailure)
completionHandler(WireGuardProviderError.dnsResolutionFailure) completionHandler(TunnelKitWireGuardError.dnsResolutionFailure)
case .setNetworkSettings(let error): case .setNetworkSettings(let error):
wg_log(.error, message: "Starting tunnel failed with setTunnelNetworkSettings returning \(error.localizedDescription)") wg_log(.error, message: "Starting tunnel failed with setTunnelNetworkSettings returning \(error.localizedDescription)")
self.cfg._appexSetLastError(.couldNotSetNetworkSettings) self.cfg._appexSetLastError(.couldNotSetNetworkSettings)
completionHandler(WireGuardProviderError.couldNotSetNetworkSettings) completionHandler(TunnelKitWireGuardError.couldNotSetNetworkSettings)
case .startWireGuardBackend(let errorCode): case .startWireGuardBackend(let errorCode):
wg_log(.error, message: "Starting tunnel failed with wgTurnOn returning \(errorCode)") wg_log(.error, message: "Starting tunnel failed with wgTurnOn returning \(errorCode)")
self.cfg._appexSetLastError(.couldNotStartBackend) self.cfg._appexSetLastError(.couldNotStartBackend)
completionHandler(WireGuardProviderError.couldNotStartBackend) completionHandler(TunnelKitWireGuardError.couldNotStartBackend)
case .invalidState: case .invalidState:
// Must never happen // Must never happen

View File

@ -3,7 +3,7 @@
import Foundation import Foundation
public enum WireGuardProviderError: String, Error { public enum TunnelKitWireGuardError: String, Error {
case savedProtocolConfigurationIsInvalid case savedProtocolConfigurationIsInvalid
case dnsResolutionFailure case dnsResolutionFailure
case couldNotStartBackend case couldNotStartBackend

View File

@ -91,7 +91,7 @@ extension WireGuard.ProviderConfiguration: NetworkExtensionConfiguration {
// MARK: Shared data // MARK: Shared data
extension WireGuard.ProviderConfiguration { extension WireGuard.ProviderConfiguration {
public var lastError: WireGuardProviderError? { public var lastError: TunnelKitWireGuardError? {
return defaults?.wireGuardLastError return defaults?.wireGuardLastError
} }
@ -105,7 +105,7 @@ extension WireGuard.ProviderConfiguration {
} }
extension WireGuard.ProviderConfiguration { extension WireGuard.ProviderConfiguration {
public func _appexSetLastError(_ newValue: WireGuardProviderError?) { public func _appexSetLastError(_ newValue: TunnelKitWireGuardError?) {
defaults?.wireGuardLastError = newValue defaults?.wireGuardLastError = newValue
} }
@ -131,12 +131,12 @@ extension UserDefaults {
.appendingPathComponent(path) .appendingPathComponent(path)
} }
public fileprivate(set) var wireGuardLastError: WireGuardProviderError? { public fileprivate(set) var wireGuardLastError: TunnelKitWireGuardError? {
get { get {
guard let rawValue = string(forKey: WireGuard.ProviderConfiguration.Keys.lastError.rawValue) else { guard let rawValue = string(forKey: WireGuard.ProviderConfiguration.Keys.lastError.rawValue) else {
return nil return nil
} }
return WireGuardProviderError(rawValue: rawValue) return TunnelKitWireGuardError(rawValue: rawValue)
} }
set { set {
guard let newValue = newValue else { guard let newValue = newValue else {

View File

@ -82,8 +82,8 @@ class ControlChannelTests: XCTestCase {
let packet: ControlPacket let packet: ControlPacket
do { do {
packet = try client.deserialize(data: original, start: 0, end: nil) packet = try client.deserialize(data: original, start: 0, end: nil)
} catch let e { } catch {
XCTAssertNil(e) XCTAssertNil(error)
return return
} }
XCTAssertEqual(packet.code, .hardResetClientV2) XCTAssertEqual(packet.code, .hardResetClientV2)
@ -94,8 +94,8 @@ class ControlChannelTests: XCTestCase {
let raw: Data let raw: Data
do { do {
raw = try server.serialize(packet: packet, timestamp: timestamp) raw = try server.serialize(packet: packet, timestamp: timestamp)
} catch let e { } catch {
XCTAssertNil(e) XCTAssertNil(error)
return return
} }
print("raw: \(raw.toHex())") print("raw: \(raw.toHex())")
@ -113,8 +113,8 @@ class ControlChannelTests: XCTestCase {
let packet: ControlPacket let packet: ControlPacket
do { do {
packet = try client.deserialize(data: original, start: 0, end: nil) packet = try client.deserialize(data: original, start: 0, end: nil)
} catch let e { } catch {
XCTAssertNil(e) XCTAssertNil(error)
return return
} }
XCTAssertEqual(packet.code, .hardResetServerV2) XCTAssertEqual(packet.code, .hardResetServerV2)
@ -126,8 +126,8 @@ class ControlChannelTests: XCTestCase {
let raw: Data let raw: Data
do { do {
raw = try server.serialize(packet: packet, timestamp: timestamp) raw = try server.serialize(packet: packet, timestamp: timestamp)
} catch let e { } catch {
XCTAssertNil(e) XCTAssertNil(error)
return return
} }
print("raw: \(raw.toHex())") print("raw: \(raw.toHex())")