Move control channel logic to PlainSerializer
This commit is contained in:
parent
595cae3563
commit
11cb312c02
|
@ -29,6 +29,14 @@ import SwiftyBeaver
|
|||
|
||||
private let log = SwiftyBeaver.self
|
||||
|
||||
class ControlChannelError: Error, CustomStringConvertible {
|
||||
let description: String
|
||||
|
||||
init(_ message: String) {
|
||||
description = "\(String(describing: ControlChannelError.self))(\(message))"
|
||||
}
|
||||
}
|
||||
|
||||
class ControlChannel {
|
||||
private let serializer: ControlChannelSerializer
|
||||
|
||||
|
@ -71,10 +79,20 @@ class ControlChannel {
|
|||
pendingAcks.removeAll()
|
||||
plainBuffer.zero()
|
||||
dataCount.reset()
|
||||
serializer.reset()
|
||||
}
|
||||
|
||||
func readInboundPacket(withCode code: PacketCode, key: UInt8, sessionId inboundSessionId: Data, packetId: UInt32, payload: Data?) -> [ControlPacket] {
|
||||
let packet = ControlPacket(code: code, key: key, sessionId: inboundSessionId, packetId: packetId, payload: payload)
|
||||
|
||||
func readInboundPacket(withCode code: PacketCode, key: UInt8, data: Data, offset: Int) throws -> ControlPacket {
|
||||
log.debug("Control: Try read packet with code \(code) and key \(key)")
|
||||
let packet = try serializer.deserialize(code: code, key: key, data: data, start: offset, end: nil)
|
||||
log.debug("Control: Read packet \(packet)")
|
||||
if let ackIds = packet.ackIds as? [UInt32], let ackRemoteSessionId = packet.ackRemoteSessionId {
|
||||
try readAcks(ackIds, acksRemoteSessionId: ackRemoteSessionId)
|
||||
}
|
||||
return packet
|
||||
}
|
||||
|
||||
func enqueueInboundPacket(packet: ControlPacket) -> [ControlPacket] {
|
||||
queue.inbound.append(packet)
|
||||
queue.inbound.sort { $0.packetId < $1.packetId }
|
||||
|
||||
|
@ -127,7 +145,7 @@ class ControlChannel {
|
|||
}
|
||||
}
|
||||
|
||||
func writeOutboundPackets() -> [Data] {
|
||||
func writeOutboundPackets() throws -> [Data] {
|
||||
var rawList: [Data] = []
|
||||
for packet in queue.outbound {
|
||||
if let sentDate = packet.sentDate {
|
||||
|
@ -138,17 +156,9 @@ class ControlChannel {
|
|||
}
|
||||
}
|
||||
|
||||
log.debug("Send control packet with code \(packet.code.rawValue)")
|
||||
|
||||
if let payload = packet.payload {
|
||||
if CoreConfiguration.logsSensitiveData {
|
||||
log.debug("Control packet has payload (\(payload.count) bytes): \(payload.toHex())")
|
||||
} else {
|
||||
log.debug("Control packet has payload (\(payload.count) bytes)")
|
||||
}
|
||||
}
|
||||
|
||||
let raw = packet.serialized()
|
||||
log.debug("Control: Write control packet \(packet)")
|
||||
|
||||
let raw = try serializer.serialize(packet: packet)
|
||||
rawList.append(raw)
|
||||
packet.sentDate = Date()
|
||||
|
||||
|
@ -163,7 +173,8 @@ class ControlChannel {
|
|||
return !pendingAcks.isEmpty
|
||||
}
|
||||
|
||||
func readAcks(_ packetIds: [UInt32], acksRemoteSessionId: Data) throws {
|
||||
// Ruby: handle_acks
|
||||
private func readAcks(_ packetIds: [UInt32], acksRemoteSessionId: Data) throws {
|
||||
guard let sessionId = sessionId else {
|
||||
throw SessionError.missingSessionId
|
||||
}
|
||||
|
@ -189,8 +200,9 @@ class ControlChannel {
|
|||
guard let sessionId = sessionId else {
|
||||
throw SessionError.missingSessionId
|
||||
}
|
||||
let ackPacket = ControlPacket(key: key, sessionId: sessionId, ackIds: ackPacketIds as [NSNumber], ackRemoteSessionId: ackRemoteSessionId)
|
||||
return ackPacket.serialized()
|
||||
let packet = ControlPacket(key: key, sessionId: sessionId, ackIds: ackPacketIds as [NSNumber], ackRemoteSessionId: ackRemoteSessionId)
|
||||
log.debug("Control: Write ack packet \(packet)")
|
||||
return try serializer.serialize(packet: packet)
|
||||
}
|
||||
|
||||
func currentControlData(withTLS tls: TLSBox) throws -> ZeroingData {
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
//
|
||||
|
||||
import Foundation
|
||||
import __TunnelKitNative
|
||||
import SwiftyBeaver
|
||||
|
||||
private let log = SwiftyBeaver.self
|
||||
|
@ -33,23 +34,95 @@ protocol ControlChannelSerializer {
|
|||
|
||||
func serialize(packet: ControlPacket) throws -> Data
|
||||
|
||||
func deserialize(data: Data, from: Int) throws -> ControlPacket
|
||||
func deserialize(code: PacketCode, key: UInt8, data: Data, start: Int, end: Int?) throws -> ControlPacket
|
||||
}
|
||||
|
||||
extension ControlChannel {
|
||||
class PlainSerializer: ControlChannelSerializer {
|
||||
func reset() {
|
||||
// TODO
|
||||
}
|
||||
|
||||
func serialize(packet: ControlPacket) throws -> Data {
|
||||
// TODO
|
||||
throw SessionError.pingTimeout
|
||||
return packet.serialized()
|
||||
}
|
||||
|
||||
func deserialize(data: Data, from: Int) throws -> ControlPacket {
|
||||
// TODO
|
||||
throw SessionError.pingTimeout
|
||||
func deserialize(code: PacketCode, key: UInt8, data packet: Data, start: Int, end: Int?) throws -> ControlPacket {
|
||||
var offset = start
|
||||
let end = end ?? packet.count
|
||||
|
||||
guard end >= offset + PacketSessionIdLength else {
|
||||
throw ControlChannelError("Missing sessionId")
|
||||
}
|
||||
let sessionId = packet.subdata(offset: offset, count: PacketSessionIdLength)
|
||||
offset += PacketSessionIdLength
|
||||
|
||||
guard end >= offset + 1 else {
|
||||
throw ControlChannelError("Missing ackSize")
|
||||
}
|
||||
let ackSize = packet[offset]
|
||||
offset += 1
|
||||
|
||||
var ackIds: [UInt32]?
|
||||
var ackRemoteSessionId: Data?
|
||||
if ackSize > 0 {
|
||||
guard end >= (offset + Int(ackSize) * PacketIdLength) else {
|
||||
throw ControlChannelError("Missing acks")
|
||||
}
|
||||
var ids: [UInt32] = []
|
||||
for _ in 0..<ackSize {
|
||||
let id = packet.networkUInt32Value(from: offset)
|
||||
ids.append(id)
|
||||
offset += PacketIdLength
|
||||
}
|
||||
|
||||
guard end >= offset + PacketSessionIdLength else {
|
||||
throw ControlChannelError("Missing remoteSessionId")
|
||||
}
|
||||
let remoteSessionId = packet.subdata(offset: offset, count: PacketSessionIdLength)
|
||||
offset += PacketSessionIdLength
|
||||
|
||||
log.debug("Server acked packetIds \(ids) with remoteSessionId \(remoteSessionId.toHex())")
|
||||
|
||||
ackIds = ids
|
||||
ackRemoteSessionId = remoteSessionId
|
||||
}
|
||||
|
||||
if code == .ackV1 {
|
||||
guard let ackIds = ackIds else {
|
||||
throw ControlChannelError("Ack packet without ids")
|
||||
}
|
||||
guard let ackRemoteSessionId = ackRemoteSessionId else {
|
||||
throw ControlChannelError("Ack packet without remoteSessionId")
|
||||
}
|
||||
return ControlPacket(key: key, sessionId: sessionId, ackIds: ackIds as [NSNumber], ackRemoteSessionId: ackRemoteSessionId)
|
||||
}
|
||||
|
||||
guard end >= offset + PacketIdLength else {
|
||||
throw ControlChannelError("Missing packetId")
|
||||
}
|
||||
let packetId = packet.networkUInt32Value(from: offset)
|
||||
log.debug("Control packet has packetId \(packetId)")
|
||||
offset += PacketIdLength
|
||||
|
||||
var payload: Data?
|
||||
if offset < end {
|
||||
payload = packet.subdata(in: offset..<end)
|
||||
|
||||
if let payload = payload {
|
||||
if CoreConfiguration.logsSensitiveData {
|
||||
log.debug("Control packet payload (\(payload.count) bytes): \(payload.toHex())")
|
||||
} else {
|
||||
log.debug("Control packet payload (\(payload.count) bytes)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let controlPacket = ControlPacket(code: code, key: key, sessionId: sessionId, packetId: packetId, payload: payload)
|
||||
if let ackIds = ackIds {
|
||||
controlPacket.ackIds = ackIds as [NSNumber]
|
||||
controlPacket.ackRemoteSessionId = ackRemoteSessionId
|
||||
}
|
||||
return controlPacket
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -453,75 +453,27 @@ public class SessionProxy {
|
|||
|
||||
continue
|
||||
}
|
||||
|
||||
guard packet.count >= offset + PacketSessionIdLength else {
|
||||
log.warning("Dropped malformed packet (missing sessionId)")
|
||||
continue
|
||||
}
|
||||
let sessionId = packet.subdata(offset: offset, count: PacketSessionIdLength)
|
||||
offset += PacketSessionIdLength
|
||||
|
||||
guard packet.count >= offset + 1 else {
|
||||
log.warning("Dropped malformed packet (missing ackSize)")
|
||||
continue
|
||||
}
|
||||
let ackSize = packet[offset]
|
||||
offset += 1
|
||||
|
||||
log.debug("Packet has code \(code.rawValue), key \(key), sessionId \(sessionId.toHex()) and \(ackSize) acks entries")
|
||||
|
||||
if (ackSize > 0) {
|
||||
guard packet.count >= (offset + Int(ackSize) * PacketIdLength) else {
|
||||
log.warning("Dropped malformed packet (missing acks)")
|
||||
log.debug("Packet has code \(code.rawValue), key \(key)")
|
||||
let controlPacket: ControlPacket
|
||||
do {
|
||||
let parsedPacket = try controlChannel.readInboundPacket(withCode: code, key: key, data: packet, offset: offset)
|
||||
handleAcks()
|
||||
if parsedPacket.code == .ackV1 {
|
||||
continue
|
||||
}
|
||||
var ackedPacketIds = [UInt32]()
|
||||
for _ in 0..<ackSize {
|
||||
let ackedPacketId = packet.networkUInt32Value(from: offset)
|
||||
ackedPacketIds.append(ackedPacketId)
|
||||
offset += PacketIdLength
|
||||
}
|
||||
|
||||
guard packet.count >= offset + PacketSessionIdLength else {
|
||||
log.warning("Dropped malformed packet (missing remoteSessionId)")
|
||||
continue
|
||||
}
|
||||
let remoteSessionId = packet.subdata(offset: offset, count: PacketSessionIdLength)
|
||||
offset += PacketSessionIdLength
|
||||
|
||||
log.debug("Server acked packetIds \(ackedPacketIds) with remoteSessionId \(remoteSessionId.toHex())")
|
||||
|
||||
handleAcks(ackedPacketIds, remoteSessionId: remoteSessionId)
|
||||
}
|
||||
|
||||
if (code == .ackV1) {
|
||||
controlPacket = parsedPacket
|
||||
} catch let e {
|
||||
log.warning("Dropped malformed packet: \(e)")
|
||||
continue
|
||||
// deferStop(.shutdown, e)
|
||||
// return
|
||||
}
|
||||
|
||||
guard packet.count >= offset + PacketIdLength else {
|
||||
log.warning("Dropped malformed packet (missing packetId)")
|
||||
continue
|
||||
}
|
||||
let packetId = packet.networkUInt32Value(from: offset)
|
||||
log.debug("Control packet has packetId \(packetId)")
|
||||
offset += PacketIdLength
|
||||
log.debug("Packet has sessionId \(controlPacket.sessionId.toHex()) and \(controlPacket.ackIds?.count ?? 0) acks entries")
|
||||
sendAck(for: controlPacket)
|
||||
|
||||
sendAck(key: key, packetId: packetId, remoteSessionId: sessionId)
|
||||
|
||||
var payload: Data?
|
||||
if (offset < packet.count) {
|
||||
payload = packet.subdata(in: offset..<packet.count)
|
||||
|
||||
if let payload = payload {
|
||||
if CoreConfiguration.logsSensitiveData {
|
||||
log.debug("Control packet payload (\(payload.count) bytes): \(payload.toHex())")
|
||||
} else {
|
||||
log.debug("Control packet payload (\(payload.count) bytes)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let pendingInboundQueue = controlChannel.readInboundPacket(withCode: code, key: key, sessionId: sessionId, packetId: packetId, payload: payload)
|
||||
let pendingInboundQueue = controlChannel.enqueueInboundPacket(packet: controlPacket)
|
||||
for inboundPacket in pendingInboundQueue {
|
||||
handleControlPacket(inboundPacket)
|
||||
}
|
||||
|
@ -899,7 +851,14 @@ public class SessionProxy {
|
|||
|
||||
// Ruby: flush_ctrl_q_out
|
||||
private func flushControlQueue() {
|
||||
let rawList = controlChannel.writeOutboundPackets()
|
||||
let rawList: [Data]
|
||||
do {
|
||||
rawList = try controlChannel.writeOutboundPackets()
|
||||
} catch let e {
|
||||
log.warning("Failed control packet serialization: \(e)")
|
||||
deferStop(.shutdown, e)
|
||||
return
|
||||
}
|
||||
for raw in rawList {
|
||||
log.debug("Send control packet (\(raw.count) bytes): \(raw.toHex())")
|
||||
}
|
||||
|
@ -1044,14 +1003,7 @@ public class SessionProxy {
|
|||
|
||||
// MARK: Acks
|
||||
|
||||
// Ruby: handle_acks
|
||||
private func handleAcks(_ packetIds: [UInt32], remoteSessionId: Data) {
|
||||
do {
|
||||
try controlChannel.readAcks(packetIds, acksRemoteSessionId: remoteSessionId)
|
||||
} catch let e {
|
||||
deferStop(.shutdown, e)
|
||||
return
|
||||
}
|
||||
private func handleAcks() {
|
||||
|
||||
// retry PUSH_REQUEST if ack queue is empty (all sent packets were ack'ed)
|
||||
if isReliableLink && !controlChannel.hasPendingAcks() {
|
||||
|
@ -1060,12 +1012,16 @@ public class SessionProxy {
|
|||
}
|
||||
|
||||
// Ruby: send_ack
|
||||
private func sendAck(key: UInt8, packetId: UInt32, remoteSessionId: Data) {
|
||||
log.debug("Send ack for received packetId \(packetId)")
|
||||
private func sendAck(for controlPacket: ControlPacket) {
|
||||
log.debug("Send ack for received packetId \(controlPacket.packetId)")
|
||||
|
||||
let raw: Data
|
||||
do {
|
||||
raw = try controlChannel.writeAcks(withKey: key, ackPacketIds: [packetId], ackRemoteSessionId: remoteSessionId)
|
||||
raw = try controlChannel.writeAcks(
|
||||
withKey: controlPacket.key,
|
||||
ackPacketIds: [controlPacket.packetId],
|
||||
ackRemoteSessionId: controlPacket.sessionId
|
||||
)
|
||||
} catch let e {
|
||||
deferStop(.shutdown, e)
|
||||
return
|
||||
|
@ -1075,12 +1031,12 @@ public class SessionProxy {
|
|||
link?.writePacket(raw) { [weak self] (error) in
|
||||
if let error = error {
|
||||
self?.queue.sync {
|
||||
log.error("Failed LINK write during send ack for packetId \(packetId): \(error)")
|
||||
log.error("Failed LINK write during send ack for packetId \(controlPacket.packetId): \(error)")
|
||||
self?.deferStop(.reconnect, SessionError.failedLinkWrite)
|
||||
return
|
||||
}
|
||||
}
|
||||
log.debug("Ack successfully written to LINK for packetId \(packetId)")
|
||||
log.debug("Ack successfully written to LINK for packetId \(controlPacket.packetId)")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue