Move sessionId and remoteSessionId
This commit is contained in:
parent
1573b2070a
commit
3608860b9d
|
@ -30,9 +30,13 @@ import SwiftyBeaver
|
|||
private let log = SwiftyBeaver.self
|
||||
|
||||
class ControlChannel {
|
||||
private(set) var sessionId: Data?
|
||||
|
||||
var remoteSessionId: Data?
|
||||
|
||||
private var queue: BidirectionalState<[ControlPacket]>
|
||||
|
||||
private var packetId: BidirectionalState<UInt32>
|
||||
private var currentPacketId: BidirectionalState<UInt32>
|
||||
|
||||
private var pendingAcks: Set<UInt32>
|
||||
|
||||
|
@ -41,48 +45,66 @@ class ControlChannel {
|
|||
private var dataCount: BidirectionalState<Int>
|
||||
|
||||
init() {
|
||||
sessionId = nil
|
||||
remoteSessionId = nil
|
||||
queue = BidirectionalState(withResetValue: [])
|
||||
packetId = BidirectionalState(withResetValue: 0)
|
||||
currentPacketId = BidirectionalState(withResetValue: 0)
|
||||
pendingAcks = []
|
||||
plainBuffer = Z(count: TLSBoxMaxBufferLength)
|
||||
dataCount = BidirectionalState(withResetValue: 0)
|
||||
}
|
||||
|
||||
func readInboundPacket(withCode code: PacketCode, key: UInt8, sessionId inboundSessionId: Data, packetId inboundPacketId: UInt32, payload: Data?) -> [ControlPacket] {
|
||||
let packet = ControlPacket(code: code, key: key, sessionId: inboundSessionId, packetId: inboundPacketId, payload: payload)
|
||||
func reset(forNewSession: Bool) throws {
|
||||
if forNewSession {
|
||||
try sessionId = SecureRandom.data(length: PacketSessionIdLength)
|
||||
remoteSessionId = nil
|
||||
}
|
||||
queue.reset()
|
||||
currentPacketId.reset()
|
||||
pendingAcks.removeAll()
|
||||
plainBuffer.zero()
|
||||
dataCount.reset()
|
||||
}
|
||||
|
||||
func readInboundPacket(withCode code: PacketCode, key: UInt8, sessionId inboundSessionId: Data, packetId: UInt32, payload: Data?) -> [ControlPacket] {
|
||||
let packet = ControlPacket(code: code, key: key, sessionId: inboundSessionId, packetId: packetId, payload: payload)
|
||||
queue.inbound.append(packet)
|
||||
queue.inbound.sort { $0.packetId < $1.packetId }
|
||||
|
||||
var toHandle: [ControlPacket] = []
|
||||
for queuedPacket in queue.inbound {
|
||||
if queuedPacket.packetId < packetId.inbound {
|
||||
if queuedPacket.packetId < currentPacketId.inbound {
|
||||
queue.inbound.removeFirst()
|
||||
continue
|
||||
}
|
||||
if queuedPacket.packetId != packetId.inbound {
|
||||
if queuedPacket.packetId != currentPacketId.inbound {
|
||||
continue
|
||||
}
|
||||
|
||||
toHandle.append(queuedPacket)
|
||||
|
||||
packetId.inbound += 1
|
||||
currentPacketId.inbound += 1
|
||||
queue.inbound.removeFirst()
|
||||
}
|
||||
return toHandle
|
||||
}
|
||||
|
||||
func enqueueOutboundPackets(withCode code: PacketCode, key: UInt8, sessionId: Data, payload: Data, maxPacketSize: Int) {
|
||||
let oldIdOut = packetId.outbound
|
||||
func enqueueOutboundPackets(withCode code: PacketCode, key: UInt8, payload: Data, maxPacketSize: Int) {
|
||||
guard let sessionId = sessionId else {
|
||||
fatalError("Missing sessionId, do reset(forNewSession: true) first")
|
||||
}
|
||||
|
||||
let oldIdOut = currentPacketId.outbound
|
||||
var queuedCount = 0
|
||||
var offset = 0
|
||||
|
||||
repeat {
|
||||
let subPayloadLength = min(maxPacketSize, payload.count - offset)
|
||||
let subPayloadData = payload.subdata(offset: offset, count: subPayloadLength)
|
||||
let packet = ControlPacket(code: code, key: key, sessionId: sessionId, packetId: packetId.outbound, payload: subPayloadData)
|
||||
let packet = ControlPacket(code: code, key: key, sessionId: sessionId, packetId: currentPacketId.outbound, payload: subPayloadData)
|
||||
|
||||
queue.outbound.append(packet)
|
||||
packetId.outbound += 1
|
||||
currentPacketId.outbound += 1
|
||||
offset += maxPacketSize
|
||||
queuedCount += subPayloadLength
|
||||
} while (offset < payload.count)
|
||||
|
@ -90,9 +112,9 @@ class ControlChannel {
|
|||
assert(queuedCount == payload.count)
|
||||
|
||||
// packet count
|
||||
let packetCount = packetId.outbound - oldIdOut
|
||||
let packetCount = currentPacketId.outbound - oldIdOut
|
||||
if (packetCount > 1) {
|
||||
log.debug("Enqueued \(packetCount) control packets [\(oldIdOut)-\(packetId.outbound - 1)]")
|
||||
log.debug("Enqueued \(packetCount) control packets [\(oldIdOut)-\(currentPacketId.outbound - 1)]")
|
||||
} else {
|
||||
log.debug("Enqueued 1 control packet [\(oldIdOut)]")
|
||||
}
|
||||
|
@ -122,7 +144,7 @@ class ControlChannel {
|
|||
let raw = packet.serialized()
|
||||
rawList.append(raw)
|
||||
packet.sentDate = Date()
|
||||
|
||||
|
||||
// track pending acks for sent packets
|
||||
pendingAcks.insert(packet.packetId)
|
||||
}
|
||||
|
@ -134,7 +156,14 @@ class ControlChannel {
|
|||
return !pendingAcks.isEmpty
|
||||
}
|
||||
|
||||
func readAcks(_ packetIds: [UInt32]) {
|
||||
func readAcks(_ packetIds: [UInt32], acksRemoteSessionId: Data) throws {
|
||||
guard let sessionId = sessionId else {
|
||||
throw SessionError.missingSessionId
|
||||
}
|
||||
guard acksRemoteSessionId == sessionId else {
|
||||
log.error("Ack session mismatch (\(acksRemoteSessionId.toHex()) != \(sessionId.toHex()))")
|
||||
throw SessionError.sessionMismatch
|
||||
}
|
||||
|
||||
// drop queued out packets if ack-ed
|
||||
for (i, packet) in queue.outbound.enumerated() {
|
||||
|
@ -149,7 +178,10 @@ class ControlChannel {
|
|||
// log.verbose("Packets still pending ack: \(pendingAcks)")
|
||||
}
|
||||
|
||||
func writeAcks(withKey key: UInt8, sessionId: Data, ackPacketIds: [UInt32], ackRemoteSessionId: Data) -> Data {
|
||||
func writeAcks(withKey key: UInt8, ackPacketIds: [UInt32], ackRemoteSessionId: Data) throws -> Data {
|
||||
guard let sessionId = sessionId else {
|
||||
throw SessionError.missingSessionId
|
||||
}
|
||||
let ackPacket = ControlPacket(key: key, sessionId: sessionId, ackIds: ackPacketIds as [NSNumber], ackRemoteSessionId: ackRemoteSessionId)
|
||||
return ackPacket.serialized()
|
||||
}
|
||||
|
@ -171,12 +203,4 @@ class ControlChannel {
|
|||
func currentDataCount() -> (Int, Int) {
|
||||
return dataCount.pair
|
||||
}
|
||||
|
||||
func reset() {
|
||||
plainBuffer.zero()
|
||||
queue.reset()
|
||||
pendingAcks.removeAll()
|
||||
packetId.reset()
|
||||
dataCount.reset()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -136,10 +136,6 @@ public class SessionProxy {
|
|||
return link?.isReliable ?? false
|
||||
}
|
||||
|
||||
private var sessionId: Data?
|
||||
|
||||
private var remoteSessionId: Data?
|
||||
|
||||
private var pushReply: SessionReply?
|
||||
|
||||
private var nextPushRequestDate: Date?
|
||||
|
@ -312,8 +308,6 @@ public class SessionProxy {
|
|||
negotiationKeyIdx = 0
|
||||
currentKeyIdx = nil
|
||||
|
||||
sessionId = nil
|
||||
remoteSessionId = nil
|
||||
nextPushRequestDate = nil
|
||||
connectedDate = nil
|
||||
authenticator = nil
|
||||
|
@ -593,23 +587,21 @@ public class SessionProxy {
|
|||
// MARK: Handshake
|
||||
|
||||
// Ruby: reset_ctrl
|
||||
private func resetControlChannel() {
|
||||
controlChannel.reset()
|
||||
private func resetControlChannel(forNewSession: Bool) {
|
||||
authenticator = nil
|
||||
do {
|
||||
try controlChannel.reset(forNewSession: forNewSession)
|
||||
} catch let e {
|
||||
deferStop(.shutdown, e)
|
||||
}
|
||||
}
|
||||
|
||||
// Ruby: hard_reset
|
||||
private func hardReset() {
|
||||
log.debug("Send hard reset")
|
||||
|
||||
resetControlChannel()
|
||||
resetControlChannel(forNewSession: true)
|
||||
pushReply = nil
|
||||
do {
|
||||
try sessionId = SecureRandom.data(length: PacketSessionIdLength)
|
||||
} catch let e {
|
||||
deferStop(.shutdown, e)
|
||||
return
|
||||
}
|
||||
negotiationKeyIdx = 0
|
||||
let newKey = SessionKey(id: UInt8(negotiationKeyIdx))
|
||||
keys[negotiationKeyIdx] = newKey
|
||||
|
@ -623,7 +615,7 @@ public class SessionProxy {
|
|||
private func softReset() {
|
||||
log.debug("Send soft reset")
|
||||
|
||||
resetControlChannel()
|
||||
resetControlChannel(forNewSession: false)
|
||||
negotiationKeyIdx = max(1, (negotiationKeyIdx + 1) % ProtocolMacros.numberOfKeys)
|
||||
let newKey = SessionKey(id: UInt8(negotiationKeyIdx))
|
||||
keys[negotiationKeyIdx] = newKey
|
||||
|
@ -725,10 +717,10 @@ public class SessionProxy {
|
|||
((packet.code == .softResetV1) && (negotiationKey.state == .softReset))) {
|
||||
|
||||
if negotiationKey.state == .hardReset {
|
||||
remoteSessionId = packet.sessionId
|
||||
controlChannel.remoteSessionId = packet.sessionId
|
||||
}
|
||||
guard let remoteSessionId = remoteSessionId else {
|
||||
log.error("No remote session id")
|
||||
guard let remoteSessionId = controlChannel.remoteSessionId else {
|
||||
log.error("No remote session id (never set)")
|
||||
deferStop(.shutdown, SessionError.missingSessionId)
|
||||
return
|
||||
}
|
||||
|
@ -764,7 +756,8 @@ public class SessionProxy {
|
|||
enqueueControlPackets(code: .controlV1, key: negotiationKey.id, payload: cipherTextOut)
|
||||
}
|
||||
else if ((packet.code == .controlV1) && (negotiationKey.state == .tls)) {
|
||||
guard let remoteSessionId = remoteSessionId else {
|
||||
guard let remoteSessionId = controlChannel.remoteSessionId else {
|
||||
log.error("No remote session id found in packet (control packets before server HARD_RESET)")
|
||||
deferStop(.shutdown, SessionError.missingSessionId)
|
||||
return
|
||||
}
|
||||
|
@ -899,12 +892,8 @@ public class SessionProxy {
|
|||
log.warning("Not writing to LINK, interface is down")
|
||||
return
|
||||
}
|
||||
guard let sessionId = sessionId else {
|
||||
fatalError("Missing sessionId, do hardReset() first")
|
||||
}
|
||||
|
||||
// FIXME: init controlChannel with sessionId
|
||||
controlChannel.enqueueOutboundPackets(withCode: code, key: key, sessionId: sessionId, payload: payload, maxPacketSize: link.mtu)
|
||||
controlChannel.enqueueOutboundPackets(withCode: code, key: key, payload: payload, maxPacketSize: link.mtu)
|
||||
flushControlQueue()
|
||||
}
|
||||
|
||||
|
@ -932,10 +921,10 @@ public class SessionProxy {
|
|||
guard let auth = authenticator else {
|
||||
fatalError("Setting up encryption without having authenticated")
|
||||
}
|
||||
guard let sessionId = sessionId else {
|
||||
guard let sessionId = controlChannel.sessionId else {
|
||||
fatalError("Setting up encryption without a local sessionId")
|
||||
}
|
||||
guard let remoteSessionId = remoteSessionId else {
|
||||
guard let remoteSessionId = controlChannel.remoteSessionId else {
|
||||
fatalError("Setting up encryption without a remote sessionId")
|
||||
}
|
||||
guard let serverRandom1 = auth.serverRandom1, let serverRandom2 = auth.serverRandom2 else {
|
||||
|
@ -1057,16 +1046,13 @@ public class SessionProxy {
|
|||
|
||||
// Ruby: handle_acks
|
||||
private func handleAcks(_ packetIds: [UInt32], remoteSessionId: Data) {
|
||||
guard (remoteSessionId == sessionId) else {
|
||||
if let sessionId = sessionId {
|
||||
log.error("Ack session mismatch (\(remoteSessionId.toHex()) != \(sessionId.toHex()))")
|
||||
}
|
||||
deferStop(.shutdown, SessionError.sessionMismatch)
|
||||
do {
|
||||
try controlChannel.readAcks(packetIds, acksRemoteSessionId: remoteSessionId)
|
||||
} catch let e {
|
||||
deferStop(.shutdown, e)
|
||||
return
|
||||
}
|
||||
|
||||
controlChannel.readAcks(packetIds)
|
||||
|
||||
// retry PUSH_REQUEST if ack queue is empty (all sent packets were ack'ed)
|
||||
if isReliableLink && !controlChannel.hasPendingAcks() {
|
||||
pushRequest()
|
||||
|
@ -1077,13 +1063,13 @@ public class SessionProxy {
|
|||
private func sendAck(key: UInt8, packetId: UInt32, remoteSessionId: Data) {
|
||||
log.debug("Send ack for received packetId \(packetId)")
|
||||
|
||||
guard let sessionId = sessionId else {
|
||||
log.warning("Sending ack without a sessionId?")
|
||||
deferStop(.shutdown, SessionError.missingSessionId)
|
||||
let raw: Data
|
||||
do {
|
||||
raw = try controlChannel.writeAcks(withKey: key, ackPacketIds: [packetId], ackRemoteSessionId: remoteSessionId)
|
||||
} catch let e {
|
||||
deferStop(.shutdown, e)
|
||||
return
|
||||
}
|
||||
|
||||
let raw = controlChannel.writeAcks(withKey: key, sessionId: sessionId, ackPacketIds: [packetId], ackRemoteSessionId: remoteSessionId)
|
||||
|
||||
// WARNING: runs in Network.framework queue
|
||||
link?.writePacket(raw) { [weak self] (error) in
|
||||
|
|
Loading…
Reference in New Issue