wireguard-go-bridge: account for network changes

Everytime the network changes, we need to recreate the UDP socket,
because the ephemeral listen port is tied to the old physical interface.
As well, we need to re-set the IP addresses for each endpoint, so that
they're passed to getaddrinfo and are then resolved using DNS46.
This commit is contained in:
Jason A. Donenfeld 2018-12-07 21:47:19 +01:00
parent 2e5d467bc7
commit 99f0e457c3
1 changed files with 85 additions and 5 deletions

View File

@ -25,6 +25,8 @@ import (
"os/signal" "os/signal"
"runtime" "runtime"
"strings" "strings"
"syscall"
"time"
"unsafe" "unsafe"
) )
@ -46,12 +48,54 @@ func (l *CLogger) Write(p []byte) (int, error) {
return len(p), nil return len(p), nil
} }
var tunnelHandles map[int32]*Device type DeviceState struct {
device *Device
logger *Logger
endpointsTimer *time.Timer
endpointsSettings string
}
var tunnelHandles map[int32]*DeviceState
func listenForRouteChanges() {
//TODO: replace with NWPathMonitor
data := make([]byte, os.Getpagesize())
routeSocket, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return
}
for {
n, err := unix.Read(routeSocket, data)
if err != nil {
if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR {
continue
}
return
}
if n < 4 {
continue
}
for _, deviceState := range tunnelHandles {
if deviceState.endpointsTimer == nil {
deviceState.endpointsTimer = time.AfterFunc(time.Second, func() {
deviceState.endpointsTimer = nil
bufferedSettings := bufio.NewReadWriter(bufio.NewReader(strings.NewReader(deviceState.endpointsSettings)), bufio.NewWriter(ioutil.Discard))
deviceState.logger.Info.Println("Setting endpoints for re-resolution due to network change")
err := ipcSetOperation(deviceState.device, bufferedSettings)
if err != nil {
deviceState.logger.Error.Println(err)
}
})
}
}
}
}
func init() { func init() {
versionString = C.CString(WireGuardGoVersion) versionString = C.CString(WireGuardGoVersion)
roamingDisabled = true roamingDisabled = true
tunnelHandles = make(map[int32]*Device) tunnelHandles = make(map[int32]*DeviceState)
signals := make(chan os.Signal) signals := make(chan os.Signal)
signal.Notify(signals, unix.SIGUSR2) signal.Notify(signals, unix.SIGUSR2)
go func() { go func() {
@ -67,6 +111,7 @@ func init() {
} }
} }
}() }()
go listenForRouteChanges()
} }
//export wgSetLogger //export wgSetLogger
@ -74,6 +119,32 @@ func wgSetLogger(loggerFn uintptr) {
loggerFunc = unsafe.Pointer(loggerFn) loggerFunc = unsafe.Pointer(loggerFn)
} }
func extractEndpointFromSettings(settings string) string {
var b strings.Builder
pubkey := ""
endpoint := ""
listenPort := "listen_port=0"
for _, line := range strings.Split(settings, "\n") {
if strings.HasPrefix(line, "listen_port=") {
listenPort = line
} else if strings.HasPrefix(line, "public_key=") {
if pubkey != "" && endpoint != "" {
b.WriteString(pubkey + "\n" + endpoint + "\n")
}
pubkey = line
} else if strings.HasPrefix(line, "endpoint=") {
endpoint = line
} else if line == "remove=true" {
pubkey = ""
endpoint = ""
}
}
if pubkey != "" && endpoint != "" {
b.WriteString(pubkey + "\n" + endpoint + "\n")
}
return listenPort + "\n" + b.String()
}
//export wgTurnOn //export wgTurnOn
func wgTurnOn(ifnameRef string, settings string, tunFd int32) int32 { func wgTurnOn(ifnameRef string, settings string, tunFd int32) int32 {
interfaceName := string([]byte(ifnameRef)) interfaceName := string([]byte(ifnameRef))
@ -113,18 +184,27 @@ func wgTurnOn(ifnameRef string, settings string, tunFd int32) int32 {
if i == math.MaxInt32 { if i == math.MaxInt32 {
return -1 return -1
} }
tunnelHandles[i] = device tunnelHandles[i] = &DeviceState{
device: device,
logger: logger,
endpointsSettings: extractEndpointFromSettings(settings),
}
return i return i
} }
//export wgTurnOff //export wgTurnOff
func wgTurnOff(tunnelHandle int32) { func wgTurnOff(tunnelHandle int32) {
device, ok := tunnelHandles[tunnelHandle] deviceState, ok := tunnelHandles[tunnelHandle]
if !ok { if !ok {
return return
} }
delete(tunnelHandles, tunnelHandle) delete(tunnelHandles, tunnelHandle)
device.Close() t := deviceState.endpointsTimer
if t != nil {
deviceState.endpointsTimer = nil
t.Stop()
}
deviceState.device.Close()
} }
//export wgVersion //export wgVersion