From 56eed120200287b12e3bec391453d5a61f3436b4 Mon Sep 17 00:00:00 2001 From: Jeroen Leenarts Date: Tue, 28 Aug 2018 14:04:38 +0200 Subject: [PATCH] Move connection logic into seperate function. --- .../PacketTunnelProvider.swift | 89 ++++++++++--------- 1 file changed, 46 insertions(+), 43 deletions(-) diff --git a/WireGuardNetworkExtension/PacketTunnelProvider.swift b/WireGuardNetworkExtension/PacketTunnelProvider.swift index b7a481e..07c33e5 100644 --- a/WireGuardNetworkExtension/PacketTunnelProvider.swift +++ b/WireGuardNetworkExtension/PacketTunnelProvider.swift @@ -46,49 +46,7 @@ class PacketTunnelProvider: NEPacketTunnelProvider { configureLogger() - let handle = withStringsAsGoStrings(interfaceName, settings) { (nameGoStr, settingsGoStr) -> Int32 in - return withUnsafeMutablePointer(to: &wgContext) { (wgCtxPtr) -> Int32 in - return wgTurnOn(nameGoStr, settingsGoStr, - // read_fn: Read from the TUN interface and pass it on to WireGuard - { (wgCtxPtr, buf, len) -> Int in - guard let wgCtxPtr = wgCtxPtr else { return 0 } - guard let buf = buf else { return 0 } - let wgContext = wgCtxPtr.bindMemory(to: WireGuardContext.self, capacity: 1).pointee - var isTunnelClosed = false - guard let packet = wgContext.readPacket(isTunnelClosed: &isTunnelClosed) else { return 0 } - if isTunnelClosed { return -1 } - let packetData = packet.data - if packetData.count <= len { - packetData.copyBytes(to: buf, count: packetData.count) - return packetData.count - } - return 0 - }, - // write_fn: Receive packets from WireGuard and write to the TUN interface - { (wgCtxPtr, buf, len) -> Int in - guard let wgCtxPtr = wgCtxPtr else { return 0 } - guard let buf = buf else { return 0 } - guard len > 0 else { return 0 } - let wgContext = wgCtxPtr.bindMemory(to: WireGuardContext.self, capacity: 1).pointee - let ipVersionBits = (buf[0] & 0xf0) >> 4 - let ipVersion: sa_family_t? = { - if ipVersionBits == 4 { return sa_family_t(AF_INET) } // IPv4 - if ipVersionBits == 6 { return sa_family_t(AF_INET6) } // IPv6 - return nil - }() - guard let protocolFamily = ipVersion else { fatalError("Unknown IP version") } - let packet = NEPacket(data: Data(bytes: buf, count: len), protocolFamily: protocolFamily) - var isTunnelClosed = false - let isWritten = wgContext.writePacket(packet: packet, isTunnelClosed: &isTunnelClosed) - if isTunnelClosed { return -1 } - if isWritten { - return len - } - return 0 - }, - wgCtxPtr) - } - } + let handle = connect(interfaceName: interfaceName, settings: settings) if handle < 0 { startTunnelCompletionHandler(PacketTunnelProviderError.tunnelSetupFailed) @@ -181,6 +139,51 @@ class PacketTunnelProvider: NEPacketTunnelProvider { } } + private func connect(interfaceName: String, settings: String) -> Int32 { // swiftlint:disable:this cyclomatic_complexity + return withStringsAsGoStrings(interfaceName, settings) { (nameGoStr, settingsGoStr) -> Int32 in + return withUnsafeMutablePointer(to: &wgContext) { (wgCtxPtr) -> Int32 in + return wgTurnOn(nameGoStr, settingsGoStr, + // read_fn: Read from the TUN interface and pass it on to WireGuard + { (wgCtxPtr, buf, len) -> Int in + guard let wgCtxPtr = wgCtxPtr else { return 0 } + guard let buf = buf else { return 0 } + let wgContext = wgCtxPtr.bindMemory(to: WireGuardContext.self, capacity: 1).pointee + var isTunnelClosed = false + guard let packet = wgContext.readPacket(isTunnelClosed: &isTunnelClosed) else { return 0 } + if isTunnelClosed { return -1 } + let packetData = packet.data + if packetData.count <= len { + packetData.copyBytes(to: buf, count: packetData.count) + return packetData.count + } + return 0 + }, + // write_fn: Receive packets from WireGuard and write to the TUN interface + { (wgCtxPtr, buf, len) -> Int in + guard let wgCtxPtr = wgCtxPtr else { return 0 } + guard let buf = buf else { return 0 } + guard len > 0 else { return 0 } + let wgContext = wgCtxPtr.bindMemory(to: WireGuardContext.self, capacity: 1).pointee + let ipVersionBits = (buf[0] & 0xf0) >> 4 + let ipVersion: sa_family_t? = { + if ipVersionBits == 4 { return sa_family_t(AF_INET) } // IPv4 + if ipVersionBits == 6 { return sa_family_t(AF_INET6) } // IPv6 + return nil + }() + guard let protocolFamily = ipVersion else { fatalError("Unknown IP version") } + let packet = NEPacket(data: Data(bytes: buf, count: len), protocolFamily: protocolFamily) + var isTunnelClosed = false + let isWritten = wgContext.writePacket(packet: packet, isTunnelClosed: &isTunnelClosed) + if isTunnelClosed { return -1 } + if isWritten { + return len + } + return 0 + }, + wgCtxPtr) + } + } + } } class WireGuardContext {