Move sessionId and remoteSessionId

This commit is contained in:
Davide De Rosa 2018-09-09 17:56:29 +02:00
parent 1573b2070a
commit 3608860b9d
2 changed files with 73 additions and 63 deletions

View File

@ -30,9 +30,13 @@ import SwiftyBeaver
private let log = SwiftyBeaver.self
class ControlChannel {
private(set) var sessionId: Data?
var remoteSessionId: Data?
private var queue: BidirectionalState<[ControlPacket]>
private var packetId: BidirectionalState<UInt32>
private var currentPacketId: BidirectionalState<UInt32>
private var pendingAcks: Set<UInt32>
@ -41,48 +45,66 @@ class ControlChannel {
private var dataCount: BidirectionalState<Int>
init() {
sessionId = nil
remoteSessionId = nil
queue = BidirectionalState(withResetValue: [])
packetId = BidirectionalState(withResetValue: 0)
currentPacketId = BidirectionalState(withResetValue: 0)
pendingAcks = []
plainBuffer = Z(count: TLSBoxMaxBufferLength)
dataCount = BidirectionalState(withResetValue: 0)
}
func readInboundPacket(withCode code: PacketCode, key: UInt8, sessionId inboundSessionId: Data, packetId inboundPacketId: UInt32, payload: Data?) -> [ControlPacket] {
let packet = ControlPacket(code: code, key: key, sessionId: inboundSessionId, packetId: inboundPacketId, payload: payload)
func reset(forNewSession: Bool) throws {
if forNewSession {
try sessionId = SecureRandom.data(length: PacketSessionIdLength)
remoteSessionId = nil
}
queue.reset()
currentPacketId.reset()
pendingAcks.removeAll()
plainBuffer.zero()
dataCount.reset()
}
func readInboundPacket(withCode code: PacketCode, key: UInt8, sessionId inboundSessionId: Data, packetId: UInt32, payload: Data?) -> [ControlPacket] {
let packet = ControlPacket(code: code, key: key, sessionId: inboundSessionId, packetId: packetId, payload: payload)
queue.inbound.append(packet)
queue.inbound.sort { $0.packetId < $1.packetId }
var toHandle: [ControlPacket] = []
for queuedPacket in queue.inbound {
if queuedPacket.packetId < packetId.inbound {
if queuedPacket.packetId < currentPacketId.inbound {
queue.inbound.removeFirst()
continue
}
if queuedPacket.packetId != packetId.inbound {
if queuedPacket.packetId != currentPacketId.inbound {
continue
}
toHandle.append(queuedPacket)
packetId.inbound += 1
currentPacketId.inbound += 1
queue.inbound.removeFirst()
}
return toHandle
}
func enqueueOutboundPackets(withCode code: PacketCode, key: UInt8, sessionId: Data, payload: Data, maxPacketSize: Int) {
let oldIdOut = packetId.outbound
func enqueueOutboundPackets(withCode code: PacketCode, key: UInt8, payload: Data, maxPacketSize: Int) {
guard let sessionId = sessionId else {
fatalError("Missing sessionId, do reset(forNewSession: true) first")
}
let oldIdOut = currentPacketId.outbound
var queuedCount = 0
var offset = 0
repeat {
let subPayloadLength = min(maxPacketSize, payload.count - offset)
let subPayloadData = payload.subdata(offset: offset, count: subPayloadLength)
let packet = ControlPacket(code: code, key: key, sessionId: sessionId, packetId: packetId.outbound, payload: subPayloadData)
let packet = ControlPacket(code: code, key: key, sessionId: sessionId, packetId: currentPacketId.outbound, payload: subPayloadData)
queue.outbound.append(packet)
packetId.outbound += 1
currentPacketId.outbound += 1
offset += maxPacketSize
queuedCount += subPayloadLength
} while (offset < payload.count)
@ -90,9 +112,9 @@ class ControlChannel {
assert(queuedCount == payload.count)
// packet count
let packetCount = packetId.outbound - oldIdOut
let packetCount = currentPacketId.outbound - oldIdOut
if (packetCount > 1) {
log.debug("Enqueued \(packetCount) control packets [\(oldIdOut)-\(packetId.outbound - 1)]")
log.debug("Enqueued \(packetCount) control packets [\(oldIdOut)-\(currentPacketId.outbound - 1)]")
} else {
log.debug("Enqueued 1 control packet [\(oldIdOut)]")
}
@ -122,7 +144,7 @@ class ControlChannel {
let raw = packet.serialized()
rawList.append(raw)
packet.sentDate = Date()
// track pending acks for sent packets
pendingAcks.insert(packet.packetId)
}
@ -134,7 +156,14 @@ class ControlChannel {
return !pendingAcks.isEmpty
}
func readAcks(_ packetIds: [UInt32]) {
func readAcks(_ packetIds: [UInt32], acksRemoteSessionId: Data) throws {
guard let sessionId = sessionId else {
throw SessionError.missingSessionId
}
guard acksRemoteSessionId == sessionId else {
log.error("Ack session mismatch (\(acksRemoteSessionId.toHex()) != \(sessionId.toHex()))")
throw SessionError.sessionMismatch
}
// drop queued out packets if ack-ed
for (i, packet) in queue.outbound.enumerated() {
@ -149,7 +178,10 @@ class ControlChannel {
// log.verbose("Packets still pending ack: \(pendingAcks)")
}
func writeAcks(withKey key: UInt8, sessionId: Data, ackPacketIds: [UInt32], ackRemoteSessionId: Data) -> Data {
func writeAcks(withKey key: UInt8, ackPacketIds: [UInt32], ackRemoteSessionId: Data) throws -> Data {
guard let sessionId = sessionId else {
throw SessionError.missingSessionId
}
let ackPacket = ControlPacket(key: key, sessionId: sessionId, ackIds: ackPacketIds as [NSNumber], ackRemoteSessionId: ackRemoteSessionId)
return ackPacket.serialized()
}
@ -171,12 +203,4 @@ class ControlChannel {
func currentDataCount() -> (Int, Int) {
return dataCount.pair
}
func reset() {
plainBuffer.zero()
queue.reset()
pendingAcks.removeAll()
packetId.reset()
dataCount.reset()
}
}

View File

@ -136,10 +136,6 @@ public class SessionProxy {
return link?.isReliable ?? false
}
private var sessionId: Data?
private var remoteSessionId: Data?
private var pushReply: SessionReply?
private var nextPushRequestDate: Date?
@ -312,8 +308,6 @@ public class SessionProxy {
negotiationKeyIdx = 0
currentKeyIdx = nil
sessionId = nil
remoteSessionId = nil
nextPushRequestDate = nil
connectedDate = nil
authenticator = nil
@ -593,23 +587,21 @@ public class SessionProxy {
// MARK: Handshake
// Ruby: reset_ctrl
private func resetControlChannel() {
controlChannel.reset()
private func resetControlChannel(forNewSession: Bool) {
authenticator = nil
do {
try controlChannel.reset(forNewSession: forNewSession)
} catch let e {
deferStop(.shutdown, e)
}
}
// Ruby: hard_reset
private func hardReset() {
log.debug("Send hard reset")
resetControlChannel()
resetControlChannel(forNewSession: true)
pushReply = nil
do {
try sessionId = SecureRandom.data(length: PacketSessionIdLength)
} catch let e {
deferStop(.shutdown, e)
return
}
negotiationKeyIdx = 0
let newKey = SessionKey(id: UInt8(negotiationKeyIdx))
keys[negotiationKeyIdx] = newKey
@ -623,7 +615,7 @@ public class SessionProxy {
private func softReset() {
log.debug("Send soft reset")
resetControlChannel()
resetControlChannel(forNewSession: false)
negotiationKeyIdx = max(1, (negotiationKeyIdx + 1) % ProtocolMacros.numberOfKeys)
let newKey = SessionKey(id: UInt8(negotiationKeyIdx))
keys[negotiationKeyIdx] = newKey
@ -725,10 +717,10 @@ public class SessionProxy {
((packet.code == .softResetV1) && (negotiationKey.state == .softReset))) {
if negotiationKey.state == .hardReset {
remoteSessionId = packet.sessionId
controlChannel.remoteSessionId = packet.sessionId
}
guard let remoteSessionId = remoteSessionId else {
log.error("No remote session id")
guard let remoteSessionId = controlChannel.remoteSessionId else {
log.error("No remote session id (never set)")
deferStop(.shutdown, SessionError.missingSessionId)
return
}
@ -764,7 +756,8 @@ public class SessionProxy {
enqueueControlPackets(code: .controlV1, key: negotiationKey.id, payload: cipherTextOut)
}
else if ((packet.code == .controlV1) && (negotiationKey.state == .tls)) {
guard let remoteSessionId = remoteSessionId else {
guard let remoteSessionId = controlChannel.remoteSessionId else {
log.error("No remote session id found in packet (control packets before server HARD_RESET)")
deferStop(.shutdown, SessionError.missingSessionId)
return
}
@ -899,12 +892,8 @@ public class SessionProxy {
log.warning("Not writing to LINK, interface is down")
return
}
guard let sessionId = sessionId else {
fatalError("Missing sessionId, do hardReset() first")
}
// FIXME: init controlChannel with sessionId
controlChannel.enqueueOutboundPackets(withCode: code, key: key, sessionId: sessionId, payload: payload, maxPacketSize: link.mtu)
controlChannel.enqueueOutboundPackets(withCode: code, key: key, payload: payload, maxPacketSize: link.mtu)
flushControlQueue()
}
@ -932,10 +921,10 @@ public class SessionProxy {
guard let auth = authenticator else {
fatalError("Setting up encryption without having authenticated")
}
guard let sessionId = sessionId else {
guard let sessionId = controlChannel.sessionId else {
fatalError("Setting up encryption without a local sessionId")
}
guard let remoteSessionId = remoteSessionId else {
guard let remoteSessionId = controlChannel.remoteSessionId else {
fatalError("Setting up encryption without a remote sessionId")
}
guard let serverRandom1 = auth.serverRandom1, let serverRandom2 = auth.serverRandom2 else {
@ -1057,16 +1046,13 @@ public class SessionProxy {
// Ruby: handle_acks
private func handleAcks(_ packetIds: [UInt32], remoteSessionId: Data) {
guard (remoteSessionId == sessionId) else {
if let sessionId = sessionId {
log.error("Ack session mismatch (\(remoteSessionId.toHex()) != \(sessionId.toHex()))")
}
deferStop(.shutdown, SessionError.sessionMismatch)
do {
try controlChannel.readAcks(packetIds, acksRemoteSessionId: remoteSessionId)
} catch let e {
deferStop(.shutdown, e)
return
}
controlChannel.readAcks(packetIds)
// retry PUSH_REQUEST if ack queue is empty (all sent packets were ack'ed)
if isReliableLink && !controlChannel.hasPendingAcks() {
pushRequest()
@ -1077,13 +1063,13 @@ public class SessionProxy {
private func sendAck(key: UInt8, packetId: UInt32, remoteSessionId: Data) {
log.debug("Send ack for received packetId \(packetId)")
guard let sessionId = sessionId else {
log.warning("Sending ack without a sessionId?")
deferStop(.shutdown, SessionError.missingSessionId)
let raw: Data
do {
raw = try controlChannel.writeAcks(withKey: key, ackPacketIds: [packetId], ackRemoteSessionId: remoteSessionId)
} catch let e {
deferStop(.shutdown, e)
return
}
let raw = controlChannel.writeAcks(withKey: key, sessionId: sessionId, ackPacketIds: [packetId], ackRemoteSessionId: remoteSessionId)
// WARNING: runs in Network.framework queue
link?.writePacket(raw) { [weak self] (error) in