Hide errors behind façade TunnelKit*Error (#325)

This commit is contained in:
Davide De Rosa 2023-07-02 11:56:40 +02:00 committed by GitHub
parent d8563e7f15
commit 729e8973cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 292 additions and 246 deletions

View File

@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## Unreleased
### Changed
- Hide errors behind façade TunnelKit\*Error. [#325](https://github.com/passepartoutvpn/tunnelkit/pull/325)
## 6.0.0 (2023-04-02)
### Added

View File

@ -53,8 +53,11 @@ public struct DNSRecord {
/// Errors coming from `DNSResolver`.
public enum DNSError: Error {
/// Resolution failed.
case failure
/// Resolution timed out.
case timeout
}
@ -71,8 +74,8 @@ public class DNSResolver {
- Parameter queue: The queue to execute the `completionHandler` in.
- Parameter completionHandler: The completion handler with the resolved addresses and an optional error.
*/
public static func resolve(_ hostname: String, timeout: Int, queue: DispatchQueue, completionHandler: @escaping (Result<[DNSRecord], DNSError>) -> Void) {
var pendingHandler: ((Result<[DNSRecord], DNSError>) -> Void)? = completionHandler
public static func resolve(_ hostname: String, timeout: Int, queue: DispatchQueue, completionHandler: @escaping (Result<[DNSRecord], Error>) -> Void) {
var pendingHandler: ((Result<[DNSRecord], Error>) -> Void)? = completionHandler
let host = CFHostCreateWithName(nil, hostname as CFString).takeRetainedValue()
DNSResolver.queue.async {
CFHostStartInfoResolution(host, .addresses, nil)
@ -91,15 +94,15 @@ public class DNSResolver {
return
}
CFHostCancelInfoResolution(host, .addresses)
handler(.failure(.timeout))
handler(.failure(TunnelKitCoreError.dnsResolver(.timeout)))
pendingHandler = nil
}
}
private static func didResolve(host: CFHost, completionHandler: @escaping (Result<[DNSRecord], DNSError>) -> Void) {
private static func didResolve(host: CFHost, completionHandler: @escaping (Result<[DNSRecord], Error>) -> Void) {
var success: DarwinBoolean = false
guard let rawAddresses = CFHostGetAddressing(host, &success)?.takeUnretainedValue() as Array? else {
completionHandler(.failure(.failure))
completionHandler(.failure(TunnelKitCoreError.dnsResolver(.failure)))
return
}
@ -129,7 +132,7 @@ public class DNSResolver {
}
}
guard !records.isEmpty else {
completionHandler(.failure(.failure))
completionHandler(.failure(TunnelKitCoreError.dnsResolver(.failure)))
return
}
completionHandler(.success(records))

View File

@ -62,23 +62,10 @@ public struct IPv4Settings: Codable, Equatable, CustomStringConvertible {
/// The address of the default gateway.
public let defaultGateway: String
/// The additional routes.
@available(*, deprecated, message: "Store routes separately")
public let routes: [Route]
public init(address: String, addressMask: String, defaultGateway: String) {
self.address = address
self.addressMask = addressMask
self.defaultGateway = defaultGateway
self.routes = []
}
@available(*, deprecated, message: "Store routes separately")
public init(address: String, addressMask: String, defaultGateway: String, routes: [Route]) {
self.address = address
self.addressMask = addressMask
self.defaultGateway = defaultGateway
self.routes = routes
}
// MARK: CustomStringConvertible

View File

@ -62,23 +62,10 @@ public struct IPv6Settings: Codable, Equatable, CustomStringConvertible {
/// The address of the default gateway.
public let defaultGateway: String
/// The additional routes.
@available(*, deprecated, message: "Store routes separately")
public let routes: [Route]
public init(address: String, addressPrefixLength: UInt8, defaultGateway: String) {
self.address = address
self.addressPrefixLength = addressPrefixLength
self.defaultGateway = defaultGateway
self.routes = []
}
@available(*, deprecated, message: "Store routes separately")
public init(address: String, addressPrefixLength: UInt8, defaultGateway: String, routes: [Route]) {
self.address = address
self.addressPrefixLength = addressPrefixLength
self.defaultGateway = defaultGateway
self.routes = routes
}
// MARK: CustomStringConvertible

View File

@ -54,7 +54,7 @@ public class SecureRandom {
var randomBuffer = [UInt8](repeating: 0, count: 4)
guard SecRandomCopyBytes(kSecRandomDefault, 4, &randomBuffer) == 0 else {
throw SecureRandomError.randomGenerator
throw TunnelKitCoreError.secureRandom(.randomGenerator)
}
var randomNumber: UInt32 = 0
@ -71,7 +71,7 @@ public class SecureRandom {
try withUnsafeMutablePointer(to: &randomNumber) {
try $0.withMemoryRebound(to: UInt8.self, capacity: 4) { (randomBytes: UnsafeMutablePointer<UInt8>) -> Void in
guard SecRandomCopyBytes(kSecRandomDefault, 4, randomBytes) == 0 else {
throw SecureRandomError.randomGenerator
throw TunnelKitCoreError.secureRandom(.randomGenerator)
}
}
}
@ -85,7 +85,7 @@ public class SecureRandom {
try randomData.withUnsafeMutableBytes {
let randomBytes = $0.bytePointer
guard SecRandomCopyBytes(kSecRandomDefault, length, randomBytes) == 0 else {
throw SecureRandomError.randomGenerator
throw TunnelKitCoreError.secureRandom(.randomGenerator)
}
}
@ -101,7 +101,7 @@ public class SecureRandom {
}
guard SecRandomCopyBytes(kSecRandomDefault, length, randomBytes) == 0 else {
throw SecureRandomError.randomGenerator
throw TunnelKitCoreError.secureRandom(.randomGenerator)
}
return Z(bytes: randomBytes, count: length)

View File

@ -1,8 +1,8 @@
//
// Errors.swift
// TunnelKitCoreError.swift
// TunnelKit
//
// Created by Davide De Rosa on 5/19/19.
// Created by Davide De Rosa on 6/16/23.
// Copyright (c) 2023 Davide De Rosa. All rights reserved.
//
// https://github.com/passepartoutvpn
@ -24,19 +24,10 @@
//
import Foundation
import CTunnelKitOpenVPNCore
extension Error {
public func isOpenVPNError() -> Bool {
let te = self as NSError
return te.domain == OpenVPNErrorDomain
}
/// Errors returned by Core library.
public enum TunnelKitCoreError: Error {
case secureRandom(_ error: SecureRandomError)
public func openVPNErrorCode() -> OpenVPNErrorCode? {
let te = self as NSError
guard te.domain == OpenVPNErrorDomain else {
return nil
}
return OpenVPNErrorCode(rawValue: te.code)
}
case dnsResolver(_ error: DNSError)
}

View File

@ -92,11 +92,11 @@ public class Keychain {
return try passwordReference(for: username, context: context)
}
removePassword(for: username, context: context)
} catch let e as KeychainError {
} catch let error as KeychainError {
// rethrow cancelation
if e == .userCancelled {
throw e
if error == .userCancelled {
throw error
}
// otherwise, no pre-existing password
@ -114,7 +114,7 @@ public class Keychain {
var ref: CFTypeRef?
let status = SecItemAdd(query as CFDictionary, &ref)
guard status == errSecSuccess, let refData = ref as? Data else {
throw KeychainError.add
throw TunnelKitManagerError.keychain(.add)
}
return refData
}
@ -160,16 +160,16 @@ public class Keychain {
break
case errSecUserCanceled:
throw KeychainError.userCancelled
throw TunnelKitManagerError.keychain(.userCancelled)
default:
throw KeychainError.notFound
throw TunnelKitManagerError.keychain(.notFound)
}
guard let data = result as? Data else {
throw KeychainError.notFound
throw TunnelKitManagerError.keychain(.notFound)
}
guard let password = String(data: data, encoding: .utf8) else {
throw KeychainError.notFound
throw TunnelKitManagerError.keychain(.notFound)
}
return password
}
@ -197,13 +197,13 @@ public class Keychain {
break
case errSecUserCanceled:
throw KeychainError.userCancelled
throw TunnelKitManagerError.keychain(.userCancelled)
default:
throw KeychainError.notFound
throw TunnelKitManagerError.keychain(.notFound)
}
guard let data = result as? Data else {
throw KeychainError.notFound
throw TunnelKitManagerError.keychain(.notFound)
}
return data
}
@ -226,16 +226,16 @@ public class Keychain {
break
case errSecUserCanceled:
throw KeychainError.userCancelled
throw TunnelKitManagerError.keychain(.userCancelled)
default:
throw KeychainError.notFound
throw TunnelKitManagerError.keychain(.notFound)
}
guard let data = result as? Data else {
throw KeychainError.notFound
throw TunnelKitManagerError.keychain(.notFound)
}
guard let password = String(data: data, encoding: .utf8) else {
throw KeychainError.notFound
throw TunnelKitManagerError.keychain(.notFound)
}
return password
}
@ -265,7 +265,7 @@ public class Keychain {
let status = SecItemAdd(query as CFDictionary, nil)
guard status == errSecSuccess else {
throw KeychainError.add
throw TunnelKitManagerError.keychain(.add)
}
return try publicKey(withIdentifier: identifier)
}
@ -294,13 +294,13 @@ public class Keychain {
break
case errSecUserCanceled:
throw KeychainError.userCancelled
throw TunnelKitManagerError.keychain(.userCancelled)
default:
throw KeychainError.notFound
throw TunnelKitManagerError.keychain(.notFound)
}
// guard let key = result as? SecKey else {
// throw KeychainError.typeMismatch
// throw TunnelKitManagerError.keychain(.typeMismatch)
// }
// return key
return result as! SecKey

View File

@ -0,0 +1,31 @@
//
// TunnelKitManagerError.swift
// TunnelKit
//
// Created by Davide De Rosa on 6/16/23.
// Copyright (c) 2023 Davide De Rosa. All rights reserved.
//
// https://github.com/passepartoutvpn
//
// This file is part of TunnelKit.
//
// TunnelKit is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// TunnelKit is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with TunnelKit. If not, see <http://www.gnu.org/licenses/>.
//
import Foundation
/// Errors returned by Core library.
public enum TunnelKitManagerError: Error {
case keychain(_ error: KeychainError)
}

View File

@ -94,7 +94,7 @@ class ConnectionStrategy {
from provider: NEProvider,
timeout: Int,
queue: DispatchQueue,
completionHandler: @escaping (Result<GenericSocket, OpenVPNProviderError>) -> Void) {
completionHandler: @escaping (Result<GenericSocket, TunnelKitOpenVPNError>) -> Void) {
guard let remote = currentRemote else {
completionHandler(.failure(.exhaustedEndpoints))
return

View File

@ -145,28 +145,28 @@ open class OpenVPNTunnelProvider: NEPacketTunnelProvider {
// required configuration
do {
guard let tunnelProtocol = protocolConfiguration as? NETunnelProviderProtocol else {
throw OpenVPNProviderConfigurationError.parameter(name: "protocolConfiguration")
throw ConfigurationError.parameter(name: "protocolConfiguration")
}
guard let _ = tunnelProtocol.serverAddress else {
throw OpenVPNProviderConfigurationError.parameter(name: "protocolConfiguration.serverAddress")
throw ConfigurationError.parameter(name: "protocolConfiguration.serverAddress")
}
guard let providerConfiguration = tunnelProtocol.providerConfiguration else {
throw OpenVPNProviderConfigurationError.parameter(name: "protocolConfiguration.providerConfiguration")
throw ConfigurationError.parameter(name: "protocolConfiguration.providerConfiguration")
}
cfg = try fromDictionary(OpenVPN.ProviderConfiguration.self, providerConfiguration)
} catch let e {
var message: String?
if let te = e as? OpenVPNProviderConfigurationError {
switch te {
case .parameter(let name):
message = "Tunnel configuration incomplete: \(name)"
} catch let cfgError as ConfigurationError {
switch cfgError {
case .parameter(let name):
NSLog("Tunnel configuration incomplete: \(name)")
default:
break
}
default:
NSLog("Tunnel configuration error: \(cfgError)")
}
NSLog(message ?? "Unexpected error in tunnel configuration: \(e)")
completionHandler(e)
completionHandler(cfgError)
return
} catch {
NSLog("Unexpected error in tunnel configuration: \(error)")
completionHandler(error)
return
}
@ -188,7 +188,7 @@ open class OpenVPNTunnelProvider: NEPacketTunnelProvider {
let credentials: OpenVPN.Credentials?
if let username = protocolConfiguration.username, let passwordReference = protocolConfiguration.passwordReference {
guard let password = try? Keychain.password(forReference: passwordReference) else {
completionHandler(OpenVPNProviderConfigurationError.credentials(details: "Keychain.password(forReference:)"))
completionHandler(ConfigurationError.credentials(details: "Keychain.password(forReference:)"))
return
}
credentials = OpenVPN.Credentials(username, password)
@ -200,7 +200,7 @@ open class OpenVPNTunnelProvider: NEPacketTunnelProvider {
cfg._appexSetLastError(nil)
guard OpenVPN.prepareRandomNumberGenerator(seedLength: prngSeedLength) else {
completionHandler(OpenVPNProviderConfigurationError.prngInitialization)
completionHandler(ConfigurationError.prngInitialization)
return
}
@ -216,8 +216,8 @@ open class OpenVPNTunnelProvider: NEPacketTunnelProvider {
do {
session = try OpenVPNSession(queue: tunnelQueue, configuration: cfg.configuration, cachesURL: cachesURL)
refreshDataCount()
} catch let e {
completionHandler(e)
} catch {
completionHandler(error)
return
}
session.credentials = credentials
@ -359,7 +359,7 @@ open class OpenVPNTunnelProvider: NEPacketTunnelProvider {
// from stopTunnel(), in which case we don't need to feed an error parameter to
// the stop completion handler
//
pendingStartHandler?(error ?? OpenVPNProviderError.socketActivity)
pendingStartHandler?(error ?? TunnelKitOpenVPNError.socketActivity)
pendingStartHandler = nil
}
// stopped intentionally
@ -434,9 +434,13 @@ extension OpenVPNTunnelProvider: GenericSocketDelegate {
// look for error causing shutdown
shutdownError = session.stopError
if failure && (shutdownError == nil) {
shutdownError = OpenVPNProviderError.linkError
shutdownError = TunnelKitOpenVPNError.linkError
}
if case .negotiationTimeout = shutdownError as? OpenVPNError {
didTimeoutNegotiation = true
} else {
didTimeoutNegotiation = false
}
didTimeoutNegotiation = (shutdownError as? OpenVPNError == .negotiationTimeout)
// only try upgrade on network errors
if shutdownError as? OpenVPNError == nil {
@ -479,7 +483,7 @@ extension OpenVPNTunnelProvider: GenericSocketDelegate {
public func socketHasBetterPath(_ socket: GenericSocket) {
log.debug("Stopping tunnel due to a new better path")
logCurrentSSID()
session?.reconnect(error: OpenVPNProviderError.networkChanged)
session?.reconnect(error: TunnelKitOpenVPNError.networkChanged)
}
}
@ -546,7 +550,7 @@ extension OpenVPNTunnelProvider: OpenVPNSessionDelegate {
let newSettings = NetworkSettingsBuilder(remoteAddress: remoteAddress, localOptions: localOptions, remoteOptions: remoteOptions)
guard !newSettings.isGateway || newSettings.hasGateway else {
session?.shutdown(error: OpenVPNProviderError.gatewayUnattainable)
session?.shutdown(error: TunnelKitOpenVPNError.gatewayUnattainable)
return
}
@ -594,7 +598,7 @@ extension OpenVPNTunnelProvider: OpenVPNSessionDelegate {
extension OpenVPNTunnelProvider {
private func tryNextEndpoint() -> Bool {
guard strategy.tryNextEndpoint() else {
disposeTunnel(error: OpenVPNProviderError.exhaustedEndpoints)
disposeTunnel(error: TunnelKitOpenVPNError.exhaustedEndpoints)
return false
}
return true
@ -647,44 +651,41 @@ extension OpenVPNTunnelProvider {
// let anyObject = object as AnyObject
// return Unmanaged<AnyObject>.passUnretained(anyObject).toOpaque()
// }
}
// MARK: Errors
// MARK: Errors
private func setErrorStatus(with error: Error) {
private extension OpenVPNTunnelProvider {
enum ConfigurationError: Error {
/// A field in the `OpenVPNProvider.Configuration` provided is incorrect or incomplete.
case parameter(name: String)
/// Credentials are missing or inaccessible.
case credentials(details: String)
/// The pseudo-random number generator could not be initialized.
case prngInitialization
/// The TLS certificate could not be serialized.
case certificateSerialization
}
func setErrorStatus(with error: Error) {
cfg._appexSetLastError(unifiedError(from: error))
}
private func unifiedError(from error: Error) -> OpenVPNProviderError {
if let te = error.openVPNErrorCode() {
switch te {
case .cryptoRandomGenerator, .cryptoAlgorithm:
return .encryptionInitialization
func unifiedError(from error: Error) -> TunnelKitOpenVPNError {
case .cryptoEncryption, .cryptoHMAC:
return .encryptionData
// XXX: error handling is limited by lastError serialization
// requirement, cannot return a generic Error here
// openVPNError(from: error) ?? error
openVPNError(from: error) ?? .linkError
}
case .tlscaRead, .tlscaUse, .tlscaPeerVerification,
.tlsClientCertificateRead, .tlsClientCertificateUse,
.tlsClientKeyRead, .tlsClientKeyUse:
return .tlsInitialization
case .tlsServerCertificate, .tlsServerEKU, .tlsServerHost:
return .tlsServerVerification
case .tlsHandshake:
return .tlsHandshake
case .dataPathOverflow, .dataPathPeerIdMismatch:
return .unexpectedReply
case .dataPathCompression:
return .serverCompression
default:
break
}
} else if let se = error as? OpenVPNError {
switch se {
func openVPNError(from error: Error) -> TunnelKitOpenVPNError? {
if let specificError = error as? OpenVPNError {
switch specificError.asNativeOpenVPNError ?? specificError {
case .negotiationTimeout, .pingTimeout, .staleSession:
return .timeout
@ -703,14 +704,45 @@ extension OpenVPNTunnelProvider {
case .serverShutdown:
return .serverShutdown
case .native(let code):
switch code {
case .cryptoRandomGenerator, .cryptoAlgorithm:
return .encryptionInitialization
case .cryptoEncryption, .cryptoHMAC:
return .encryptionData
case .tlscaRead, .tlscaUse, .tlscaPeerVerification,
.tlsClientCertificateRead, .tlsClientCertificateUse,
.tlsClientKeyRead, .tlsClientKeyUse:
return .tlsInitialization
case .tlsServerCertificate, .tlsServerEKU, .tlsServerHost:
return .tlsServerVerification
case .tlsHandshake:
return .tlsHandshake
case .dataPathOverflow, .dataPathPeerIdMismatch:
return .unexpectedReply
case .dataPathCompression:
return .serverCompression
default:
break
}
default:
return .unexpectedReply
}
}
return error as? OpenVPNProviderError ?? .linkError
return nil
}
}
// MARK: Hacks
private extension NEPacketTunnelProvider {
func forceExitOnMac() {
#if os(macOS)

View File

@ -64,7 +64,7 @@ class ResolvedRemote: CustomStringConvertible {
}
}
private func handleResult(_ result: Result<[DNSRecord], DNSError>) {
private func handleResult(_ result: Result<[DNSRecord], Error>) {
switch result {
case .success(let records):
log.debug("DNS resolved addresses: \(records.map { $0.address }.maskedDescription)")

View File

@ -44,5 +44,8 @@ extension OpenVPN {
/// Encryption passphrase is incorrect or key is corrupt.
case unableToDecrypt(error: Error)
/// The PUSH_REPLY is multipart.
case continuationPushReply
}
}

View File

@ -339,7 +339,7 @@ extension OpenVPN {
isContinuation = ($0.first == "2")
}
guard !isContinuation else {
throw OpenVPNError.continuationPushReply
throw ConfigurationError.continuationPushReply
}
// MARK: Inline content
@ -805,8 +805,8 @@ extension OpenVPN {
}
do {
sessionBuilder.clientKey = try clientKey.decrypted(with: passphrase)
} catch let e {
throw ConfigurationError.unableToDecrypt(error: e)
} catch {
throw ConfigurationError.unableToDecrypt(error: error)
}
} else {
sessionBuilder.clientKey = optClientKey

View File

@ -35,9 +35,10 @@
//
import Foundation
import CTunnelKitOpenVPNCore
/// The possible errors raised/thrown during `OpenVPNSession` operation.
public enum OpenVPNError: String, Error {
public enum OpenVPNError: Error {
/// The negotiation timed out.
case negotiationTimeout
@ -51,15 +52,15 @@ public enum OpenVPNError: String, Error {
/// The connection key is wrong or wasn't expected.
case badKey
/// Control channel failure.
case controlChannel(message: String)
/// The control packet has an incorrect prefix payload.
case wrongControlDataPrefix
/// The provided credentials failed authentication.
case badCredentials
/// The PUSH_REPLY is multipart.
case continuationPushReply
/// The reply to PUSH_REQUEST is malformed.
case malformedPushReply
@ -80,4 +81,17 @@ public enum OpenVPNError: String, Error {
/// Remote server shut down (--explicit-exit-notify).
case serverShutdown
/// NSError from ObjC layer.
case native(code: OpenVPNErrorCode)
}
extension Error {
public var asNativeOpenVPNError: OpenVPNError? {
let nativeError = self as NSError
guard nativeError.domain == OpenVPNErrorDomain, let code = OpenVPNErrorCode(rawValue: nativeError.code) else {
return nil
}
return .native(code: code)
}
}

View File

@ -139,7 +139,7 @@ extension OpenVPN.ProviderConfiguration {
/**
The last error reported by the tunnel, if any.
*/
public var lastError: OpenVPNProviderError? {
public var lastError: TunnelKitOpenVPNError? {
return defaults?.openVPNLastError
}
@ -164,7 +164,7 @@ extension OpenVPN.ProviderConfiguration {
defaults?.openVPNServerConfiguration = newValue
}
public func _appexSetLastError(_ newValue: OpenVPNProviderError?) {
public func _appexSetLastError(_ newValue: TunnelKitOpenVPNError?) {
defaults?.openVPNLastError = newValue
}
@ -250,12 +250,12 @@ extension UserDefaults {
}
}
public fileprivate(set) var openVPNLastError: OpenVPNProviderError? {
public fileprivate(set) var openVPNLastError: TunnelKitOpenVPNError? {
get {
guard let rawValue = string(forKey: OpenVPN.ProviderConfiguration.Keys.lastError.rawValue) else {
return nil
}
return OpenVPNProviderError(rawValue: rawValue)
return TunnelKitOpenVPNError(rawValue: rawValue)
}
set {
guard let newValue = newValue else {

View File

@ -1,5 +1,5 @@
//
// OpenVPNProviderError.swift
// TunnelKitOpenVPNError.swift
// TunnelKit
//
// Created by Davide De Rosa on 11/8/21.
@ -35,25 +35,10 @@
//
import Foundation
/// Mostly programming errors by host app.
public enum OpenVPNProviderConfigurationError: Error {
/// A field in the `OpenVPNProvider.Configuration` provided is incorrect or incomplete.
case parameter(name: String)
/// Credentials are missing or inaccessible.
case credentials(details: String)
/// The pseudo-random number generator could not be initialized.
case prngInitialization
/// The TLS certificate could not be serialized.
case certificateSerialization
}
import TunnelKitOpenVPNCore
/// The errors causing a tunnel disconnection.
public enum OpenVPNProviderError: String, Error {
public enum TunnelKitOpenVPNError: String, Error {
/// Socket endpoint could not be resolved.
case dnsFailure

View File

@ -33,14 +33,6 @@ import CTunnelKitOpenVPNProtocol
private let log = SwiftyBeaver.self
extension OpenVPN {
class ControlChannelError: Error, CustomStringConvertible {
let description: String
init(_ message: String) {
description = "\(String(describing: ControlChannelError.self))(\(message))"
}
}
class ControlChannel {
private let serializer: ControlChannelSerializer
@ -101,12 +93,17 @@ extension OpenVPN {
}
func readInboundPacket(withData data: Data, offset: Int) throws -> ControlPacket {
let packet = try serializer.deserialize(data: data, start: offset, end: nil)
log.debug("Control: Read packet \(packet)")
if let ackIds = packet.ackIds as? [UInt32], let ackRemoteSessionId = packet.ackRemoteSessionId {
try readAcks(ackIds, acksRemoteSessionId: ackRemoteSessionId)
do {
let packet = try serializer.deserialize(data: data, start: offset, end: nil)
log.debug("Control: Read packet \(packet)")
if let ackIds = packet.ackIds as? [UInt32], let ackRemoteSessionId = packet.ackRemoteSessionId {
try readAcks(ackIds, acksRemoteSessionId: ackRemoteSessionId)
}
return packet
} catch {
log.error("Control: Channel failure \(error)")
throw error
}
return packet
}
func enqueueInboundPacket(packet: ControlPacket) -> [ControlPacket] {

View File

@ -54,11 +54,11 @@ extension OpenVPN.ControlChannel {
let end = end ?? packet.count
guard end >= offset + PacketOpcodeLength else {
throw OpenVPN.ControlChannelError("Missing opcode")
throw OpenVPNError.controlChannel(message: "Missing opcode")
}
let codeValue = packet[offset] >> 3
guard let code = PacketCode(rawValue: codeValue) else {
throw OpenVPN.ControlChannelError("Unknown code: \(codeValue))")
throw OpenVPNError.controlChannel(message: "Unknown code: \(codeValue))")
}
let key = packet[offset] & 0b111
offset += PacketOpcodeLength
@ -66,13 +66,13 @@ extension OpenVPN.ControlChannel {
log.debug("Control: Try read packet with code \(code) and key \(key)")
guard end >= offset + PacketSessionIdLength else {
throw OpenVPN.ControlChannelError("Missing sessionId")
throw OpenVPNError.controlChannel(message: "Missing sessionId")
}
let sessionId = packet.subdata(offset: offset, count: PacketSessionIdLength)
offset += PacketSessionIdLength
guard end >= offset + 1 else {
throw OpenVPN.ControlChannelError("Missing ackSize")
throw OpenVPNError.controlChannel(message: "Missing ackSize")
}
let ackSize = packet[offset]
offset += 1
@ -81,7 +81,7 @@ extension OpenVPN.ControlChannel {
var ackRemoteSessionId: Data?
if ackSize > 0 {
guard end >= (offset + Int(ackSize) * PacketIdLength) else {
throw OpenVPN.ControlChannelError("Missing acks")
throw OpenVPNError.controlChannel(message: "Missing acks")
}
var ids: [UInt32] = []
for _ in 0..<ackSize {
@ -91,7 +91,7 @@ extension OpenVPN.ControlChannel {
}
guard end >= offset + PacketSessionIdLength else {
throw OpenVPN.ControlChannelError("Missing remoteSessionId")
throw OpenVPNError.controlChannel(message: "Missing remoteSessionId")
}
let remoteSessionId = packet.subdata(offset: offset, count: PacketSessionIdLength)
offset += PacketSessionIdLength
@ -102,16 +102,16 @@ extension OpenVPN.ControlChannel {
if code == .ackV1 {
guard let ackIds = ackIds else {
throw OpenVPN.ControlChannelError("Ack packet without ids")
throw OpenVPNError.controlChannel(message: "Ack packet without ids")
}
guard let ackRemoteSessionId = ackRemoteSessionId else {
throw OpenVPN.ControlChannelError("Ack packet without remoteSessionId")
throw OpenVPNError.controlChannel(message: "Ack packet without remoteSessionId")
}
return ControlPacket(key: key, sessionId: sessionId, ackIds: ackIds as [NSNumber], ackRemoteSessionId: ackRemoteSessionId)
}
guard end >= offset + PacketIdLength else {
throw OpenVPN.ControlChannelError("Missing packetId")
throw OpenVPNError.controlChannel(message: "Missing packetId")
}
let packetId = packet.networkUInt32Value(from: offset)
offset += PacketIdLength
@ -192,7 +192,7 @@ extension OpenVPN.ControlChannel {
// data starts with (prefix=(header + sessionId) + auth=(hmac + replayId))
guard end >= preambleLength else {
throw OpenVPN.ControlChannelError("Missing HMAC")
throw OpenVPNError.controlChannel(message: "Missing HMAC")
}
// needs a copy for swapping
@ -206,7 +206,12 @@ extension OpenVPN.ControlChannel {
// TODO: validate replay packet id
return try plain.deserialize(data: authPacket, start: authLength, end: nil)
do {
return try plain.deserialize(data: authPacket, start: authLength, end: nil)
} catch {
log.error("Control: Channel failure \(error)")
throw error
}
}
}
}
@ -269,7 +274,7 @@ extension OpenVPN.ControlChannel {
// data starts with (ad=(header + sessionId + replayId) + tag)
guard end >= start + adLength + tagLength else {
throw OpenVPN.ControlChannelError("Missing AD+TAG")
throw OpenVPNError.controlChannel(message: "Missing AD+TAG")
}
let encryptedCount = packet.count - adLength
@ -288,7 +293,12 @@ extension OpenVPN.ControlChannel {
// TODO: validate replay packet id
return try plain.deserialize(data: decryptedPacket, start: 0, end: nil)
do {
return try plain.deserialize(data: decryptedPacket, start: 0, end: nil)
} catch {
log.error("Control: Channel failure \(error)")
throw error
}
}
}
}

View File

@ -477,8 +477,8 @@ public class OpenVPNSession: Session {
continue
}
controlPacket = parsedPacket
} catch let e {
log.warning("Dropped malformed packet: \(e)")
} catch {
log.warning("Dropped malformed packet: \(error)")
continue
// deferStop(.shutdown, e)
// return
@ -573,8 +573,8 @@ public class OpenVPNSession: Session {
authenticator = nil
do {
try controlChannel.reset(forNewSession: forNewSession)
} catch let e {
deferStop(.shutdown, e)
} catch {
deferStop(.shutdown, error)
}
}
@ -658,18 +658,18 @@ public class OpenVPNSession: Session {
authenticator = try OpenVPN.Authenticator(credentials?.username, pushReply?.options.authToken ?? credentials?.password)
authenticator?.withLocalOptions = withLocalOptions
try authenticator?.putAuth(into: negotiationKey.tls, options: configuration)
} catch let e {
deferStop(.shutdown, e)
} catch {
deferStop(.shutdown, error)
return
}
let cipherTextOut: Data
do {
cipherTextOut = try negotiationKey.tls.pullCipherText()
} catch let e {
if let _ = e.openVPNErrorCode() {
log.error("TLS.auth: Failed pulling ciphertext (error: \(e))")
shutdown(error: e)
} catch {
if let nativeError = error.asNativeOpenVPNError {
log.error("TLS.auth: Failed pulling ciphertext (error: \(nativeError))")
shutdown(error: nativeError)
return
}
log.verbose("TLS.auth: Still can't pull ciphertext")
@ -695,10 +695,10 @@ public class OpenVPNSession: Session {
let cipherTextOut: Data
do {
cipherTextOut = try negotiationKey.tls.pullCipherText()
} catch let e {
if let _ = e.openVPNErrorCode() {
log.error("TLS.auth: Failed pulling ciphertext (error: \(e))")
shutdown(error: e)
} catch {
if let nativeError = error.asNativeOpenVPNError {
log.error("TLS.auth: Failed pulling ciphertext (error: \(nativeError))")
shutdown(error: nativeError)
return
}
log.verbose("TLS.ifconfig: Still can't pull ciphertext")
@ -789,21 +789,21 @@ public class OpenVPNSession: Session {
negotiationKey.tlsOptional = tls
do {
try negotiationKey.tls.start()
} catch let e {
deferStop(.shutdown, e)
} catch {
deferStop(.shutdown, error)
return
}
let cipherTextOut: Data
do {
cipherTextOut = try negotiationKey.tls.pullCipherText()
} catch let e {
if let _ = e.openVPNErrorCode() {
log.error("TLS.connect: Failed pulling ciphertext (error: \(e))")
shutdown(error: e)
} catch {
if let nativeError = error.asNativeOpenVPNError {
log.error("TLS.connect: Failed pulling ciphertext (error: \(nativeError))")
shutdown(error: nativeError)
return
}
deferStop(.shutdown, e)
deferStop(.shutdown, error)
return
}
@ -836,10 +836,10 @@ public class OpenVPNSession: Session {
cipherTextOut = try negotiationKey.tls.pullCipherText()
log.debug("TLS.connect: Send pulled ciphertext (\(cipherTextOut.count) bytes)")
enqueueControlPackets(code: .controlV1, key: negotiationKey.id, payload: cipherTextOut)
} catch let e {
if let _ = e.openVPNErrorCode() {
log.error("TLS.connect: Failed pulling ciphertext (error: \(e))")
shutdown(error: e)
} catch {
if let nativeError = error.asNativeOpenVPNError {
log.error("TLS.connect: Failed pulling ciphertext (error: \(nativeError))")
shutdown(error: nativeError)
return
}
log.verbose("TLS.connect: No available ciphertext to pull")
@ -878,8 +878,8 @@ public class OpenVPNSession: Session {
guard try auth.parseAuthReply() else {
return
}
} catch let e {
deferStop(.shutdown, e)
} catch {
deferStop(.shutdown, error)
return
}
@ -962,12 +962,12 @@ public class OpenVPNSession: Session {
throw OpenVPNError.serverCompression
}
}
} catch OpenVPNError.continuationPushReply {
} catch OpenVPN.ConfigurationError.continuationPushReply {
continuatedPushReplyMessage = completeMessage.replacingOccurrences(of: "push-continuation", with: "")
// FIXME: strip "PUSH_REPLY" and "push-continuation 2"
return
} catch let e {
deferStop(.shutdown, e)
} catch {
deferStop(.shutdown, error)
return
}
@ -1025,9 +1025,9 @@ public class OpenVPNSession: Session {
let rawList: [Data]
do {
rawList = try controlChannel.writeOutboundPackets()
} catch let e {
log.warning("Failed control packet serialization: \(e)")
deferStop(.shutdown, e)
} catch {
log.warning("Failed control packet serialization: \(error)")
deferStop(.shutdown, error)
return
}
for raw in rawList {
@ -1110,8 +1110,8 @@ public class OpenVPNSession: Session {
sessionId,
remoteSessionId
)
} catch let e {
deferStop(.shutdown, e)
} catch {
deferStop(.shutdown, error)
return
}
@ -1141,12 +1141,12 @@ public class OpenVPNSession: Session {
}
tunnel?.writePackets(decryptedPackets, completionHandler: nil)
} catch let e {
guard !e.isOpenVPNError() else {
deferStop(.shutdown, e)
} catch {
if let nativeError = error.asNativeOpenVPNError {
deferStop(.shutdown, nativeError)
return
}
deferStop(.reconnect, e)
deferStop(.reconnect, error)
}
}
@ -1181,12 +1181,12 @@ public class OpenVPNSession: Session {
// log.verbose("Data: \(encryptedPackets.count) packets successfully written to LINK")
}
}
} catch let e {
guard !e.isOpenVPNError() else {
deferStop(.shutdown, e)
} catch {
if let nativeError = error.asNativeOpenVPNError {
deferStop(.shutdown, nativeError)
return
}
deferStop(.reconnect, e)
deferStop(.reconnect, error)
}
}
@ -1206,8 +1206,8 @@ public class OpenVPNSession: Session {
ackPacketIds: [controlPacket.packetId],
ackRemoteSessionId: controlPacket.sessionId
)
} catch let e {
deferStop(.shutdown, e)
} catch {
deferStop(.shutdown, error)
return
}

View File

@ -36,7 +36,7 @@ open class WireGuardTunnelProvider: NEPacketTunnelProvider {
cfg = try fromDictionary(WireGuard.ProviderConfiguration.self, providerConfiguration)
tunnelConfiguration = cfg.configuration.tunnelConfiguration
} catch {
completionHandler(WireGuardProviderError.savedProtocolConfigurationIsInvalid)
completionHandler(TunnelKitWireGuardError.savedProtocolConfigurationIsInvalid)
return
}
@ -59,24 +59,24 @@ open class WireGuardTunnelProvider: NEPacketTunnelProvider {
case .cannotLocateTunnelFileDescriptor:
wg_log(.error, staticMessage: "Starting tunnel failed: could not determine file descriptor")
self.cfg._appexSetLastError(.couldNotDetermineFileDescriptor)
completionHandler(WireGuardProviderError.couldNotDetermineFileDescriptor)
completionHandler(TunnelKitWireGuardError.couldNotDetermineFileDescriptor)
case .dnsResolution(let dnsErrors):
let hostnamesWithDnsResolutionFailure = dnsErrors.map(\.address)
.joined(separator: ", ")
wg_log(.error, message: "DNS resolution failed for the following hostnames: \(hostnamesWithDnsResolutionFailure)")
self.cfg._appexSetLastError(.dnsResolutionFailure)
completionHandler(WireGuardProviderError.dnsResolutionFailure)
completionHandler(TunnelKitWireGuardError.dnsResolutionFailure)
case .setNetworkSettings(let error):
wg_log(.error, message: "Starting tunnel failed with setTunnelNetworkSettings returning \(error.localizedDescription)")
self.cfg._appexSetLastError(.couldNotSetNetworkSettings)
completionHandler(WireGuardProviderError.couldNotSetNetworkSettings)
completionHandler(TunnelKitWireGuardError.couldNotSetNetworkSettings)
case .startWireGuardBackend(let errorCode):
wg_log(.error, message: "Starting tunnel failed with wgTurnOn returning \(errorCode)")
self.cfg._appexSetLastError(.couldNotStartBackend)
completionHandler(WireGuardProviderError.couldNotStartBackend)
completionHandler(TunnelKitWireGuardError.couldNotStartBackend)
case .invalidState:
// Must never happen

View File

@ -3,7 +3,7 @@
import Foundation
public enum WireGuardProviderError: String, Error {
public enum TunnelKitWireGuardError: String, Error {
case savedProtocolConfigurationIsInvalid
case dnsResolutionFailure
case couldNotStartBackend

View File

@ -91,7 +91,7 @@ extension WireGuard.ProviderConfiguration: NetworkExtensionConfiguration {
// MARK: Shared data
extension WireGuard.ProviderConfiguration {
public var lastError: WireGuardProviderError? {
public var lastError: TunnelKitWireGuardError? {
return defaults?.wireGuardLastError
}
@ -105,7 +105,7 @@ extension WireGuard.ProviderConfiguration {
}
extension WireGuard.ProviderConfiguration {
public func _appexSetLastError(_ newValue: WireGuardProviderError?) {
public func _appexSetLastError(_ newValue: TunnelKitWireGuardError?) {
defaults?.wireGuardLastError = newValue
}
@ -131,12 +131,12 @@ extension UserDefaults {
.appendingPathComponent(path)
}
public fileprivate(set) var wireGuardLastError: WireGuardProviderError? {
public fileprivate(set) var wireGuardLastError: TunnelKitWireGuardError? {
get {
guard let rawValue = string(forKey: WireGuard.ProviderConfiguration.Keys.lastError.rawValue) else {
return nil
}
return WireGuardProviderError(rawValue: rawValue)
return TunnelKitWireGuardError(rawValue: rawValue)
}
set {
guard let newValue = newValue else {

View File

@ -82,8 +82,8 @@ class ControlChannelTests: XCTestCase {
let packet: ControlPacket
do {
packet = try client.deserialize(data: original, start: 0, end: nil)
} catch let e {
XCTAssertNil(e)
} catch {
XCTAssertNil(error)
return
}
XCTAssertEqual(packet.code, .hardResetClientV2)
@ -94,8 +94,8 @@ class ControlChannelTests: XCTestCase {
let raw: Data
do {
raw = try server.serialize(packet: packet, timestamp: timestamp)
} catch let e {
XCTAssertNil(e)
} catch {
XCTAssertNil(error)
return
}
print("raw: \(raw.toHex())")
@ -113,8 +113,8 @@ class ControlChannelTests: XCTestCase {
let packet: ControlPacket
do {
packet = try client.deserialize(data: original, start: 0, end: nil)
} catch let e {
XCTAssertNil(e)
} catch {
XCTAssertNil(error)
return
}
XCTAssertEqual(packet.code, .hardResetServerV2)
@ -126,8 +126,8 @@ class ControlChannelTests: XCTestCase {
let raw: Data
do {
raw = try server.serialize(packet: packet, timestamp: timestamp)
} catch let e {
XCTAssertNil(e)
} catch {
XCTAssertNil(error)
return
}
print("raw: \(raw.toHex())")