From aea3c4d93e2e6b2008fa76b5b3283ed7c7ad729b Mon Sep 17 00:00:00 2001 From: Fabio Alessandrelli Date: Mon, 13 Jan 2020 13:10:12 +0100 Subject: [PATCH] Custom godot sockets for ENet now support DTLS. Non-DTLS implementation uses plain NetSocket for performance as before. --- thirdparty/enet/enet/enet.h | 2 + thirdparty/enet/godot.cpp | 375 ++++++++++++++++++++++++++++++------ 2 files changed, 321 insertions(+), 56 deletions(-) diff --git a/thirdparty/enet/enet/enet.h b/thirdparty/enet/enet/enet.h index 966e3a465dc..ac7552adb26 100644 --- a/thirdparty/enet/enet/enet.h +++ b/thirdparty/enet/enet/enet.h @@ -578,6 +578,8 @@ ENET_API void enet_host_channel_limit (ENetHost *, size_t); ENET_API void enet_host_bandwidth_limit (ENetHost *, enet_uint32, enet_uint32); extern void enet_host_bandwidth_throttle (ENetHost *); extern enet_uint32 enet_host_random_seed (void); +ENET_API void enet_host_dtls_server_setup (ENetHost *, void *, void *); +ENET_API void enet_host_dtls_client_setup (ENetHost *, void *, uint8_t, const char *); ENET_API int enet_peer_send (ENetPeer *, enet_uint8, ENetPacket *); ENET_API ENetPacket * enet_peer_receive (ENetPeer *, enet_uint8 * channelID); diff --git a/thirdparty/enet/godot.cpp b/thirdparty/enet/godot.cpp index 63580b6d1a2..da3a86277b3 100644 --- a/thirdparty/enet/godot.cpp +++ b/thirdparty/enet/godot.cpp @@ -32,13 +32,313 @@ @brief ENet Godot specific functions */ +#include "core/io/dtls_server.h" #include "core/io/ip.h" #include "core/io/net_socket.h" +#include "core/io/packet_peer_dtls.h" +#include "core/io/udp_server.h" #include "core/os/os.h" // This must be last for windows to compile (tested with MinGW) #include "enet/enet.h" +/// Abstract ENet interface for UDP/DTLS. +class ENetGodotSocket { + +public: + virtual Error bind(IP_Address p_ip, uint16_t p_port) = 0; + virtual Error sendto(const uint8_t *p_buffer, int p_len, int &r_sent, IP_Address p_ip, uint16_t p_port) = 0; + virtual Error recvfrom(uint8_t *p_buffer, int p_len, int &r_read, IP_Address &r_ip, uint16_t &r_port) = 0; + virtual int set_option(ENetSocketOption p_option, int p_value) = 0; + virtual void close() = 0; + virtual ~ENetGodotSocket(){}; +}; + +class ENetDTLSClient; +class ENetDTLSServer; + +/// NetSocket interface +class ENetUDP : public ENetGodotSocket { + + friend class ENetDTLSClient; + friend class ENetDTLSServer; + +private: + Ref sock; + IP_Address address; + uint16_t port; + bool bound; + +public: + ENetUDP() { + sock = Ref(NetSocket::create()); + IP::Type ip_type = IP::TYPE_ANY; + bound = false; + sock->open(NetSocket::TYPE_UDP, ip_type); + } + + ~ENetUDP() { + sock->close(); + } + + Error bind(IP_Address p_ip, uint16_t p_port) { + address = p_ip; + port = p_port; + bound = true; + return sock->bind(address, port); + } + + Error sendto(const uint8_t *p_buffer, int p_len, int &r_sent, IP_Address p_ip, uint16_t p_port) { + return sock->sendto(p_buffer, p_len, r_sent, p_ip, p_port); + } + + Error recvfrom(uint8_t *p_buffer, int p_len, int &r_read, IP_Address &r_ip, uint16_t &r_port) { + Error err = sock->poll(NetSocket::POLL_TYPE_IN, 0); + if (err != OK) + return err; + return sock->recvfrom(p_buffer, p_len, r_read, r_ip, r_port); + } + + int set_option(ENetSocketOption p_option, int p_value) { + switch (p_option) { + case ENET_SOCKOPT_NONBLOCK: { + sock->set_blocking_enabled(p_value ? false : true); + return 0; + } break; + + case ENET_SOCKOPT_BROADCAST: { + sock->set_broadcasting_enabled(p_value ? true : false); + return 0; + } break; + + case ENET_SOCKOPT_REUSEADDR: { + sock->set_reuse_address_enabled(p_value ? true : false); + return 0; + } break; + + case ENET_SOCKOPT_RCVBUF: { + return -1; + } break; + + case ENET_SOCKOPT_SNDBUF: { + return -1; + } break; + + case ENET_SOCKOPT_RCVTIMEO: { + return -1; + } break; + + case ENET_SOCKOPT_SNDTIMEO: { + return -1; + } break; + + case ENET_SOCKOPT_NODELAY: { + sock->set_tcp_no_delay_enabled(p_value ? true : false); + return 0; + } break; + } + + return -1; + } + + void close() { + sock->close(); + } +}; + +/// DTLS Client ENet interface +class ENetDTLSClient : public ENetGodotSocket { + + bool connected; + Ref udp; + Ref dtls; + bool verify; + String for_hostname; + Ref cert; + +public: + ENetDTLSClient(ENetUDP *p_base, Ref p_cert, bool p_verify, String p_for_hostname) { + verify = p_verify; + for_hostname = p_for_hostname; + cert = p_cert; + udp.instance(); + dtls = Ref(PacketPeerDTLS::create()); + p_base->close(); + if (p_base->bound) { + bind(p_base->address, p_base->port); + } + connected = false; + } + + ~ENetDTLSClient() { + close(); + } + + Error bind(IP_Address p_ip, uint16_t p_port) { + return udp->listen(p_port, p_ip); + } + + Error sendto(const uint8_t *p_buffer, int p_len, int &r_sent, IP_Address p_ip, uint16_t p_port) { + if (!connected) { + udp->connect_to_host(p_ip, p_port); + dtls->connect_to_peer(udp, verify, for_hostname, cert); + connected = true; + } + dtls->poll(); + if (dtls->get_status() == PacketPeerDTLS::STATUS_HANDSHAKING) + return ERR_BUSY; + else if (dtls->get_status() != PacketPeerDTLS::STATUS_CONNECTED) + return FAILED; + r_sent = p_len; + return dtls->put_packet(p_buffer, p_len); + } + + Error recvfrom(uint8_t *p_buffer, int p_len, int &r_read, IP_Address &r_ip, uint16_t &r_port) { + dtls->poll(); + if (dtls->get_status() == PacketPeerDTLS::STATUS_HANDSHAKING) + return ERR_BUSY; + if (dtls->get_status() != PacketPeerDTLS::STATUS_CONNECTED) + return FAILED; + int pc = dtls->get_available_packet_count(); + if (pc == 0) + return ERR_BUSY; + else if (pc < 0) + return FAILED; + + const uint8_t *buffer; + Error err = dtls->get_packet(&buffer, r_read); + ERR_FAIL_COND_V(err != OK, err); + ERR_FAIL_COND_V(p_len < r_read, ERR_OUT_OF_MEMORY); + + copymem(p_buffer, buffer, r_read); + r_ip = udp->get_packet_address(); + r_port = udp->get_packet_port(); + return err; + } + + int set_option(ENetSocketOption p_option, int p_value) { + return -1; + } + + void close() { + dtls->disconnect_from_peer(); + udp->close(); + } +}; + +/// DTLSServer - ENet interface +class ENetDTLSServer : public ENetGodotSocket { + + Ref server; + Ref udp_server; + Map > peers; + int last_service; + +public: + ENetDTLSServer(ENetUDP *p_base, Ref p_key, Ref p_cert) { + last_service = 0; + udp_server.instance(); + p_base->close(); + if (p_base->bound) { + bind(p_base->address, p_base->port); + } + server = Ref(DTLSServer::create()); + server->setup(p_key, p_cert); + } + + ~ENetDTLSServer() { + close(); + } + + Error bind(IP_Address p_ip, uint16_t p_port) { + return udp_server->listen(p_port, p_ip); + } + + Error sendto(const uint8_t *p_buffer, int p_len, int &r_sent, IP_Address p_ip, uint16_t p_port) { + String key = String(p_ip) + ":" + itos(p_port); + ERR_FAIL_COND_V(!peers.has(key), ERR_UNAVAILABLE); + Ref peer = peers[key]; + Error err = peer->put_packet(p_buffer, p_len); + if (err == OK) + r_sent = p_len; + else if (err == ERR_BUSY) + r_sent = 0; + else + r_sent = -1; + return err; + } + + Error recvfrom(uint8_t *p_buffer, int p_len, int &r_read, IP_Address &r_ip, uint16_t &r_port) { + // TODO limits? Maybe we can better enforce allowed connections! + if (udp_server->is_connection_available()) { + Ref udp = udp_server->take_connection(); + IP_Address peer_ip = udp->get_packet_address(); + int peer_port = udp->get_packet_port(); + Ref peer = server->take_connection(udp); + PacketPeerDTLS::Status status = peer->get_status(); + if (status == PacketPeerDTLS::STATUS_HANDSHAKING || status == PacketPeerDTLS::STATUS_CONNECTED) { + String key = String(peer_ip) + ":" + itos(peer_port); + peers[key] = peer; + } + } + + List remove; + Error err = ERR_BUSY; + // TODO this needs to be fair! + for (Map >::Element *E = peers.front(); E; E = E->next()) { + Ref peer = E->get(); + peer->poll(); + + if (peer->get_status() == PacketPeerDTLS::STATUS_HANDSHAKING) + continue; + else if (peer->get_status() != PacketPeerDTLS::STATUS_CONNECTED) { + // Peer disconnected, removing it. + remove.push_back(E->key()); + continue; + } + + if (peer->get_available_packet_count() > 0) { + const uint8_t *buffer; + err = peer->get_packet(&buffer, r_read); + if (err != OK || p_len < r_read) { + // Something wrong with this peer, removing it. + remove.push_back(E->key()); + err = FAILED; + continue; + } + + Vector s = E->key().rsplit(":", false, 1); + ERR_CONTINUE(s.size() != 2); // BUG! + + copymem(p_buffer, buffer, r_read); + r_ip = s[0]; + r_port = s[1].to_int(); + break; // err = OK + } + } + + // Remove disconnected peers from map. + for (List::Element *E = remove.front(); E; E = E->next()) { + peers.erase(E->get()); + } + + return err; // OK, ERR_BUSY, or possibly an error. + } + + int set_option(ENetSocketOption p_option, int p_value) { + return -1; + } + + void close() { + for (Map >::Element *E = peers.front(); E; E = E->next()) { + E->get()->disconnect_from_peer(); + } + peers.clear(); + udp_server->stop(); + server->stop(); + } +}; + static enet_uint32 timeBase = 0; int enet_initialize(void) { @@ -92,13 +392,23 @@ int enet_address_get_host(const ENetAddress *address, char *name, size_t nameLen ENetSocket enet_socket_create(ENetSocketType type) { - NetSocket *socket = NetSocket::create(); - IP::Type ip_type = IP::TYPE_ANY; - socket->open(NetSocket::TYPE_UDP, ip_type); + ENetUDP *socket = memnew(ENetUDP); return socket; } +void enet_host_dtls_server_setup(ENetHost *host, void *p_key, void *p_cert) { + ENetUDP *sock = (ENetUDP *)host->socket; + host->socket = memnew(ENetDTLSServer(sock, Ref((CryptoKey *)p_key), Ref((X509Certificate *)p_cert))); + memdelete(sock); +} + +void enet_host_dtls_client_setup(ENetHost *host, void *p_cert, uint8_t p_verify, const char *p_for_hostname) { + ENetUDP *sock = (ENetUDP *)host->socket; + host->socket = memnew(ENetDTLSClient(sock, Ref((X509Certificate *)p_cert), p_verify, String(p_for_hostname))); + memdelete(sock); +} + int enet_socket_bind(ENetSocket socket, const ENetAddress *address) { IP_Address ip; @@ -108,7 +418,7 @@ int enet_socket_bind(ENetSocket socket, const ENetAddress *address) { ip.set_ipv6(address->host); } - NetSocket *sock = (NetSocket *)socket; + ENetGodotSocket *sock = (ENetGodotSocket *)socket; if (sock->bind(ip, address->port) != OK) { return -1; } @@ -116,7 +426,7 @@ int enet_socket_bind(ENetSocket socket, const ENetAddress *address) { } void enet_socket_destroy(ENetSocket socket) { - NetSocket *sock = (NetSocket *)socket; + ENetGodotSocket *sock = (ENetGodotSocket *)socket; sock->close(); memdelete(sock); } @@ -125,7 +435,7 @@ int enet_socket_send(ENetSocket socket, const ENetAddress *address, const ENetBu ERR_FAIL_COND_V(address == NULL, -1); - NetSocket *sock = (NetSocket *)socket; + ENetGodotSocket *sock = (ENetGodotSocket *)socket; IP_Address dest; Error err; size_t i = 0; @@ -167,15 +477,7 @@ int enet_socket_receive(ENetSocket socket, ENetAddress *address, ENetBuffer *buf ERR_FAIL_COND_V(bufferCount != 1, -1); - NetSocket *sock = (NetSocket *)socket; - - Error ret = sock->poll(NetSocket::POLL_TYPE_IN, 0); - - if (ret == ERR_BUSY) - return 0; - - if (ret != OK) - return -1; + ENetGodotSocket *sock = (ENetGodotSocket *)socket; int read; IP_Address ip; @@ -215,47 +517,8 @@ int enet_socket_listen(ENetSocket socket, int backlog) { int enet_socket_set_option(ENetSocket socket, ENetSocketOption option, int value) { - NetSocket *sock = (NetSocket *)socket; - - switch (option) { - case ENET_SOCKOPT_NONBLOCK: { - sock->set_blocking_enabled(value ? false : true); - return 0; - } break; - - case ENET_SOCKOPT_BROADCAST: { - sock->set_broadcasting_enabled(value ? true : false); - return 0; - } break; - - case ENET_SOCKOPT_REUSEADDR: { - sock->set_reuse_address_enabled(value ? true : false); - return 0; - } break; - - case ENET_SOCKOPT_RCVBUF: { - return -1; - } break; - - case ENET_SOCKOPT_SNDBUF: { - return -1; - } break; - - case ENET_SOCKOPT_RCVTIMEO: { - return -1; - } break; - - case ENET_SOCKOPT_SNDTIMEO: { - return -1; - } break; - - case ENET_SOCKOPT_NODELAY: { - sock->set_tcp_no_delay_enabled(value ? true : false); - return 0; - } break; - } - - return -1; + ENetGodotSocket *sock = (ENetGodotSocket *)socket; + return sock->set_option(option, value); } int enet_socket_get_option(ENetSocket socket, ENetSocketOption option, int *value) {