Fix race condition in resource loader when a load task is reused

This commit is contained in:
Pedro J. Estébanez 2023-02-20 19:00:26 +01:00
parent daa29d1007
commit 618bb173ba
2 changed files with 24 additions and 27 deletions

View File

@ -33,6 +33,7 @@
#include "core/config/project_settings.h" #include "core/config/project_settings.h"
#include "core/io/file_access.h" #include "core/io/file_access.h"
#include "core/io/resource_importer.h" #include "core/io/resource_importer.h"
#include "core/os/condition_variable.h"
#include "core/os/os.h" #include "core/os/os.h"
#include "core/string/print_string.h" #include "core/string/print_string.h"
#include "core/string/translation.h" #include "core/string/translation.h"
@ -233,7 +234,7 @@ void ResourceLoader::_thread_load_function(void *p_userdata) {
ThreadLoadTask &load_task = *(ThreadLoadTask *)p_userdata; ThreadLoadTask &load_task = *(ThreadLoadTask *)p_userdata;
load_task.loader_id = Thread::get_caller_id(); load_task.loader_id = Thread::get_caller_id();
if (load_task.semaphore) { if (load_task.cond_var) {
//this is an actual thread, so wait for Ok from semaphore //this is an actual thread, so wait for Ok from semaphore
thread_load_semaphore->wait(); //wait until its ok to start loading thread_load_semaphore->wait(); //wait until its ok to start loading
} }
@ -247,7 +248,7 @@ void ResourceLoader::_thread_load_function(void *p_userdata) {
} else { } else {
load_task.status = THREAD_LOAD_LOADED; load_task.status = THREAD_LOAD_LOADED;
} }
if (load_task.semaphore) { if (load_task.cond_var) {
if (load_task.start_next && thread_waiting_count > 0) { if (load_task.start_next && thread_waiting_count > 0) {
thread_waiting_count--; thread_waiting_count--;
//thread loading count remains constant, this ends but another one begins //thread loading count remains constant, this ends but another one begins
@ -258,11 +259,9 @@ void ResourceLoader::_thread_load_function(void *p_userdata) {
print_lt("END: load count: " + itos(thread_loading_count) + " / wait count: " + itos(thread_waiting_count) + " / suspended count: " + itos(thread_suspended_count) + " / active: " + itos(thread_loading_count - thread_suspended_count)); print_lt("END: load count: " + itos(thread_loading_count) + " / wait count: " + itos(thread_waiting_count) + " / suspended count: " + itos(thread_suspended_count) + " / active: " + itos(thread_loading_count - thread_suspended_count));
for (int i = 0; i < load_task.poll_requests; i++) { load_task.cond_var->notify_all();
load_task.semaphore->post(); memdelete(load_task.cond_var);
} load_task.cond_var = nullptr;
memdelete(load_task.semaphore);
load_task.semaphore = nullptr;
} }
if (load_task.resource.is_valid()) { if (load_task.resource.is_valid()) {
@ -373,7 +372,7 @@ Error ResourceLoader::load_threaded_request(const String &p_path, const String &
if (load_task.resource.is_null()) { //needs to be loaded in thread if (load_task.resource.is_null()) { //needs to be loaded in thread
load_task.semaphore = memnew(Semaphore); load_task.cond_var = memnew(ConditionVariable);
if (thread_loading_count < thread_load_max) { if (thread_loading_count < thread_load_max) {
thread_loading_count++; thread_loading_count++;
thread_load_semaphore->post(); //we have free threads, so allow one thread_load_semaphore->post(); //we have free threads, so allow one
@ -438,9 +437,8 @@ ResourceLoader::ThreadLoadStatus ResourceLoader::load_threaded_get_status(const
Ref<Resource> ResourceLoader::load_threaded_get(const String &p_path, Error *r_error) { Ref<Resource> ResourceLoader::load_threaded_get(const String &p_path, Error *r_error) {
String local_path = _validate_local_path(p_path); String local_path = _validate_local_path(p_path);
thread_load_mutex->lock(); MutexLock thread_load_lock(*thread_load_mutex);
if (!thread_load_tasks.has(local_path)) { if (!thread_load_tasks.has(local_path)) {
thread_load_mutex->unlock();
if (r_error) { if (r_error) {
*r_error = ERR_INVALID_PARAMETER; *r_error = ERR_INVALID_PARAMETER;
} }
@ -449,13 +447,10 @@ Ref<Resource> ResourceLoader::load_threaded_get(const String &p_path, Error *r_e
ThreadLoadTask &load_task = thread_load_tasks[local_path]; ThreadLoadTask &load_task = thread_load_tasks[local_path];
//semaphore still exists, meaning it's still loading, request poll //cond var still exists, meaning it's still loading, request poll
Semaphore *semaphore = load_task.semaphore; if (load_task.cond_var) {
if (semaphore) {
load_task.poll_requests++;
{ {
// As we got a semaphore, this means we are going to have to wait // As we got a cond var, this means we are going to have to wait
// until the sub-resource is done loading // until the sub-resource is done loading
// //
// As this thread will become 'blocked' we should "exchange" its // As this thread will become 'blocked' we should "exchange" its
@ -477,14 +472,13 @@ Ref<Resource> ResourceLoader::load_threaded_get(const String &p_path, Error *r_e
print_lt("GET: load count: " + itos(thread_loading_count) + " / wait count: " + itos(thread_waiting_count) + " / suspended count: " + itos(thread_suspended_count) + " / active: " + itos(thread_loading_count - thread_suspended_count)); print_lt("GET: load count: " + itos(thread_loading_count) + " / wait count: " + itos(thread_waiting_count) + " / suspended count: " + itos(thread_suspended_count) + " / active: " + itos(thread_loading_count - thread_suspended_count));
} }
thread_load_mutex->unlock(); do {
semaphore->wait(); load_task.cond_var->wait(thread_load_lock);
thread_load_mutex->lock(); } while (load_task.cond_var); // In case of spurious wakeup.
thread_suspended_count--; thread_suspended_count--;
if (!thread_load_tasks.has(local_path)) { //may have been erased during unlock and this was always an invalid call if (!thread_load_tasks.has(local_path)) { //may have been erased during unlock and this was always an invalid call
thread_load_mutex->unlock();
if (r_error) { if (r_error) {
*r_error = ERR_INVALID_PARAMETER; *r_error = ERR_INVALID_PARAMETER;
} }
@ -507,8 +501,6 @@ Ref<Resource> ResourceLoader::load_threaded_get(const String &p_path, Error *r_e
thread_load_tasks.erase(local_path); thread_load_tasks.erase(local_path);
} }
thread_load_mutex->unlock();
return resource; return resource;
} }
@ -1067,7 +1059,7 @@ void ResourceLoader::remove_custom_loaders() {
} }
void ResourceLoader::initialize() { void ResourceLoader::initialize() {
thread_load_mutex = memnew(Mutex); thread_load_mutex = memnew(SafeBinaryMutex<BINARY_MUTEX_TAG>);
thread_load_max = OS::get_singleton()->get_processor_count(); thread_load_max = OS::get_singleton()->get_processor_count();
thread_loading_count = 0; thread_loading_count = 0;
thread_waiting_count = 0; thread_waiting_count = 0;
@ -1090,7 +1082,9 @@ bool ResourceLoader::create_missing_resources_if_class_unavailable = false;
bool ResourceLoader::abort_on_missing_resource = true; bool ResourceLoader::abort_on_missing_resource = true;
bool ResourceLoader::timestamp_on_load = false; bool ResourceLoader::timestamp_on_load = false;
Mutex *ResourceLoader::thread_load_mutex = nullptr; template <>
thread_local uint32_t SafeBinaryMutex<ResourceLoader::BINARY_MUTEX_TAG>::count = 0;
SafeBinaryMutex<ResourceLoader::BINARY_MUTEX_TAG> *ResourceLoader::thread_load_mutex = nullptr;
HashMap<String, ResourceLoader::ThreadLoadTask> ResourceLoader::thread_load_tasks; HashMap<String, ResourceLoader::ThreadLoadTask> ResourceLoader::thread_load_tasks;
Semaphore *ResourceLoader::thread_load_semaphore = nullptr; Semaphore *ResourceLoader::thread_load_semaphore = nullptr;

View File

@ -37,6 +37,8 @@
#include "core/os/semaphore.h" #include "core/os/semaphore.h"
#include "core/os/thread.h" #include "core/os/thread.h"
class ConditionVariable;
class ResourceFormatLoader : public RefCounted { class ResourceFormatLoader : public RefCounted {
GDCLASS(ResourceFormatLoader, RefCounted); GDCLASS(ResourceFormatLoader, RefCounted);
@ -105,6 +107,8 @@ public:
THREAD_LOAD_LOADED THREAD_LOAD_LOADED
}; };
static const int BINARY_MUTEX_TAG = 1;
private: private:
static Ref<ResourceFormatLoader> loader[MAX_LOADERS]; static Ref<ResourceFormatLoader> loader[MAX_LOADERS];
static int loader_count; static int loader_count;
@ -136,7 +140,7 @@ private:
struct ThreadLoadTask { struct ThreadLoadTask {
Thread *thread = nullptr; Thread *thread = nullptr;
Thread::ID loader_id = 0; Thread::ID loader_id = 0;
Semaphore *semaphore = nullptr; ConditionVariable *cond_var = nullptr;
String local_path; String local_path;
String remapped_path; String remapped_path;
String type_hint; String type_hint;
@ -149,12 +153,11 @@ private:
bool use_sub_threads = false; bool use_sub_threads = false;
bool start_next = true; bool start_next = true;
int requests = 0; int requests = 0;
int poll_requests = 0;
HashSet<String> sub_tasks; HashSet<String> sub_tasks;
}; };
static void _thread_load_function(void *p_userdata); static void _thread_load_function(void *p_userdata);
static Mutex *thread_load_mutex; static SafeBinaryMutex<BINARY_MUTEX_TAG> *thread_load_mutex;
static HashMap<String, ThreadLoadTask> thread_load_tasks; static HashMap<String, ThreadLoadTask> thread_load_tasks;
static Semaphore *thread_load_semaphore; static Semaphore *thread_load_semaphore;
static int thread_waiting_count; static int thread_waiting_count;