wireguard-go-bridge: take fd instead of fnptr

This commit is contained in:
Jason A. Donenfeld 2018-11-06 15:46:44 +01:00
parent 2a7aa578d2
commit 040f0a25ea
7 changed files with 77 additions and 198 deletions

View File

@ -17,7 +17,6 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
// MARK: Properties
private var wgHandle: Int32?
private var wgContext: WireGuardContext?
// MARK: NEPacketTunnelProvider
@ -64,9 +63,14 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
}
configureLogger()
wgContext = WireGuardContext(packetFlow: self.packetFlow)
let handle = connect(interfaceName: interfaceName, settings: wireguardSettings, mtu: mtu.uint16Value)
let fd = packetFlow.value(forKeyPath: "socket.fileDescriptor") as! Int32
if fd < 0 {
os_log("Starting tunnel failed: Could not determine file descriptor", log: OSLog.default, type: .error)
startTunnelCompletionHandler(PacketTunnelProviderError.couldNotStartWireGuard)
return
}
let handle = connect(interfaceName: interfaceName, settings: wireguardSettings, fd: fd)
if handle < 0 {
os_log("Starting tunnel failed: Could not start WireGuard", log: OSLog.default, type: .error)
@ -114,9 +118,8 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
// MTU
if (mtu == 0) {
// 0 imples automatic MTU, where we set overhead as 95 bytes,
// 80 for WireGuard and the 15 to make sure WireGuard's padding will work.
networkSettings.tunnelOverheadBytes = 95
// 0 imples automatic MTU, where we set overhead as 80 bytes, which is the worst case for WireGuard
networkSettings.tunnelOverheadBytes = 80
} else {
networkSettings.mtu = mtu
}
@ -134,7 +137,6 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
/// Begin the process of stopping the tunnel.
override func stopTunnel(with reason: NEProviderStopReason, completionHandler: @escaping () -> Void) {
os_log("Stopping tunnel", log: OSLog.default, type: .info)
wgContext?.closeTunnel()
if let handle = wgHandle {
wgTurnOff(handle)
}
@ -159,99 +161,13 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
}
}
private func connect(interfaceName: String, settings: String, mtu: UInt16) -> Int32 { // swiftlint:disable:this cyclomatic_complexity
private func connect(interfaceName: String, settings: String, fd: Int32) -> Int32 { // swiftlint:disable:this cyclomatic_complexity
return withStringsAsGoStrings(interfaceName, settings) { (nameGoStr, settingsGoStr) -> Int32 in
return withUnsafeMutablePointer(to: &wgContext) { (wgCtxPtr) -> Int32 in
return wgTurnOn(nameGoStr, settingsGoStr, mtu, { (wgCtxPtr, buf, len) -> Int in
autoreleasepool {
// read_fn: Read from the TUN interface and pass it on to WireGuard
guard let wgCtxPtr = wgCtxPtr else { return 0 }
guard let buf = buf else { return 0 }
let wgContext = wgCtxPtr.bindMemory(to: WireGuardContext.self, capacity: 1).pointee
var isTunnelClosed = false
let packet = wgContext.readPacket(isTunnelClosed: &isTunnelClosed)
if isTunnelClosed { return -1 }
guard let packetData = packet?.data else { return 0 }
if packetData.count <= len {
packetData.copyBytes(to: buf, count: packetData.count)
return packetData.count
}
return 0
}
}, { (wgCtxPtr, buf, len) -> Int in
autoreleasepool {
// write_fn: Receive packets from WireGuard and write to the TUN interface
guard let wgCtxPtr = wgCtxPtr else { return 0 }
guard let buf = buf else { return 0 }
guard len > 0 else { return 0 }
let wgContext = wgCtxPtr.bindMemory(to: WireGuardContext.self, capacity: 1).pointee
let ipVersionBits = (buf[0] & 0xf0) >> 4
let ipVersion: sa_family_t? = {
if ipVersionBits == 4 { return sa_family_t(AF_INET) } // IPv4
if ipVersionBits == 6 { return sa_family_t(AF_INET6) } // IPv6
return nil
}()
guard let protocolFamily = ipVersion else { fatalError("Unknown IP version") }
let packet = NEPacket(data: Data(bytes: buf, count: len), protocolFamily: protocolFamily)
var isTunnelClosed = false
let isWritten = wgContext.writePacket(packet: packet, isTunnelClosed: &isTunnelClosed)
if isTunnelClosed { return -1 }
if isWritten {
return len
}
return 0
}
},
wgCtxPtr)
}
return wgTurnOn(nameGoStr, settingsGoStr, fd)
}
}
}
class WireGuardContext {
private var packetFlow: NEPacketTunnelFlow
private var outboundPackets: [NEPacket] = []
private var isTunnelClosed: Bool = false
private var readPacketCondition = NSCondition()
init(packetFlow: NEPacketTunnelFlow) {
self.packetFlow = packetFlow
}
func closeTunnel() {
isTunnelClosed = true
readPacketCondition.signal()
}
func packetsRead(packets: [NEPacket]) {
readPacketCondition.lock()
outboundPackets.append(contentsOf: packets)
readPacketCondition.unlock()
readPacketCondition.signal()
}
func readPacket(isTunnelClosed: inout Bool) -> NEPacket? {
if outboundPackets.isEmpty {
readPacketCondition.lock()
packetFlow.readPacketObjects(completionHandler: packetsRead)
while outboundPackets.isEmpty && !self.isTunnelClosed {
readPacketCondition.wait()
}
readPacketCondition.unlock()
}
isTunnelClosed = self.isTunnelClosed
if !outboundPackets.isEmpty {
return outboundPackets.removeFirst()
}
return nil
}
func writePacket(packet: NEPacket, isTunnelClosed: inout Bool) -> Bool {
isTunnelClosed = self.isTunnelClosed
return packetFlow.writePacketObjects([packet])
}
}
private func withStringsAsGoStrings<R>(_ str1: String, _ str2: String, closure: (gostring_t, gostring_t) -> R) -> R {
return str1.withCString { (s1cStr) -> R in
let gstr1 = gostring_t(p: s1cStr, n: str1.utf8.count)

@ -1 +1 @@
Subproject commit 738d027f0bfc59e14384e36c44753d7b61fb1c43
Subproject commit 276bf973e8a086da7767dc25ebe116926c0b59db

View File

@ -15,6 +15,7 @@ import "C"
import (
"bufio"
"errors"
"git.zx2c4.com/wireguard-go/tun"
"golang.org/x/sys/unix"
"io/ioutil"
@ -25,7 +26,6 @@ import (
"runtime"
"strings"
"unsafe"
"errors"
)
var loggerFunc unsafe.Pointer
@ -75,7 +75,7 @@ func wgSetLogger(loggerFn uintptr) {
}
//export wgTurnOn
func wgTurnOn(ifnameRef string, settings string, mtu uint16, readFn uintptr, writeFn uintptr, ctx uintptr) int32 {
func wgTurnOn(ifnameRef string, settings string, tunFd int32) int32 {
interfaceName := string([]byte(ifnameRef))
logger := &Logger{
@ -86,12 +86,14 @@ func wgTurnOn(ifnameRef string, settings string, mtu uint16, readFn uintptr, wri
logger.Debug.Println("Debug log enabled")
tun := tun.CreateTUN(mtu, unsafe.Pointer(readFn), unsafe.Pointer(writeFn), unsafe.Pointer(ctx))
tun, _, err := tun.CreateTUNFromFD(int(tunFd))
if err != nil {
logger.Error.Println(err)
return -1
}
logger.Info.Println("Attaching to interface")
device := NewDevice(tun, logger)
logger.Debug.Println("Interface has MTU", device.tun.mtu)
bufferedSettings := bufio.NewReadWriter(bufio.NewReader(strings.NewReader(settings)), bufio.NewWriter(ioutil.Discard))
setError := ipcSetOperation(device, bufferedSettings)
if setError != nil {

View File

@ -8,9 +8,9 @@ package main
/* Fit within memory limits for iOS */
const (
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
MaxSegmentSize = 1700
PreallocatedBuffersPerPool = 1024
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
MaxSegmentSize = 1700
PreallocatedBuffersPerPool = 1024
)

View File

@ -0,0 +1,52 @@
/* SPDX-License-Identifier: GPL-2.0
*
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
*/
package tun
import (
"git.zx2c4.com/wireguard-go/rwcancel"
"golang.org/x/sys/unix"
"net"
"os"
)
func CreateTUNFromFD(tunFd int) (TUNDevice, string, error) {
file := os.NewFile(uintptr(tunFd), "/dev/tun")
tun := &nativeTun{
tunFile: file,
fd: file.Fd(),
events: make(chan TUNEvent, 5),
errors: make(chan error, 5),
}
var err error
tun.rwcancel, err = rwcancel.NewRWCancel(tunFd)
if err != nil {
return nil, "", err
}
name, err := tun.Name()
if err != nil {
tun.rwcancel.Cancel()
return nil, "", err
}
tunIfindex, err := func() (int, error) {
iface, err := net.InterfaceByName(name)
if err != nil {
return -1, err
}
return iface.Index, nil
}()
if err != nil {
tun.tunFile.Close()
return nil, "", err
}
tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
tun.tunFile.Close()
return nil, "", err
}
go tun.routineRouteListener(tunIfindex)
return tun, name, nil
}

View File

@ -1,90 +0,0 @@
/* SPDX-License-Identifier: GPL-2.0
*
* Copyright (C) 2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
*/
package tun
// #include <sys/types.h>
// static ssize_t callFnWithCtx(const void *func, const void *ctx, const void *buffer, size_t len)
// {
// return ((ssize_t(*)(const void *, const unsigned char *, size_t))func)(ctx, buffer, len);
// }
import "C"
import (
"os"
"syscall"
"unsafe"
)
type nativeTun struct {
events chan TUNEvent
mtu int
readFn unsafe.Pointer
writeFn unsafe.Pointer
ctx unsafe.Pointer
}
func CreateTUN(mtu uint16, readFn unsafe.Pointer, writeFn unsafe.Pointer, ctx unsafe.Pointer) TUNDevice {
if mtu == 0 {
/* 0 means automatic MTU, which iOS makes outerMTU-80-15. The 80 is for
* WireGuard and the 15 ensures our padding will work. Therefore, it's
* safe to have this code assume a massive MTU.
*/
mtu = ^mtu
}
tun := &nativeTun{
events: make(chan TUNEvent, 10),
mtu: int(mtu),
readFn: readFn,
writeFn: writeFn,
ctx: ctx,
}
tun.events <- TUNEventUp
return tun
}
func (tun *nativeTun) Name() (string, error) {
return "tun", nil
}
func (tun *nativeTun) File() *os.File {
return nil
}
func (tun *nativeTun) Events() chan TUNEvent {
return tun.events
}
func (tun *nativeTun) Read(buff []byte, offset int) (int, error) {
ret := C.callFnWithCtx(tun.readFn, tun.ctx, unsafe.Pointer(&buff[offset]), C.size_t(len(buff) - offset))
if ret < 0 {
return 0, syscall.Errno(-ret)
}
return int(ret), nil
}
func (tun *nativeTun) Write(buff []byte, offset int) (int, error) {
ret := C.callFnWithCtx(tun.writeFn, tun.ctx, unsafe.Pointer(&buff[offset]), C.size_t(len(buff) - offset))
if ret < 0 {
return 0, syscall.Errno(-ret)
}
return int(ret), nil
}
func (tun *nativeTun) Close() error {
if tun.events != nil {
close(tun.events)
}
return nil
}
func (tun *nativeTun) setMTU(n int) error {
tun.mtu = n
return nil
}
func (tun *nativeTun) MTU() (int, error) {
return tun.mtu, nil
}

View File

@ -10,10 +10,9 @@
#include <stdint.h>
typedef struct { const char *p; size_t n; } gostring_t;
typedef ssize_t(*read_write_fn_t)(void *ctx, unsigned char *buf, size_t len);
typedef void(*logger_fn_t)(int level, const char *msg);
extern void wgSetLogger(logger_fn_t logger_fn);
extern int wgTurnOn(gostring_t ifname, gostring_t settings, uint16_t mtu, read_write_fn_t read_fn, read_write_fn_t write_fn, void *ctx);
extern int wgTurnOn(gostring_t ifname, gostring_t settings, int32_t tun_fd);
extern void wgTurnOff(int handle);
extern char *wgVersion();