wireguard-go-bridge: account for network changes

Everytime the network changes, we need to recreate the UDP socket,
because the ephemeral listen port is tied to the old physical interface.
As well, we need to re-set the IP addresses for each endpoint, so that
they're passed to getaddrinfo and are then resolved using DNS46.
This commit is contained in:
Jason A. Donenfeld 2018-12-07 21:47:19 +01:00
parent 2e5d467bc7
commit 99f0e457c3

View File

@ -25,6 +25,8 @@ import (
"os/signal"
"runtime"
"strings"
"syscall"
"time"
"unsafe"
)
@ -46,12 +48,54 @@ func (l *CLogger) Write(p []byte) (int, error) {
return len(p), nil
}
var tunnelHandles map[int32]*Device
type DeviceState struct {
device *Device
logger *Logger
endpointsTimer *time.Timer
endpointsSettings string
}
var tunnelHandles map[int32]*DeviceState
func listenForRouteChanges() {
//TODO: replace with NWPathMonitor
data := make([]byte, os.Getpagesize())
routeSocket, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
return
}
for {
n, err := unix.Read(routeSocket, data)
if err != nil {
if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR {
continue
}
return
}
if n < 4 {
continue
}
for _, deviceState := range tunnelHandles {
if deviceState.endpointsTimer == nil {
deviceState.endpointsTimer = time.AfterFunc(time.Second, func() {
deviceState.endpointsTimer = nil
bufferedSettings := bufio.NewReadWriter(bufio.NewReader(strings.NewReader(deviceState.endpointsSettings)), bufio.NewWriter(ioutil.Discard))
deviceState.logger.Info.Println("Setting endpoints for re-resolution due to network change")
err := ipcSetOperation(deviceState.device, bufferedSettings)
if err != nil {
deviceState.logger.Error.Println(err)
}
})
}
}
}
}
func init() {
versionString = C.CString(WireGuardGoVersion)
roamingDisabled = true
tunnelHandles = make(map[int32]*Device)
tunnelHandles = make(map[int32]*DeviceState)
signals := make(chan os.Signal)
signal.Notify(signals, unix.SIGUSR2)
go func() {
@ -67,6 +111,7 @@ func init() {
}
}
}()
go listenForRouteChanges()
}
//export wgSetLogger
@ -74,6 +119,32 @@ func wgSetLogger(loggerFn uintptr) {
loggerFunc = unsafe.Pointer(loggerFn)
}
func extractEndpointFromSettings(settings string) string {
var b strings.Builder
pubkey := ""
endpoint := ""
listenPort := "listen_port=0"
for _, line := range strings.Split(settings, "\n") {
if strings.HasPrefix(line, "listen_port=") {
listenPort = line
} else if strings.HasPrefix(line, "public_key=") {
if pubkey != "" && endpoint != "" {
b.WriteString(pubkey + "\n" + endpoint + "\n")
}
pubkey = line
} else if strings.HasPrefix(line, "endpoint=") {
endpoint = line
} else if line == "remove=true" {
pubkey = ""
endpoint = ""
}
}
if pubkey != "" && endpoint != "" {
b.WriteString(pubkey + "\n" + endpoint + "\n")
}
return listenPort + "\n" + b.String()
}
//export wgTurnOn
func wgTurnOn(ifnameRef string, settings string, tunFd int32) int32 {
interfaceName := string([]byte(ifnameRef))
@ -113,18 +184,27 @@ func wgTurnOn(ifnameRef string, settings string, tunFd int32) int32 {
if i == math.MaxInt32 {
return -1
}
tunnelHandles[i] = device
tunnelHandles[i] = &DeviceState{
device: device,
logger: logger,
endpointsSettings: extractEndpointFromSettings(settings),
}
return i
}
//export wgTurnOff
func wgTurnOff(tunnelHandle int32) {
device, ok := tunnelHandles[tunnelHandle]
deviceState, ok := tunnelHandles[tunnelHandle]
if !ok {
return
}
delete(tunnelHandles, tunnelHandle)
device.Close()
t := deviceState.endpointsTimer
if t != nil {
deviceState.endpointsTimer = nil
t.Stop()
}
deviceState.device.Close()
}
//export wgVersion