diff --git a/TunnelKit/Sources/Core/ControlChannel.swift b/TunnelKit/Sources/Core/ControlChannel.swift index ac0078b..0c1bc3d 100644 --- a/TunnelKit/Sources/Core/ControlChannel.swift +++ b/TunnelKit/Sources/Core/ControlChannel.swift @@ -29,6 +29,14 @@ import SwiftyBeaver private let log = SwiftyBeaver.self +class ControlChannelError: Error, CustomStringConvertible { + let description: String + + init(_ message: String) { + description = "\(String(describing: ControlChannelError.self))(\(message))" + } +} + class ControlChannel { private let serializer: ControlChannelSerializer @@ -71,10 +79,20 @@ class ControlChannel { pendingAcks.removeAll() plainBuffer.zero() dataCount.reset() + serializer.reset() } - - func readInboundPacket(withCode code: PacketCode, key: UInt8, sessionId inboundSessionId: Data, packetId: UInt32, payload: Data?) -> [ControlPacket] { - let packet = ControlPacket(code: code, key: key, sessionId: inboundSessionId, packetId: packetId, payload: payload) + + func readInboundPacket(withCode code: PacketCode, key: UInt8, data: Data, offset: Int) throws -> ControlPacket { + log.debug("Control: Try read packet with code \(code) and key \(key)") + let packet = try serializer.deserialize(code: code, key: key, data: data, start: offset, end: nil) + log.debug("Control: Read packet \(packet)") + if let ackIds = packet.ackIds as? [UInt32], let ackRemoteSessionId = packet.ackRemoteSessionId { + try readAcks(ackIds, acksRemoteSessionId: ackRemoteSessionId) + } + return packet + } + + func enqueueInboundPacket(packet: ControlPacket) -> [ControlPacket] { queue.inbound.append(packet) queue.inbound.sort { $0.packetId < $1.packetId } @@ -127,7 +145,7 @@ class ControlChannel { } } - func writeOutboundPackets() -> [Data] { + func writeOutboundPackets() throws -> [Data] { var rawList: [Data] = [] for packet in queue.outbound { if let sentDate = packet.sentDate { @@ -138,17 +156,9 @@ class ControlChannel { } } - log.debug("Send control packet with code \(packet.code.rawValue)") - - if let payload = packet.payload { - if CoreConfiguration.logsSensitiveData { - log.debug("Control packet has payload (\(payload.count) bytes): \(payload.toHex())") - } else { - log.debug("Control packet has payload (\(payload.count) bytes)") - } - } - - let raw = packet.serialized() + log.debug("Control: Write control packet \(packet)") + + let raw = try serializer.serialize(packet: packet) rawList.append(raw) packet.sentDate = Date() @@ -163,7 +173,8 @@ class ControlChannel { return !pendingAcks.isEmpty } - func readAcks(_ packetIds: [UInt32], acksRemoteSessionId: Data) throws { + // Ruby: handle_acks + private func readAcks(_ packetIds: [UInt32], acksRemoteSessionId: Data) throws { guard let sessionId = sessionId else { throw SessionError.missingSessionId } @@ -189,8 +200,9 @@ class ControlChannel { guard let sessionId = sessionId else { throw SessionError.missingSessionId } - let ackPacket = ControlPacket(key: key, sessionId: sessionId, ackIds: ackPacketIds as [NSNumber], ackRemoteSessionId: ackRemoteSessionId) - return ackPacket.serialized() + let packet = ControlPacket(key: key, sessionId: sessionId, ackIds: ackPacketIds as [NSNumber], ackRemoteSessionId: ackRemoteSessionId) + log.debug("Control: Write ack packet \(packet)") + return try serializer.serialize(packet: packet) } func currentControlData(withTLS tls: TLSBox) throws -> ZeroingData { diff --git a/TunnelKit/Sources/Core/ControlChannelSerializer.swift b/TunnelKit/Sources/Core/ControlChannelSerializer.swift index 7f75c46..6b09a23 100644 --- a/TunnelKit/Sources/Core/ControlChannelSerializer.swift +++ b/TunnelKit/Sources/Core/ControlChannelSerializer.swift @@ -24,6 +24,7 @@ // import Foundation +import __TunnelKitNative import SwiftyBeaver private let log = SwiftyBeaver.self @@ -33,23 +34,95 @@ protocol ControlChannelSerializer { func serialize(packet: ControlPacket) throws -> Data - func deserialize(data: Data, from: Int) throws -> ControlPacket + func deserialize(code: PacketCode, key: UInt8, data: Data, start: Int, end: Int?) throws -> ControlPacket } extension ControlChannel { class PlainSerializer: ControlChannelSerializer { func reset() { - // TODO } func serialize(packet: ControlPacket) throws -> Data { - // TODO - throw SessionError.pingTimeout + return packet.serialized() } - func deserialize(data: Data, from: Int) throws -> ControlPacket { - // TODO - throw SessionError.pingTimeout + func deserialize(code: PacketCode, key: UInt8, data packet: Data, start: Int, end: Int?) throws -> ControlPacket { + var offset = start + let end = end ?? packet.count + + guard end >= offset + PacketSessionIdLength else { + throw ControlChannelError("Missing sessionId") + } + let sessionId = packet.subdata(offset: offset, count: PacketSessionIdLength) + offset += PacketSessionIdLength + + guard end >= offset + 1 else { + throw ControlChannelError("Missing ackSize") + } + let ackSize = packet[offset] + offset += 1 + + var ackIds: [UInt32]? + var ackRemoteSessionId: Data? + if ackSize > 0 { + guard end >= (offset + Int(ackSize) * PacketIdLength) else { + throw ControlChannelError("Missing acks") + } + var ids: [UInt32] = [] + for _ in 0..= offset + PacketSessionIdLength else { + throw ControlChannelError("Missing remoteSessionId") + } + let remoteSessionId = packet.subdata(offset: offset, count: PacketSessionIdLength) + offset += PacketSessionIdLength + + log.debug("Server acked packetIds \(ids) with remoteSessionId \(remoteSessionId.toHex())") + + ackIds = ids + ackRemoteSessionId = remoteSessionId + } + + if code == .ackV1 { + guard let ackIds = ackIds else { + throw ControlChannelError("Ack packet without ids") + } + guard let ackRemoteSessionId = ackRemoteSessionId else { + throw ControlChannelError("Ack packet without remoteSessionId") + } + return ControlPacket(key: key, sessionId: sessionId, ackIds: ackIds as [NSNumber], ackRemoteSessionId: ackRemoteSessionId) + } + + guard end >= offset + PacketIdLength else { + throw ControlChannelError("Missing packetId") + } + let packetId = packet.networkUInt32Value(from: offset) + log.debug("Control packet has packetId \(packetId)") + offset += PacketIdLength + + var payload: Data? + if offset < end { + payload = packet.subdata(in: offset..= offset + PacketSessionIdLength else { - log.warning("Dropped malformed packet (missing sessionId)") - continue - } - let sessionId = packet.subdata(offset: offset, count: PacketSessionIdLength) - offset += PacketSessionIdLength - - guard packet.count >= offset + 1 else { - log.warning("Dropped malformed packet (missing ackSize)") - continue - } - let ackSize = packet[offset] - offset += 1 - log.debug("Packet has code \(code.rawValue), key \(key), sessionId \(sessionId.toHex()) and \(ackSize) acks entries") - - if (ackSize > 0) { - guard packet.count >= (offset + Int(ackSize) * PacketIdLength) else { - log.warning("Dropped malformed packet (missing acks)") + log.debug("Packet has code \(code.rawValue), key \(key)") + let controlPacket: ControlPacket + do { + let parsedPacket = try controlChannel.readInboundPacket(withCode: code, key: key, data: packet, offset: offset) + handleAcks() + if parsedPacket.code == .ackV1 { continue } - var ackedPacketIds = [UInt32]() - for _ in 0..= offset + PacketSessionIdLength else { - log.warning("Dropped malformed packet (missing remoteSessionId)") - continue - } - let remoteSessionId = packet.subdata(offset: offset, count: PacketSessionIdLength) - offset += PacketSessionIdLength - - log.debug("Server acked packetIds \(ackedPacketIds) with remoteSessionId \(remoteSessionId.toHex())") - - handleAcks(ackedPacketIds, remoteSessionId: remoteSessionId) - } - - if (code == .ackV1) { + controlPacket = parsedPacket + } catch let e { + log.warning("Dropped malformed packet: \(e)") continue +// deferStop(.shutdown, e) +// return } - guard packet.count >= offset + PacketIdLength else { - log.warning("Dropped malformed packet (missing packetId)") - continue - } - let packetId = packet.networkUInt32Value(from: offset) - log.debug("Control packet has packetId \(packetId)") - offset += PacketIdLength + log.debug("Packet has sessionId \(controlPacket.sessionId.toHex()) and \(controlPacket.ackIds?.count ?? 0) acks entries") + sendAck(for: controlPacket) - sendAck(key: key, packetId: packetId, remoteSessionId: sessionId) - - var payload: Data? - if (offset < packet.count) { - payload = packet.subdata(in: offset..