Merge branch 'refactor-shutdown-code'

This commit is contained in:
Davide De Rosa 2018-10-24 12:22:27 +02:00
commit 3f3a712bac
4 changed files with 46 additions and 43 deletions

View File

@ -44,8 +44,6 @@ protocol LinkProducer {
protocol GenericSocketDelegate: class { protocol GenericSocketDelegate: class {
func socketDidTimeout(_ socket: GenericSocket) func socketDidTimeout(_ socket: GenericSocket)
func socketShouldChangeProtocol(_ socket: GenericSocket) -> Bool
func socketDidBecomeActive(_ socket: GenericSocket) func socketDidBecomeActive(_ socket: GenericSocket)
func socket(_ socket: GenericSocket, didShutdownWithFailure failure: Bool) func socket(_ socket: GenericSocket, didShutdownWithFailure failure: Bool)

View File

@ -79,7 +79,6 @@ class NETCPSocket: NSObject, GenericSocket {
return return
} }
guard _self.isActive else { guard _self.isActive else {
_ = _self.delegate?.socketShouldChangeProtocol(_self)
_self.delegate?.socketDidTimeout(_self) _self.delegate?.socketDidTimeout(_self)
return return
} }

View File

@ -56,7 +56,7 @@ extension TunnelKitProvider {
} }
/// Defines the communication protocol of an endpoint. /// Defines the communication protocol of an endpoint.
public struct EndpointProtocol: Equatable, CustomStringConvertible { public struct EndpointProtocol: RawRepresentable, Equatable, CustomStringConvertible {
/// The socket type. /// The socket type.
public let socketType: SocketType public let socketType: SocketType
@ -70,23 +70,25 @@ extension TunnelKitProvider {
self.port = port self.port = port
} }
// MARK: RawRepresentable
/// :nodoc: /// :nodoc:
public static func deserialized(_ string: String) throws -> EndpointProtocol { public init?(rawValue: String) {
let components = string.components(separatedBy: ":") let components = rawValue.components(separatedBy: ":")
guard components.count == 2 else { guard components.count == 2 else {
throw ProviderConfigurationError.parameter(name: "endpointProtocol") return nil
} }
guard let socketType = SocketType(rawValue: components[0]) else { guard let socketType = SocketType(rawValue: components[0]) else {
throw ProviderConfigurationError.parameter(name: "endpointProtocol.socketType") return nil
} }
guard let port = UInt16(components[1]) else { guard let port = UInt16(components[1]) else {
throw ProviderConfigurationError.parameter(name: "endpointProtocol.port") return nil
} }
return EndpointProtocol(socketType, port) self.init(socketType, port)
} }
/// :nodoc: /// :nodoc:
public func serialized() -> String { public var rawValue: String {
return "\(socketType.rawValue):\(port)" return "\(socketType.rawValue):\(port)"
} }
@ -101,7 +103,7 @@ extension TunnelKitProvider {
/// :nodoc: /// :nodoc:
public var description: String { public var description: String {
return serialized() return rawValue
} }
} }
@ -229,7 +231,12 @@ extension TunnelKitProvider {
guard let endpointProtocolsStrings = providerConfiguration[S.endpointProtocols] as? [String], !endpointProtocolsStrings.isEmpty else { guard let endpointProtocolsStrings = providerConfiguration[S.endpointProtocols] as? [String], !endpointProtocolsStrings.isEmpty else {
throw ProviderConfigurationError.parameter(name: "protocolConfiguration.providerConfiguration[\(S.endpointProtocols)] is nil or empty") throw ProviderConfigurationError.parameter(name: "protocolConfiguration.providerConfiguration[\(S.endpointProtocols)] is nil or empty")
} }
endpointProtocols = try endpointProtocolsStrings.map { try EndpointProtocol.deserialized($0) } endpointProtocols = try endpointProtocolsStrings.map {
guard let ep = EndpointProtocol(rawValue: $0) else {
throw ProviderConfigurationError.parameter(name: "protocolConfiguration.providerConfiguration[\(S.endpointProtocols)] has a badly formed element")
}
return ep
}
self.cipher = cipher self.cipher = cipher
self.digest = digest self.digest = digest
@ -444,7 +451,7 @@ extension TunnelKitProvider {
var dict: [String: Any] = [ var dict: [String: Any] = [
S.appGroup: appGroup, S.appGroup: appGroup,
S.prefersResolvedAddresses: prefersResolvedAddresses, S.prefersResolvedAddresses: prefersResolvedAddresses,
S.endpointProtocols: endpointProtocols.map { $0.serialized() }, S.endpointProtocols: endpointProtocols.map { $0.rawValue },
S.cipherAlgorithm: cipher.rawValue, S.cipherAlgorithm: cipher.rawValue,
S.digestAlgorithm: digest.rawValue, S.digestAlgorithm: digest.rawValue,
S.ca: ca.pem, S.ca: ca.pem,
@ -602,12 +609,14 @@ extension TunnelKitProvider.Configuration: Equatable {
extension TunnelKitProvider.EndpointProtocol: Codable { extension TunnelKitProvider.EndpointProtocol: Codable {
public init(from decoder: Decoder) throws { public init(from decoder: Decoder) throws {
let container = try decoder.singleValueContainer() let container = try decoder.singleValueContainer()
let proto = try TunnelKitProvider.EndpointProtocol.deserialized(container.decode(String.self)) guard let proto = try TunnelKitProvider.EndpointProtocol(rawValue: container.decode(String.self)) else {
throw TunnelKitProvider.ProviderConfigurationError.parameter(name: "endpointProtocol.decodable")
}
self.init(proto.socketType, proto.port) self.init(proto.socketType, proto.port)
} }
public func encode(to encoder: Encoder) throws { public func encode(to encoder: Encoder) throws {
var container = encoder.singleValueContainer() var container = encoder.singleValueContainer()
try container.encode(serialized()) try container.encode(rawValue)
} }
} }

View File

@ -106,8 +106,6 @@ open class TunnelKitProvider: NEPacketTunnelProvider {
private var socket: GenericSocket? private var socket: GenericSocket?
private var linkFailures = 0
private var pendingStartHandler: ((Error?) -> Void)? private var pendingStartHandler: ((Error?) -> Void)?
private var pendingStopHandler: (() -> Void)? private var pendingStopHandler: (() -> Void)?
@ -400,14 +398,14 @@ extension TunnelKitProvider: GenericSocketDelegate {
log.debug("Socket timed out waiting for activity, cancelling...") log.debug("Socket timed out waiting for activity, cancelling...")
reasserting = true reasserting = true
socket.shutdown() socket.shutdown()
}
func socketShouldChangeProtocol(_ socket: GenericSocket) -> Bool { // fallback: TCP connection timeout suggests falling back
guard strategy.tryNextProtocol() else { if let _ = socket as? NETCPSocket {
disposeTunnel(error: ProviderError.exhaustedProtocols) guard tryNextProtocol() else {
return false // disposeTunnel
return
}
} }
return true
} }
func socketDidBecomeActive(_ socket: GenericSocket) { func socketDidBecomeActive(_ socket: GenericSocket) {
@ -428,19 +426,17 @@ extension TunnelKitProvider: GenericSocketDelegate {
} }
var shutdownError: Error? var shutdownError: Error?
if !failure { let didTimeoutNegotiation: Bool
shutdownError = proxy.stopError var upgradedSocket: GenericSocket?
} else {
shutdownError = proxy.stopError ?? ProviderError.linkError
linkFailures += 1
log.debug("Link failures so far: \(linkFailures) (max = \(maxLinkFailures))")
}
// neg timeout? // look for error causing shutdown
let didTimeoutNegotiation = (proxy.stopError as? SessionError == .negotiationTimeout) shutdownError = proxy.stopError
if failure && (shutdownError == nil) {
shutdownError = ProviderError.linkError
}
didTimeoutNegotiation = (shutdownError as? SessionError == .negotiationTimeout)
// only try upgrade on network errors // only try upgrade on network errors
var upgradedSocket: GenericSocket? = nil
if shutdownError as? SessionError == nil { if shutdownError as? SessionError == nil {
upgradedSocket = socket.upgraded() upgradedSocket = socket.upgraded()
} }
@ -448,9 +444,9 @@ extension TunnelKitProvider: GenericSocketDelegate {
// clean up // clean up
finishTunnelDisconnection(error: shutdownError) finishTunnelDisconnection(error: shutdownError)
// treat negotiation timeout as socket timeout, UDP is connection-less // fallback: UDP is connection-less, treat negotiation timeout as socket timeout
if didTimeoutNegotiation { if didTimeoutNegotiation {
guard socketShouldChangeProtocol(socket) else { guard tryNextProtocol() else {
// disposeTunnel // disposeTunnel
return return
} }
@ -458,12 +454,6 @@ extension TunnelKitProvider: GenericSocketDelegate {
// reconnect? // reconnect?
if reasserting { if reasserting {
guard (linkFailures < maxLinkFailures) else {
log.debug("Too many link failures (\(linkFailures)), tunnel will die now")
reasserting = false
disposeTunnel(error: shutdownError)
return
}
log.debug("Disconnection is recoverable, tunnel will reconnect in \(reconnectionDelay) milliseconds...") log.debug("Disconnection is recoverable, tunnel will reconnect in \(reconnectionDelay) milliseconds...")
tunnelQueue.schedule(after: .milliseconds(reconnectionDelay)) { tunnelQueue.schedule(after: .milliseconds(reconnectionDelay)) {
self.connectTunnel(upgradedSocket: upgradedSocket, preferredAddress: socket.remoteAddress) self.connectTunnel(upgradedSocket: upgradedSocket, preferredAddress: socket.remoteAddress)
@ -574,6 +564,13 @@ extension TunnelKitProvider: SessionProxyDelegate {
} }
extension TunnelKitProvider { extension TunnelKitProvider {
private func tryNextProtocol() -> Bool {
guard strategy.tryNextProtocol() else {
disposeTunnel(error: ProviderError.exhaustedProtocols)
return false
}
return true
}
// MARK: Logging // MARK: Logging