diff --git a/wireguard-go-bridge/src/api-ios.go b/wireguard-go-bridge/src/api-ios.go index 67ce785..0fdb3be 100644 --- a/wireguard-go-bridge/src/api-ios.go +++ b/wireguard-go-bridge/src/api-ios.go @@ -25,6 +25,8 @@ import ( "os/signal" "runtime" "strings" + "syscall" + "time" "unsafe" ) @@ -46,12 +48,54 @@ func (l *CLogger) Write(p []byte) (int, error) { return len(p), nil } -var tunnelHandles map[int32]*Device +type DeviceState struct { + device *Device + logger *Logger + endpointsTimer *time.Timer + endpointsSettings string +} + +var tunnelHandles map[int32]*DeviceState + +func listenForRouteChanges() { + //TODO: replace with NWPathMonitor + data := make([]byte, os.Getpagesize()) + routeSocket, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) + if err != nil { + return + } + for { + n, err := unix.Read(routeSocket, data) + if err != nil { + if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR { + continue + } + return + } + + if n < 4 { + continue + } + for _, deviceState := range tunnelHandles { + if deviceState.endpointsTimer == nil { + deviceState.endpointsTimer = time.AfterFunc(time.Second, func() { + deviceState.endpointsTimer = nil + bufferedSettings := bufio.NewReadWriter(bufio.NewReader(strings.NewReader(deviceState.endpointsSettings)), bufio.NewWriter(ioutil.Discard)) + deviceState.logger.Info.Println("Setting endpoints for re-resolution due to network change") + err := ipcSetOperation(deviceState.device, bufferedSettings) + if err != nil { + deviceState.logger.Error.Println(err) + } + }) + } + } + } +} func init() { versionString = C.CString(WireGuardGoVersion) roamingDisabled = true - tunnelHandles = make(map[int32]*Device) + tunnelHandles = make(map[int32]*DeviceState) signals := make(chan os.Signal) signal.Notify(signals, unix.SIGUSR2) go func() { @@ -67,6 +111,7 @@ func init() { } } }() + go listenForRouteChanges() } //export wgSetLogger @@ -74,6 +119,32 @@ func wgSetLogger(loggerFn uintptr) { loggerFunc = unsafe.Pointer(loggerFn) } +func extractEndpointFromSettings(settings string) string { + var b strings.Builder + pubkey := "" + endpoint := "" + listenPort := "listen_port=0" + for _, line := range strings.Split(settings, "\n") { + if strings.HasPrefix(line, "listen_port=") { + listenPort = line + } else if strings.HasPrefix(line, "public_key=") { + if pubkey != "" && endpoint != "" { + b.WriteString(pubkey + "\n" + endpoint + "\n") + } + pubkey = line + } else if strings.HasPrefix(line, "endpoint=") { + endpoint = line + } else if line == "remove=true" { + pubkey = "" + endpoint = "" + } + } + if pubkey != "" && endpoint != "" { + b.WriteString(pubkey + "\n" + endpoint + "\n") + } + return listenPort + "\n" + b.String() +} + //export wgTurnOn func wgTurnOn(ifnameRef string, settings string, tunFd int32) int32 { interfaceName := string([]byte(ifnameRef)) @@ -113,18 +184,27 @@ func wgTurnOn(ifnameRef string, settings string, tunFd int32) int32 { if i == math.MaxInt32 { return -1 } - tunnelHandles[i] = device + tunnelHandles[i] = &DeviceState{ + device: device, + logger: logger, + endpointsSettings: extractEndpointFromSettings(settings), + } return i } //export wgTurnOff func wgTurnOff(tunnelHandle int32) { - device, ok := tunnelHandles[tunnelHandle] + deviceState, ok := tunnelHandles[tunnelHandle] if !ok { return } delete(tunnelHandles, tunnelHandle) - device.Close() + t := deviceState.endpointsTimer + if t != nil { + deviceState.endpointsTimer = nil + t.Stop() + } + deviceState.device.Close() } //export wgVersion