Merge branch 'refactor-shutdown-code'
This commit is contained in:
commit
3f3a712bac
|
@ -44,8 +44,6 @@ protocol LinkProducer {
|
||||||
protocol GenericSocketDelegate: class {
|
protocol GenericSocketDelegate: class {
|
||||||
func socketDidTimeout(_ socket: GenericSocket)
|
func socketDidTimeout(_ socket: GenericSocket)
|
||||||
|
|
||||||
func socketShouldChangeProtocol(_ socket: GenericSocket) -> Bool
|
|
||||||
|
|
||||||
func socketDidBecomeActive(_ socket: GenericSocket)
|
func socketDidBecomeActive(_ socket: GenericSocket)
|
||||||
|
|
||||||
func socket(_ socket: GenericSocket, didShutdownWithFailure failure: Bool)
|
func socket(_ socket: GenericSocket, didShutdownWithFailure failure: Bool)
|
||||||
|
|
|
@ -79,7 +79,6 @@ class NETCPSocket: NSObject, GenericSocket {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
guard _self.isActive else {
|
guard _self.isActive else {
|
||||||
_ = _self.delegate?.socketShouldChangeProtocol(_self)
|
|
||||||
_self.delegate?.socketDidTimeout(_self)
|
_self.delegate?.socketDidTimeout(_self)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -56,7 +56,7 @@ extension TunnelKitProvider {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Defines the communication protocol of an endpoint.
|
/// Defines the communication protocol of an endpoint.
|
||||||
public struct EndpointProtocol: Equatable, CustomStringConvertible {
|
public struct EndpointProtocol: RawRepresentable, Equatable, CustomStringConvertible {
|
||||||
|
|
||||||
/// The socket type.
|
/// The socket type.
|
||||||
public let socketType: SocketType
|
public let socketType: SocketType
|
||||||
|
@ -70,23 +70,25 @@ extension TunnelKitProvider {
|
||||||
self.port = port
|
self.port = port
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MARK: RawRepresentable
|
||||||
|
|
||||||
/// :nodoc:
|
/// :nodoc:
|
||||||
public static func deserialized(_ string: String) throws -> EndpointProtocol {
|
public init?(rawValue: String) {
|
||||||
let components = string.components(separatedBy: ":")
|
let components = rawValue.components(separatedBy: ":")
|
||||||
guard components.count == 2 else {
|
guard components.count == 2 else {
|
||||||
throw ProviderConfigurationError.parameter(name: "endpointProtocol")
|
return nil
|
||||||
}
|
}
|
||||||
guard let socketType = SocketType(rawValue: components[0]) else {
|
guard let socketType = SocketType(rawValue: components[0]) else {
|
||||||
throw ProviderConfigurationError.parameter(name: "endpointProtocol.socketType")
|
return nil
|
||||||
}
|
}
|
||||||
guard let port = UInt16(components[1]) else {
|
guard let port = UInt16(components[1]) else {
|
||||||
throw ProviderConfigurationError.parameter(name: "endpointProtocol.port")
|
return nil
|
||||||
}
|
}
|
||||||
return EndpointProtocol(socketType, port)
|
self.init(socketType, port)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// :nodoc:
|
/// :nodoc:
|
||||||
public func serialized() -> String {
|
public var rawValue: String {
|
||||||
return "\(socketType.rawValue):\(port)"
|
return "\(socketType.rawValue):\(port)"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -101,7 +103,7 @@ extension TunnelKitProvider {
|
||||||
|
|
||||||
/// :nodoc:
|
/// :nodoc:
|
||||||
public var description: String {
|
public var description: String {
|
||||||
return serialized()
|
return rawValue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -229,7 +231,12 @@ extension TunnelKitProvider {
|
||||||
guard let endpointProtocolsStrings = providerConfiguration[S.endpointProtocols] as? [String], !endpointProtocolsStrings.isEmpty else {
|
guard let endpointProtocolsStrings = providerConfiguration[S.endpointProtocols] as? [String], !endpointProtocolsStrings.isEmpty else {
|
||||||
throw ProviderConfigurationError.parameter(name: "protocolConfiguration.providerConfiguration[\(S.endpointProtocols)] is nil or empty")
|
throw ProviderConfigurationError.parameter(name: "protocolConfiguration.providerConfiguration[\(S.endpointProtocols)] is nil or empty")
|
||||||
}
|
}
|
||||||
endpointProtocols = try endpointProtocolsStrings.map { try EndpointProtocol.deserialized($0) }
|
endpointProtocols = try endpointProtocolsStrings.map {
|
||||||
|
guard let ep = EndpointProtocol(rawValue: $0) else {
|
||||||
|
throw ProviderConfigurationError.parameter(name: "protocolConfiguration.providerConfiguration[\(S.endpointProtocols)] has a badly formed element")
|
||||||
|
}
|
||||||
|
return ep
|
||||||
|
}
|
||||||
|
|
||||||
self.cipher = cipher
|
self.cipher = cipher
|
||||||
self.digest = digest
|
self.digest = digest
|
||||||
|
@ -444,7 +451,7 @@ extension TunnelKitProvider {
|
||||||
var dict: [String: Any] = [
|
var dict: [String: Any] = [
|
||||||
S.appGroup: appGroup,
|
S.appGroup: appGroup,
|
||||||
S.prefersResolvedAddresses: prefersResolvedAddresses,
|
S.prefersResolvedAddresses: prefersResolvedAddresses,
|
||||||
S.endpointProtocols: endpointProtocols.map { $0.serialized() },
|
S.endpointProtocols: endpointProtocols.map { $0.rawValue },
|
||||||
S.cipherAlgorithm: cipher.rawValue,
|
S.cipherAlgorithm: cipher.rawValue,
|
||||||
S.digestAlgorithm: digest.rawValue,
|
S.digestAlgorithm: digest.rawValue,
|
||||||
S.ca: ca.pem,
|
S.ca: ca.pem,
|
||||||
|
@ -602,12 +609,14 @@ extension TunnelKitProvider.Configuration: Equatable {
|
||||||
extension TunnelKitProvider.EndpointProtocol: Codable {
|
extension TunnelKitProvider.EndpointProtocol: Codable {
|
||||||
public init(from decoder: Decoder) throws {
|
public init(from decoder: Decoder) throws {
|
||||||
let container = try decoder.singleValueContainer()
|
let container = try decoder.singleValueContainer()
|
||||||
let proto = try TunnelKitProvider.EndpointProtocol.deserialized(container.decode(String.self))
|
guard let proto = try TunnelKitProvider.EndpointProtocol(rawValue: container.decode(String.self)) else {
|
||||||
|
throw TunnelKitProvider.ProviderConfigurationError.parameter(name: "endpointProtocol.decodable")
|
||||||
|
}
|
||||||
self.init(proto.socketType, proto.port)
|
self.init(proto.socketType, proto.port)
|
||||||
}
|
}
|
||||||
|
|
||||||
public func encode(to encoder: Encoder) throws {
|
public func encode(to encoder: Encoder) throws {
|
||||||
var container = encoder.singleValueContainer()
|
var container = encoder.singleValueContainer()
|
||||||
try container.encode(serialized())
|
try container.encode(rawValue)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -106,8 +106,6 @@ open class TunnelKitProvider: NEPacketTunnelProvider {
|
||||||
|
|
||||||
private var socket: GenericSocket?
|
private var socket: GenericSocket?
|
||||||
|
|
||||||
private var linkFailures = 0
|
|
||||||
|
|
||||||
private var pendingStartHandler: ((Error?) -> Void)?
|
private var pendingStartHandler: ((Error?) -> Void)?
|
||||||
|
|
||||||
private var pendingStopHandler: (() -> Void)?
|
private var pendingStopHandler: (() -> Void)?
|
||||||
|
@ -400,14 +398,14 @@ extension TunnelKitProvider: GenericSocketDelegate {
|
||||||
log.debug("Socket timed out waiting for activity, cancelling...")
|
log.debug("Socket timed out waiting for activity, cancelling...")
|
||||||
reasserting = true
|
reasserting = true
|
||||||
socket.shutdown()
|
socket.shutdown()
|
||||||
}
|
|
||||||
|
|
||||||
func socketShouldChangeProtocol(_ socket: GenericSocket) -> Bool {
|
// fallback: TCP connection timeout suggests falling back
|
||||||
guard strategy.tryNextProtocol() else {
|
if let _ = socket as? NETCPSocket {
|
||||||
disposeTunnel(error: ProviderError.exhaustedProtocols)
|
guard tryNextProtocol() else {
|
||||||
return false
|
// disposeTunnel
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func socketDidBecomeActive(_ socket: GenericSocket) {
|
func socketDidBecomeActive(_ socket: GenericSocket) {
|
||||||
|
@ -428,19 +426,17 @@ extension TunnelKitProvider: GenericSocketDelegate {
|
||||||
}
|
}
|
||||||
|
|
||||||
var shutdownError: Error?
|
var shutdownError: Error?
|
||||||
if !failure {
|
let didTimeoutNegotiation: Bool
|
||||||
shutdownError = proxy.stopError
|
var upgradedSocket: GenericSocket?
|
||||||
} else {
|
|
||||||
shutdownError = proxy.stopError ?? ProviderError.linkError
|
|
||||||
linkFailures += 1
|
|
||||||
log.debug("Link failures so far: \(linkFailures) (max = \(maxLinkFailures))")
|
|
||||||
}
|
|
||||||
|
|
||||||
// neg timeout?
|
// look for error causing shutdown
|
||||||
let didTimeoutNegotiation = (proxy.stopError as? SessionError == .negotiationTimeout)
|
shutdownError = proxy.stopError
|
||||||
|
if failure && (shutdownError == nil) {
|
||||||
|
shutdownError = ProviderError.linkError
|
||||||
|
}
|
||||||
|
didTimeoutNegotiation = (shutdownError as? SessionError == .negotiationTimeout)
|
||||||
|
|
||||||
// only try upgrade on network errors
|
// only try upgrade on network errors
|
||||||
var upgradedSocket: GenericSocket? = nil
|
|
||||||
if shutdownError as? SessionError == nil {
|
if shutdownError as? SessionError == nil {
|
||||||
upgradedSocket = socket.upgraded()
|
upgradedSocket = socket.upgraded()
|
||||||
}
|
}
|
||||||
|
@ -448,9 +444,9 @@ extension TunnelKitProvider: GenericSocketDelegate {
|
||||||
// clean up
|
// clean up
|
||||||
finishTunnelDisconnection(error: shutdownError)
|
finishTunnelDisconnection(error: shutdownError)
|
||||||
|
|
||||||
// treat negotiation timeout as socket timeout, UDP is connection-less
|
// fallback: UDP is connection-less, treat negotiation timeout as socket timeout
|
||||||
if didTimeoutNegotiation {
|
if didTimeoutNegotiation {
|
||||||
guard socketShouldChangeProtocol(socket) else {
|
guard tryNextProtocol() else {
|
||||||
// disposeTunnel
|
// disposeTunnel
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -458,12 +454,6 @@ extension TunnelKitProvider: GenericSocketDelegate {
|
||||||
|
|
||||||
// reconnect?
|
// reconnect?
|
||||||
if reasserting {
|
if reasserting {
|
||||||
guard (linkFailures < maxLinkFailures) else {
|
|
||||||
log.debug("Too many link failures (\(linkFailures)), tunnel will die now")
|
|
||||||
reasserting = false
|
|
||||||
disposeTunnel(error: shutdownError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.debug("Disconnection is recoverable, tunnel will reconnect in \(reconnectionDelay) milliseconds...")
|
log.debug("Disconnection is recoverable, tunnel will reconnect in \(reconnectionDelay) milliseconds...")
|
||||||
tunnelQueue.schedule(after: .milliseconds(reconnectionDelay)) {
|
tunnelQueue.schedule(after: .milliseconds(reconnectionDelay)) {
|
||||||
self.connectTunnel(upgradedSocket: upgradedSocket, preferredAddress: socket.remoteAddress)
|
self.connectTunnel(upgradedSocket: upgradedSocket, preferredAddress: socket.remoteAddress)
|
||||||
|
@ -574,6 +564,13 @@ extension TunnelKitProvider: SessionProxyDelegate {
|
||||||
}
|
}
|
||||||
|
|
||||||
extension TunnelKitProvider {
|
extension TunnelKitProvider {
|
||||||
|
private func tryNextProtocol() -> Bool {
|
||||||
|
guard strategy.tryNextProtocol() else {
|
||||||
|
disposeTunnel(error: ProviderError.exhaustedProtocols)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// MARK: Logging
|
// MARK: Logging
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue