Move control channel logic to PlainSerializer

This commit is contained in:
Davide De Rosa 2018-09-10 11:12:58 +02:00
parent 595cae3563
commit 11cb312c02
3 changed files with 142 additions and 101 deletions

View File

@ -29,6 +29,14 @@ import SwiftyBeaver
private let log = SwiftyBeaver.self private let log = SwiftyBeaver.self
class ControlChannelError: Error, CustomStringConvertible {
let description: String
init(_ message: String) {
description = "\(String(describing: ControlChannelError.self))(\(message))"
}
}
class ControlChannel { class ControlChannel {
private let serializer: ControlChannelSerializer private let serializer: ControlChannelSerializer
@ -71,10 +79,20 @@ class ControlChannel {
pendingAcks.removeAll() pendingAcks.removeAll()
plainBuffer.zero() plainBuffer.zero()
dataCount.reset() dataCount.reset()
serializer.reset()
} }
func readInboundPacket(withCode code: PacketCode, key: UInt8, sessionId inboundSessionId: Data, packetId: UInt32, payload: Data?) -> [ControlPacket] { func readInboundPacket(withCode code: PacketCode, key: UInt8, data: Data, offset: Int) throws -> ControlPacket {
let packet = ControlPacket(code: code, key: key, sessionId: inboundSessionId, packetId: packetId, payload: payload) log.debug("Control: Try read packet with code \(code) and key \(key)")
let packet = try serializer.deserialize(code: code, key: key, data: data, start: offset, end: nil)
log.debug("Control: Read packet \(packet)")
if let ackIds = packet.ackIds as? [UInt32], let ackRemoteSessionId = packet.ackRemoteSessionId {
try readAcks(ackIds, acksRemoteSessionId: ackRemoteSessionId)
}
return packet
}
func enqueueInboundPacket(packet: ControlPacket) -> [ControlPacket] {
queue.inbound.append(packet) queue.inbound.append(packet)
queue.inbound.sort { $0.packetId < $1.packetId } queue.inbound.sort { $0.packetId < $1.packetId }
@ -127,7 +145,7 @@ class ControlChannel {
} }
} }
func writeOutboundPackets() -> [Data] { func writeOutboundPackets() throws -> [Data] {
var rawList: [Data] = [] var rawList: [Data] = []
for packet in queue.outbound { for packet in queue.outbound {
if let sentDate = packet.sentDate { if let sentDate = packet.sentDate {
@ -138,17 +156,9 @@ class ControlChannel {
} }
} }
log.debug("Send control packet with code \(packet.code.rawValue)") log.debug("Control: Write control packet \(packet)")
if let payload = packet.payload { let raw = try serializer.serialize(packet: packet)
if CoreConfiguration.logsSensitiveData {
log.debug("Control packet has payload (\(payload.count) bytes): \(payload.toHex())")
} else {
log.debug("Control packet has payload (\(payload.count) bytes)")
}
}
let raw = packet.serialized()
rawList.append(raw) rawList.append(raw)
packet.sentDate = Date() packet.sentDate = Date()
@ -163,7 +173,8 @@ class ControlChannel {
return !pendingAcks.isEmpty return !pendingAcks.isEmpty
} }
func readAcks(_ packetIds: [UInt32], acksRemoteSessionId: Data) throws { // Ruby: handle_acks
private func readAcks(_ packetIds: [UInt32], acksRemoteSessionId: Data) throws {
guard let sessionId = sessionId else { guard let sessionId = sessionId else {
throw SessionError.missingSessionId throw SessionError.missingSessionId
} }
@ -189,8 +200,9 @@ class ControlChannel {
guard let sessionId = sessionId else { guard let sessionId = sessionId else {
throw SessionError.missingSessionId throw SessionError.missingSessionId
} }
let ackPacket = ControlPacket(key: key, sessionId: sessionId, ackIds: ackPacketIds as [NSNumber], ackRemoteSessionId: ackRemoteSessionId) let packet = ControlPacket(key: key, sessionId: sessionId, ackIds: ackPacketIds as [NSNumber], ackRemoteSessionId: ackRemoteSessionId)
return ackPacket.serialized() log.debug("Control: Write ack packet \(packet)")
return try serializer.serialize(packet: packet)
} }
func currentControlData(withTLS tls: TLSBox) throws -> ZeroingData { func currentControlData(withTLS tls: TLSBox) throws -> ZeroingData {

View File

@ -24,6 +24,7 @@
// //
import Foundation import Foundation
import __TunnelKitNative
import SwiftyBeaver import SwiftyBeaver
private let log = SwiftyBeaver.self private let log = SwiftyBeaver.self
@ -33,23 +34,95 @@ protocol ControlChannelSerializer {
func serialize(packet: ControlPacket) throws -> Data func serialize(packet: ControlPacket) throws -> Data
func deserialize(data: Data, from: Int) throws -> ControlPacket func deserialize(code: PacketCode, key: UInt8, data: Data, start: Int, end: Int?) throws -> ControlPacket
} }
extension ControlChannel { extension ControlChannel {
class PlainSerializer: ControlChannelSerializer { class PlainSerializer: ControlChannelSerializer {
func reset() { func reset() {
// TODO
} }
func serialize(packet: ControlPacket) throws -> Data { func serialize(packet: ControlPacket) throws -> Data {
// TODO return packet.serialized()
throw SessionError.pingTimeout
} }
func deserialize(data: Data, from: Int) throws -> ControlPacket { func deserialize(code: PacketCode, key: UInt8, data packet: Data, start: Int, end: Int?) throws -> ControlPacket {
// TODO var offset = start
throw SessionError.pingTimeout let end = end ?? packet.count
guard end >= offset + PacketSessionIdLength else {
throw ControlChannelError("Missing sessionId")
}
let sessionId = packet.subdata(offset: offset, count: PacketSessionIdLength)
offset += PacketSessionIdLength
guard end >= offset + 1 else {
throw ControlChannelError("Missing ackSize")
}
let ackSize = packet[offset]
offset += 1
var ackIds: [UInt32]?
var ackRemoteSessionId: Data?
if ackSize > 0 {
guard end >= (offset + Int(ackSize) * PacketIdLength) else {
throw ControlChannelError("Missing acks")
}
var ids: [UInt32] = []
for _ in 0..<ackSize {
let id = packet.networkUInt32Value(from: offset)
ids.append(id)
offset += PacketIdLength
}
guard end >= offset + PacketSessionIdLength else {
throw ControlChannelError("Missing remoteSessionId")
}
let remoteSessionId = packet.subdata(offset: offset, count: PacketSessionIdLength)
offset += PacketSessionIdLength
log.debug("Server acked packetIds \(ids) with remoteSessionId \(remoteSessionId.toHex())")
ackIds = ids
ackRemoteSessionId = remoteSessionId
}
if code == .ackV1 {
guard let ackIds = ackIds else {
throw ControlChannelError("Ack packet without ids")
}
guard let ackRemoteSessionId = ackRemoteSessionId else {
throw ControlChannelError("Ack packet without remoteSessionId")
}
return ControlPacket(key: key, sessionId: sessionId, ackIds: ackIds as [NSNumber], ackRemoteSessionId: ackRemoteSessionId)
}
guard end >= offset + PacketIdLength else {
throw ControlChannelError("Missing packetId")
}
let packetId = packet.networkUInt32Value(from: offset)
log.debug("Control packet has packetId \(packetId)")
offset += PacketIdLength
var payload: Data?
if offset < end {
payload = packet.subdata(in: offset..<end)
if let payload = payload {
if CoreConfiguration.logsSensitiveData {
log.debug("Control packet payload (\(payload.count) bytes): \(payload.toHex())")
} else {
log.debug("Control packet payload (\(payload.count) bytes)")
}
}
}
let controlPacket = ControlPacket(code: code, key: key, sessionId: sessionId, packetId: packetId, payload: payload)
if let ackIds = ackIds {
controlPacket.ackIds = ackIds as [NSNumber]
controlPacket.ackRemoteSessionId = ackRemoteSessionId
}
return controlPacket
} }
} }
} }

View File

@ -454,74 +454,26 @@ public class SessionProxy {
continue continue
} }
guard packet.count >= offset + PacketSessionIdLength else { log.debug("Packet has code \(code.rawValue), key \(key)")
log.warning("Dropped malformed packet (missing sessionId)") let controlPacket: ControlPacket
do {
let parsedPacket = try controlChannel.readInboundPacket(withCode: code, key: key, data: packet, offset: offset)
handleAcks()
if parsedPacket.code == .ackV1 {
continue continue
} }
let sessionId = packet.subdata(offset: offset, count: PacketSessionIdLength) controlPacket = parsedPacket
offset += PacketSessionIdLength } catch let e {
log.warning("Dropped malformed packet: \(e)")
guard packet.count >= offset + 1 else {
log.warning("Dropped malformed packet (missing ackSize)")
continue continue
} // deferStop(.shutdown, e)
let ackSize = packet[offset] // return
offset += 1
log.debug("Packet has code \(code.rawValue), key \(key), sessionId \(sessionId.toHex()) and \(ackSize) acks entries")
if (ackSize > 0) {
guard packet.count >= (offset + Int(ackSize) * PacketIdLength) else {
log.warning("Dropped malformed packet (missing acks)")
continue
}
var ackedPacketIds = [UInt32]()
for _ in 0..<ackSize {
let ackedPacketId = packet.networkUInt32Value(from: offset)
ackedPacketIds.append(ackedPacketId)
offset += PacketIdLength
} }
guard packet.count >= offset + PacketSessionIdLength else { log.debug("Packet has sessionId \(controlPacket.sessionId.toHex()) and \(controlPacket.ackIds?.count ?? 0) acks entries")
log.warning("Dropped malformed packet (missing remoteSessionId)") sendAck(for: controlPacket)
continue
}
let remoteSessionId = packet.subdata(offset: offset, count: PacketSessionIdLength)
offset += PacketSessionIdLength
log.debug("Server acked packetIds \(ackedPacketIds) with remoteSessionId \(remoteSessionId.toHex())") let pendingInboundQueue = controlChannel.enqueueInboundPacket(packet: controlPacket)
handleAcks(ackedPacketIds, remoteSessionId: remoteSessionId)
}
if (code == .ackV1) {
continue
}
guard packet.count >= offset + PacketIdLength else {
log.warning("Dropped malformed packet (missing packetId)")
continue
}
let packetId = packet.networkUInt32Value(from: offset)
log.debug("Control packet has packetId \(packetId)")
offset += PacketIdLength
sendAck(key: key, packetId: packetId, remoteSessionId: sessionId)
var payload: Data?
if (offset < packet.count) {
payload = packet.subdata(in: offset..<packet.count)
if let payload = payload {
if CoreConfiguration.logsSensitiveData {
log.debug("Control packet payload (\(payload.count) bytes): \(payload.toHex())")
} else {
log.debug("Control packet payload (\(payload.count) bytes)")
}
}
}
let pendingInboundQueue = controlChannel.readInboundPacket(withCode: code, key: key, sessionId: sessionId, packetId: packetId, payload: payload)
for inboundPacket in pendingInboundQueue { for inboundPacket in pendingInboundQueue {
handleControlPacket(inboundPacket) handleControlPacket(inboundPacket)
} }
@ -899,7 +851,14 @@ public class SessionProxy {
// Ruby: flush_ctrl_q_out // Ruby: flush_ctrl_q_out
private func flushControlQueue() { private func flushControlQueue() {
let rawList = controlChannel.writeOutboundPackets() let rawList: [Data]
do {
rawList = try controlChannel.writeOutboundPackets()
} catch let e {
log.warning("Failed control packet serialization: \(e)")
deferStop(.shutdown, e)
return
}
for raw in rawList { for raw in rawList {
log.debug("Send control packet (\(raw.count) bytes): \(raw.toHex())") log.debug("Send control packet (\(raw.count) bytes): \(raw.toHex())")
} }
@ -1044,14 +1003,7 @@ public class SessionProxy {
// MARK: Acks // MARK: Acks
// Ruby: handle_acks private func handleAcks() {
private func handleAcks(_ packetIds: [UInt32], remoteSessionId: Data) {
do {
try controlChannel.readAcks(packetIds, acksRemoteSessionId: remoteSessionId)
} catch let e {
deferStop(.shutdown, e)
return
}
// retry PUSH_REQUEST if ack queue is empty (all sent packets were ack'ed) // retry PUSH_REQUEST if ack queue is empty (all sent packets were ack'ed)
if isReliableLink && !controlChannel.hasPendingAcks() { if isReliableLink && !controlChannel.hasPendingAcks() {
@ -1060,12 +1012,16 @@ public class SessionProxy {
} }
// Ruby: send_ack // Ruby: send_ack
private func sendAck(key: UInt8, packetId: UInt32, remoteSessionId: Data) { private func sendAck(for controlPacket: ControlPacket) {
log.debug("Send ack for received packetId \(packetId)") log.debug("Send ack for received packetId \(controlPacket.packetId)")
let raw: Data let raw: Data
do { do {
raw = try controlChannel.writeAcks(withKey: key, ackPacketIds: [packetId], ackRemoteSessionId: remoteSessionId) raw = try controlChannel.writeAcks(
withKey: controlPacket.key,
ackPacketIds: [controlPacket.packetId],
ackRemoteSessionId: controlPacket.sessionId
)
} catch let e { } catch let e {
deferStop(.shutdown, e) deferStop(.shutdown, e)
return return
@ -1075,12 +1031,12 @@ public class SessionProxy {
link?.writePacket(raw) { [weak self] (error) in link?.writePacket(raw) { [weak self] (error) in
if let error = error { if let error = error {
self?.queue.sync { self?.queue.sync {
log.error("Failed LINK write during send ack for packetId \(packetId): \(error)") log.error("Failed LINK write during send ack for packetId \(controlPacket.packetId): \(error)")
self?.deferStop(.reconnect, SessionError.failedLinkWrite) self?.deferStop(.reconnect, SessionError.failedLinkWrite)
return return
} }
} }
log.debug("Ack successfully written to LINK for packetId \(packetId)") log.debug("Ack successfully written to LINK for packetId \(controlPacket.packetId)")
} }
} }