wireguard-go-bridge: take fd instead of fnptr
This commit is contained in:
parent
2a7aa578d2
commit
040f0a25ea
|
@ -17,7 +17,6 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
|
|||
// MARK: Properties
|
||||
|
||||
private var wgHandle: Int32?
|
||||
private var wgContext: WireGuardContext?
|
||||
|
||||
// MARK: NEPacketTunnelProvider
|
||||
|
||||
|
@ -64,9 +63,14 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
|
|||
}
|
||||
|
||||
configureLogger()
|
||||
wgContext = WireGuardContext(packetFlow: self.packetFlow)
|
||||
|
||||
let handle = connect(interfaceName: interfaceName, settings: wireguardSettings, mtu: mtu.uint16Value)
|
||||
let fd = packetFlow.value(forKeyPath: "socket.fileDescriptor") as! Int32
|
||||
if fd < 0 {
|
||||
os_log("Starting tunnel failed: Could not determine file descriptor", log: OSLog.default, type: .error)
|
||||
startTunnelCompletionHandler(PacketTunnelProviderError.couldNotStartWireGuard)
|
||||
return
|
||||
}
|
||||
let handle = connect(interfaceName: interfaceName, settings: wireguardSettings, fd: fd)
|
||||
|
||||
if handle < 0 {
|
||||
os_log("Starting tunnel failed: Could not start WireGuard", log: OSLog.default, type: .error)
|
||||
|
@ -114,9 +118,8 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
|
|||
|
||||
// MTU
|
||||
if (mtu == 0) {
|
||||
// 0 imples automatic MTU, where we set overhead as 95 bytes,
|
||||
// 80 for WireGuard and the 15 to make sure WireGuard's padding will work.
|
||||
networkSettings.tunnelOverheadBytes = 95
|
||||
// 0 imples automatic MTU, where we set overhead as 80 bytes, which is the worst case for WireGuard
|
||||
networkSettings.tunnelOverheadBytes = 80
|
||||
} else {
|
||||
networkSettings.mtu = mtu
|
||||
}
|
||||
|
@ -134,7 +137,6 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
|
|||
/// Begin the process of stopping the tunnel.
|
||||
override func stopTunnel(with reason: NEProviderStopReason, completionHandler: @escaping () -> Void) {
|
||||
os_log("Stopping tunnel", log: OSLog.default, type: .info)
|
||||
wgContext?.closeTunnel()
|
||||
if let handle = wgHandle {
|
||||
wgTurnOff(handle)
|
||||
}
|
||||
|
@ -159,99 +161,13 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
|
|||
}
|
||||
}
|
||||
|
||||
private func connect(interfaceName: String, settings: String, mtu: UInt16) -> Int32 { // swiftlint:disable:this cyclomatic_complexity
|
||||
private func connect(interfaceName: String, settings: String, fd: Int32) -> Int32 { // swiftlint:disable:this cyclomatic_complexity
|
||||
return withStringsAsGoStrings(interfaceName, settings) { (nameGoStr, settingsGoStr) -> Int32 in
|
||||
return withUnsafeMutablePointer(to: &wgContext) { (wgCtxPtr) -> Int32 in
|
||||
return wgTurnOn(nameGoStr, settingsGoStr, mtu, { (wgCtxPtr, buf, len) -> Int in
|
||||
autoreleasepool {
|
||||
// read_fn: Read from the TUN interface and pass it on to WireGuard
|
||||
guard let wgCtxPtr = wgCtxPtr else { return 0 }
|
||||
guard let buf = buf else { return 0 }
|
||||
let wgContext = wgCtxPtr.bindMemory(to: WireGuardContext.self, capacity: 1).pointee
|
||||
var isTunnelClosed = false
|
||||
let packet = wgContext.readPacket(isTunnelClosed: &isTunnelClosed)
|
||||
if isTunnelClosed { return -1 }
|
||||
guard let packetData = packet?.data else { return 0 }
|
||||
if packetData.count <= len {
|
||||
packetData.copyBytes(to: buf, count: packetData.count)
|
||||
return packetData.count
|
||||
}
|
||||
return 0
|
||||
}
|
||||
}, { (wgCtxPtr, buf, len) -> Int in
|
||||
autoreleasepool {
|
||||
// write_fn: Receive packets from WireGuard and write to the TUN interface
|
||||
guard let wgCtxPtr = wgCtxPtr else { return 0 }
|
||||
guard let buf = buf else { return 0 }
|
||||
guard len > 0 else { return 0 }
|
||||
let wgContext = wgCtxPtr.bindMemory(to: WireGuardContext.self, capacity: 1).pointee
|
||||
let ipVersionBits = (buf[0] & 0xf0) >> 4
|
||||
let ipVersion: sa_family_t? = {
|
||||
if ipVersionBits == 4 { return sa_family_t(AF_INET) } // IPv4
|
||||
if ipVersionBits == 6 { return sa_family_t(AF_INET6) } // IPv6
|
||||
return nil
|
||||
}()
|
||||
guard let protocolFamily = ipVersion else { fatalError("Unknown IP version") }
|
||||
let packet = NEPacket(data: Data(bytes: buf, count: len), protocolFamily: protocolFamily)
|
||||
var isTunnelClosed = false
|
||||
let isWritten = wgContext.writePacket(packet: packet, isTunnelClosed: &isTunnelClosed)
|
||||
if isTunnelClosed { return -1 }
|
||||
if isWritten {
|
||||
return len
|
||||
}
|
||||
return 0
|
||||
}
|
||||
},
|
||||
wgCtxPtr)
|
||||
}
|
||||
return wgTurnOn(nameGoStr, settingsGoStr, fd)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class WireGuardContext {
|
||||
private var packetFlow: NEPacketTunnelFlow
|
||||
private var outboundPackets: [NEPacket] = []
|
||||
private var isTunnelClosed: Bool = false
|
||||
private var readPacketCondition = NSCondition()
|
||||
|
||||
init(packetFlow: NEPacketTunnelFlow) {
|
||||
self.packetFlow = packetFlow
|
||||
}
|
||||
|
||||
func closeTunnel() {
|
||||
isTunnelClosed = true
|
||||
readPacketCondition.signal()
|
||||
}
|
||||
|
||||
func packetsRead(packets: [NEPacket]) {
|
||||
readPacketCondition.lock()
|
||||
outboundPackets.append(contentsOf: packets)
|
||||
readPacketCondition.unlock()
|
||||
readPacketCondition.signal()
|
||||
}
|
||||
|
||||
func readPacket(isTunnelClosed: inout Bool) -> NEPacket? {
|
||||
if outboundPackets.isEmpty {
|
||||
readPacketCondition.lock()
|
||||
packetFlow.readPacketObjects(completionHandler: packetsRead)
|
||||
while outboundPackets.isEmpty && !self.isTunnelClosed {
|
||||
readPacketCondition.wait()
|
||||
}
|
||||
readPacketCondition.unlock()
|
||||
}
|
||||
isTunnelClosed = self.isTunnelClosed
|
||||
if !outboundPackets.isEmpty {
|
||||
return outboundPackets.removeFirst()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func writePacket(packet: NEPacket, isTunnelClosed: inout Bool) -> Bool {
|
||||
isTunnelClosed = self.isTunnelClosed
|
||||
return packetFlow.writePacketObjects([packet])
|
||||
}
|
||||
}
|
||||
|
||||
private func withStringsAsGoStrings<R>(_ str1: String, _ str2: String, closure: (gostring_t, gostring_t) -> R) -> R {
|
||||
return str1.withCString { (s1cStr) -> R in
|
||||
let gstr1 = gostring_t(p: s1cStr, n: str1.utf8.count)
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit 738d027f0bfc59e14384e36c44753d7b61fb1c43
|
||||
Subproject commit 276bf973e8a086da7767dc25ebe116926c0b59db
|
|
@ -15,6 +15,7 @@ import "C"
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"git.zx2c4.com/wireguard-go/tun"
|
||||
"golang.org/x/sys/unix"
|
||||
"io/ioutil"
|
||||
|
@ -25,7 +26,6 @@ import (
|
|||
"runtime"
|
||||
"strings"
|
||||
"unsafe"
|
||||
"errors"
|
||||
)
|
||||
|
||||
var loggerFunc unsafe.Pointer
|
||||
|
@ -75,7 +75,7 @@ func wgSetLogger(loggerFn uintptr) {
|
|||
}
|
||||
|
||||
//export wgTurnOn
|
||||
func wgTurnOn(ifnameRef string, settings string, mtu uint16, readFn uintptr, writeFn uintptr, ctx uintptr) int32 {
|
||||
func wgTurnOn(ifnameRef string, settings string, tunFd int32) int32 {
|
||||
interfaceName := string([]byte(ifnameRef))
|
||||
|
||||
logger := &Logger{
|
||||
|
@ -86,12 +86,14 @@ func wgTurnOn(ifnameRef string, settings string, mtu uint16, readFn uintptr, wri
|
|||
|
||||
logger.Debug.Println("Debug log enabled")
|
||||
|
||||
tun := tun.CreateTUN(mtu, unsafe.Pointer(readFn), unsafe.Pointer(writeFn), unsafe.Pointer(ctx))
|
||||
tun, _, err := tun.CreateTUNFromFD(int(tunFd))
|
||||
if err != nil {
|
||||
logger.Error.Println(err)
|
||||
return -1
|
||||
}
|
||||
logger.Info.Println("Attaching to interface")
|
||||
device := NewDevice(tun, logger)
|
||||
|
||||
logger.Debug.Println("Interface has MTU", device.tun.mtu)
|
||||
|
||||
bufferedSettings := bufio.NewReadWriter(bufio.NewReader(strings.NewReader(settings)), bufio.NewWriter(ioutil.Discard))
|
||||
setError := ipcSetOperation(device, bufferedSettings)
|
||||
if setError != nil {
|
||||
|
|
|
@ -8,9 +8,9 @@ package main
|
|||
/* Fit within memory limits for iOS */
|
||||
|
||||
const (
|
||||
QueueOutboundSize = 1024
|
||||
QueueInboundSize = 1024
|
||||
QueueHandshakeSize = 1024
|
||||
MaxSegmentSize = 1700
|
||||
PreallocatedBuffersPerPool = 1024
|
||||
QueueOutboundSize = 1024
|
||||
QueueInboundSize = 1024
|
||||
QueueHandshakeSize = 1024
|
||||
MaxSegmentSize = 1700
|
||||
PreallocatedBuffersPerPool = 1024
|
||||
)
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
/* SPDX-License-Identifier: GPL-2.0
|
||||
*
|
||||
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"git.zx2c4.com/wireguard-go/rwcancel"
|
||||
"golang.org/x/sys/unix"
|
||||
"net"
|
||||
"os"
|
||||
)
|
||||
|
||||
func CreateTUNFromFD(tunFd int) (TUNDevice, string, error) {
|
||||
file := os.NewFile(uintptr(tunFd), "/dev/tun")
|
||||
tun := &nativeTun{
|
||||
tunFile: file,
|
||||
fd: file.Fd(),
|
||||
events: make(chan TUNEvent, 5),
|
||||
errors: make(chan error, 5),
|
||||
}
|
||||
var err error
|
||||
tun.rwcancel, err = rwcancel.NewRWCancel(tunFd)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
name, err := tun.Name()
|
||||
if err != nil {
|
||||
tun.rwcancel.Cancel()
|
||||
return nil, "", err
|
||||
}
|
||||
tunIfindex, err := func() (int, error) {
|
||||
iface, err := net.InterfaceByName(name)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
return iface.Index, nil
|
||||
}()
|
||||
if err != nil {
|
||||
tun.tunFile.Close()
|
||||
return nil, "", err
|
||||
}
|
||||
tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
|
||||
if err != nil {
|
||||
tun.tunFile.Close()
|
||||
return nil, "", err
|
||||
}
|
||||
go tun.routineRouteListener(tunIfindex)
|
||||
|
||||
return tun, name, nil
|
||||
}
|
|
@ -1,90 +0,0 @@
|
|||
/* SPDX-License-Identifier: GPL-2.0
|
||||
*
|
||||
* Copyright (C) 2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tun
|
||||
|
||||
// #include <sys/types.h>
|
||||
// static ssize_t callFnWithCtx(const void *func, const void *ctx, const void *buffer, size_t len)
|
||||
// {
|
||||
// return ((ssize_t(*)(const void *, const unsigned char *, size_t))func)(ctx, buffer, len);
|
||||
// }
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type nativeTun struct {
|
||||
events chan TUNEvent
|
||||
mtu int
|
||||
readFn unsafe.Pointer
|
||||
writeFn unsafe.Pointer
|
||||
ctx unsafe.Pointer
|
||||
}
|
||||
|
||||
func CreateTUN(mtu uint16, readFn unsafe.Pointer, writeFn unsafe.Pointer, ctx unsafe.Pointer) TUNDevice {
|
||||
if mtu == 0 {
|
||||
/* 0 means automatic MTU, which iOS makes outerMTU-80-15. The 80 is for
|
||||
* WireGuard and the 15 ensures our padding will work. Therefore, it's
|
||||
* safe to have this code assume a massive MTU.
|
||||
*/
|
||||
mtu = ^mtu
|
||||
}
|
||||
tun := &nativeTun{
|
||||
events: make(chan TUNEvent, 10),
|
||||
mtu: int(mtu),
|
||||
readFn: readFn,
|
||||
writeFn: writeFn,
|
||||
ctx: ctx,
|
||||
}
|
||||
tun.events <- TUNEventUp
|
||||
return tun
|
||||
}
|
||||
|
||||
func (tun *nativeTun) Name() (string, error) {
|
||||
return "tun", nil
|
||||
}
|
||||
|
||||
func (tun *nativeTun) File() *os.File {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tun *nativeTun) Events() chan TUNEvent {
|
||||
return tun.events
|
||||
}
|
||||
|
||||
func (tun *nativeTun) Read(buff []byte, offset int) (int, error) {
|
||||
ret := C.callFnWithCtx(tun.readFn, tun.ctx, unsafe.Pointer(&buff[offset]), C.size_t(len(buff) - offset))
|
||||
if ret < 0 {
|
||||
return 0, syscall.Errno(-ret)
|
||||
}
|
||||
return int(ret), nil
|
||||
}
|
||||
|
||||
func (tun *nativeTun) Write(buff []byte, offset int) (int, error) {
|
||||
ret := C.callFnWithCtx(tun.writeFn, tun.ctx, unsafe.Pointer(&buff[offset]), C.size_t(len(buff) - offset))
|
||||
if ret < 0 {
|
||||
return 0, syscall.Errno(-ret)
|
||||
}
|
||||
return int(ret), nil
|
||||
}
|
||||
|
||||
func (tun *nativeTun) Close() error {
|
||||
if tun.events != nil {
|
||||
close(tun.events)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tun *nativeTun) setMTU(n int) error {
|
||||
tun.mtu = n
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tun *nativeTun) MTU() (int, error) {
|
||||
return tun.mtu, nil
|
||||
}
|
|
@ -10,10 +10,9 @@
|
|||
#include <stdint.h>
|
||||
|
||||
typedef struct { const char *p; size_t n; } gostring_t;
|
||||
typedef ssize_t(*read_write_fn_t)(void *ctx, unsigned char *buf, size_t len);
|
||||
typedef void(*logger_fn_t)(int level, const char *msg);
|
||||
extern void wgSetLogger(logger_fn_t logger_fn);
|
||||
extern int wgTurnOn(gostring_t ifname, gostring_t settings, uint16_t mtu, read_write_fn_t read_fn, read_write_fn_t write_fn, void *ctx);
|
||||
extern int wgTurnOn(gostring_t ifname, gostring_t settings, int32_t tun_fd);
|
||||
extern void wgTurnOff(int handle);
|
||||
extern char *wgVersion();
|
||||
|
||||
|
|
Loading…
Reference in New Issue