Add OpenImageDenoise thirdparty library
This commit is contained in:
parent
782a548ee3
commit
ad8abef74c
118
modules/denoise/SCsub
Normal file
118
modules/denoise/SCsub
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
import resource_to_cpp
|
||||||
|
|
||||||
|
Import("env")
|
||||||
|
Import("env_modules")
|
||||||
|
|
||||||
|
env_oidn = env_modules.Clone()
|
||||||
|
|
||||||
|
# Thirdparty source files
|
||||||
|
thirdparty_dir = "#thirdparty/oidn/"
|
||||||
|
thirdparty_sources = [
|
||||||
|
"core/api.cpp",
|
||||||
|
"core/device.cpp",
|
||||||
|
"core/filter.cpp",
|
||||||
|
"core/network.cpp",
|
||||||
|
"core/autoencoder.cpp",
|
||||||
|
"core/transfer_function.cpp",
|
||||||
|
"weights/rtlightmap_hdr.gen.cpp",
|
||||||
|
"mkl-dnn/src/common/batch_normalization.cpp",
|
||||||
|
"mkl-dnn/src/common/concat.cpp",
|
||||||
|
"mkl-dnn/src/common/convolution.cpp",
|
||||||
|
"mkl-dnn/src/common/convolution_pd.cpp",
|
||||||
|
"mkl-dnn/src/common/deconvolution.cpp",
|
||||||
|
"mkl-dnn/src/common/eltwise.cpp",
|
||||||
|
"mkl-dnn/src/common/engine.cpp",
|
||||||
|
"mkl-dnn/src/common/inner_product.cpp",
|
||||||
|
"mkl-dnn/src/common/inner_product_pd.cpp",
|
||||||
|
"mkl-dnn/src/common/lrn.cpp",
|
||||||
|
"mkl-dnn/src/common/memory.cpp",
|
||||||
|
"mkl-dnn/src/common/memory_desc_wrapper.cpp",
|
||||||
|
"mkl-dnn/src/common/mkldnn_debug.cpp",
|
||||||
|
"mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp",
|
||||||
|
"mkl-dnn/src/common/pooling.cpp",
|
||||||
|
"mkl-dnn/src/common/primitive.cpp",
|
||||||
|
"mkl-dnn/src/common/primitive_attr.cpp",
|
||||||
|
"mkl-dnn/src/common/primitive_desc.cpp",
|
||||||
|
"mkl-dnn/src/common/primitive_exec_types.cpp",
|
||||||
|
"mkl-dnn/src/common/primitive_iterator.cpp",
|
||||||
|
"mkl-dnn/src/common/query.cpp",
|
||||||
|
"mkl-dnn/src/common/reorder.cpp",
|
||||||
|
"mkl-dnn/src/common/rnn.cpp",
|
||||||
|
"mkl-dnn/src/common/scratchpad.cpp",
|
||||||
|
"mkl-dnn/src/common/shuffle.cpp",
|
||||||
|
"mkl-dnn/src/common/softmax.cpp",
|
||||||
|
"mkl-dnn/src/common/stream.cpp",
|
||||||
|
"mkl-dnn/src/common/sum.cpp",
|
||||||
|
"mkl-dnn/src/common/utils.cpp",
|
||||||
|
"mkl-dnn/src/common/verbose.cpp",
|
||||||
|
"mkl-dnn/src/cpu/cpu_barrier.cpp",
|
||||||
|
"mkl-dnn/src/cpu/cpu_concat.cpp",
|
||||||
|
"mkl-dnn/src/cpu/cpu_engine.cpp",
|
||||||
|
"mkl-dnn/src/cpu/cpu_memory.cpp",
|
||||||
|
"mkl-dnn/src/cpu/cpu_reducer.cpp",
|
||||||
|
"mkl-dnn/src/cpu/cpu_reorder.cpp",
|
||||||
|
"mkl-dnn/src/cpu/cpu_sum.cpp",
|
||||||
|
"mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.cpp",
|
||||||
|
"mkl-dnn/src/cpu/jit_avx2_convolution.cpp",
|
||||||
|
"mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp",
|
||||||
|
"mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.cpp",
|
||||||
|
"mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp",
|
||||||
|
"mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.cpp",
|
||||||
|
"mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.cpp",
|
||||||
|
"mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.cpp",
|
||||||
|
"mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp",
|
||||||
|
"mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.cpp",
|
||||||
|
"mkl-dnn/src/cpu/jit_sse42_convolution.cpp",
|
||||||
|
"mkl-dnn/src/cpu/jit_transpose_src_utils.cpp",
|
||||||
|
"mkl-dnn/src/cpu/jit_uni_eltwise.cpp",
|
||||||
|
"mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.cpp",
|
||||||
|
"mkl-dnn/src/cpu/jit_uni_pooling.cpp",
|
||||||
|
"mkl-dnn/src/cpu/jit_uni_reorder.cpp",
|
||||||
|
"mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp",
|
||||||
|
"mkl-dnn/src/cpu/jit_utils/jit_utils.cpp",
|
||||||
|
"mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.c",
|
||||||
|
"common/platform.cpp",
|
||||||
|
"common/thread.cpp",
|
||||||
|
"common/tensor.cpp",
|
||||||
|
]
|
||||||
|
thirdparty_sources = [thirdparty_dir + file for file in thirdparty_sources]
|
||||||
|
|
||||||
|
thirdparty_include_dirs = [
|
||||||
|
"",
|
||||||
|
"include",
|
||||||
|
"mkl-dnn/include",
|
||||||
|
"mkl-dnn/src",
|
||||||
|
"mkl-dnn/src/common",
|
||||||
|
"mkl-dnn/src/cpu/xbyak",
|
||||||
|
"mkl-dnn/src/cpu",
|
||||||
|
]
|
||||||
|
thirdparty_include_dirs = [thirdparty_dir + file for file in thirdparty_include_dirs]
|
||||||
|
|
||||||
|
|
||||||
|
env_oidn.Prepend(CPPPATH=thirdparty_include_dirs)
|
||||||
|
env_oidn.Append(
|
||||||
|
CPPDEFINES=[
|
||||||
|
"MKLDNN_THR=MKLDNN_THR_SEQ",
|
||||||
|
"OIDN_STATIC_LIB",
|
||||||
|
"__STDC_CONSTANT_MACROS",
|
||||||
|
"__STDC_LIMIT_MACROS",
|
||||||
|
"DISABLE_VERBOSE",
|
||||||
|
"MKLDNN_ENABLE_CONCURRENT_EXEC",
|
||||||
|
"NDEBUG",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
env_thirdparty = env_oidn.Clone()
|
||||||
|
env_thirdparty.disable_warnings()
|
||||||
|
env_thirdparty.add_source_files(env.modules_sources, thirdparty_sources)
|
||||||
|
|
||||||
|
weights_in_path = thirdparty_dir + "weights/rtlightmap_hdr.tza"
|
||||||
|
weights_out_path = thirdparty_dir + "weights/rtlightmap_hdr.gen.cpp"
|
||||||
|
|
||||||
|
env_thirdparty.Depends(weights_out_path, weights_in_path)
|
||||||
|
env_thirdparty.CommandNoCache(weights_out_path, weights_in_path, resource_to_cpp.tza_to_cpp)
|
||||||
|
|
||||||
|
env_oidn.add_source_files(env.modules_sources, "denoise_wrapper.cpp")
|
||||||
|
env_modules.add_source_files(env.modules_sources, ["register_types.cpp", "lightmap_denoiser.cpp"])
|
15
modules/denoise/config.py
Normal file
15
modules/denoise/config.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
def can_build(env, platform):
|
||||||
|
# Thirdparty dependency OpenImage Denoise includes oneDNN library
|
||||||
|
# which only supports 64-bit architectures.
|
||||||
|
# It's also only relevant for tools build and desktop platforms,
|
||||||
|
# as doing lightmap generation and denoising on Android or HTML5
|
||||||
|
# would be a bit far-fetched.
|
||||||
|
# Note: oneDNN doesn't support ARM64, OIDN needs updating to the latest version
|
||||||
|
supported_platform = platform in ["x11", "osx", "windows", "server"]
|
||||||
|
supported_bits = env["bits"] == "64"
|
||||||
|
supported_arch = env["arch"] != "arm64"
|
||||||
|
return env["tools"] and supported_platform and supported_bits and supported_arch
|
||||||
|
|
||||||
|
|
||||||
|
def configure(env):
|
||||||
|
pass
|
69
modules/denoise/denoise_wrapper.cpp
Normal file
69
modules/denoise/denoise_wrapper.cpp
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
/*************************************************************************/
|
||||||
|
/* denoise_wrapper.cpp */
|
||||||
|
/*************************************************************************/
|
||||||
|
/* This file is part of: */
|
||||||
|
/* GODOT ENGINE */
|
||||||
|
/* https://godotengine.org */
|
||||||
|
/*************************************************************************/
|
||||||
|
/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */
|
||||||
|
/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */
|
||||||
|
/* */
|
||||||
|
/* Permission is hereby granted, free of charge, to any person obtaining */
|
||||||
|
/* a copy of this software and associated documentation files (the */
|
||||||
|
/* "Software"), to deal in the Software without restriction, including */
|
||||||
|
/* without limitation the rights to use, copy, modify, merge, publish, */
|
||||||
|
/* distribute, sublicense, and/or sell copies of the Software, and to */
|
||||||
|
/* permit persons to whom the Software is furnished to do so, subject to */
|
||||||
|
/* the following conditions: */
|
||||||
|
/* */
|
||||||
|
/* The above copyright notice and this permission notice shall be */
|
||||||
|
/* included in all copies or substantial portions of the Software. */
|
||||||
|
/* */
|
||||||
|
/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */
|
||||||
|
/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */
|
||||||
|
/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/
|
||||||
|
/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */
|
||||||
|
/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */
|
||||||
|
/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */
|
||||||
|
/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */
|
||||||
|
/*************************************************************************/
|
||||||
|
|
||||||
|
#include "denoise_wrapper.h"
|
||||||
|
#include "core/os/copymem.h"
|
||||||
|
#include "core/os/memory.h"
|
||||||
|
#include "thirdparty/oidn/include/OpenImageDenoise/oidn.h"
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
void *oidn_denoiser_init() {
|
||||||
|
OIDNDeviceImpl *device = oidnNewDevice(OIDN_DEVICE_TYPE_CPU);
|
||||||
|
oidnCommitDevice(device);
|
||||||
|
return device;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool oidn_denoise(void *deviceptr, float *p_floats, int p_width, int p_height) {
|
||||||
|
OIDNDeviceImpl *device = (OIDNDeviceImpl *)deviceptr;
|
||||||
|
OIDNFilter filter = oidnNewFilter(device, "RTLightmap");
|
||||||
|
void *input_buffer = memalloc(p_width * p_height * 3 * sizeof(float));
|
||||||
|
copymem(input_buffer, p_floats, p_width * p_height * 3 * sizeof(float));
|
||||||
|
oidnSetSharedFilterImage(filter, "color", input_buffer, OIDN_FORMAT_FLOAT3, p_width, p_height, 0, 0, 0);
|
||||||
|
oidnSetSharedFilterImage(filter, "output", (void *)p_floats, OIDN_FORMAT_FLOAT3, p_width, p_height, 0, 0, 0);
|
||||||
|
oidnSetFilter1b(filter, "hdr", true);
|
||||||
|
//oidnSetFilter1f(filter, "hdrScale", 1.0f);
|
||||||
|
//oidnSetFilter1i(filter, "verbose", 4);
|
||||||
|
oidnCommitFilter(filter);
|
||||||
|
oidnExecuteFilter(filter);
|
||||||
|
|
||||||
|
const char *msg;
|
||||||
|
bool success = true;
|
||||||
|
if (oidnGetDeviceError(device, &msg) != OIDN_ERROR_NONE) {
|
||||||
|
printf("LightmapDenoiser: %s\n", msg);
|
||||||
|
success = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
oidnReleaseFilter(filter);
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
void oidn_denoiser_finish(void *device) {
|
||||||
|
oidnReleaseDevice((OIDNDeviceImpl *)device);
|
||||||
|
}
|
38
modules/denoise/denoise_wrapper.h
Normal file
38
modules/denoise/denoise_wrapper.h
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
/*************************************************************************/
|
||||||
|
/* denoise_wrapper.h */
|
||||||
|
/*************************************************************************/
|
||||||
|
/* This file is part of: */
|
||||||
|
/* GODOT ENGINE */
|
||||||
|
/* https://godotengine.org */
|
||||||
|
/*************************************************************************/
|
||||||
|
/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */
|
||||||
|
/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */
|
||||||
|
/* */
|
||||||
|
/* Permission is hereby granted, free of charge, to any person obtaining */
|
||||||
|
/* a copy of this software and associated documentation files (the */
|
||||||
|
/* "Software"), to deal in the Software without restriction, including */
|
||||||
|
/* without limitation the rights to use, copy, modify, merge, publish, */
|
||||||
|
/* distribute, sublicense, and/or sell copies of the Software, and to */
|
||||||
|
/* permit persons to whom the Software is furnished to do so, subject to */
|
||||||
|
/* the following conditions: */
|
||||||
|
/* */
|
||||||
|
/* The above copyright notice and this permission notice shall be */
|
||||||
|
/* included in all copies or substantial portions of the Software. */
|
||||||
|
/* */
|
||||||
|
/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */
|
||||||
|
/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */
|
||||||
|
/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/
|
||||||
|
/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */
|
||||||
|
/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */
|
||||||
|
/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */
|
||||||
|
/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */
|
||||||
|
/*************************************************************************/
|
||||||
|
|
||||||
|
#ifndef DENOISE_WRAPPER_H
|
||||||
|
#define DENOISE_WRAPPER_H
|
||||||
|
|
||||||
|
void *oidn_denoiser_init();
|
||||||
|
bool oidn_denoise(void *device, float *p_floats, int p_width, int p_height);
|
||||||
|
void oidn_denoiser_finish(void *device);
|
||||||
|
|
||||||
|
#endif // DENOISE_WRAPPER_H
|
65
modules/denoise/lightmap_denoiser.cpp
Normal file
65
modules/denoise/lightmap_denoiser.cpp
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
/*************************************************************************/
|
||||||
|
/* lightmap_denoiser.cpp */
|
||||||
|
/*************************************************************************/
|
||||||
|
/* This file is part of: */
|
||||||
|
/* GODOT ENGINE */
|
||||||
|
/* https://godotengine.org */
|
||||||
|
/*************************************************************************/
|
||||||
|
/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */
|
||||||
|
/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */
|
||||||
|
/* */
|
||||||
|
/* Permission is hereby granted, free of charge, to any person obtaining */
|
||||||
|
/* a copy of this software and associated documentation files (the */
|
||||||
|
/* "Software"), to deal in the Software without restriction, including */
|
||||||
|
/* without limitation the rights to use, copy, modify, merge, publish, */
|
||||||
|
/* distribute, sublicense, and/or sell copies of the Software, and to */
|
||||||
|
/* permit persons to whom the Software is furnished to do so, subject to */
|
||||||
|
/* the following conditions: */
|
||||||
|
/* */
|
||||||
|
/* The above copyright notice and this permission notice shall be */
|
||||||
|
/* included in all copies or substantial portions of the Software. */
|
||||||
|
/* */
|
||||||
|
/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */
|
||||||
|
/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */
|
||||||
|
/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/
|
||||||
|
/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */
|
||||||
|
/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */
|
||||||
|
/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */
|
||||||
|
/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */
|
||||||
|
/*************************************************************************/
|
||||||
|
|
||||||
|
#include "lightmap_denoiser.h"
|
||||||
|
#include "denoise_wrapper.h"
|
||||||
|
|
||||||
|
LightmapDenoiser *LightmapDenoiserOIDN::create_oidn_denoiser() {
|
||||||
|
return memnew(LightmapDenoiserOIDN);
|
||||||
|
}
|
||||||
|
|
||||||
|
void LightmapDenoiserOIDN::make_default_denoiser() {
|
||||||
|
create_function = create_oidn_denoiser;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ref<Image> LightmapDenoiserOIDN::denoise_image(const Ref<Image> &p_image) {
|
||||||
|
Ref<Image> img = p_image->duplicate();
|
||||||
|
|
||||||
|
img->convert(Image::FORMAT_RGBF);
|
||||||
|
|
||||||
|
PoolByteArray data = img->get_data();
|
||||||
|
{
|
||||||
|
PoolByteArray::Write w = data.write();
|
||||||
|
if (!oidn_denoise(device, (float *)w.ptr(), img->get_width(), img->get_height())) {
|
||||||
|
return p_image;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
img->create(img->get_width(), img->get_height(), false, img->get_format(), data);
|
||||||
|
return img;
|
||||||
|
}
|
||||||
|
|
||||||
|
LightmapDenoiserOIDN::LightmapDenoiserOIDN() {
|
||||||
|
device = oidn_denoiser_init();
|
||||||
|
}
|
||||||
|
|
||||||
|
LightmapDenoiserOIDN::~LightmapDenoiserOIDN() {
|
||||||
|
oidn_denoiser_finish(device);
|
||||||
|
}
|
56
modules/denoise/lightmap_denoiser.h
Normal file
56
modules/denoise/lightmap_denoiser.h
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
/*************************************************************************/
|
||||||
|
/* lightmap_denoiser.h */
|
||||||
|
/*************************************************************************/
|
||||||
|
/* This file is part of: */
|
||||||
|
/* GODOT ENGINE */
|
||||||
|
/* https://godotengine.org */
|
||||||
|
/*************************************************************************/
|
||||||
|
/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */
|
||||||
|
/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */
|
||||||
|
/* */
|
||||||
|
/* Permission is hereby granted, free of charge, to any person obtaining */
|
||||||
|
/* a copy of this software and associated documentation files (the */
|
||||||
|
/* "Software"), to deal in the Software without restriction, including */
|
||||||
|
/* without limitation the rights to use, copy, modify, merge, publish, */
|
||||||
|
/* distribute, sublicense, and/or sell copies of the Software, and to */
|
||||||
|
/* permit persons to whom the Software is furnished to do so, subject to */
|
||||||
|
/* the following conditions: */
|
||||||
|
/* */
|
||||||
|
/* The above copyright notice and this permission notice shall be */
|
||||||
|
/* included in all copies or substantial portions of the Software. */
|
||||||
|
/* */
|
||||||
|
/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */
|
||||||
|
/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */
|
||||||
|
/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/
|
||||||
|
/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */
|
||||||
|
/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */
|
||||||
|
/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */
|
||||||
|
/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */
|
||||||
|
/*************************************************************************/
|
||||||
|
|
||||||
|
#ifndef LIGHTMAP_DENOISER_H
|
||||||
|
#define LIGHTMAP_DENOISER_H
|
||||||
|
|
||||||
|
#include "core/class_db.h"
|
||||||
|
#include "scene/3d/lightmapper.h"
|
||||||
|
|
||||||
|
struct OIDNDeviceImpl;
|
||||||
|
|
||||||
|
class LightmapDenoiserOIDN : public LightmapDenoiser {
|
||||||
|
GDCLASS(LightmapDenoiserOIDN, LightmapDenoiser);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void *device = nullptr;
|
||||||
|
|
||||||
|
public:
|
||||||
|
static LightmapDenoiser *create_oidn_denoiser();
|
||||||
|
|
||||||
|
Ref<Image> denoise_image(const Ref<Image> &p_image) override;
|
||||||
|
|
||||||
|
static void make_default_denoiser();
|
||||||
|
|
||||||
|
LightmapDenoiserOIDN();
|
||||||
|
~LightmapDenoiserOIDN();
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // LIGHTMAP_DENOISER_H
|
40
modules/denoise/register_types.cpp
Normal file
40
modules/denoise/register_types.cpp
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
/*************************************************************************/
|
||||||
|
/* register_types.cpp */
|
||||||
|
/*************************************************************************/
|
||||||
|
/* This file is part of: */
|
||||||
|
/* GODOT ENGINE */
|
||||||
|
/* https://godotengine.org */
|
||||||
|
/*************************************************************************/
|
||||||
|
/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */
|
||||||
|
/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */
|
||||||
|
/* */
|
||||||
|
/* Permission is hereby granted, free of charge, to any person obtaining */
|
||||||
|
/* a copy of this software and associated documentation files (the */
|
||||||
|
/* "Software"), to deal in the Software without restriction, including */
|
||||||
|
/* without limitation the rights to use, copy, modify, merge, publish, */
|
||||||
|
/* distribute, sublicense, and/or sell copies of the Software, and to */
|
||||||
|
/* permit persons to whom the Software is furnished to do so, subject to */
|
||||||
|
/* the following conditions: */
|
||||||
|
/* */
|
||||||
|
/* The above copyright notice and this permission notice shall be */
|
||||||
|
/* included in all copies or substantial portions of the Software. */
|
||||||
|
/* */
|
||||||
|
/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */
|
||||||
|
/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */
|
||||||
|
/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/
|
||||||
|
/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */
|
||||||
|
/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */
|
||||||
|
/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */
|
||||||
|
/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */
|
||||||
|
/*************************************************************************/
|
||||||
|
|
||||||
|
#include "register_types.h"
|
||||||
|
#include "core/engine.h"
|
||||||
|
#include "lightmap_denoiser.h"
|
||||||
|
|
||||||
|
void register_denoise_types() {
|
||||||
|
LightmapDenoiserOIDN::make_default_denoiser();
|
||||||
|
}
|
||||||
|
|
||||||
|
void unregister_denoise_types() {
|
||||||
|
}
|
37
modules/denoise/register_types.h
Normal file
37
modules/denoise/register_types.h
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
/*************************************************************************/
|
||||||
|
/* register_types.h */
|
||||||
|
/*************************************************************************/
|
||||||
|
/* This file is part of: */
|
||||||
|
/* GODOT ENGINE */
|
||||||
|
/* https://godotengine.org */
|
||||||
|
/*************************************************************************/
|
||||||
|
/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */
|
||||||
|
/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */
|
||||||
|
/* */
|
||||||
|
/* Permission is hereby granted, free of charge, to any person obtaining */
|
||||||
|
/* a copy of this software and associated documentation files (the */
|
||||||
|
/* "Software"), to deal in the Software without restriction, including */
|
||||||
|
/* without limitation the rights to use, copy, modify, merge, publish, */
|
||||||
|
/* distribute, sublicense, and/or sell copies of the Software, and to */
|
||||||
|
/* permit persons to whom the Software is furnished to do so, subject to */
|
||||||
|
/* the following conditions: */
|
||||||
|
/* */
|
||||||
|
/* The above copyright notice and this permission notice shall be */
|
||||||
|
/* included in all copies or substantial portions of the Software. */
|
||||||
|
/* */
|
||||||
|
/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */
|
||||||
|
/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */
|
||||||
|
/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/
|
||||||
|
/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */
|
||||||
|
/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */
|
||||||
|
/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */
|
||||||
|
/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */
|
||||||
|
/*************************************************************************/
|
||||||
|
|
||||||
|
#ifndef DENOISE_REGISTER_TYPES_H
|
||||||
|
#define DENOISE_REGISTER_TYPES_H
|
||||||
|
|
||||||
|
void register_denoise_types();
|
||||||
|
void unregister_denoise_types();
|
||||||
|
|
||||||
|
#endif // DENOISE_REGISTER_TYPES_H
|
68
modules/denoise/resource_to_cpp.py
Normal file
68
modules/denoise/resource_to_cpp.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
## ======================================================================== ##
|
||||||
|
## Copyright 2009-2019 Intel Corporation ##
|
||||||
|
## ##
|
||||||
|
## Licensed under the Apache License, Version 2.0 (the "License"); ##
|
||||||
|
## you may not use this file except in compliance with the License. ##
|
||||||
|
## You may obtain a copy of the License at ##
|
||||||
|
## ##
|
||||||
|
## http://www.apache.org/licenses/LICENSE-2.0 ##
|
||||||
|
## ##
|
||||||
|
## Unless required by applicable law or agreed to in writing, software ##
|
||||||
|
## distributed under the License is distributed on an "AS IS" BASIS, ##
|
||||||
|
## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ##
|
||||||
|
## See the License for the specific language governing permissions and ##
|
||||||
|
## limitations under the License. ##
|
||||||
|
## ======================================================================== ##
|
||||||
|
|
||||||
|
import os
|
||||||
|
from array import array
|
||||||
|
|
||||||
|
# Generates a C++ file from the specified binary resource file
|
||||||
|
def generate(in_path, out_path):
|
||||||
|
|
||||||
|
namespace = "oidn::weights"
|
||||||
|
scopes = namespace.split("::")
|
||||||
|
|
||||||
|
file_name = os.path.basename(in_path)
|
||||||
|
var_name = os.path.splitext(file_name)[0]
|
||||||
|
|
||||||
|
with open(in_path, "rb") as in_file, open(out_path, "w") as out_file:
|
||||||
|
# Header
|
||||||
|
out_file.write("// Generated from: %s\n" % file_name)
|
||||||
|
out_file.write("#include <cstddef>\n\n")
|
||||||
|
|
||||||
|
# Open the namespaces
|
||||||
|
for s in scopes:
|
||||||
|
out_file.write("namespace %s {\n" % s)
|
||||||
|
if scopes:
|
||||||
|
out_file.write("\n")
|
||||||
|
|
||||||
|
# Read the file
|
||||||
|
in_data = array("B", in_file.read())
|
||||||
|
|
||||||
|
# Write the size
|
||||||
|
out_file.write("//const size_t %s_size = %d;\n\n" % (var_name, len(in_data)))
|
||||||
|
|
||||||
|
# Write the data
|
||||||
|
out_file.write("unsigned char %s[] = {" % var_name)
|
||||||
|
for i in range(len(in_data)):
|
||||||
|
c = in_data[i]
|
||||||
|
if i > 0:
|
||||||
|
out_file.write(",")
|
||||||
|
if (i + 1) % 20 == 1:
|
||||||
|
out_file.write("\n")
|
||||||
|
out_file.write("%d" % c)
|
||||||
|
out_file.write("\n};\n")
|
||||||
|
|
||||||
|
# Close the namespaces
|
||||||
|
if scopes:
|
||||||
|
out_file.write("\n")
|
||||||
|
for scope in reversed(scopes):
|
||||||
|
out_file.write("} // namespace %s\n" % scope)
|
||||||
|
|
||||||
|
|
||||||
|
def tza_to_cpp(target, source, env):
|
||||||
|
for x in zip(source, target):
|
||||||
|
generate(str(x[0]), str(x[1]))
|
31
thirdparty/README.md
vendored
31
thirdparty/README.md
vendored
@ -360,6 +360,37 @@ Files extracted from the upstream source:
|
|||||||
- LICENSE.txt
|
- LICENSE.txt
|
||||||
|
|
||||||
|
|
||||||
|
## oidn
|
||||||
|
|
||||||
|
- Upstream: https://github.com/OpenImageDenoise/oidn
|
||||||
|
- Version: 1.1.0 (c58c5216db05ceef4cde5a096862f2eeffd14c06, 2019)
|
||||||
|
- License: Apache 2.0
|
||||||
|
|
||||||
|
Files extracted from upstream source:
|
||||||
|
|
||||||
|
common/* (except tasking.* and CMakeLists.txt)
|
||||||
|
core/*
|
||||||
|
include/OpenImageDenoise/* (except version.h.in)
|
||||||
|
LICENSE.txt
|
||||||
|
mkl-dnn/include/*
|
||||||
|
mkl-dnn/src/* (except CMakeLists.txt)
|
||||||
|
weights/rtlightmap_hdr.tza
|
||||||
|
scripts/resource_to_cpp.py
|
||||||
|
|
||||||
|
Modified files:
|
||||||
|
Modifications are marked with `// -- GODOT start --` and `// -- GODOT end --`.
|
||||||
|
Patch files are provided in `oidn/patches/`.
|
||||||
|
|
||||||
|
core/autoencoder.cpp
|
||||||
|
core/autoencoder.h
|
||||||
|
core/common.h
|
||||||
|
core/device.cpp
|
||||||
|
core/device.h
|
||||||
|
core/transfer_function.cpp
|
||||||
|
|
||||||
|
scripts/resource_to_cpp.py (used in modules/denoise/resource_to_cpp.py)
|
||||||
|
|
||||||
|
|
||||||
## opus
|
## opus
|
||||||
|
|
||||||
- Upstream: https://opus-codec.org
|
- Upstream: https://opus-codec.org
|
||||||
|
202
thirdparty/oidn/LICENSE.txt
vendored
Normal file
202
thirdparty/oidn/LICENSE.txt
vendored
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
|
||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
52
thirdparty/oidn/common/barrier.h
vendored
Normal file
52
thirdparty/oidn/common/barrier.h
vendored
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
// ======================================================================== //
|
||||||
|
// Copyright 2009-2019 Intel Corporation //
|
||||||
|
// //
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||||||
|
// you may not use this file except in compliance with the License. //
|
||||||
|
// You may obtain a copy of the License at //
|
||||||
|
// //
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||||||
|
// //
|
||||||
|
// Unless required by applicable law or agreed to in writing, software //
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||||||
|
// See the License for the specific language governing permissions and //
|
||||||
|
// limitations under the License. //
|
||||||
|
// ======================================================================== //
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "platform.h"
|
||||||
|
#include <mutex>
|
||||||
|
#include <condition_variable>
|
||||||
|
|
||||||
|
namespace oidn {
|
||||||
|
|
||||||
|
class Barrier
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
std::mutex m;
|
||||||
|
std::condition_variable cv;
|
||||||
|
volatile int count;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Barrier(int count) : count(count) {}
|
||||||
|
|
||||||
|
void wait()
|
||||||
|
{
|
||||||
|
std::unique_lock<std::mutex> lk(m);
|
||||||
|
count--;
|
||||||
|
|
||||||
|
if (count == 0)
|
||||||
|
{
|
||||||
|
lk.unlock();
|
||||||
|
cv.notify_all();
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
cv.wait(lk, [&]{ return count == 0; });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace oidn
|
45
thirdparty/oidn/common/exception.h
vendored
Normal file
45
thirdparty/oidn/common/exception.h
vendored
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
// ======================================================================== //
|
||||||
|
// Copyright 2009-2019 Intel Corporation //
|
||||||
|
// //
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||||||
|
// you may not use this file except in compliance with the License. //
|
||||||
|
// You may obtain a copy of the License at //
|
||||||
|
// //
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||||||
|
// //
|
||||||
|
// Unless required by applicable law or agreed to in writing, software //
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||||||
|
// See the License for the specific language governing permissions and //
|
||||||
|
// limitations under the License. //
|
||||||
|
// ======================================================================== //
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <exception>
|
||||||
|
#include "platform.h"
|
||||||
|
|
||||||
|
namespace oidn {
|
||||||
|
|
||||||
|
class Exception : public std::exception
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
Error error;
|
||||||
|
const char* message;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Exception(Error error, const char* message)
|
||||||
|
: error(error), message(message) {}
|
||||||
|
|
||||||
|
Error code() const noexcept
|
||||||
|
{
|
||||||
|
return error;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* what() const noexcept override
|
||||||
|
{
|
||||||
|
return message;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace oidn
|
114
thirdparty/oidn/common/platform.cpp
vendored
Normal file
114
thirdparty/oidn/common/platform.cpp
vendored
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
// ======================================================================== //
|
||||||
|
// Copyright 2009-2019 Intel Corporation //
|
||||||
|
// //
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||||||
|
// you may not use this file except in compliance with the License. //
|
||||||
|
// You may obtain a copy of the License at //
|
||||||
|
// //
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||||||
|
// //
|
||||||
|
// Unless required by applicable law or agreed to in writing, software //
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||||||
|
// See the License for the specific language governing permissions and //
|
||||||
|
// limitations under the License. //
|
||||||
|
// ======================================================================== //
|
||||||
|
|
||||||
|
#include "platform.h"
|
||||||
|
|
||||||
|
namespace oidn {
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// Common functions
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
void* alignedMalloc(size_t size, size_t alignment)
|
||||||
|
{
|
||||||
|
if (size == 0)
|
||||||
|
return nullptr;
|
||||||
|
|
||||||
|
assert((alignment & (alignment-1)) == 0);
|
||||||
|
void* ptr = _mm_malloc(size, alignment);
|
||||||
|
|
||||||
|
if (ptr == nullptr)
|
||||||
|
throw std::bad_alloc();
|
||||||
|
|
||||||
|
return ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void alignedFree(void* ptr)
|
||||||
|
{
|
||||||
|
if (ptr)
|
||||||
|
_mm_free(ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// System information
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
std::string getPlatformName()
|
||||||
|
{
|
||||||
|
std::string name;
|
||||||
|
|
||||||
|
#if defined(__linux__)
|
||||||
|
name = "Linux";
|
||||||
|
#elif defined(__FreeBSD__)
|
||||||
|
name = "FreeBSD";
|
||||||
|
#elif defined(__CYGWIN__)
|
||||||
|
name = "Cygwin";
|
||||||
|
#elif defined(_WIN32)
|
||||||
|
name = "Windows";
|
||||||
|
#elif defined(__APPLE__)
|
||||||
|
name = "macOS";
|
||||||
|
#elif defined(__unix__)
|
||||||
|
name = "Unix";
|
||||||
|
#else
|
||||||
|
return "Unknown";
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(__x86_64__) || defined(_M_X64) || defined(__ia64__) || defined(__aarch64__)
|
||||||
|
name += " (64-bit)";
|
||||||
|
#else
|
||||||
|
name += " (32-bit)";
|
||||||
|
#endif
|
||||||
|
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string getCompilerName()
|
||||||
|
{
|
||||||
|
#if defined(__INTEL_COMPILER)
|
||||||
|
int mayor = __INTEL_COMPILER / 100 % 100;
|
||||||
|
int minor = __INTEL_COMPILER % 100;
|
||||||
|
std::string version = "Intel Compiler ";
|
||||||
|
version += toString(mayor);
|
||||||
|
version += "." + toString(minor);
|
||||||
|
#if defined(__INTEL_COMPILER_UPDATE)
|
||||||
|
version += "." + toString(__INTEL_COMPILER_UPDATE);
|
||||||
|
#endif
|
||||||
|
return version;
|
||||||
|
#elif defined(__clang__)
|
||||||
|
return "Clang " __clang_version__;
|
||||||
|
#elif defined(__GNUC__)
|
||||||
|
return "GCC " __VERSION__;
|
||||||
|
#elif defined(_MSC_VER)
|
||||||
|
std::string version = toString(_MSC_FULL_VER);
|
||||||
|
version.insert(4, ".");
|
||||||
|
version.insert(9, ".");
|
||||||
|
version.insert(2, ".");
|
||||||
|
return "Visual C++ Compiler " + version;
|
||||||
|
#else
|
||||||
|
return "Unknown";
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string getBuildName()
|
||||||
|
{
|
||||||
|
#if defined(NDEBUG)
|
||||||
|
return "Release";
|
||||||
|
#else
|
||||||
|
return "Debug";
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace oidn
|
131
thirdparty/oidn/common/platform.h
vendored
Normal file
131
thirdparty/oidn/common/platform.h
vendored
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
// ======================================================================== //
|
||||||
|
// Copyright 2009-2019 Intel Corporation //
|
||||||
|
// //
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||||||
|
// you may not use this file except in compliance with the License. //
|
||||||
|
// You may obtain a copy of the License at //
|
||||||
|
// //
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||||||
|
// //
|
||||||
|
// Unless required by applicable law or agreed to in writing, software //
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||||||
|
// See the License for the specific language governing permissions and //
|
||||||
|
// limitations under the License. //
|
||||||
|
// ======================================================================== //
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#if defined(_WIN32)
|
||||||
|
#define WIN32_LEAN_AND_MEAN
|
||||||
|
#define NOMINMAX
|
||||||
|
#include <windows.h>
|
||||||
|
#elif defined(__APPLE__)
|
||||||
|
#include <sys/sysctl.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <xmmintrin.h>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <climits>
|
||||||
|
#include <limits>
|
||||||
|
#include <atomic>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
#include <cmath>
|
||||||
|
#include <string>
|
||||||
|
#include <sstream>
|
||||||
|
#include <iostream>
|
||||||
|
#include <cassert>
|
||||||
|
#include "include/OpenImageDenoise/oidn.hpp"
|
||||||
|
|
||||||
|
namespace oidn {
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// Macros
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
#if defined(_WIN32)
|
||||||
|
// Windows
|
||||||
|
#if !defined(__noinline)
|
||||||
|
#define __noinline __declspec(noinline)
|
||||||
|
#endif
|
||||||
|
#else
|
||||||
|
// Unix
|
||||||
|
#if !defined(__forceinline)
|
||||||
|
#define __forceinline inline __attribute__((always_inline))
|
||||||
|
#endif
|
||||||
|
#if !defined(__noinline)
|
||||||
|
#define __noinline __attribute__((noinline))
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef UNUSED
|
||||||
|
#define UNUSED(x) ((void)x)
|
||||||
|
#endif
|
||||||
|
#ifndef MAYBE_UNUSED
|
||||||
|
#define MAYBE_UNUSED(x) UNUSED(x)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// Error handling and debugging
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
struct Verbose
|
||||||
|
{
|
||||||
|
int verbose;
|
||||||
|
|
||||||
|
Verbose(int v = 0) : verbose(v) {}
|
||||||
|
__forceinline bool isVerbose(int v = 1) const { return v <= verbose; }
|
||||||
|
};
|
||||||
|
|
||||||
|
#define OIDN_WARNING(message) { if (isVerbose()) std::cerr << "Warning: " << message << std::endl; }
|
||||||
|
#define OIDN_FATAL(message) throw std::runtime_error(message);
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// Common functions
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
using std::min;
|
||||||
|
using std::max;
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
__forceinline T clamp(const T& value, const T& minValue, const T& maxValue)
|
||||||
|
{
|
||||||
|
return min(max(value, minValue), maxValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
void* alignedMalloc(size_t size, size_t alignment);
|
||||||
|
void alignedFree(void* ptr);
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
inline std::string toString(const T& a)
|
||||||
|
{
|
||||||
|
std::stringstream sm;
|
||||||
|
sm << a;
|
||||||
|
return sm.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
#if defined(__APPLE__)
|
||||||
|
template<typename T>
|
||||||
|
bool getSysctl(const char* name, T& value)
|
||||||
|
{
|
||||||
|
int64_t result = 0;
|
||||||
|
size_t size = sizeof(result);
|
||||||
|
|
||||||
|
if (sysctlbyname(name, &result, &size, nullptr, 0) != 0)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
value = T(result);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// System information
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
std::string getPlatformName();
|
||||||
|
std::string getCompilerName();
|
||||||
|
std::string getBuildName();
|
||||||
|
|
||||||
|
} // namespace oidn
|
163
thirdparty/oidn/common/ref.h
vendored
Normal file
163
thirdparty/oidn/common/ref.h
vendored
Normal file
@ -0,0 +1,163 @@
|
|||||||
|
// ======================================================================== //
|
||||||
|
// Copyright 2009-2019 Intel Corporation //
|
||||||
|
// //
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||||||
|
// you may not use this file except in compliance with the License. //
|
||||||
|
// You may obtain a copy of the License at //
|
||||||
|
// //
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||||||
|
// //
|
||||||
|
// Unless required by applicable law or agreed to in writing, software //
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||||||
|
// See the License for the specific language governing permissions and //
|
||||||
|
// limitations under the License. //
|
||||||
|
// ======================================================================== //
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "platform.h"
|
||||||
|
|
||||||
|
namespace oidn {
|
||||||
|
|
||||||
|
class RefCount
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
std::atomic<size_t> count;
|
||||||
|
|
||||||
|
public:
|
||||||
|
__forceinline RefCount(int count = 0) noexcept : count(count) {}
|
||||||
|
|
||||||
|
__forceinline size_t incRef() noexcept
|
||||||
|
{
|
||||||
|
return count.fetch_add(1) + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline size_t decRef()
|
||||||
|
{
|
||||||
|
const size_t newCount = decRefKeep();
|
||||||
|
if (newCount == 0)
|
||||||
|
destroy();
|
||||||
|
return newCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline size_t decRefKeep() noexcept
|
||||||
|
{
|
||||||
|
return count.fetch_add(-1) - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline void destroy()
|
||||||
|
{
|
||||||
|
delete this;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
// Disable copying
|
||||||
|
RefCount(const RefCount&) = delete;
|
||||||
|
RefCount& operator =(const RefCount&) = delete;
|
||||||
|
|
||||||
|
virtual ~RefCount() noexcept = default;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
class Ref
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
T* ptr;
|
||||||
|
|
||||||
|
public:
|
||||||
|
__forceinline Ref() noexcept : ptr(nullptr) {}
|
||||||
|
__forceinline Ref(std::nullptr_t) noexcept : ptr(nullptr) {}
|
||||||
|
__forceinline Ref(const Ref& other) noexcept : ptr(other.ptr) { if (ptr) ptr->incRef(); }
|
||||||
|
__forceinline Ref(Ref&& other) noexcept : ptr(other.ptr) { other.ptr = nullptr; }
|
||||||
|
__forceinline Ref(T* ptr) noexcept : ptr(ptr) { if (ptr) ptr->incRef(); }
|
||||||
|
|
||||||
|
template<typename Y>
|
||||||
|
__forceinline Ref(const Ref<Y>& other) noexcept : ptr(other.get()) { if (ptr) ptr->incRef(); }
|
||||||
|
|
||||||
|
template<typename Y>
|
||||||
|
__forceinline explicit Ref(Y* ptr) noexcept : ptr(ptr) { if (ptr) ptr->incRef(); }
|
||||||
|
|
||||||
|
__forceinline ~Ref() { if (ptr) ptr->decRef(); }
|
||||||
|
|
||||||
|
__forceinline Ref& operator =(const Ref& other)
|
||||||
|
{
|
||||||
|
if (other.ptr)
|
||||||
|
other.ptr->incRef();
|
||||||
|
if (ptr)
|
||||||
|
ptr->decRef();
|
||||||
|
ptr = other.ptr;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline Ref& operator =(Ref&& other)
|
||||||
|
{
|
||||||
|
if (ptr)
|
||||||
|
ptr->decRef();
|
||||||
|
ptr = other.ptr;
|
||||||
|
other.ptr = nullptr;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline Ref& operator =(T* other)
|
||||||
|
{
|
||||||
|
if (other)
|
||||||
|
other->incRef();
|
||||||
|
if (ptr)
|
||||||
|
ptr->decRef();
|
||||||
|
ptr = other;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline Ref& operator =(std::nullptr_t)
|
||||||
|
{
|
||||||
|
if (ptr)
|
||||||
|
ptr->decRef();
|
||||||
|
ptr = nullptr;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline operator bool() const noexcept { return ptr != nullptr; }
|
||||||
|
|
||||||
|
__forceinline T& operator *() const noexcept { return *ptr; }
|
||||||
|
__forceinline T* operator ->() const noexcept { return ptr; }
|
||||||
|
|
||||||
|
__forceinline T* get() const noexcept { return ptr; }
|
||||||
|
|
||||||
|
__forceinline T* detach() noexcept
|
||||||
|
{
|
||||||
|
T* res = ptr;
|
||||||
|
ptr = nullptr;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename T> __forceinline bool operator < (const Ref<T>& a, const Ref<T>& b) noexcept { return a.ptr < b.ptr; }
|
||||||
|
|
||||||
|
template<typename T> __forceinline bool operator ==(const Ref<T>& a, std::nullptr_t) noexcept { return a.ptr == nullptr; }
|
||||||
|
template<typename T> __forceinline bool operator ==(std::nullptr_t, const Ref<T>& b) noexcept { return nullptr == b.ptr; }
|
||||||
|
template<typename T> __forceinline bool operator ==(const Ref<T>& a, const Ref<T>& b) noexcept { return a.ptr == b.ptr; }
|
||||||
|
|
||||||
|
template<typename T> __forceinline bool operator !=(const Ref<T>& a, std::nullptr_t) noexcept { return a.ptr != nullptr; }
|
||||||
|
template<typename T> __forceinline bool operator !=(std::nullptr_t, const Ref<T>& b) noexcept { return nullptr != b.ptr; }
|
||||||
|
template<typename T> __forceinline bool operator !=(const Ref<T>& a, const Ref<T>& b) noexcept { return a.ptr != b.ptr; }
|
||||||
|
|
||||||
|
template<typename T, typename... Args>
|
||||||
|
__forceinline Ref<T> makeRef(Args&&... args)
|
||||||
|
{
|
||||||
|
return Ref<T>(new T(std::forward<Args>(args)...));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T, typename Y>
|
||||||
|
__forceinline Ref<Y> staticRefCast(const Ref<T>& a)
|
||||||
|
{
|
||||||
|
return Ref<Y>(static_cast<Y*>(a.get()));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T, typename Y>
|
||||||
|
__forceinline Ref<Y> dynamicRefCast(const Ref<T>& a)
|
||||||
|
{
|
||||||
|
return Ref<Y>(dynamic_cast<Y*>(a.get()));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace oidn
|
83
thirdparty/oidn/common/tensor.cpp
vendored
Normal file
83
thirdparty/oidn/common/tensor.cpp
vendored
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
// ======================================================================== //
|
||||||
|
// Copyright 2009-2019 Intel Corporation //
|
||||||
|
// //
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||||||
|
// you may not use this file except in compliance with the License. //
|
||||||
|
// You may obtain a copy of the License at //
|
||||||
|
// //
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||||||
|
// //
|
||||||
|
// Unless required by applicable law or agreed to in writing, software //
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||||||
|
// See the License for the specific language governing permissions and //
|
||||||
|
// limitations under the License. //
|
||||||
|
// ======================================================================== //
|
||||||
|
|
||||||
|
#include "exception.h"
|
||||||
|
#include "tensor.h"
|
||||||
|
|
||||||
|
namespace oidn {
|
||||||
|
|
||||||
|
std::map<std::string, Tensor> parseTensors(void* buffer)
|
||||||
|
{
|
||||||
|
char* input = (char*)buffer;
|
||||||
|
|
||||||
|
// Parse the magic value
|
||||||
|
const int magic = *(unsigned short*)input;
|
||||||
|
if (magic != 0x41D7)
|
||||||
|
throw Exception(Error::InvalidOperation, "invalid tensor archive");
|
||||||
|
input += sizeof(unsigned short);
|
||||||
|
|
||||||
|
// Parse the version
|
||||||
|
const int majorVersion = *(unsigned char*)input++;
|
||||||
|
const int minorVersion = *(unsigned char*)input++;
|
||||||
|
UNUSED(minorVersion);
|
||||||
|
if (majorVersion > 1)
|
||||||
|
throw Exception(Error::InvalidOperation, "unsupported tensor archive version");
|
||||||
|
|
||||||
|
// Parse the number of tensors
|
||||||
|
const int numTensors = *(int*)input;
|
||||||
|
input += sizeof(int);
|
||||||
|
|
||||||
|
// Parse the tensors
|
||||||
|
std::map<std::string, Tensor> tensorMap;
|
||||||
|
for (int i = 0; i < numTensors; ++i)
|
||||||
|
{
|
||||||
|
Tensor tensor;
|
||||||
|
|
||||||
|
// Parse the name
|
||||||
|
const int nameLen = *(unsigned char*)input++;
|
||||||
|
std::string name(input, nameLen);
|
||||||
|
input += nameLen;
|
||||||
|
|
||||||
|
// Parse the number of dimensions
|
||||||
|
const int ndims = *(unsigned char*)input++;
|
||||||
|
|
||||||
|
// Parse the shape of the tensor
|
||||||
|
tensor.dims.resize(ndims);
|
||||||
|
for (int i = 0; i < ndims; ++i)
|
||||||
|
tensor.dims[i] = ((int*)input)[i];
|
||||||
|
input += ndims * sizeof(int);
|
||||||
|
|
||||||
|
// Parse the format of the tensor
|
||||||
|
tensor.format = std::string(input, input + ndims);
|
||||||
|
input += ndims;
|
||||||
|
|
||||||
|
// Parse the data type of the tensor
|
||||||
|
const char type = *(unsigned char*)input++;
|
||||||
|
if (type != 'f') // only float32 is supported
|
||||||
|
throw Exception(Error::InvalidOperation, "unsupported tensor data type");
|
||||||
|
|
||||||
|
// Skip the data
|
||||||
|
tensor.data = (float*)input;
|
||||||
|
input += tensor.size() * sizeof(float);
|
||||||
|
|
||||||
|
// Add the tensor to the map
|
||||||
|
tensorMap.emplace(name, std::move(tensor));
|
||||||
|
}
|
||||||
|
|
||||||
|
return tensorMap;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace oidn
|
66
thirdparty/oidn/common/tensor.h
vendored
Normal file
66
thirdparty/oidn/common/tensor.h
vendored
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
// ======================================================================== //
|
||||||
|
// Copyright 2009-2019 Intel Corporation //
|
||||||
|
// //
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||||||
|
// you may not use this file except in compliance with the License. //
|
||||||
|
// You may obtain a copy of the License at //
|
||||||
|
// //
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||||||
|
// //
|
||||||
|
// Unless required by applicable law or agreed to in writing, software //
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||||||
|
// See the License for the specific language governing permissions and //
|
||||||
|
// limitations under the License. //
|
||||||
|
// ======================================================================== //
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "platform.h"
|
||||||
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
namespace oidn {
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
using shared_vector = std::shared_ptr<std::vector<T>>;
|
||||||
|
|
||||||
|
// Generic tensor
|
||||||
|
struct Tensor
|
||||||
|
{
|
||||||
|
float* data;
|
||||||
|
std::vector<int64_t> dims;
|
||||||
|
std::string format;
|
||||||
|
shared_vector<char> buffer; // optional, only for reference counting
|
||||||
|
|
||||||
|
__forceinline Tensor() : data(nullptr) {}
|
||||||
|
|
||||||
|
__forceinline Tensor(const std::vector<int64_t>& dims, const std::string& format)
|
||||||
|
: dims(dims),
|
||||||
|
format(format)
|
||||||
|
{
|
||||||
|
buffer = std::make_shared<std::vector<char>>(size() * sizeof(float));
|
||||||
|
data = (float*)buffer->data();
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline operator bool() const { return data != nullptr; }
|
||||||
|
|
||||||
|
__forceinline int ndims() const { return (int)dims.size(); }
|
||||||
|
|
||||||
|
// Returns the number of values
|
||||||
|
__forceinline size_t size() const
|
||||||
|
{
|
||||||
|
size_t size = 1;
|
||||||
|
for (int i = 0; i < ndims(); ++i)
|
||||||
|
size *= dims[i];
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline float& operator [](size_t i) { return data[i]; }
|
||||||
|
__forceinline const float& operator [](size_t i) const { return data[i]; }
|
||||||
|
};
|
||||||
|
|
||||||
|
// Parses tensors from a buffer
|
||||||
|
std::map<std::string, Tensor> parseTensors(void* buffer);
|
||||||
|
|
||||||
|
} // namespace oidn
|
297
thirdparty/oidn/common/thread.cpp
vendored
Normal file
297
thirdparty/oidn/common/thread.cpp
vendored
Normal file
@ -0,0 +1,297 @@
|
|||||||
|
// ======================================================================== //
|
||||||
|
// Copyright 2009-2019 Intel Corporation //
|
||||||
|
// //
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||||||
|
// you may not use this file except in compliance with the License. //
|
||||||
|
// You may obtain a copy of the License at //
|
||||||
|
// //
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||||||
|
// //
|
||||||
|
// Unless required by applicable law or agreed to in writing, software //
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||||||
|
// See the License for the specific language governing permissions and //
|
||||||
|
// limitations under the License. //
|
||||||
|
// ======================================================================== //
|
||||||
|
|
||||||
|
#if defined(_MSC_VER)
|
||||||
|
#pragma warning (disable : 4146) // unary minus operator applied to unsigned type, result still unsigned
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(__APPLE__)
|
||||||
|
#include <mach/thread_act.h>
|
||||||
|
#include <mach/mach_init.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "thread.h"
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
|
namespace oidn {
|
||||||
|
|
||||||
|
#if defined(_WIN32)
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// ThreadAffinity - Windows
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
ThreadAffinity::ThreadAffinity(int numThreadsPerCore, int verbose)
|
||||||
|
: Verbose(verbose)
|
||||||
|
{
|
||||||
|
HMODULE hLib = GetModuleHandle(TEXT("kernel32"));
|
||||||
|
pGetLogicalProcessorInformationEx = (GetLogicalProcessorInformationExFunc)GetProcAddress(hLib, "GetLogicalProcessorInformationEx");
|
||||||
|
pSetThreadGroupAffinity = (SetThreadGroupAffinityFunc)GetProcAddress(hLib, "SetThreadGroupAffinity");
|
||||||
|
|
||||||
|
if (pGetLogicalProcessorInformationEx && pSetThreadGroupAffinity)
|
||||||
|
{
|
||||||
|
// Get logical processor information
|
||||||
|
PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX buffer = nullptr;
|
||||||
|
DWORD bufferSize = 0;
|
||||||
|
|
||||||
|
// First call the function with an empty buffer to get the required buffer size
|
||||||
|
BOOL result = pGetLogicalProcessorInformationEx(RelationProcessorCore, buffer, &bufferSize);
|
||||||
|
if (result || GetLastError() != ERROR_INSUFFICIENT_BUFFER)
|
||||||
|
{
|
||||||
|
OIDN_WARNING("GetLogicalProcessorInformationEx failed");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate the buffer
|
||||||
|
buffer = (PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX)malloc(bufferSize);
|
||||||
|
if (!buffer)
|
||||||
|
{
|
||||||
|
OIDN_WARNING("SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX allocation failed");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call again the function but now with the properly sized buffer
|
||||||
|
result = pGetLogicalProcessorInformationEx(RelationProcessorCore, buffer, &bufferSize);
|
||||||
|
if (!result)
|
||||||
|
{
|
||||||
|
OIDN_WARNING("GetLogicalProcessorInformationEx failed");
|
||||||
|
free(buffer);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Iterate over the logical processor information structures
|
||||||
|
// There should be one structure for each physical core
|
||||||
|
char* ptr = (char*)buffer;
|
||||||
|
while (ptr < (char*)buffer + bufferSize)
|
||||||
|
{
|
||||||
|
PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX item = (PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX)ptr;
|
||||||
|
if (item->Relationship == RelationProcessorCore && item->Processor.GroupCount > 0)
|
||||||
|
{
|
||||||
|
// Iterate over the groups
|
||||||
|
int numThreads = 0;
|
||||||
|
for (int group = 0; (group < item->Processor.GroupCount) && (numThreads < numThreadsPerCore); ++group)
|
||||||
|
{
|
||||||
|
GROUP_AFFINITY coreAffinity = item->Processor.GroupMask[group];
|
||||||
|
while ((coreAffinity.Mask != 0) && (numThreads < numThreadsPerCore))
|
||||||
|
{
|
||||||
|
// Extract the next set bit/thread from the mask
|
||||||
|
GROUP_AFFINITY threadAffinity = coreAffinity;
|
||||||
|
threadAffinity.Mask = threadAffinity.Mask & -threadAffinity.Mask;
|
||||||
|
|
||||||
|
// Push the affinity for this thread
|
||||||
|
affinities.push_back(threadAffinity);
|
||||||
|
oldAffinities.push_back(threadAffinity);
|
||||||
|
numThreads++;
|
||||||
|
|
||||||
|
// Remove this bit/thread from the mask
|
||||||
|
coreAffinity.Mask ^= threadAffinity.Mask;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next structure
|
||||||
|
ptr += item->Size;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Free the buffer
|
||||||
|
free(buffer);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ThreadAffinity::set(int threadIndex)
|
||||||
|
{
|
||||||
|
if (threadIndex >= (int)affinities.size())
|
||||||
|
return;
|
||||||
|
|
||||||
|
// Save the current affinity and set the new one
|
||||||
|
const HANDLE thread = GetCurrentThread();
|
||||||
|
if (!pSetThreadGroupAffinity(thread, &affinities[threadIndex], &oldAffinities[threadIndex]))
|
||||||
|
OIDN_WARNING("SetThreadGroupAffinity failed");
|
||||||
|
}
|
||||||
|
|
||||||
|
void ThreadAffinity::restore(int threadIndex)
|
||||||
|
{
|
||||||
|
if (threadIndex >= (int)affinities.size())
|
||||||
|
return;
|
||||||
|
|
||||||
|
// Restore the original affinity
|
||||||
|
const HANDLE thread = GetCurrentThread();
|
||||||
|
if (!pSetThreadGroupAffinity(thread, &oldAffinities[threadIndex], nullptr))
|
||||||
|
OIDN_WARNING("SetThreadGroupAffinity failed");
|
||||||
|
}
|
||||||
|
|
||||||
|
#elif defined(__linux__)
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// ThreadAffinity - Linux
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
ThreadAffinity::ThreadAffinity(int numThreadsPerCore, int verbose)
|
||||||
|
: Verbose(verbose)
|
||||||
|
{
|
||||||
|
std::vector<int> threadIds;
|
||||||
|
|
||||||
|
// Parse the thread/CPU topology
|
||||||
|
for (int cpuId = 0; ; cpuId++)
|
||||||
|
{
|
||||||
|
std::fstream fs;
|
||||||
|
std::string cpu = std::string("/sys/devices/system/cpu/cpu") + std::to_string(cpuId) + std::string("/topology/thread_siblings_list");
|
||||||
|
fs.open(cpu.c_str(), std::fstream::in);
|
||||||
|
if (fs.fail()) break;
|
||||||
|
|
||||||
|
int i;
|
||||||
|
int j = 0;
|
||||||
|
while ((j < numThreadsPerCore) && (fs >> i))
|
||||||
|
{
|
||||||
|
if (std::none_of(threadIds.begin(), threadIds.end(), [&](int id) { return id == i; }))
|
||||||
|
threadIds.push_back(i);
|
||||||
|
|
||||||
|
if (fs.peek() == ',')
|
||||||
|
fs.ignore();
|
||||||
|
j++;
|
||||||
|
}
|
||||||
|
|
||||||
|
fs.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
#if 0
|
||||||
|
for (size_t i = 0; i < thread_ids.size(); ++i)
|
||||||
|
std::cout << "thread " << i << " -> " << thread_ids[i] << std::endl;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Create the affinity structures
|
||||||
|
affinities.resize(threadIds.size());
|
||||||
|
oldAffinities.resize(threadIds.size());
|
||||||
|
|
||||||
|
for (size_t i = 0; i < threadIds.size(); ++i)
|
||||||
|
{
|
||||||
|
cpu_set_t affinity;
|
||||||
|
CPU_ZERO(&affinity);
|
||||||
|
CPU_SET(threadIds[i], &affinity);
|
||||||
|
|
||||||
|
affinities[i] = affinity;
|
||||||
|
oldAffinities[i] = affinity;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ThreadAffinity::set(int threadIndex)
|
||||||
|
{
|
||||||
|
if (threadIndex >= (int)affinities.size())
|
||||||
|
return;
|
||||||
|
|
||||||
|
const pthread_t thread = pthread_self();
|
||||||
|
|
||||||
|
// Save the current affinity
|
||||||
|
if (pthread_getaffinity_np(thread, sizeof(cpu_set_t), &oldAffinities[threadIndex]) != 0)
|
||||||
|
{
|
||||||
|
OIDN_WARNING("pthread_getaffinity_np failed");
|
||||||
|
oldAffinities[threadIndex] = affinities[threadIndex];
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the new affinity
|
||||||
|
if (pthread_setaffinity_np(thread, sizeof(cpu_set_t), &affinities[threadIndex]) != 0)
|
||||||
|
OIDN_WARNING("pthread_setaffinity_np failed");
|
||||||
|
}
|
||||||
|
|
||||||
|
void ThreadAffinity::restore(int threadIndex)
|
||||||
|
{
|
||||||
|
if (threadIndex >= (int)affinities.size())
|
||||||
|
return;
|
||||||
|
|
||||||
|
const pthread_t thread = pthread_self();
|
||||||
|
|
||||||
|
// Restore the original affinity
|
||||||
|
if (pthread_setaffinity_np(thread, sizeof(cpu_set_t), &oldAffinities[threadIndex]) != 0)
|
||||||
|
OIDN_WARNING("pthread_setaffinity_np failed");
|
||||||
|
}
|
||||||
|
|
||||||
|
#elif defined(__APPLE__)
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// ThreadAffinity - macOS
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
ThreadAffinity::ThreadAffinity(int numThreadsPerCore, int verbose)
|
||||||
|
: Verbose(verbose)
|
||||||
|
{
|
||||||
|
// Query the thread/CPU topology
|
||||||
|
int numPhysicalCpus;
|
||||||
|
int numLogicalCpus;
|
||||||
|
|
||||||
|
if (!getSysctl("hw.physicalcpu", numPhysicalCpus) || !getSysctl("hw.logicalcpu", numLogicalCpus))
|
||||||
|
{
|
||||||
|
OIDN_WARNING("sysctlbyname failed");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((numLogicalCpus % numPhysicalCpus != 0) && (numThreadsPerCore > 1))
|
||||||
|
return; // this shouldn't happen
|
||||||
|
const int maxThreadsPerCore = numLogicalCpus / numPhysicalCpus;
|
||||||
|
|
||||||
|
// Create the affinity structures
|
||||||
|
// macOS doesn't support binding a thread to a specific core, but we can at least group threads which
|
||||||
|
// should be on the same core together
|
||||||
|
for (int core = 1; core <= numPhysicalCpus; ++core) // tags start from 1!
|
||||||
|
{
|
||||||
|
thread_affinity_policy affinity;
|
||||||
|
affinity.affinity_tag = core;
|
||||||
|
|
||||||
|
for (int thread = 0; thread < min(numThreadsPerCore, maxThreadsPerCore); ++thread)
|
||||||
|
{
|
||||||
|
affinities.push_back(affinity);
|
||||||
|
oldAffinities.push_back(affinity);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ThreadAffinity::set(int threadIndex)
|
||||||
|
{
|
||||||
|
if (threadIndex >= (int)affinities.size())
|
||||||
|
return;
|
||||||
|
|
||||||
|
const auto thread = mach_thread_self();
|
||||||
|
|
||||||
|
// Save the current affinity
|
||||||
|
mach_msg_type_number_t policyCount = THREAD_AFFINITY_POLICY_COUNT;
|
||||||
|
boolean_t getDefault = FALSE;
|
||||||
|
if (thread_policy_get(thread, THREAD_AFFINITY_POLICY, (thread_policy_t)&oldAffinities[threadIndex], &policyCount, &getDefault) != KERN_SUCCESS)
|
||||||
|
{
|
||||||
|
OIDN_WARNING("thread_policy_get failed");
|
||||||
|
oldAffinities[threadIndex] = affinities[threadIndex];
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the new affinity
|
||||||
|
if (thread_policy_set(thread, THREAD_AFFINITY_POLICY, (thread_policy_t)&affinities[threadIndex], THREAD_AFFINITY_POLICY_COUNT) != KERN_SUCCESS)
|
||||||
|
OIDN_WARNING("thread_policy_set failed");
|
||||||
|
}
|
||||||
|
|
||||||
|
void ThreadAffinity::restore(int threadIndex)
|
||||||
|
{
|
||||||
|
if (threadIndex >= (int)affinities.size())
|
||||||
|
return;
|
||||||
|
|
||||||
|
const auto thread = mach_thread_self();
|
||||||
|
|
||||||
|
// Restore the original affinity
|
||||||
|
if (thread_policy_set(thread, THREAD_AFFINITY_POLICY, (thread_policy_t)&oldAffinities[threadIndex], THREAD_AFFINITY_POLICY_COUNT) != KERN_SUCCESS)
|
||||||
|
OIDN_WARNING("thread_policy_set failed");
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace oidn
|
202
thirdparty/oidn/common/thread.h
vendored
Normal file
202
thirdparty/oidn/common/thread.h
vendored
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
// ======================================================================== //
|
||||||
|
// Copyright 2009-2019 Intel Corporation //
|
||||||
|
// //
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||||||
|
// you may not use this file except in compliance with the License. //
|
||||||
|
// You may obtain a copy of the License at //
|
||||||
|
// //
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||||||
|
// //
|
||||||
|
// Unless required by applicable law or agreed to in writing, software //
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||||||
|
// See the License for the specific language governing permissions and //
|
||||||
|
// limitations under the License. //
|
||||||
|
// ======================================================================== //
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "platform.h"
|
||||||
|
|
||||||
|
#if !defined(_WIN32)
|
||||||
|
#include <pthread.h>
|
||||||
|
#include <sched.h>
|
||||||
|
#if defined(__APPLE__)
|
||||||
|
#include <mach/thread_policy.h>
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <mutex>
|
||||||
|
|
||||||
|
namespace oidn {
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// ThreadLocal
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// Wrapper which makes any variable thread-local
|
||||||
|
template<typename T>
|
||||||
|
class ThreadLocal : public Verbose
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
#if defined(_WIN32)
|
||||||
|
DWORD key;
|
||||||
|
#else
|
||||||
|
pthread_key_t key;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
std::vector<T*> instances;
|
||||||
|
std::mutex mutex;
|
||||||
|
|
||||||
|
public:
|
||||||
|
ThreadLocal(int verbose = 0)
|
||||||
|
: Verbose(verbose)
|
||||||
|
{
|
||||||
|
#if defined(_WIN32)
|
||||||
|
key = TlsAlloc();
|
||||||
|
if (key == TLS_OUT_OF_INDEXES)
|
||||||
|
OIDN_FATAL("TlsAlloc failed");
|
||||||
|
#else
|
||||||
|
if (pthread_key_create(&key, nullptr) != 0)
|
||||||
|
OIDN_FATAL("pthread_key_create failed");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
~ThreadLocal()
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(mutex);
|
||||||
|
for (T* ptr : instances)
|
||||||
|
delete ptr;
|
||||||
|
|
||||||
|
#if defined(_WIN32)
|
||||||
|
if (!TlsFree(key))
|
||||||
|
OIDN_WARNING("TlsFree failed");
|
||||||
|
#else
|
||||||
|
if (pthread_key_delete(key) != 0)
|
||||||
|
OIDN_WARNING("pthread_key_delete failed");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
T& get()
|
||||||
|
{
|
||||||
|
#if defined(_WIN32)
|
||||||
|
T* ptr = (T*)TlsGetValue(key);
|
||||||
|
#else
|
||||||
|
T* ptr = (T*)pthread_getspecific(key);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
if (ptr)
|
||||||
|
return *ptr;
|
||||||
|
|
||||||
|
ptr = new T;
|
||||||
|
std::lock_guard<std::mutex> lock(mutex);
|
||||||
|
instances.push_back(ptr);
|
||||||
|
|
||||||
|
#if defined(_WIN32)
|
||||||
|
if (!TlsSetValue(key, ptr))
|
||||||
|
OIDN_FATAL("TlsSetValue failed");
|
||||||
|
#else
|
||||||
|
if (pthread_setspecific(key, ptr) != 0)
|
||||||
|
OIDN_FATAL("pthread_setspecific failed");
|
||||||
|
#endif
|
||||||
|
|
||||||
|
return *ptr;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#if defined(_WIN32)
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// ThreadAffinity - Windows
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class ThreadAffinity : public Verbose
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
typedef BOOL (WINAPI *GetLogicalProcessorInformationExFunc)(LOGICAL_PROCESSOR_RELATIONSHIP,
|
||||||
|
PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX,
|
||||||
|
PDWORD);
|
||||||
|
|
||||||
|
typedef BOOL (WINAPI *SetThreadGroupAffinityFunc)(HANDLE,
|
||||||
|
CONST GROUP_AFFINITY*,
|
||||||
|
PGROUP_AFFINITY);
|
||||||
|
|
||||||
|
GetLogicalProcessorInformationExFunc pGetLogicalProcessorInformationEx = nullptr;
|
||||||
|
SetThreadGroupAffinityFunc pSetThreadGroupAffinity = nullptr;
|
||||||
|
|
||||||
|
std::vector<GROUP_AFFINITY> affinities; // thread affinities
|
||||||
|
std::vector<GROUP_AFFINITY> oldAffinities; // original thread affinities
|
||||||
|
|
||||||
|
public:
|
||||||
|
ThreadAffinity(int numThreadsPerCore = INT_MAX, int verbose = 0);
|
||||||
|
|
||||||
|
int getNumThreads() const
|
||||||
|
{
|
||||||
|
return (int)affinities.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sets the affinity (0..numThreads-1) of the thread after saving the current affinity
|
||||||
|
void set(int threadIndex);
|
||||||
|
|
||||||
|
// Restores the affinity of the thread
|
||||||
|
void restore(int threadIndex);
|
||||||
|
};
|
||||||
|
|
||||||
|
#elif defined(__linux__)
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// ThreadAffinity - Linux
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class ThreadAffinity : public Verbose
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
std::vector<cpu_set_t> affinities; // thread affinities
|
||||||
|
std::vector<cpu_set_t> oldAffinities; // original thread affinities
|
||||||
|
|
||||||
|
public:
|
||||||
|
ThreadAffinity(int numThreadsPerCore = INT_MAX, int verbose = 0);
|
||||||
|
|
||||||
|
int getNumThreads() const
|
||||||
|
{
|
||||||
|
return (int)affinities.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sets the affinity (0..numThreads-1) of the thread after saving the current affinity
|
||||||
|
void set(int threadIndex);
|
||||||
|
|
||||||
|
// Restores the affinity of the thread
|
||||||
|
void restore(int threadIndex);
|
||||||
|
};
|
||||||
|
|
||||||
|
#elif defined(__APPLE__)
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// ThreadAffinity - macOS
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class ThreadAffinity : public Verbose
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
std::vector<thread_affinity_policy> affinities; // thread affinities
|
||||||
|
std::vector<thread_affinity_policy> oldAffinities; // original thread affinities
|
||||||
|
|
||||||
|
public:
|
||||||
|
ThreadAffinity(int numThreadsPerCore = INT_MAX, int verbose = 0);
|
||||||
|
|
||||||
|
int getNumThreads() const
|
||||||
|
{
|
||||||
|
return (int)affinities.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sets the affinity (0..numThreads-1) of the thread after saving the current affinity
|
||||||
|
void set(int threadIndex);
|
||||||
|
|
||||||
|
// Restores the affinity of the thread
|
||||||
|
void restore(int threadIndex);
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace oidn
|
49
thirdparty/oidn/common/timer.h
vendored
Normal file
49
thirdparty/oidn/common/timer.h
vendored
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
// ======================================================================== //
|
||||||
|
// Copyright 2009-2019 Intel Corporation //
|
||||||
|
// //
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||||||
|
// you may not use this file except in compliance with the License. //
|
||||||
|
// You may obtain a copy of the License at //
|
||||||
|
// //
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||||||
|
// //
|
||||||
|
// Unless required by applicable law or agreed to in writing, software //
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||||||
|
// See the License for the specific language governing permissions and //
|
||||||
|
// limitations under the License. //
|
||||||
|
// ======================================================================== //
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "platform.h"
|
||||||
|
#include <chrono>
|
||||||
|
|
||||||
|
namespace oidn {
|
||||||
|
|
||||||
|
class Timer
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
using clock = std::chrono::high_resolution_clock;
|
||||||
|
|
||||||
|
std::chrono::time_point<clock> start;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Timer()
|
||||||
|
{
|
||||||
|
reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
void reset()
|
||||||
|
{
|
||||||
|
start = clock::now();
|
||||||
|
}
|
||||||
|
|
||||||
|
double query() const
|
||||||
|
{
|
||||||
|
auto end = clock::now();
|
||||||
|
return std::chrono::duration_cast<std::chrono::duration<double>>(end - start).count();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace oidn
|
408
thirdparty/oidn/core/api.cpp
vendored
Normal file
408
thirdparty/oidn/core/api.cpp
vendored
Normal file
@ -0,0 +1,408 @@
|
|||||||
|
// ======================================================================== //
|
||||||
|
// Copyright 2009-2019 Intel Corporation //
|
||||||
|
// //
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||||||
|
// you may not use this file except in compliance with the License. //
|
||||||
|
// You may obtain a copy of the License at //
|
||||||
|
// //
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||||||
|
// //
|
||||||
|
// Unless required by applicable law or agreed to in writing, software //
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||||||
|
// See the License for the specific language governing permissions and //
|
||||||
|
// limitations under the License. //
|
||||||
|
// ======================================================================== //
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
# define OIDN_API extern "C" __declspec(dllexport)
|
||||||
|
#else
|
||||||
|
# define OIDN_API extern "C" __attribute__ ((visibility ("default")))
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Locks the device that owns the specified object
|
||||||
|
// Use *only* inside OIDN_TRY/CATCH!
|
||||||
|
#define OIDN_LOCK(obj) \
|
||||||
|
std::lock_guard<std::mutex> lock(obj->getDevice()->getMutex());
|
||||||
|
|
||||||
|
// Try/catch for converting exceptions to errors
|
||||||
|
#define OIDN_TRY \
|
||||||
|
try {
|
||||||
|
|
||||||
|
#define OIDN_CATCH(obj) \
|
||||||
|
} catch (Exception& e) { \
|
||||||
|
Device::setError(obj ? obj->getDevice() : nullptr, e.code(), e.what()); \
|
||||||
|
} catch (std::bad_alloc&) { \
|
||||||
|
Device::setError(obj ? obj->getDevice() : nullptr, Error::OutOfMemory, "out of memory"); \
|
||||||
|
} catch (mkldnn::error& e) { \
|
||||||
|
if (e.status == mkldnn_out_of_memory) \
|
||||||
|
Device::setError(obj ? obj->getDevice() : nullptr, Error::OutOfMemory, "out of memory"); \
|
||||||
|
else \
|
||||||
|
Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, e.message); \
|
||||||
|
} catch (std::exception& e) { \
|
||||||
|
Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, e.what()); \
|
||||||
|
} catch (...) { \
|
||||||
|
Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, "unknown exception caught"); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#include "device.h"
|
||||||
|
#include "filter.h"
|
||||||
|
#include <mutex>
|
||||||
|
|
||||||
|
namespace oidn {
|
||||||
|
|
||||||
|
namespace
|
||||||
|
{
|
||||||
|
__forceinline void checkHandle(void* handle)
|
||||||
|
{
|
||||||
|
if (handle == nullptr)
|
||||||
|
throw Exception(Error::InvalidArgument, "invalid handle");
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
__forceinline void retainObject(T* obj)
|
||||||
|
{
|
||||||
|
if (obj)
|
||||||
|
{
|
||||||
|
obj->incRef();
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
OIDN_TRY
|
||||||
|
checkHandle(obj);
|
||||||
|
OIDN_CATCH(obj)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
__forceinline void releaseObject(T* obj)
|
||||||
|
{
|
||||||
|
if (obj == nullptr || obj->decRefKeep() == 0)
|
||||||
|
{
|
||||||
|
OIDN_TRY
|
||||||
|
checkHandle(obj);
|
||||||
|
OIDN_LOCK(obj);
|
||||||
|
obj->destroy();
|
||||||
|
OIDN_CATCH(obj)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
__forceinline void releaseObject(Device* obj)
|
||||||
|
{
|
||||||
|
if (obj == nullptr || obj->decRefKeep() == 0)
|
||||||
|
{
|
||||||
|
OIDN_TRY
|
||||||
|
checkHandle(obj);
|
||||||
|
// Do NOT lock the device because it owns the mutex
|
||||||
|
obj->destroy();
|
||||||
|
OIDN_CATCH(obj)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
OIDN_API OIDNDevice oidnNewDevice(OIDNDeviceType type)
|
||||||
|
{
|
||||||
|
Ref<Device> device = nullptr;
|
||||||
|
OIDN_TRY
|
||||||
|
if (type == OIDN_DEVICE_TYPE_CPU || type == OIDN_DEVICE_TYPE_DEFAULT)
|
||||||
|
device = makeRef<Device>();
|
||||||
|
else
|
||||||
|
throw Exception(Error::InvalidArgument, "invalid device type");
|
||||||
|
OIDN_CATCH(device)
|
||||||
|
return (OIDNDevice)device.detach();
|
||||||
|
}
|
||||||
|
|
||||||
|
OIDN_API void oidnRetainDevice(OIDNDevice hDevice)
|
||||||
|
{
|
||||||
|
Device* device = (Device*)hDevice;
|
||||||
|
retainObject(device);
|
||||||
|
}
|
||||||
|
|
||||||
|
OIDN_API void oidnReleaseDevice(OIDNDevice hDevice)
|
||||||
|
{
|
||||||
|
Device* device = (Device*)hDevice;
|
||||||
|
releaseObject(device);
|
||||||
|
}
|
||||||
|
|
||||||
|
OIDN_API void oidnSetDevice1b(OIDNDevice hDevice, const char* name, bool value)
|
||||||
|
{
|
||||||
|
Device* device = (Device*)hDevice;
|
||||||
|
OIDN_TRY
|
||||||
|
checkHandle(hDevice);
|
||||||
|
OIDN_LOCK(device);
|
||||||
|
device->set1i(name, value);
|
||||||
|
OIDN_CATCH(device)
|
||||||
|
}
|
||||||
|
|
||||||
|
OIDN_API void oidnSetDevice1i(OIDNDevice hDevice, const char* name, int value)
|
||||||
|
{
|
||||||
|
Device* device = (Device*)hDevice;
|
||||||
|
OIDN_TRY
|
||||||
|
checkHandle(hDevice);
|
||||||
|
OIDN_LOCK(device);
|
||||||
|
device->set1i(name, value);
|
||||||
|
OIDN_CATCH(device)
|
||||||
|
}
|
||||||
|
|
||||||
|
OIDN_API bool oidnGetDevice1b(OIDNDevice hDevice, const char* name)
|
||||||
|
{
|
||||||
|
Device* device = (Device*)hDevice;
|
||||||
|
OIDN_TRY
|
||||||
|
checkHandle(hDevice);
|
||||||
|
OIDN_LOCK(device);
|
||||||
|
return device->get1i(name);
|
||||||
|
OIDN_CATCH(device)
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
OIDN_API int oidnGetDevice1i(OIDNDevice hDevice, const char* name)
|
||||||
|
{
|
||||||
|
Device* device = (Device*)hDevice;
|
||||||
|
OIDN_TRY
|
||||||
|
checkHandle(hDevice);
|
||||||
|
OIDN_LOCK(device);
|
||||||
|
return device->get1i(name);
|
||||||
|
OIDN_CATCH(device)
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
OIDN_API void oidnSetDeviceErrorFunction(OIDNDevice hDevice, OIDNErrorFunction func, void* userPtr)
|
||||||
|
{
|
||||||
|
Device* device = (Device*)hDevice;
|
||||||
|
OIDN_TRY
|
||||||
|
checkHandle(hDevice);
|
||||||
|
OIDN_LOCK(device);
|
||||||
|
device->setErrorFunction((ErrorFunction)func, userPtr);
|
||||||
|
OIDN_CATCH(device)
|
||||||
|
}
|
||||||
|
|
||||||
|
OIDN_API OIDNError oidnGetDeviceError(OIDNDevice hDevice, const char** outMessage)
|
||||||
|
{
|
||||||
|
Device* device = (Device*)hDevice;
|
||||||
|
OIDN_TRY
|
||||||
|
return (OIDNError)Device::getError(device, outMessage);
|
||||||
|
OIDN_CATCH(device)
|
||||||
|
if (outMessage) *outMessage = "";
|
||||||
|
return OIDN_ERROR_UNKNOWN;
|
||||||
|
}
|
||||||
|
|
||||||
|
OIDN_API void oidnCommitDevice(OIDNDevice hDevice)
|
||||||
|
{
|
||||||
|
Device* device = (Device*)hDevice;
|
||||||
|
OIDN_TRY
|
||||||
|
checkHandle(hDevice);
|
||||||
|
OIDN_LOCK(device);
|
||||||
|
device->commit();
|
||||||
|
OIDN_CATCH(device)
|
||||||
|
}
|
||||||
|
|
||||||
|
OIDN_API OIDNBuffer oidnNewBuffer(OIDNDevice hDevice, size_t byteSize)
|
||||||
|
{
|
||||||
|
Device* device = (Device*)hDevice;
|
||||||
|
OIDN_TRY
|
||||||
|
checkHandle(hDevice);
|
||||||
|
OIDN_LOCK(device);
|
||||||
|
Ref<Buffer> buffer = device->newBuffer(byteSize);
|
||||||
|
return (OIDNBuffer)buffer.detach();
|
||||||
|
OIDN_CATCH(device)
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
OIDN_API OIDNBuffer oidnNewSharedBuffer(OIDNDevice hDevice, void* ptr, size_t byteSize)
|
||||||
|
{
|
||||||
|
Device* device = (Device*)hDevice;
|
||||||
|
OIDN_TRY
|
||||||
|
checkHandle(hDevice);
|
||||||
|
OIDN_LOCK(device);
|
||||||
|
Ref<Buffer> buffer = device->newBuffer(ptr, byteSize);
|
||||||
|
return (OIDNBuffer)buffer.detach();
|
||||||
|
OIDN_CATCH(device)
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
OIDN_API void oidnRetainBuffer(OIDNBuffer hBuffer)
|
||||||
|
{
|
||||||
|
Buffer* buffer = (Buffer*)hBuffer;
|
||||||
|
retainObject(buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
OIDN_API void oidnReleaseBuffer(OIDNBuffer hBuffer)
|
||||||
|
{
|
||||||
|
Buffer* buffer = (Buffer*)hBuffer;
|
||||||
|
releaseObject(buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
OIDN_API void* oidnMapBuffer(OIDNBuffer hBuffer, OIDNAccess access, size_t byteOffset, size_t byteSize)
|
||||||
|
{
|
||||||
|
Buffer* buffer = (Buffer*)hBuffer;
|
||||||
|
OIDN_TRY
|
||||||
|
checkHandle(hBuffer);
|
||||||
|
OIDN_LOCK(buffer);
|
||||||
|
return buffer->map(byteOffset, byteSize);
|
||||||
|
OIDN_CATCH(buffer)
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
OIDN_API void oidnUnmapBuffer(OIDNBuffer hBuffer, void* mappedPtr)
|
||||||
|
{
|
||||||
|
Buffer* buffer = (Buffer*)hBuffer;
|
||||||
|
OIDN_TRY
|
||||||
|
checkHandle(hBuffer);
|
||||||
|
OIDN_LOCK(buffer);
|
||||||
|
return buffer->unmap(mappedPtr);
|
||||||
|
OIDN_CATCH(buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
OIDN_API OIDNFilter oidnNewFilter(OIDNDevice hDevice, const char* type)
|
||||||
|
{
|
||||||
|
Device* device = (Device*)hDevice;
|
||||||
|
OIDN_TRY
|
||||||
|
checkHandle(hDevice);
|
||||||
|
OIDN_LOCK(device);
|
||||||
|
Ref<Filter> filter = device->newFilter(type);
|
||||||
|
return (OIDNFilter)filter.detach();
|
||||||
|
OIDN_CATCH(device)
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
OIDN_API void oidnRetainFilter(OIDNFilter hFilter)
|
||||||
|
{
|
||||||
|
Filter* filter = (Filter*)hFilter;
|
||||||
|
retainObject(filter);
|
||||||
|
}
|
||||||
|
|
||||||
|
OIDN_API void oidnReleaseFilter(OIDNFilter hFilter)
|
||||||
|
{
|
||||||
|
Filter* filter = (Filter*)hFilter;
|
||||||
|
releaseObject(filter);
|
||||||
|
}
|
||||||
|
|
||||||
|
OIDN_API void oidnSetFilterImage(OIDNFilter hFilter, const char* name,
|
||||||
|
OIDNBuffer hBuffer, OIDNFormat format,
|
||||||
|
size_t width, size_t height,
|
||||||
|
size_t byteOffset,
|
||||||
|
size_t bytePixelStride, size_t byteRowStride)
|
||||||
|
{
|
||||||
|
Filter* filter = (Filter*)hFilter;
|
||||||
|
OIDN_TRY
|
||||||
|
checkHandle(hFilter);
|
||||||
|
checkHandle(hBuffer);
|
||||||
|
OIDN_LOCK(filter);
|
||||||
|
Ref<Buffer> buffer = (Buffer*)hBuffer;
|
||||||
|
if (buffer->getDevice() != filter->getDevice())
|
||||||
|
throw Exception(Error::InvalidArgument, "the specified objects are bound to different devices");
|
||||||
|
Image data(buffer, (Format)format, (int)width, (int)height, byteOffset, bytePixelStride, byteRowStride);
|
||||||
|
filter->setImage(name, data);
|
||||||
|
OIDN_CATCH(filter)
|
||||||
|
}
|
||||||
|
|
||||||
|
OIDN_API void oidnSetSharedFilterImage(OIDNFilter hFilter, const char* name,
|
||||||
|
void* ptr, OIDNFormat format,
|
||||||
|
size_t width, size_t height,
|
||||||
|
size_t byteOffset,
|
||||||
|
size_t bytePixelStride, size_t byteRowStride)
|
||||||
|
{
|
||||||
|
Filter* filter = (Filter*)hFilter;
|
||||||
|
OIDN_TRY
|
||||||
|
checkHandle(hFilter);
|
||||||
|
OIDN_LOCK(filter);
|
||||||
|
Image data(ptr, (Format)format, (int)width, (int)height, byteOffset, bytePixelStride, byteRowStride);
|
||||||
|
filter->setImage(name, data);
|
||||||
|
OIDN_CATCH(filter)
|
||||||
|
}
|
||||||
|
|
||||||
|
OIDN_API void oidnSetFilter1b(OIDNFilter hFilter, const char* name, bool value)
|
||||||
|
{
|
||||||
|
Filter* filter = (Filter*)hFilter;
|
||||||
|
OIDN_TRY
|
||||||
|
checkHandle(hFilter);
|
||||||
|
OIDN_LOCK(filter);
|
||||||
|
filter->set1i(name, int(value));
|
||||||
|
OIDN_CATCH(filter)
|
||||||
|
}
|
||||||
|
|
||||||
|
OIDN_API bool oidnGetFilter1b(OIDNFilter hFilter, const char* name)
|
||||||
|
{
|
||||||
|
Filter* filter = (Filter*)hFilter;
|
||||||
|
OIDN_TRY
|
||||||
|
checkHandle(hFilter);
|
||||||
|
OIDN_LOCK(filter);
|
||||||
|
return filter->get1i(name);
|
||||||
|
OIDN_CATCH(filter)
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
OIDN_API void oidnSetFilter1i(OIDNFilter hFilter, const char* name, int value)
|
||||||
|
{
|
||||||
|
Filter* filter = (Filter*)hFilter;
|
||||||
|
OIDN_TRY
|
||||||
|
checkHandle(hFilter);
|
||||||
|
OIDN_LOCK(filter);
|
||||||
|
filter->set1i(name, value);
|
||||||
|
OIDN_CATCH(filter)
|
||||||
|
}
|
||||||
|
|
||||||
|
OIDN_API int oidnGetFilter1i(OIDNFilter hFilter, const char* name)
|
||||||
|
{
|
||||||
|
Filter* filter = (Filter*)hFilter;
|
||||||
|
OIDN_TRY
|
||||||
|
checkHandle(hFilter);
|
||||||
|
OIDN_LOCK(filter);
|
||||||
|
return filter->get1i(name);
|
||||||
|
OIDN_CATCH(filter)
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
OIDN_API void oidnSetFilter1f(OIDNFilter hFilter, const char* name, float value)
|
||||||
|
{
|
||||||
|
Filter* filter = (Filter*)hFilter;
|
||||||
|
OIDN_TRY
|
||||||
|
checkHandle(hFilter);
|
||||||
|
OIDN_LOCK(filter);
|
||||||
|
filter->set1f(name, value);
|
||||||
|
OIDN_CATCH(filter)
|
||||||
|
}
|
||||||
|
|
||||||
|
OIDN_API float oidnGetFilter1f(OIDNFilter hFilter, const char* name)
|
||||||
|
{
|
||||||
|
Filter* filter = (Filter*)hFilter;
|
||||||
|
OIDN_TRY
|
||||||
|
checkHandle(hFilter);
|
||||||
|
OIDN_LOCK(filter);
|
||||||
|
return filter->get1f(name);
|
||||||
|
OIDN_CATCH(filter)
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
OIDN_API void oidnSetFilterProgressMonitorFunction(OIDNFilter hFilter, OIDNProgressMonitorFunction func, void* userPtr)
|
||||||
|
{
|
||||||
|
Filter* filter = (Filter*)hFilter;
|
||||||
|
OIDN_TRY
|
||||||
|
checkHandle(hFilter);
|
||||||
|
OIDN_LOCK(filter);
|
||||||
|
filter->setProgressMonitorFunction(func, userPtr);
|
||||||
|
OIDN_CATCH(filter)
|
||||||
|
}
|
||||||
|
|
||||||
|
OIDN_API void oidnCommitFilter(OIDNFilter hFilter)
|
||||||
|
{
|
||||||
|
Filter* filter = (Filter*)hFilter;
|
||||||
|
OIDN_TRY
|
||||||
|
checkHandle(hFilter);
|
||||||
|
OIDN_LOCK(filter);
|
||||||
|
filter->commit();
|
||||||
|
OIDN_CATCH(filter)
|
||||||
|
}
|
||||||
|
|
||||||
|
OIDN_API void oidnExecuteFilter(OIDNFilter hFilter)
|
||||||
|
{
|
||||||
|
Filter* filter = (Filter*)hFilter;
|
||||||
|
OIDN_TRY
|
||||||
|
checkHandle(hFilter);
|
||||||
|
OIDN_LOCK(filter);
|
||||||
|
filter->execute();
|
||||||
|
OIDN_CATCH(filter)
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace oidn
|
535
thirdparty/oidn/core/autoencoder.cpp
vendored
Normal file
535
thirdparty/oidn/core/autoencoder.cpp
vendored
Normal file
@ -0,0 +1,535 @@
|
|||||||
|
// ======================================================================== //
|
||||||
|
// Copyright 2009-2019 Intel Corporation //
|
||||||
|
// //
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||||||
|
// you may not use this file except in compliance with the License. //
|
||||||
|
// You may obtain a copy of the License at //
|
||||||
|
// //
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||||||
|
// //
|
||||||
|
// Unless required by applicable law or agreed to in writing, software //
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||||||
|
// See the License for the specific language governing permissions and //
|
||||||
|
// limitations under the License. //
|
||||||
|
// ======================================================================== //
|
||||||
|
|
||||||
|
#include "autoencoder.h"
|
||||||
|
|
||||||
|
namespace oidn {
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// AutoencoderFilter
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
AutoencoderFilter::AutoencoderFilter(const Ref<Device>& device)
|
||||||
|
: Filter(device)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
void AutoencoderFilter::setImage(const std::string& name, const Image& data)
|
||||||
|
{
|
||||||
|
if (name == "color")
|
||||||
|
color = data;
|
||||||
|
else if (name == "albedo")
|
||||||
|
albedo = data;
|
||||||
|
else if (name == "normal")
|
||||||
|
normal = data;
|
||||||
|
else if (name == "output")
|
||||||
|
output = data;
|
||||||
|
|
||||||
|
dirty = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AutoencoderFilter::set1i(const std::string& name, int value)
|
||||||
|
{
|
||||||
|
if (name == "hdr")
|
||||||
|
hdr = value;
|
||||||
|
else if (name == "srgb")
|
||||||
|
srgb = value;
|
||||||
|
else if (name == "maxMemoryMB")
|
||||||
|
maxMemoryMB = value;
|
||||||
|
|
||||||
|
dirty = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
int AutoencoderFilter::get1i(const std::string& name)
|
||||||
|
{
|
||||||
|
if (name == "hdr")
|
||||||
|
return hdr;
|
||||||
|
else if (name == "srgb")
|
||||||
|
return srgb;
|
||||||
|
else if (name == "maxMemoryMB")
|
||||||
|
return maxMemoryMB;
|
||||||
|
else if (name == "alignment")
|
||||||
|
return alignment;
|
||||||
|
else if (name == "overlap")
|
||||||
|
return overlap;
|
||||||
|
else
|
||||||
|
throw Exception(Error::InvalidArgument, "invalid parameter");
|
||||||
|
}
|
||||||
|
|
||||||
|
void AutoencoderFilter::set1f(const std::string& name, float value)
|
||||||
|
{
|
||||||
|
if (name == "hdrScale")
|
||||||
|
hdrScale = value;
|
||||||
|
|
||||||
|
dirty = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
float AutoencoderFilter::get1f(const std::string& name)
|
||||||
|
{
|
||||||
|
if (name == "hdrScale")
|
||||||
|
return hdrScale;
|
||||||
|
else
|
||||||
|
throw Exception(Error::InvalidArgument, "invalid parameter");
|
||||||
|
}
|
||||||
|
|
||||||
|
void AutoencoderFilter::commit()
|
||||||
|
{
|
||||||
|
if (!dirty)
|
||||||
|
return;
|
||||||
|
|
||||||
|
// -- GODOT start --
|
||||||
|
//device->executeTask([&]()
|
||||||
|
//{
|
||||||
|
// GODOT end --
|
||||||
|
|
||||||
|
if (mayiuse(avx512_common))
|
||||||
|
net = buildNet<16>();
|
||||||
|
else
|
||||||
|
net = buildNet<8>();
|
||||||
|
|
||||||
|
// GODOT start --
|
||||||
|
//});
|
||||||
|
// GODOT end --
|
||||||
|
|
||||||
|
dirty = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AutoencoderFilter::execute()
|
||||||
|
{
|
||||||
|
if (dirty)
|
||||||
|
throw Exception(Error::InvalidOperation, "changes to the filter are not committed");
|
||||||
|
|
||||||
|
if (!net)
|
||||||
|
return;
|
||||||
|
// -- GODOT start --
|
||||||
|
//device->executeTask([&]()
|
||||||
|
//{
|
||||||
|
// -- GODOT end --
|
||||||
|
Progress progress;
|
||||||
|
progress.func = progressFunc;
|
||||||
|
progress.userPtr = progressUserPtr;
|
||||||
|
progress.taskCount = tileCountH * tileCountW;
|
||||||
|
|
||||||
|
// Iterate over the tiles
|
||||||
|
int tileIndex = 0;
|
||||||
|
|
||||||
|
for (int i = 0; i < tileCountH; ++i)
|
||||||
|
{
|
||||||
|
const int h = i * (tileH - 2*overlap); // input tile position (including overlap)
|
||||||
|
const int overlapBeginH = i > 0 ? overlap : 0; // overlap on the top
|
||||||
|
const int overlapEndH = i < tileCountH-1 ? overlap : 0; // overlap on the bottom
|
||||||
|
const int tileH1 = min(H - h, tileH); // input tile size (including overlap)
|
||||||
|
const int tileH2 = tileH1 - overlapBeginH - overlapEndH; // output tile size
|
||||||
|
const int alignOffsetH = tileH - roundUp(tileH1, alignment); // align to the bottom in the tile buffer
|
||||||
|
|
||||||
|
for (int j = 0; j < tileCountW; ++j)
|
||||||
|
{
|
||||||
|
const int w = j * (tileW - 2*overlap); // input tile position (including overlap)
|
||||||
|
const int overlapBeginW = j > 0 ? overlap : 0; // overlap on the left
|
||||||
|
const int overlapEndW = j < tileCountW-1 ? overlap : 0; // overlap on the right
|
||||||
|
const int tileW1 = min(W - w, tileW); // input tile size (including overlap)
|
||||||
|
const int tileW2 = tileW1 - overlapBeginW - overlapEndW; // output tile size
|
||||||
|
const int alignOffsetW = tileW - roundUp(tileW1, alignment); // align to the right in the tile buffer
|
||||||
|
|
||||||
|
// Set the input tile
|
||||||
|
inputReorder->setTile(h, w,
|
||||||
|
alignOffsetH, alignOffsetW,
|
||||||
|
tileH1, tileW1);
|
||||||
|
|
||||||
|
// Set the output tile
|
||||||
|
outputReorder->setTile(alignOffsetH + overlapBeginH, alignOffsetW + overlapBeginW,
|
||||||
|
h + overlapBeginH, w + overlapBeginW,
|
||||||
|
tileH2, tileW2);
|
||||||
|
|
||||||
|
//printf("Tile: %d %d -> %d %d\n", w+overlapBeginW, h+overlapBeginH, w+overlapBeginW+tileW2, h+overlapBeginH+tileH2);
|
||||||
|
|
||||||
|
// Denoise the tile
|
||||||
|
net->execute(progress, tileIndex);
|
||||||
|
|
||||||
|
// Next tile
|
||||||
|
tileIndex++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// -- GODOT start --
|
||||||
|
//});
|
||||||
|
// -- GODOT end --
|
||||||
|
}
|
||||||
|
|
||||||
|
void AutoencoderFilter::computeTileSize()
|
||||||
|
{
|
||||||
|
const int minTileSize = 3*overlap;
|
||||||
|
const int estimatedBytesPerPixel = mayiuse(avx512_common) ? estimatedBytesPerPixel16 : estimatedBytesPerPixel8;
|
||||||
|
const int64_t maxTilePixels = (int64_t(maxMemoryMB)*1024*1024 - estimatedBytesBase) / estimatedBytesPerPixel;
|
||||||
|
|
||||||
|
tileCountH = 1;
|
||||||
|
tileCountW = 1;
|
||||||
|
tileH = roundUp(H, alignment);
|
||||||
|
tileW = roundUp(W, alignment);
|
||||||
|
|
||||||
|
// Divide the image into tiles until the tile size gets below the threshold
|
||||||
|
while (int64_t(tileH) * tileW > maxTilePixels)
|
||||||
|
{
|
||||||
|
if (tileH > minTileSize && tileH > tileW)
|
||||||
|
{
|
||||||
|
tileCountH++;
|
||||||
|
tileH = max(roundUp(ceilDiv(H - 2*overlap, tileCountH), alignment) + 2*overlap, minTileSize);
|
||||||
|
}
|
||||||
|
else if (tileW > minTileSize)
|
||||||
|
{
|
||||||
|
tileCountW++;
|
||||||
|
tileW = max(roundUp(ceilDiv(W - 2*overlap, tileCountW), alignment) + 2*overlap, minTileSize);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the final number of tiles
|
||||||
|
tileCountH = (H > tileH) ? ceilDiv(H - 2*overlap, tileH - 2*overlap) : 1;
|
||||||
|
tileCountW = (W > tileW) ? ceilDiv(W - 2*overlap, tileW - 2*overlap) : 1;
|
||||||
|
|
||||||
|
if (device->isVerbose(2))
|
||||||
|
{
|
||||||
|
std::cout << "Tile size : " << tileW << "x" << tileH << std::endl;
|
||||||
|
std::cout << "Tile count: " << tileCountW << "x" << tileCountH << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int K>
|
||||||
|
std::shared_ptr<Executable> AutoencoderFilter::buildNet()
|
||||||
|
{
|
||||||
|
H = color.height;
|
||||||
|
W = color.width;
|
||||||
|
|
||||||
|
// Configure the network
|
||||||
|
int inputC;
|
||||||
|
void* weightPtr;
|
||||||
|
|
||||||
|
if (srgb && hdr)
|
||||||
|
throw Exception(Error::InvalidOperation, "srgb and hdr modes cannot be enabled at the same time");
|
||||||
|
|
||||||
|
if (color && !albedo && !normal && weightData.hdr)
|
||||||
|
{
|
||||||
|
inputC = 3;
|
||||||
|
weightPtr = hdr ? weightData.hdr : weightData.ldr;
|
||||||
|
}
|
||||||
|
else if (color && albedo && !normal && weightData.hdr_alb)
|
||||||
|
{
|
||||||
|
inputC = 6;
|
||||||
|
weightPtr = hdr ? weightData.hdr_alb : weightData.ldr_alb;
|
||||||
|
}
|
||||||
|
else if (color && albedo && normal && weightData.hdr_alb_nrm)
|
||||||
|
{
|
||||||
|
inputC = 9;
|
||||||
|
weightPtr = hdr ? weightData.hdr_alb_nrm : weightData.ldr_alb_nrm;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
throw Exception(Error::InvalidOperation, "unsupported combination of input features");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!output)
|
||||||
|
throw Exception(Error::InvalidOperation, "output image not specified");
|
||||||
|
|
||||||
|
if ((color.format != Format::Float3)
|
||||||
|
|| (albedo && albedo.format != Format::Float3)
|
||||||
|
|| (normal && normal.format != Format::Float3)
|
||||||
|
|| (output.format != Format::Float3))
|
||||||
|
throw Exception(Error::InvalidOperation, "unsupported image format");
|
||||||
|
|
||||||
|
if ((albedo && (albedo.width != W || albedo.height != H))
|
||||||
|
|| (normal && (normal.width != W || normal.height != H))
|
||||||
|
|| (output.width != W || output.height != H))
|
||||||
|
throw Exception(Error::InvalidOperation, "image size mismatch");
|
||||||
|
|
||||||
|
// Compute the tile size
|
||||||
|
computeTileSize();
|
||||||
|
|
||||||
|
// If the image size is zero, there is nothing else to do
|
||||||
|
if (H <= 0 || W <= 0)
|
||||||
|
return nullptr;
|
||||||
|
|
||||||
|
// Parse the weights
|
||||||
|
const auto weightMap = parseTensors(weightPtr);
|
||||||
|
|
||||||
|
// Create the network
|
||||||
|
std::shared_ptr<Network<K>> net = std::make_shared<Network<K>>(device, weightMap);
|
||||||
|
|
||||||
|
// Compute the tensor sizes
|
||||||
|
const auto inputDims = memory::dims({1, inputC, tileH, tileW});
|
||||||
|
const auto inputReorderDims = net->getInputReorderDims(inputDims, alignment); //-> concat0
|
||||||
|
|
||||||
|
const auto conv1Dims = net->getConvDims("conv1", inputReorderDims); //-> temp0
|
||||||
|
const auto conv1bDims = net->getConvDims("conv1b", conv1Dims); //-> temp1
|
||||||
|
const auto pool1Dims = net->getPoolDims(conv1bDims); //-> concat1
|
||||||
|
const auto conv2Dims = net->getConvDims("conv2", pool1Dims); //-> temp0
|
||||||
|
const auto pool2Dims = net->getPoolDims(conv2Dims); //-> concat2
|
||||||
|
const auto conv3Dims = net->getConvDims("conv3", pool2Dims); //-> temp0
|
||||||
|
const auto pool3Dims = net->getPoolDims(conv3Dims); //-> concat3
|
||||||
|
const auto conv4Dims = net->getConvDims("conv4", pool3Dims); //-> temp0
|
||||||
|
const auto pool4Dims = net->getPoolDims(conv4Dims); //-> concat4
|
||||||
|
const auto conv5Dims = net->getConvDims("conv5", pool4Dims); //-> temp0
|
||||||
|
const auto pool5Dims = net->getPoolDims(conv5Dims); //-> temp1
|
||||||
|
const auto upsample4Dims = net->getUpsampleDims(pool5Dims); //-> concat4
|
||||||
|
const auto concat4Dims = net->getConcatDims(upsample4Dims, pool4Dims);
|
||||||
|
const auto conv6Dims = net->getConvDims("conv6", concat4Dims); //-> temp0
|
||||||
|
const auto conv6bDims = net->getConvDims("conv6b", conv6Dims); //-> temp1
|
||||||
|
const auto upsample3Dims = net->getUpsampleDims(conv6bDims); //-> concat3
|
||||||
|
const auto concat3Dims = net->getConcatDims(upsample3Dims, pool3Dims);
|
||||||
|
const auto conv7Dims = net->getConvDims("conv7", concat3Dims); //-> temp0
|
||||||
|
const auto conv7bDims = net->getConvDims("conv7b", conv7Dims); //-> temp1
|
||||||
|
const auto upsample2Dims = net->getUpsampleDims(conv7bDims); //-> concat2
|
||||||
|
const auto concat2Dims = net->getConcatDims(upsample2Dims, pool2Dims);
|
||||||
|
const auto conv8Dims = net->getConvDims("conv8", concat2Dims); //-> temp0
|
||||||
|
const auto conv8bDims = net->getConvDims("conv8b", conv8Dims); //-> temp1
|
||||||
|
const auto upsample1Dims = net->getUpsampleDims(conv8bDims); //-> concat1
|
||||||
|
const auto concat1Dims = net->getConcatDims(upsample1Dims, pool1Dims);
|
||||||
|
const auto conv9Dims = net->getConvDims("conv9", concat1Dims); //-> temp0
|
||||||
|
const auto conv9bDims = net->getConvDims("conv9b", conv9Dims); //-> temp1
|
||||||
|
const auto upsample0Dims = net->getUpsampleDims(conv9bDims); //-> concat0
|
||||||
|
const auto concat0Dims = net->getConcatDims(upsample0Dims, inputReorderDims);
|
||||||
|
const auto conv10Dims = net->getConvDims("conv10", concat0Dims); //-> temp0
|
||||||
|
const auto conv10bDims = net->getConvDims("conv10b", conv10Dims); //-> temp1
|
||||||
|
const auto conv11Dims = net->getConvDims("conv11", conv10bDims); //-> temp0
|
||||||
|
|
||||||
|
const auto outputDims = memory::dims({1, 3, tileH, tileW});
|
||||||
|
|
||||||
|
// Allocate two temporary ping-pong buffers to decrease memory usage
|
||||||
|
const auto temp0Dims = getMaxTensorDims({
|
||||||
|
conv1Dims,
|
||||||
|
conv2Dims,
|
||||||
|
conv3Dims,
|
||||||
|
conv4Dims,
|
||||||
|
conv5Dims,
|
||||||
|
conv6Dims,
|
||||||
|
conv7Dims,
|
||||||
|
conv8Dims,
|
||||||
|
conv9Dims,
|
||||||
|
conv10Dims,
|
||||||
|
conv11Dims
|
||||||
|
});
|
||||||
|
|
||||||
|
const auto temp1Dims = getMaxTensorDims({
|
||||||
|
conv1bDims,
|
||||||
|
pool5Dims,
|
||||||
|
conv6bDims,
|
||||||
|
conv7bDims,
|
||||||
|
conv8bDims,
|
||||||
|
conv9bDims,
|
||||||
|
conv10bDims,
|
||||||
|
});
|
||||||
|
|
||||||
|
auto temp0 = net->allocTensor(temp0Dims);
|
||||||
|
auto temp1 = net->allocTensor(temp1Dims);
|
||||||
|
|
||||||
|
// Allocate enough memory to hold the concat outputs. Then use the first
|
||||||
|
// half to hold the previous conv output and the second half to hold the
|
||||||
|
// pool/orig image output. This works because everything is C dimension
|
||||||
|
// outermost, padded to K floats, and all the concats are on the C dimension.
|
||||||
|
auto concat0Dst = net->allocTensor(concat0Dims);
|
||||||
|
auto concat1Dst = net->allocTensor(concat1Dims);
|
||||||
|
auto concat2Dst = net->allocTensor(concat2Dims);
|
||||||
|
auto concat3Dst = net->allocTensor(concat3Dims);
|
||||||
|
auto concat4Dst = net->allocTensor(concat4Dims);
|
||||||
|
|
||||||
|
// Transfer function
|
||||||
|
std::shared_ptr<TransferFunction> transferFunc = makeTransferFunc();
|
||||||
|
|
||||||
|
// Autoexposure
|
||||||
|
if (auto tf = std::dynamic_pointer_cast<HDRTransferFunction>(transferFunc))
|
||||||
|
{
|
||||||
|
if (isnan(hdrScale))
|
||||||
|
net->addAutoexposure(color, tf);
|
||||||
|
else
|
||||||
|
tf->setExposure(hdrScale);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Input reorder
|
||||||
|
auto inputReorderDst = net->castTensor(inputReorderDims, concat0Dst, upsample0Dims);
|
||||||
|
inputReorder = net->addInputReorder(color, albedo, normal,
|
||||||
|
transferFunc,
|
||||||
|
alignment, inputReorderDst);
|
||||||
|
|
||||||
|
// conv1
|
||||||
|
auto conv1 = net->addConv("conv1", inputReorder->getDst(), temp0);
|
||||||
|
|
||||||
|
// conv1b
|
||||||
|
auto conv1b = net->addConv("conv1b", conv1->getDst(), temp1);
|
||||||
|
|
||||||
|
// pool1
|
||||||
|
// Adjust pointer for pool1 to eliminate concat1
|
||||||
|
auto pool1Dst = net->castTensor(pool1Dims, concat1Dst, upsample1Dims);
|
||||||
|
auto pool1 = net->addPool(conv1b->getDst(), pool1Dst);
|
||||||
|
|
||||||
|
// conv2
|
||||||
|
auto conv2 = net->addConv("conv2", pool1->getDst(), temp0);
|
||||||
|
|
||||||
|
// pool2
|
||||||
|
// Adjust pointer for pool2 to eliminate concat2
|
||||||
|
auto pool2Dst = net->castTensor(pool2Dims, concat2Dst, upsample2Dims);
|
||||||
|
auto pool2 = net->addPool(conv2->getDst(), pool2Dst);
|
||||||
|
|
||||||
|
// conv3
|
||||||
|
auto conv3 = net->addConv("conv3", pool2->getDst(), temp0);
|
||||||
|
|
||||||
|
// pool3
|
||||||
|
// Adjust pointer for pool3 to eliminate concat3
|
||||||
|
auto pool3Dst = net->castTensor(pool3Dims, concat3Dst, upsample3Dims);
|
||||||
|
auto pool3 = net->addPool(conv3->getDst(), pool3Dst);
|
||||||
|
|
||||||
|
// conv4
|
||||||
|
auto conv4 = net->addConv("conv4", pool3->getDst(), temp0);
|
||||||
|
|
||||||
|
// pool4
|
||||||
|
// Adjust pointer for pool4 to eliminate concat4
|
||||||
|
auto pool4Dst = net->castTensor(pool4Dims, concat4Dst, upsample4Dims);
|
||||||
|
auto pool4 = net->addPool(conv4->getDst(), pool4Dst);
|
||||||
|
|
||||||
|
// conv5
|
||||||
|
auto conv5 = net->addConv("conv5", pool4->getDst(), temp0);
|
||||||
|
|
||||||
|
// pool5
|
||||||
|
auto pool5 = net->addPool(conv5->getDst(), temp1);
|
||||||
|
|
||||||
|
// upsample4
|
||||||
|
auto upsample4Dst = net->castTensor(upsample4Dims, concat4Dst);
|
||||||
|
auto upsample4 = net->addUpsample(pool5->getDst(), upsample4Dst);
|
||||||
|
|
||||||
|
// conv6
|
||||||
|
auto conv6 = net->addConv("conv6", concat4Dst, temp0);
|
||||||
|
|
||||||
|
// conv6b
|
||||||
|
auto conv6b = net->addConv("conv6b", conv6->getDst(), temp1);
|
||||||
|
|
||||||
|
// upsample3
|
||||||
|
auto upsample3Dst = net->castTensor(upsample3Dims, concat3Dst);
|
||||||
|
auto upsample3 = net->addUpsample(conv6b->getDst(), upsample3Dst);
|
||||||
|
|
||||||
|
// conv7
|
||||||
|
auto conv7 = net->addConv("conv7", concat3Dst, temp0);
|
||||||
|
|
||||||
|
// conv7b
|
||||||
|
auto conv7b = net->addConv("conv7b", conv7->getDst(), temp1);
|
||||||
|
|
||||||
|
// upsample2
|
||||||
|
auto upsample2Dst = net->castTensor(upsample2Dims, concat2Dst);
|
||||||
|
auto upsample2 = net->addUpsample(conv7b->getDst(), upsample2Dst);
|
||||||
|
|
||||||
|
// conv8
|
||||||
|
auto conv8 = net->addConv("conv8", concat2Dst, temp0);
|
||||||
|
|
||||||
|
// conv8b
|
||||||
|
auto conv8b = net->addConv("conv8b", conv8->getDst(), temp1);
|
||||||
|
|
||||||
|
// upsample1
|
||||||
|
auto upsample1Dst = net->castTensor(upsample1Dims, concat1Dst);
|
||||||
|
auto upsample1 = net->addUpsample(conv8b->getDst(), upsample1Dst);
|
||||||
|
|
||||||
|
// conv9
|
||||||
|
auto conv9 = net->addConv("conv9", concat1Dst, temp0);
|
||||||
|
|
||||||
|
// conv9b
|
||||||
|
auto conv9b = net->addConv("conv9b", conv9->getDst(), temp1);
|
||||||
|
|
||||||
|
// upsample0
|
||||||
|
auto upsample0Dst = net->castTensor(upsample0Dims, concat0Dst);
|
||||||
|
auto upsample0 = net->addUpsample(conv9b->getDst(), upsample0Dst);
|
||||||
|
|
||||||
|
// conv10
|
||||||
|
auto conv10 = net->addConv("conv10", concat0Dst, temp0);
|
||||||
|
|
||||||
|
// conv10b
|
||||||
|
auto conv10b = net->addConv("conv10b", conv10->getDst(), temp1);
|
||||||
|
|
||||||
|
// conv11
|
||||||
|
auto conv11 = net->addConv("conv11", conv10b->getDst(), temp0, false /* no relu */);
|
||||||
|
|
||||||
|
// Output reorder
|
||||||
|
outputReorder = net->addOutputReorder(conv11->getDst(), transferFunc, output);
|
||||||
|
|
||||||
|
net->finalize();
|
||||||
|
return net;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<TransferFunction> AutoencoderFilter::makeTransferFunc()
|
||||||
|
{
|
||||||
|
if (hdr)
|
||||||
|
return std::make_shared<PQXTransferFunction>();
|
||||||
|
else if (srgb)
|
||||||
|
return std::make_shared<LinearTransferFunction>();
|
||||||
|
else
|
||||||
|
return std::make_shared<GammaTransferFunction>();
|
||||||
|
}
|
||||||
|
|
||||||
|
// -- GODOT start --
|
||||||
|
// Godot doesn't need Raytracing filters. Removing them saves space in the weights files.
|
||||||
|
#if 0
|
||||||
|
// -- GODOT end --
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// RTFilter
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
namespace weights
|
||||||
|
{
|
||||||
|
// LDR
|
||||||
|
extern unsigned char rt_ldr[]; // color
|
||||||
|
extern unsigned char rt_ldr_alb[]; // color, albedo
|
||||||
|
extern unsigned char rt_ldr_alb_nrm[]; // color, albedo, normal
|
||||||
|
|
||||||
|
// HDR
|
||||||
|
extern unsigned char rt_hdr[]; // color
|
||||||
|
extern unsigned char rt_hdr_alb[]; // color, albedo
|
||||||
|
extern unsigned char rt_hdr_alb_nrm[]; // color, albedo, normal
|
||||||
|
}
|
||||||
|
|
||||||
|
RTFilter::RTFilter(const Ref<Device>& device)
|
||||||
|
: AutoencoderFilter(device)
|
||||||
|
{
|
||||||
|
weightData.ldr = weights::rt_ldr;
|
||||||
|
weightData.ldr_alb = weights::rt_ldr_alb;
|
||||||
|
weightData.ldr_alb_nrm = weights::rt_ldr_alb_nrm;
|
||||||
|
weightData.hdr = weights::rt_hdr;
|
||||||
|
weightData.hdr_alb = weights::rt_hdr_alb;
|
||||||
|
weightData.hdr_alb_nrm = weights::rt_hdr_alb_nrm;
|
||||||
|
}
|
||||||
|
// -- GODOT start --
|
||||||
|
#endif
|
||||||
|
// -- GODOT end --
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// RTLightmapFilter
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
namespace weights
|
||||||
|
{
|
||||||
|
// HDR
|
||||||
|
extern unsigned char rtlightmap_hdr[]; // color
|
||||||
|
}
|
||||||
|
|
||||||
|
RTLightmapFilter::RTLightmapFilter(const Ref<Device>& device)
|
||||||
|
: AutoencoderFilter(device)
|
||||||
|
{
|
||||||
|
weightData.hdr = weights::rtlightmap_hdr;
|
||||||
|
|
||||||
|
hdr = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<TransferFunction> RTLightmapFilter::makeTransferFunc()
|
||||||
|
{
|
||||||
|
return std::make_shared<LogTransferFunction>();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace oidn
|
120
thirdparty/oidn/core/autoencoder.h
vendored
Normal file
120
thirdparty/oidn/core/autoencoder.h
vendored
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
// ======================================================================== //
|
||||||
|
// Copyright 2009-2019 Intel Corporation //
|
||||||
|
// //
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||||||
|
// you may not use this file except in compliance with the License. //
|
||||||
|
// You may obtain a copy of the License at //
|
||||||
|
// //
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||||||
|
// //
|
||||||
|
// Unless required by applicable law or agreed to in writing, software //
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||||||
|
// See the License for the specific language governing permissions and //
|
||||||
|
// limitations under the License. //
|
||||||
|
// ======================================================================== //
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "filter.h"
|
||||||
|
#include "network.h"
|
||||||
|
#include "transfer_function.h"
|
||||||
|
|
||||||
|
namespace oidn {
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// AutoencoderFilter - Direct-predicting autoencoder
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class AutoencoderFilter : public Filter
|
||||||
|
{
|
||||||
|
protected:
|
||||||
|
static constexpr int alignment = 32; // required spatial alignment in pixels (padding may be necessary)
|
||||||
|
static constexpr int receptiveField = 222; // receptive field in pixels
|
||||||
|
static constexpr int overlap = roundUp(receptiveField / 2, alignment); // required spatial overlap between tiles in pixels
|
||||||
|
|
||||||
|
static constexpr int estimatedBytesBase = 16*1024*1024; // estimated base memory usage
|
||||||
|
static constexpr int estimatedBytesPerPixel8 = 889; // estimated memory usage per pixel for K=8
|
||||||
|
static constexpr int estimatedBytesPerPixel16 = 2185; // estimated memory usage per pixel for K=16
|
||||||
|
|
||||||
|
Image color;
|
||||||
|
Image albedo;
|
||||||
|
Image normal;
|
||||||
|
Image output;
|
||||||
|
bool hdr = false;
|
||||||
|
float hdrScale = std::numeric_limits<float>::quiet_NaN();
|
||||||
|
bool srgb = false;
|
||||||
|
int maxMemoryMB = 6000; // approximate maximum memory usage in MBs
|
||||||
|
|
||||||
|
int H = 0; // image height
|
||||||
|
int W = 0; // image width
|
||||||
|
int tileH = 0; // tile height
|
||||||
|
int tileW = 0; // tile width
|
||||||
|
int tileCountH = 1; // number of tiles in H dimension
|
||||||
|
int tileCountW = 1; // number of tiles in W dimension
|
||||||
|
|
||||||
|
std::shared_ptr<Executable> net;
|
||||||
|
std::shared_ptr<Node> inputReorder;
|
||||||
|
std::shared_ptr<Node> outputReorder;
|
||||||
|
|
||||||
|
struct
|
||||||
|
{
|
||||||
|
void* ldr = nullptr;
|
||||||
|
void* ldr_alb = nullptr;
|
||||||
|
void* ldr_alb_nrm = nullptr;
|
||||||
|
void* hdr = nullptr;
|
||||||
|
void* hdr_alb = nullptr;
|
||||||
|
void* hdr_alb_nrm = nullptr;
|
||||||
|
} weightData;
|
||||||
|
|
||||||
|
explicit AutoencoderFilter(const Ref<Device>& device);
|
||||||
|
virtual std::shared_ptr<TransferFunction> makeTransferFunc();
|
||||||
|
|
||||||
|
public:
|
||||||
|
void setImage(const std::string& name, const Image& data) override;
|
||||||
|
void set1i(const std::string& name, int value) override;
|
||||||
|
int get1i(const std::string& name) override;
|
||||||
|
void set1f(const std::string& name, float value) override;
|
||||||
|
float get1f(const std::string& name) override;
|
||||||
|
|
||||||
|
void commit() override;
|
||||||
|
void execute() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void computeTileSize();
|
||||||
|
|
||||||
|
template<int K>
|
||||||
|
std::shared_ptr<Executable> buildNet();
|
||||||
|
|
||||||
|
bool isCommitted() const { return bool(net); }
|
||||||
|
};
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// RTFilter - Generic ray tracing denoiser
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// -- GODOT start --
|
||||||
|
// Godot doesn't need Raytracing filters. Removing them saves space in the weights files.
|
||||||
|
#if 0
|
||||||
|
// -- GODOT end --
|
||||||
|
class RTFilter : public AutoencoderFilter
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
explicit RTFilter(const Ref<Device>& device);
|
||||||
|
};
|
||||||
|
// -- GODOT start --
|
||||||
|
#endif
|
||||||
|
// -- GODOT end --
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// RTLightmapFilter - Ray traced lightmap denoiser
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class RTLightmapFilter : public AutoencoderFilter
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
explicit RTLightmapFilter(const Ref<Device>& device);
|
||||||
|
std::shared_ptr<TransferFunction> makeTransferFunc() override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace oidn
|
75
thirdparty/oidn/core/buffer.h
vendored
Normal file
75
thirdparty/oidn/core/buffer.h
vendored
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
// ======================================================================== //
|
||||||
|
// Copyright 2009-2019 Intel Corporation //
|
||||||
|
// //
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||||||
|
// you may not use this file except in compliance with the License. //
|
||||||
|
// You may obtain a copy of the License at //
|
||||||
|
// //
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||||||
|
// //
|
||||||
|
// Unless required by applicable law or agreed to in writing, software //
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||||||
|
// See the License for the specific language governing permissions and //
|
||||||
|
// limitations under the License. //
|
||||||
|
// ======================================================================== //
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "common.h"
|
||||||
|
#include "device.h"
|
||||||
|
|
||||||
|
namespace oidn {
|
||||||
|
|
||||||
|
class Device;
|
||||||
|
|
||||||
|
// Buffer which may or may not own its data
|
||||||
|
class Buffer : public RefCount
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
char* ptr;
|
||||||
|
size_t byteSize;
|
||||||
|
bool shared;
|
||||||
|
Ref<Device> device;
|
||||||
|
|
||||||
|
public:
|
||||||
|
__forceinline Buffer(const Ref<Device>& device, size_t size)
|
||||||
|
: ptr((char*)alignedMalloc(size, 64)),
|
||||||
|
byteSize(size),
|
||||||
|
shared(false),
|
||||||
|
device(device) {}
|
||||||
|
|
||||||
|
__forceinline Buffer(const Ref<Device>& device, void* data, size_t size)
|
||||||
|
: ptr((char*)data),
|
||||||
|
byteSize(size),
|
||||||
|
shared(true),
|
||||||
|
device(device)
|
||||||
|
{
|
||||||
|
if (data == nullptr)
|
||||||
|
throw Exception(Error::InvalidArgument, "buffer pointer null");
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline ~Buffer()
|
||||||
|
{
|
||||||
|
if (!shared)
|
||||||
|
alignedFree(ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline char* data() { return ptr; }
|
||||||
|
__forceinline const char* data() const { return ptr; }
|
||||||
|
__forceinline size_t size() const { return byteSize; }
|
||||||
|
|
||||||
|
void* map(size_t offset, size_t size)
|
||||||
|
{
|
||||||
|
if (offset + size > byteSize)
|
||||||
|
throw Exception(Error::InvalidArgument, "buffer region out of range");
|
||||||
|
|
||||||
|
return ptr + offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
void unmap(void* mappedPtr) {}
|
||||||
|
|
||||||
|
Device* getDevice() { return device.get(); }
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace oidn
|
136
thirdparty/oidn/core/common.h
vendored
Normal file
136
thirdparty/oidn/core/common.h
vendored
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
// ======================================================================== //
|
||||||
|
// Copyright 2009-2019 Intel Corporation //
|
||||||
|
// //
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||||||
|
// you may not use this file except in compliance with the License. //
|
||||||
|
// You may obtain a copy of the License at //
|
||||||
|
// //
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||||||
|
// //
|
||||||
|
// Unless required by applicable law or agreed to in writing, software //
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||||||
|
// See the License for the specific language governing permissions and //
|
||||||
|
// limitations under the License. //
|
||||||
|
// ======================================================================== //
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "common/platform.h"
|
||||||
|
|
||||||
|
#include "mkl-dnn/include/mkldnn.hpp"
|
||||||
|
#include "mkl-dnn/include/mkldnn_debug.h"
|
||||||
|
#include "mkl-dnn/src/common/mkldnn_thread.hpp"
|
||||||
|
#include "mkl-dnn/src/common/type_helpers.hpp"
|
||||||
|
#include "mkl-dnn/src/cpu/jit_generator.hpp"
|
||||||
|
|
||||||
|
#include "common/ref.h"
|
||||||
|
#include "common/exception.h"
|
||||||
|
#include "common/thread.h"
|
||||||
|
// -- GODOT start --
|
||||||
|
//#include "common/tasking.h"
|
||||||
|
// -- GODOT end --
|
||||||
|
#include "math.h"
|
||||||
|
|
||||||
|
namespace oidn {
|
||||||
|
|
||||||
|
using namespace mkldnn;
|
||||||
|
using namespace mkldnn::impl::cpu;
|
||||||
|
using mkldnn::impl::parallel_nd;
|
||||||
|
using mkldnn::impl::memory_desc_matches_tag;
|
||||||
|
|
||||||
|
|
||||||
|
inline size_t getFormatBytes(Format format)
|
||||||
|
{
|
||||||
|
switch (format)
|
||||||
|
{
|
||||||
|
case Format::Undefined: return 1;
|
||||||
|
case Format::Float: return sizeof(float);
|
||||||
|
case Format::Float2: return sizeof(float)*2;
|
||||||
|
case Format::Float3: return sizeof(float)*3;
|
||||||
|
case Format::Float4: return sizeof(float)*4;
|
||||||
|
}
|
||||||
|
assert(0);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
inline memory::dims getTensorDims(const std::shared_ptr<memory>& mem)
|
||||||
|
{
|
||||||
|
const mkldnn_memory_desc_t& desc = mem->get_desc().data;
|
||||||
|
return memory::dims(&desc.dims[0], &desc.dims[desc.ndims]);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline memory::data_type getTensorType(const std::shared_ptr<memory>& mem)
|
||||||
|
{
|
||||||
|
const mkldnn_memory_desc_t& desc = mem->get_desc().data;
|
||||||
|
return memory::data_type(desc.data_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the number of values in a tensor
|
||||||
|
inline size_t getTensorSize(const memory::dims& dims)
|
||||||
|
{
|
||||||
|
size_t res = 1;
|
||||||
|
for (int i = 0; i < (int)dims.size(); ++i)
|
||||||
|
res *= dims[i];
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline memory::dims getMaxTensorDims(const std::vector<memory::dims>& dims)
|
||||||
|
{
|
||||||
|
memory::dims result;
|
||||||
|
size_t maxSize = 0;
|
||||||
|
|
||||||
|
for (const auto& d : dims)
|
||||||
|
{
|
||||||
|
const size_t size = getTensorSize(d);
|
||||||
|
if (size > maxSize)
|
||||||
|
{
|
||||||
|
result = d;
|
||||||
|
maxSize = size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline size_t getTensorSize(const std::shared_ptr<memory>& mem)
|
||||||
|
{
|
||||||
|
return getTensorSize(getTensorDims(mem));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template<int K>
|
||||||
|
inline int getPadded(int dim)
|
||||||
|
{
|
||||||
|
return (dim + (K-1)) & ~(K-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int K>
|
||||||
|
inline memory::dims getPadded_nchw(const memory::dims& dims)
|
||||||
|
{
|
||||||
|
assert(dims.size() == 4);
|
||||||
|
memory::dims padDims = dims;
|
||||||
|
padDims[1] = getPadded<K>(dims[1]); // pad C
|
||||||
|
return padDims;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template<int K>
|
||||||
|
struct BlockedFormat;
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct BlockedFormat<8>
|
||||||
|
{
|
||||||
|
static constexpr memory::format_tag nChwKc = memory::format_tag::nChw8c;
|
||||||
|
static constexpr memory::format_tag OIhwKiKo = memory::format_tag::OIhw8i8o;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct BlockedFormat<16>
|
||||||
|
{
|
||||||
|
static constexpr memory::format_tag nChwKc = memory::format_tag::nChw16c;
|
||||||
|
static constexpr memory::format_tag OIhwKiKo = memory::format_tag::OIhw16i16o;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace oidn
|
238
thirdparty/oidn/core/device.cpp
vendored
Normal file
238
thirdparty/oidn/core/device.cpp
vendored
Normal file
@ -0,0 +1,238 @@
|
|||||||
|
// ======================================================================== //
|
||||||
|
// Copyright 2009-2019 Intel Corporation //
|
||||||
|
// //
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||||||
|
// you may not use this file except in compliance with the License. //
|
||||||
|
// You may obtain a copy of the License at //
|
||||||
|
// //
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||||||
|
// //
|
||||||
|
// Unless required by applicable law or agreed to in writing, software //
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||||||
|
// See the License for the specific language governing permissions and //
|
||||||
|
// limitations under the License. //
|
||||||
|
// ======================================================================== //
|
||||||
|
|
||||||
|
#include "device.h"
|
||||||
|
#include "autoencoder.h"
|
||||||
|
|
||||||
|
namespace oidn {
|
||||||
|
|
||||||
|
thread_local Device::ErrorState Device::globalError;
|
||||||
|
|
||||||
|
Device::Device()
|
||||||
|
{
|
||||||
|
if (!mayiuse(sse41))
|
||||||
|
throw Exception(Error::UnsupportedHardware, "SSE4.1 support is required at minimum");
|
||||||
|
}
|
||||||
|
|
||||||
|
Device::~Device()
|
||||||
|
{
|
||||||
|
// -- GODOT start --
|
||||||
|
//observer.reset();
|
||||||
|
// -- GODOT end --
|
||||||
|
}
|
||||||
|
|
||||||
|
void Device::setError(Device* device, Error code, const std::string& message)
|
||||||
|
{
|
||||||
|
// Update the stored error only if the previous error was queried
|
||||||
|
if (device)
|
||||||
|
{
|
||||||
|
ErrorState& curError = device->error.get();
|
||||||
|
|
||||||
|
if (curError.code == Error::None)
|
||||||
|
{
|
||||||
|
curError.code = code;
|
||||||
|
curError.message = message;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print the error message in verbose mode
|
||||||
|
if (device->isVerbose())
|
||||||
|
std::cerr << "Error: " << message << std::endl;
|
||||||
|
|
||||||
|
// Call the error callback function
|
||||||
|
ErrorFunction errorFunc;
|
||||||
|
void* errorUserPtr;
|
||||||
|
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(device->mutex);
|
||||||
|
errorFunc = device->errorFunc;
|
||||||
|
errorUserPtr = device->errorUserPtr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (errorFunc)
|
||||||
|
errorFunc(errorUserPtr, code, (code == Error::None) ? nullptr : message.c_str());
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
if (globalError.code == Error::None)
|
||||||
|
{
|
||||||
|
globalError.code = code;
|
||||||
|
globalError.message = message;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Error Device::getError(Device* device, const char** outMessage)
|
||||||
|
{
|
||||||
|
// Return and clear the stored error code, but keep the error message so pointers to it will
|
||||||
|
// remain valid until the next getError call
|
||||||
|
if (device)
|
||||||
|
{
|
||||||
|
ErrorState& curError = device->error.get();
|
||||||
|
const Error code = curError.code;
|
||||||
|
if (outMessage)
|
||||||
|
*outMessage = (code == Error::None) ? nullptr : curError.message.c_str();
|
||||||
|
curError.code = Error::None;
|
||||||
|
return code;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
const Error code = globalError.code;
|
||||||
|
if (outMessage)
|
||||||
|
*outMessage = (code == Error::None) ? nullptr : globalError.message.c_str();
|
||||||
|
globalError.code = Error::None;
|
||||||
|
return code;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Device::setErrorFunction(ErrorFunction func, void* userPtr)
|
||||||
|
{
|
||||||
|
errorFunc = func;
|
||||||
|
errorUserPtr = userPtr;
|
||||||
|
}
|
||||||
|
|
||||||
|
int Device::get1i(const std::string& name)
|
||||||
|
{
|
||||||
|
if (name == "numThreads")
|
||||||
|
return numThreads;
|
||||||
|
else if (name == "setAffinity")
|
||||||
|
return setAffinity;
|
||||||
|
else if (name == "verbose")
|
||||||
|
return verbose;
|
||||||
|
else if (name == "version")
|
||||||
|
return OIDN_VERSION;
|
||||||
|
else if (name == "versionMajor")
|
||||||
|
return OIDN_VERSION_MAJOR;
|
||||||
|
else if (name == "versionMinor")
|
||||||
|
return OIDN_VERSION_MINOR;
|
||||||
|
else if (name == "versionPatch")
|
||||||
|
return OIDN_VERSION_PATCH;
|
||||||
|
else
|
||||||
|
throw Exception(Error::InvalidArgument, "invalid parameter");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Device::set1i(const std::string& name, int value)
|
||||||
|
{
|
||||||
|
if (name == "numThreads")
|
||||||
|
numThreads = value;
|
||||||
|
else if (name == "setAffinity")
|
||||||
|
setAffinity = value;
|
||||||
|
else if (name == "verbose")
|
||||||
|
{
|
||||||
|
verbose = value;
|
||||||
|
error.verbose = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
dirty = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Device::commit()
|
||||||
|
{
|
||||||
|
if (isCommitted())
|
||||||
|
throw Exception(Error::InvalidOperation, "device can be committed only once");
|
||||||
|
|
||||||
|
// -- GODOT start --
|
||||||
|
#if 0
|
||||||
|
// -- GODOT end --
|
||||||
|
// Get the optimal thread affinities
|
||||||
|
if (setAffinity)
|
||||||
|
{
|
||||||
|
affinity = std::make_shared<ThreadAffinity>(1, verbose); // one thread per core
|
||||||
|
if (affinity->getNumThreads() == 0)
|
||||||
|
affinity.reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the task arena
|
||||||
|
const int maxNumThreads = affinity ? affinity->getNumThreads() : tbb::this_task_arena::max_concurrency();
|
||||||
|
numThreads = (numThreads > 0) ? min(numThreads, maxNumThreads) : maxNumThreads;
|
||||||
|
arena = std::make_shared<tbb::task_arena>(numThreads);
|
||||||
|
|
||||||
|
// Automatically set the thread affinities
|
||||||
|
if (affinity)
|
||||||
|
observer = std::make_shared<PinningObserver>(affinity, *arena);
|
||||||
|
// -- GODOT start --
|
||||||
|
#endif
|
||||||
|
numThreads = 1;
|
||||||
|
// -- GODOT end --
|
||||||
|
dirty = false;
|
||||||
|
|
||||||
|
if (isVerbose())
|
||||||
|
print();
|
||||||
|
}
|
||||||
|
|
||||||
|
void Device::checkCommitted()
|
||||||
|
{
|
||||||
|
if (dirty)
|
||||||
|
throw Exception(Error::InvalidOperation, "changes to the device are not committed");
|
||||||
|
}
|
||||||
|
|
||||||
|
Ref<Buffer> Device::newBuffer(size_t byteSize)
|
||||||
|
{
|
||||||
|
checkCommitted();
|
||||||
|
return makeRef<Buffer>(Ref<Device>(this), byteSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ref<Buffer> Device::newBuffer(void* ptr, size_t byteSize)
|
||||||
|
{
|
||||||
|
checkCommitted();
|
||||||
|
return makeRef<Buffer>(Ref<Device>(this), ptr, byteSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ref<Filter> Device::newFilter(const std::string& type)
|
||||||
|
{
|
||||||
|
checkCommitted();
|
||||||
|
|
||||||
|
if (isVerbose())
|
||||||
|
std::cout << "Filter: " << type << std::endl;
|
||||||
|
|
||||||
|
Ref<Filter> filter;
|
||||||
|
|
||||||
|
// -- GODOT start --
|
||||||
|
// Godot doesn't need Raytracing filters. Removing them saves space in the weights files.
|
||||||
|
#if 0
|
||||||
|
// -- GODOT end --
|
||||||
|
if (type == "RT")
|
||||||
|
filter = makeRef<RTFilter>(Ref<Device>(this));
|
||||||
|
// -- GODOT start --
|
||||||
|
// Godot doesn't need Raytracing filters. Removing them saves space in the weights files.
|
||||||
|
#endif
|
||||||
|
if (type == "RTLightmap")
|
||||||
|
// -- GODOT end --
|
||||||
|
filter = makeRef<RTLightmapFilter>(Ref<Device>(this));
|
||||||
|
else
|
||||||
|
throw Exception(Error::InvalidArgument, "unknown filter type");
|
||||||
|
|
||||||
|
return filter;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Device::print()
|
||||||
|
{
|
||||||
|
std::cout << std::endl;
|
||||||
|
|
||||||
|
std::cout << "Intel(R) Open Image Denoise " << OIDN_VERSION_STRING << std::endl;
|
||||||
|
std::cout << " Compiler: " << getCompilerName() << std::endl;
|
||||||
|
std::cout << " Build : " << getBuildName() << std::endl;
|
||||||
|
std::cout << " Platform: " << getPlatformName() << std::endl;
|
||||||
|
|
||||||
|
// -- GODOT start --
|
||||||
|
// std::cout << " Tasking :";
|
||||||
|
// std::cout << " TBB" << TBB_VERSION_MAJOR << "." << TBB_VERSION_MINOR;
|
||||||
|
// std::cout << " TBB_header_interface_" << TBB_INTERFACE_VERSION << " TBB_lib_interface_" << tbb::TBB_runtime_interface_version();
|
||||||
|
// std::cout << std::endl;
|
||||||
|
// -- GODOT end --
|
||||||
|
std::cout << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace oidn
|
102
thirdparty/oidn/core/device.h
vendored
Normal file
102
thirdparty/oidn/core/device.h
vendored
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
// ======================================================================== //
|
||||||
|
// Copyright 2009-2019 Intel Corporation //
|
||||||
|
// //
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||||||
|
// you may not use this file except in compliance with the License. //
|
||||||
|
// You may obtain a copy of the License at //
|
||||||
|
// //
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||||||
|
// //
|
||||||
|
// Unless required by applicable law or agreed to in writing, software //
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||||||
|
// See the License for the specific language governing permissions and //
|
||||||
|
// limitations under the License. //
|
||||||
|
// ======================================================================== //
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "common.h"
|
||||||
|
|
||||||
|
namespace oidn {
|
||||||
|
|
||||||
|
class Buffer;
|
||||||
|
class Filter;
|
||||||
|
|
||||||
|
class Device : public RefCount, public Verbose
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
// Thread-safety
|
||||||
|
std::mutex mutex;
|
||||||
|
|
||||||
|
// Error handling
|
||||||
|
struct ErrorState
|
||||||
|
{
|
||||||
|
Error code = Error::None;
|
||||||
|
std::string message;
|
||||||
|
};
|
||||||
|
|
||||||
|
static thread_local ErrorState globalError;
|
||||||
|
ThreadLocal<ErrorState> error;
|
||||||
|
ErrorFunction errorFunc = nullptr;
|
||||||
|
void* errorUserPtr = nullptr;
|
||||||
|
|
||||||
|
// -- GODOT start --
|
||||||
|
// // Tasking
|
||||||
|
// std::shared_ptr<tbb::task_arena> arena;
|
||||||
|
// std::shared_ptr<PinningObserver> observer;
|
||||||
|
// std::shared_ptr<ThreadAffinity> affinity;
|
||||||
|
// -- GODOT end --
|
||||||
|
|
||||||
|
// Parameters
|
||||||
|
int numThreads = 0; // autodetect by default
|
||||||
|
bool setAffinity = true;
|
||||||
|
|
||||||
|
bool dirty = true;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Device();
|
||||||
|
~Device();
|
||||||
|
|
||||||
|
static void setError(Device* device, Error code, const std::string& message);
|
||||||
|
static Error getError(Device* device, const char** outMessage);
|
||||||
|
|
||||||
|
void setErrorFunction(ErrorFunction func, void* userPtr);
|
||||||
|
|
||||||
|
int get1i(const std::string& name);
|
||||||
|
void set1i(const std::string& name, int value);
|
||||||
|
|
||||||
|
void commit();
|
||||||
|
|
||||||
|
// -- GODOT start --
|
||||||
|
// template<typename F>
|
||||||
|
// void executeTask(F& f)
|
||||||
|
// {
|
||||||
|
// arena->execute(f);
|
||||||
|
// }
|
||||||
|
|
||||||
|
// template<typename F>
|
||||||
|
// void executeTask(const F& f)
|
||||||
|
// {
|
||||||
|
// arena->execute(f);
|
||||||
|
// }
|
||||||
|
// -- GODOT end --
|
||||||
|
|
||||||
|
Ref<Buffer> newBuffer(size_t byteSize);
|
||||||
|
Ref<Buffer> newBuffer(void* ptr, size_t byteSize);
|
||||||
|
Ref<Filter> newFilter(const std::string& type);
|
||||||
|
|
||||||
|
__forceinline Device* getDevice() { return this; }
|
||||||
|
__forceinline std::mutex& getMutex() { return mutex; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
// -- GODOT start --
|
||||||
|
//bool isCommitted() const { return bool(arena); }
|
||||||
|
bool isCommitted() const { return false; }
|
||||||
|
// -- GODOT end --
|
||||||
|
void checkCommitted();
|
||||||
|
|
||||||
|
void print();
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace oidn
|
27
thirdparty/oidn/core/filter.cpp
vendored
Normal file
27
thirdparty/oidn/core/filter.cpp
vendored
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
// ======================================================================== //
|
||||||
|
// Copyright 2009-2019 Intel Corporation //
|
||||||
|
// //
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||||||
|
// you may not use this file except in compliance with the License. //
|
||||||
|
// You may obtain a copy of the License at //
|
||||||
|
// //
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||||||
|
// //
|
||||||
|
// Unless required by applicable law or agreed to in writing, software //
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||||||
|
// See the License for the specific language governing permissions and //
|
||||||
|
// limitations under the License. //
|
||||||
|
// ======================================================================== //
|
||||||
|
|
||||||
|
#include "filter.h"
|
||||||
|
|
||||||
|
namespace oidn {
|
||||||
|
|
||||||
|
void Filter::setProgressMonitorFunction(ProgressMonitorFunction func, void* userPtr)
|
||||||
|
{
|
||||||
|
progressFunc = func;
|
||||||
|
progressUserPtr = userPtr;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace oidn
|
52
thirdparty/oidn/core/filter.h
vendored
Normal file
52
thirdparty/oidn/core/filter.h
vendored
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
// ======================================================================== //
|
||||||
|
// Copyright 2009-2019 Intel Corporation //
|
||||||
|
// //
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||||||
|
// you may not use this file except in compliance with the License. //
|
||||||
|
// You may obtain a copy of the License at //
|
||||||
|
// //
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||||||
|
// //
|
||||||
|
// Unless required by applicable law or agreed to in writing, software //
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||||||
|
// See the License for the specific language governing permissions and //
|
||||||
|
// limitations under the License. //
|
||||||
|
// ======================================================================== //
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "common.h"
|
||||||
|
#include "device.h"
|
||||||
|
#include "image.h"
|
||||||
|
|
||||||
|
namespace oidn {
|
||||||
|
|
||||||
|
class Filter : public RefCount
|
||||||
|
{
|
||||||
|
protected:
|
||||||
|
Ref<Device> device;
|
||||||
|
|
||||||
|
ProgressMonitorFunction progressFunc = nullptr;
|
||||||
|
void* progressUserPtr = nullptr;
|
||||||
|
|
||||||
|
bool dirty = true;
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit Filter(const Ref<Device>& device) : device(device) {}
|
||||||
|
|
||||||
|
virtual void setImage(const std::string& name, const Image& data) = 0;
|
||||||
|
virtual void set1i(const std::string& name, int value) = 0;
|
||||||
|
virtual int get1i(const std::string& name) = 0;
|
||||||
|
virtual void set1f(const std::string& name, float value) = 0;
|
||||||
|
virtual float get1f(const std::string& name) = 0;
|
||||||
|
|
||||||
|
void setProgressMonitorFunction(ProgressMonitorFunction func, void* userPtr);
|
||||||
|
|
||||||
|
virtual void commit() = 0;
|
||||||
|
virtual void execute() = 0;
|
||||||
|
|
||||||
|
Device* getDevice() { return device.get(); }
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace oidn
|
111
thirdparty/oidn/core/image.h
vendored
Normal file
111
thirdparty/oidn/core/image.h
vendored
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
// ======================================================================== //
|
||||||
|
// Copyright 2009-2019 Intel Corporation //
|
||||||
|
// //
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||||||
|
// you may not use this file except in compliance with the License. //
|
||||||
|
// You may obtain a copy of the License at //
|
||||||
|
// //
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||||||
|
// //
|
||||||
|
// Unless required by applicable law or agreed to in writing, software //
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||||||
|
// See the License for the specific language governing permissions and //
|
||||||
|
// limitations under the License. //
|
||||||
|
// ======================================================================== //
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "common.h"
|
||||||
|
#include "buffer.h"
|
||||||
|
|
||||||
|
namespace oidn {
|
||||||
|
|
||||||
|
struct Image
|
||||||
|
{
|
||||||
|
static constexpr int maxSize = 65536;
|
||||||
|
|
||||||
|
char* ptr; // pointer to the first pixel
|
||||||
|
int width; // width in number of pixels
|
||||||
|
int height; // height in number of pixels
|
||||||
|
size_t bytePixelStride; // pixel stride in number of *bytes*
|
||||||
|
size_t rowStride; // row stride in number of *pixel strides*
|
||||||
|
Format format; // pixel format
|
||||||
|
Ref<Buffer> buffer; // buffer containing the image data
|
||||||
|
|
||||||
|
Image() : ptr(nullptr), width(0), height(0), bytePixelStride(0), rowStride(0), format(Format::Undefined) {}
|
||||||
|
|
||||||
|
Image(void* ptr, Format format, int width, int height, size_t byteOffset, size_t inBytePixelStride, size_t inByteRowStride)
|
||||||
|
{
|
||||||
|
if (ptr == nullptr)
|
||||||
|
throw Exception(Error::InvalidArgument, "buffer pointer null");
|
||||||
|
|
||||||
|
init((char*)ptr + byteOffset, format, width, height, inBytePixelStride, inByteRowStride);
|
||||||
|
}
|
||||||
|
|
||||||
|
Image(const Ref<Buffer>& buffer, Format format, int width, int height, size_t byteOffset, size_t inBytePixelStride, size_t inByteRowStride)
|
||||||
|
{
|
||||||
|
init(buffer->data() + byteOffset, format, width, height, inBytePixelStride, inByteRowStride);
|
||||||
|
|
||||||
|
if (byteOffset + height * rowStride * bytePixelStride > buffer->size())
|
||||||
|
throw Exception(Error::InvalidArgument, "buffer region out of range");
|
||||||
|
}
|
||||||
|
|
||||||
|
void init(char* ptr, Format format, int width, int height, size_t inBytePixelStride, size_t inByteRowStride)
|
||||||
|
{
|
||||||
|
assert(width >= 0);
|
||||||
|
assert(height >= 0);
|
||||||
|
if (width > maxSize || height > maxSize)
|
||||||
|
throw Exception(Error::InvalidArgument, "image size too large");
|
||||||
|
|
||||||
|
this->ptr = ptr;
|
||||||
|
this->width = width;
|
||||||
|
this->height = height;
|
||||||
|
|
||||||
|
const size_t pixelSize = getFormatBytes(format);
|
||||||
|
if (inBytePixelStride != 0)
|
||||||
|
{
|
||||||
|
if (inBytePixelStride < pixelSize)
|
||||||
|
throw Exception(Error::InvalidArgument, "pixel stride smaller than pixel size");
|
||||||
|
|
||||||
|
this->bytePixelStride = inBytePixelStride;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
this->bytePixelStride = pixelSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (inByteRowStride != 0)
|
||||||
|
{
|
||||||
|
if (inByteRowStride < width * this->bytePixelStride)
|
||||||
|
throw Exception(Error::InvalidArgument, "row stride smaller than width * pixel stride");
|
||||||
|
if (inByteRowStride % this->bytePixelStride != 0)
|
||||||
|
throw Exception(Error::InvalidArgument, "row stride not integer multiple of pixel stride");
|
||||||
|
|
||||||
|
this->rowStride = inByteRowStride / this->bytePixelStride;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
this->rowStride = width;
|
||||||
|
}
|
||||||
|
|
||||||
|
this->format = format;
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline char* get(int y, int x)
|
||||||
|
{
|
||||||
|
return ptr + ((size_t(y) * rowStride + size_t(x)) * bytePixelStride);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline const char* get(int y, int x) const
|
||||||
|
{
|
||||||
|
return ptr + ((size_t(y) * rowStride + size_t(x)) * bytePixelStride);
|
||||||
|
}
|
||||||
|
|
||||||
|
operator bool() const
|
||||||
|
{
|
||||||
|
return ptr != nullptr;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace oidn
|
232
thirdparty/oidn/core/input_reorder.h
vendored
Normal file
232
thirdparty/oidn/core/input_reorder.h
vendored
Normal file
@ -0,0 +1,232 @@
|
|||||||
|
// ======================================================================== //
|
||||||
|
// Copyright 2009-2019 Intel Corporation //
|
||||||
|
// //
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||||||
|
// you may not use this file except in compliance with the License. //
|
||||||
|
// You may obtain a copy of the License at //
|
||||||
|
// //
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||||||
|
// //
|
||||||
|
// Unless required by applicable law or agreed to in writing, software //
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||||||
|
// See the License for the specific language governing permissions and //
|
||||||
|
// limitations under the License. //
|
||||||
|
// ======================================================================== //
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "node.h"
|
||||||
|
#include "image.h"
|
||||||
|
|
||||||
|
namespace oidn {
|
||||||
|
|
||||||
|
// Input reorder node
|
||||||
|
template<int K, class TransferFunction>
|
||||||
|
class InputReorderNode : public Node
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
// Source
|
||||||
|
Image color;
|
||||||
|
Image albedo;
|
||||||
|
Image normal;
|
||||||
|
|
||||||
|
// Destination
|
||||||
|
std::shared_ptr<memory> dst;
|
||||||
|
float* dstPtr;
|
||||||
|
int C2;
|
||||||
|
int H2;
|
||||||
|
int W2;
|
||||||
|
|
||||||
|
// Tile
|
||||||
|
int h1Begin;
|
||||||
|
int w1Begin;
|
||||||
|
int h2Begin;
|
||||||
|
int w2Begin;
|
||||||
|
int H;
|
||||||
|
int W;
|
||||||
|
|
||||||
|
std::shared_ptr<TransferFunction> transferFunc;
|
||||||
|
|
||||||
|
public:
|
||||||
|
InputReorderNode(const Image& color,
|
||||||
|
const Image& albedo,
|
||||||
|
const Image& normal,
|
||||||
|
const std::shared_ptr<memory>& dst,
|
||||||
|
const std::shared_ptr<TransferFunction>& transferFunc)
|
||||||
|
: color(color), albedo(albedo), normal(normal),
|
||||||
|
dst(dst),
|
||||||
|
h1Begin(0), w1Begin(0),
|
||||||
|
H(color.height), W(color.width),
|
||||||
|
transferFunc(transferFunc)
|
||||||
|
{
|
||||||
|
const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data;
|
||||||
|
assert(memory_desc_matches_tag(dstDesc, mkldnn_format_tag_t(BlockedFormat<K>::nChwKc)));
|
||||||
|
assert(dstDesc.ndims == 4);
|
||||||
|
assert(dstDesc.data_type == memory::data_type::f32);
|
||||||
|
assert(dstDesc.dims[0] == 1);
|
||||||
|
//assert(dstDesc.dims[1] >= getPadded<K>(C1));
|
||||||
|
|
||||||
|
dstPtr = (float*)dst->get_data_handle();
|
||||||
|
C2 = dstDesc.dims[1];
|
||||||
|
H2 = dstDesc.dims[2];
|
||||||
|
W2 = dstDesc.dims[3];
|
||||||
|
}
|
||||||
|
|
||||||
|
void setTile(int h1, int w1, int h2, int w2, int H, int W) override
|
||||||
|
{
|
||||||
|
h1Begin = h1;
|
||||||
|
w1Begin = w1;
|
||||||
|
h2Begin = h2;
|
||||||
|
w2Begin = w2;
|
||||||
|
this->H = H;
|
||||||
|
this->W = W;
|
||||||
|
}
|
||||||
|
|
||||||
|
void execute(stream& sm) override
|
||||||
|
{
|
||||||
|
assert(H + h1Begin <= color.height);
|
||||||
|
assert(W + w1Begin <= color.width);
|
||||||
|
assert(H + h2Begin <= H2);
|
||||||
|
assert(W + w2Begin <= W2);
|
||||||
|
|
||||||
|
parallel_nd(H2, [&](int h2)
|
||||||
|
{
|
||||||
|
const int h = h2 - h2Begin;
|
||||||
|
|
||||||
|
if (h >= 0 && h < H)
|
||||||
|
{
|
||||||
|
const int h1 = h + h1Begin;
|
||||||
|
|
||||||
|
// Zero pad
|
||||||
|
for (int w2 = 0; w2 < w2Begin; ++w2)
|
||||||
|
{
|
||||||
|
int c = 0;
|
||||||
|
while (c < C2)
|
||||||
|
store(h2, w2, c, 0.f);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reorder
|
||||||
|
for (int w = 0; w < W; ++w)
|
||||||
|
{
|
||||||
|
const int w1 = w + w1Begin;
|
||||||
|
const int w2 = w + w2Begin;
|
||||||
|
|
||||||
|
int c = 0;
|
||||||
|
storeColor(h2, w2, c, (float*)color.get(h1, w1));
|
||||||
|
if (albedo)
|
||||||
|
storeAlbedo(h2, w2, c, (float*)albedo.get(h1, w1));
|
||||||
|
if (normal)
|
||||||
|
storeNormal(h2, w2, c, (float*)normal.get(h1, w1));
|
||||||
|
while (c < C2)
|
||||||
|
store(h2, w2, c, 0.f);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Zero pad
|
||||||
|
for (int w2 = W + w2Begin; w2 < W2; ++w2)
|
||||||
|
{
|
||||||
|
int c = 0;
|
||||||
|
while (c < C2)
|
||||||
|
store(h2, w2, c, 0.f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// Zero pad
|
||||||
|
for (int w2 = 0; w2 < W2; ++w2)
|
||||||
|
{
|
||||||
|
int c = 0;
|
||||||
|
while (c < C2)
|
||||||
|
store(h2, w2, c, 0.f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<memory> getDst() const override { return dst; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Stores a single value
|
||||||
|
__forceinline void store(int h, int w, int& c, float value)
|
||||||
|
{
|
||||||
|
// Destination is in nChwKc format
|
||||||
|
float* dst_c = dstPtr + (H2*W2*K*(c/K)) + h*W2*K + w*K + (c%K);
|
||||||
|
*dst_c = value;
|
||||||
|
c++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stores a color
|
||||||
|
__forceinline void storeColor(int h, int w, int& c, const float* values)
|
||||||
|
{
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 3; ++i)
|
||||||
|
{
|
||||||
|
// Load the value
|
||||||
|
float x = values[i];
|
||||||
|
|
||||||
|
// Sanitize the value
|
||||||
|
x = maxSafe(x, 0.f);
|
||||||
|
|
||||||
|
// Apply the transfer function
|
||||||
|
x = transferFunc->forward(x);
|
||||||
|
|
||||||
|
// Store the value
|
||||||
|
store(h, w, c, x);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stores an albedo
|
||||||
|
__forceinline void storeAlbedo(int h, int w, int& c, const float* values)
|
||||||
|
{
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 3; ++i)
|
||||||
|
{
|
||||||
|
// Load the value
|
||||||
|
float x = values[i];
|
||||||
|
|
||||||
|
// Sanitize the value
|
||||||
|
x = clampSafe(x, 0.f, 1.f);
|
||||||
|
|
||||||
|
// Store the value
|
||||||
|
store(h, w, c, x);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stores a normal
|
||||||
|
__forceinline void storeNormal(int h, int w, int& c, const float* values)
|
||||||
|
{
|
||||||
|
// Load the normal
|
||||||
|
float x = values[0];
|
||||||
|
float y = values[1];
|
||||||
|
float z = values[2];
|
||||||
|
|
||||||
|
// Compute the length of the normal
|
||||||
|
const float lengthSqr = sqr(x) + sqr(y) + sqr(z);
|
||||||
|
|
||||||
|
// Normalize the normal and transform it to [0..1]
|
||||||
|
if (isfinite(lengthSqr))
|
||||||
|
{
|
||||||
|
const float invLength = (lengthSqr > minVectorLengthSqr) ? rsqrt(lengthSqr) : 1.f;
|
||||||
|
|
||||||
|
const float scale = invLength * 0.5f;
|
||||||
|
const float offset = 0.5f;
|
||||||
|
|
||||||
|
x = x * scale + offset;
|
||||||
|
y = y * scale + offset;
|
||||||
|
z = z * scale + offset;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
x = 0.f;
|
||||||
|
y = 0.f;
|
||||||
|
z = 0.f;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store the normal
|
||||||
|
store(h, w, c, x);
|
||||||
|
store(h, w, c, y);
|
||||||
|
store(h, w, c, z);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace oidn
|
78
thirdparty/oidn/core/math.h
vendored
Normal file
78
thirdparty/oidn/core/math.h
vendored
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
// ======================================================================== //
|
||||||
|
// Copyright 2009-2019 Intel Corporation //
|
||||||
|
// //
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||||||
|
// you may not use this file except in compliance with the License. //
|
||||||
|
// You may obtain a copy of the License at //
|
||||||
|
// //
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||||||
|
// //
|
||||||
|
// Unless required by applicable law or agreed to in writing, software //
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||||||
|
// See the License for the specific language governing permissions and //
|
||||||
|
// limitations under the License. //
|
||||||
|
// ======================================================================== //
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "common/platform.h"
|
||||||
|
|
||||||
|
namespace oidn {
|
||||||
|
|
||||||
|
constexpr float minVectorLength = 1e-10f;
|
||||||
|
constexpr float minVectorLengthSqr = minVectorLength * minVectorLength;
|
||||||
|
|
||||||
|
using std::log;
|
||||||
|
using std::log2;
|
||||||
|
using std::exp;
|
||||||
|
using std::exp2;
|
||||||
|
using std::pow;
|
||||||
|
using std::isfinite;
|
||||||
|
using std::isnan;
|
||||||
|
|
||||||
|
__forceinline float sqr(float x)
|
||||||
|
{
|
||||||
|
return x * x;
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline float rcp(float x)
|
||||||
|
{
|
||||||
|
__m128 r = _mm_rcp_ss(_mm_set_ss(x));
|
||||||
|
return _mm_cvtss_f32(_mm_sub_ss(_mm_add_ss(r, r), _mm_mul_ss(_mm_mul_ss(r, r), _mm_set_ss(x))));
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline float rsqrt(float x)
|
||||||
|
{
|
||||||
|
__m128 r = _mm_rsqrt_ss(_mm_set_ss(x));
|
||||||
|
return _mm_cvtss_f32(_mm_add_ss(_mm_mul_ss(_mm_set_ss(1.5f), r),
|
||||||
|
_mm_mul_ss(_mm_mul_ss(_mm_mul_ss(_mm_set_ss(x), _mm_set_ss(-0.5f)), r), _mm_mul_ss(r, r))));
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline float maxSafe(float value, float minValue)
|
||||||
|
{
|
||||||
|
return isfinite(value) ? max(value, minValue) : minValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline float clampSafe(float value, float minValue, float maxValue)
|
||||||
|
{
|
||||||
|
return isfinite(value) ? clamp(value, minValue, maxValue) : minValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns ceil(a / b) for non-negative integers
|
||||||
|
template<class Int>
|
||||||
|
__forceinline constexpr Int ceilDiv(Int a, Int b)
|
||||||
|
{
|
||||||
|
//assert(a >= 0);
|
||||||
|
//assert(b > 0);
|
||||||
|
return (a + b - 1) / b;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns a rounded up to multiple of b
|
||||||
|
template<class Int>
|
||||||
|
__forceinline constexpr Int roundUp(Int a, Int b)
|
||||||
|
{
|
||||||
|
return ceilDiv(a, b) * b;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace oidn
|
436
thirdparty/oidn/core/network.cpp
vendored
Normal file
436
thirdparty/oidn/core/network.cpp
vendored
Normal file
@ -0,0 +1,436 @@
|
|||||||
|
// ======================================================================== //
|
||||||
|
// Copyright 2009-2019 Intel Corporation //
|
||||||
|
// //
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||||||
|
// you may not use this file except in compliance with the License. //
|
||||||
|
// You may obtain a copy of the License at //
|
||||||
|
// //
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||||||
|
// //
|
||||||
|
// Unless required by applicable law or agreed to in writing, software //
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||||||
|
// See the License for the specific language governing permissions and //
|
||||||
|
// limitations under the License. //
|
||||||
|
// ======================================================================== //
|
||||||
|
|
||||||
|
#include "upsample.h"
|
||||||
|
#include "weights_reorder.h"
|
||||||
|
#include "network.h"
|
||||||
|
// -- GODOT start --
|
||||||
|
#include <cstring>
|
||||||
|
// -- GODOT end --
|
||||||
|
|
||||||
|
namespace oidn {
|
||||||
|
|
||||||
|
template<int K>
|
||||||
|
Network<K>::Network(const Ref<Device>& device, const std::map<std::string, Tensor>& weightMap)
|
||||||
|
: device(device),
|
||||||
|
eng(engine::cpu, 0),
|
||||||
|
sm(eng),
|
||||||
|
weightMap(weightMap)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int K>
|
||||||
|
void Network<K>::execute(const Progress& progress, int taskIndex)
|
||||||
|
{
|
||||||
|
if (progress.func)
|
||||||
|
{
|
||||||
|
const double value = double(taskIndex) / double(progress.taskCount);
|
||||||
|
if (!progress.func(progress.userPtr, value))
|
||||||
|
throw Exception(Error::Cancelled, "execution was cancelled");
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < nodes.size(); ++i)
|
||||||
|
{
|
||||||
|
nodes[i]->execute(sm);
|
||||||
|
|
||||||
|
if (progress.func)
|
||||||
|
{
|
||||||
|
const double value = (double(taskIndex) + double(i+1) / double(nodes.size())) / double(progress.taskCount);
|
||||||
|
if (!progress.func(progress.userPtr, value))
|
||||||
|
throw Exception(Error::Cancelled, "execution was cancelled");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int K>
|
||||||
|
std::shared_ptr<memory> Network<K>::allocTensor(const memory::dims& dims,
|
||||||
|
memory::format_tag format,
|
||||||
|
void* data)
|
||||||
|
{
|
||||||
|
if (format == memory::format_tag::any)
|
||||||
|
{
|
||||||
|
if (dims.size() == 4)
|
||||||
|
format = BlockedFormat<K>::nChwKc;
|
||||||
|
else if (dims.size() == 1)
|
||||||
|
format = memory::format_tag::x;
|
||||||
|
else
|
||||||
|
assert(0);
|
||||||
|
}
|
||||||
|
memory::desc desc(dims, memory::data_type::f32, format);
|
||||||
|
if (data == nullptr)
|
||||||
|
{
|
||||||
|
const size_t bytes = getTensorSize(dims) * sizeof(float);
|
||||||
|
if (format == BlockedFormat<K>::nChwKc)
|
||||||
|
activationAllocBytes += bytes;
|
||||||
|
totalAllocBytes += bytes;
|
||||||
|
|
||||||
|
return std::make_shared<memory>(desc, eng);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
return std::make_shared<memory>(desc, eng, data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int K>
|
||||||
|
std::shared_ptr<memory> Network<K>::castTensor(const memory::dims& dims,
|
||||||
|
const std::shared_ptr<memory>& src,
|
||||||
|
size_t srcOffset,
|
||||||
|
memory::format_tag format)
|
||||||
|
{
|
||||||
|
const mkldnn_memory_desc_t& srcDesc = src->get_desc().data;
|
||||||
|
MAYBE_UNUSED(srcDesc);
|
||||||
|
assert(srcDesc.data_type == memory::data_type::f32);
|
||||||
|
assert(getTensorSize(src) >= srcOffset + getTensorSize(dims));
|
||||||
|
|
||||||
|
if (format == memory::format_tag::any)
|
||||||
|
{
|
||||||
|
if (dims.size() == 4)
|
||||||
|
format = BlockedFormat<K>::nChwKc;
|
||||||
|
else if (dims.size() == 1)
|
||||||
|
format = memory::format_tag::x;
|
||||||
|
else
|
||||||
|
assert(0);
|
||||||
|
}
|
||||||
|
memory::desc desc(dims, memory::data_type::f32, format);
|
||||||
|
float* srcPtr = (float*)src->get_data_handle() + srcOffset;
|
||||||
|
return std::make_shared<memory>(desc, eng, srcPtr);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int K>
|
||||||
|
std::shared_ptr<memory> Network<K>::castTensor(const memory::dims& dims,
|
||||||
|
const std::shared_ptr<memory>& src,
|
||||||
|
const memory::dims& srcOffset)
|
||||||
|
{
|
||||||
|
return castTensor(dims, src, getTensorSize(srcOffset));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int K>
|
||||||
|
void Network<K>::zeroTensor(const std::shared_ptr<memory>& dst)
|
||||||
|
{
|
||||||
|
assert(getTensorType(dst) == memory::data_type::f32);
|
||||||
|
memset(dst->get_data_handle(), 0, getTensorSize(dst)*sizeof(float));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int K>
|
||||||
|
memory::dims Network<K>::getInputReorderDims(const memory::dims& srcDims, int alignment)
|
||||||
|
{
|
||||||
|
memory::dims dstDims = srcDims;
|
||||||
|
dstDims[1] = getPadded<K>(srcDims[1]); // round up C
|
||||||
|
dstDims[2] = roundUp(srcDims[2], memory::dim(alignment)); // round up H
|
||||||
|
dstDims[3] = roundUp(srcDims[3], memory::dim(alignment)); // round up W
|
||||||
|
return dstDims;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int K>
|
||||||
|
std::shared_ptr<Node> Network<K>::addInputReorder(const Image& color,
|
||||||
|
const Image& albedo,
|
||||||
|
const Image& normal,
|
||||||
|
const std::shared_ptr<TransferFunction>& transferFunc,
|
||||||
|
int alignment,
|
||||||
|
const std::shared_ptr<memory>& userDst)
|
||||||
|
{
|
||||||
|
assert(color);
|
||||||
|
int inputC = 3;
|
||||||
|
if (albedo) inputC += 3;
|
||||||
|
if (normal) inputC += 3;
|
||||||
|
|
||||||
|
memory::dims srcDims = {1, inputC, color.height, color.width};
|
||||||
|
memory::dims dstDims = getInputReorderDims(srcDims, alignment);
|
||||||
|
|
||||||
|
// Allocate padded memory
|
||||||
|
auto dst = userDst;
|
||||||
|
if (!dst)
|
||||||
|
dst = allocTensor(dstDims);
|
||||||
|
|
||||||
|
// Push node
|
||||||
|
std::shared_ptr<Node> node;
|
||||||
|
|
||||||
|
if (auto tf = std::dynamic_pointer_cast<LinearTransferFunction>(transferFunc))
|
||||||
|
node = std::make_shared<InputReorderNode<K, LinearTransferFunction>>(color, albedo, normal, dst, tf);
|
||||||
|
else if (auto tf = std::dynamic_pointer_cast<GammaTransferFunction>(transferFunc))
|
||||||
|
node = std::make_shared<InputReorderNode<K, GammaTransferFunction>>(color, albedo, normal, dst, tf);
|
||||||
|
else if (auto tf = std::dynamic_pointer_cast<LogTransferFunction>(transferFunc))
|
||||||
|
node = std::make_shared<InputReorderNode<K, LogTransferFunction>>(color, albedo, normal, dst, tf);
|
||||||
|
else if (auto tf = std::dynamic_pointer_cast<PQXTransferFunction>(transferFunc))
|
||||||
|
node = std::make_shared<InputReorderNode<K, PQXTransferFunction>>(color, albedo, normal, dst, tf);
|
||||||
|
else
|
||||||
|
assert(0);
|
||||||
|
|
||||||
|
nodes.push_back(node);
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int K>
|
||||||
|
std::shared_ptr<Node> Network<K>::addOutputReorder(const std::shared_ptr<memory>& src,
|
||||||
|
const std::shared_ptr<TransferFunction>& transferFunc,
|
||||||
|
const Image& output)
|
||||||
|
{
|
||||||
|
memory::dims srcDims = getTensorDims(src);
|
||||||
|
assert(srcDims[1] == K);
|
||||||
|
|
||||||
|
// Push node
|
||||||
|
std::shared_ptr<Node> node;
|
||||||
|
|
||||||
|
if (auto tf = std::dynamic_pointer_cast<LinearTransferFunction>(transferFunc))
|
||||||
|
node = std::make_shared<OutputReorderNode<K, LinearTransferFunction>>(src, output, tf);
|
||||||
|
else if (auto tf = std::dynamic_pointer_cast<GammaTransferFunction>(transferFunc))
|
||||||
|
node = std::make_shared<OutputReorderNode<K, GammaTransferFunction>>(src, output, tf);
|
||||||
|
else if (auto tf = std::dynamic_pointer_cast<LogTransferFunction>(transferFunc))
|
||||||
|
node = std::make_shared<OutputReorderNode<K, LogTransferFunction>>(src, output, tf);
|
||||||
|
else if (auto tf = std::dynamic_pointer_cast<PQXTransferFunction>(transferFunc))
|
||||||
|
node = std::make_shared<OutputReorderNode<K, PQXTransferFunction>>(src, output, tf);
|
||||||
|
else
|
||||||
|
assert(0);
|
||||||
|
|
||||||
|
nodes.push_back(node);
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int K>
|
||||||
|
memory::dims Network<K>::getConvDims(const std::string& name, const memory::dims& srcDims)
|
||||||
|
{
|
||||||
|
auto b = weightMap[name + "/b"];
|
||||||
|
memory::dims dstDims = srcDims;
|
||||||
|
dstDims[1] = getPadded<K>(b.dims[0]); // dstDims[C] = getPadded(OC)
|
||||||
|
return dstDims;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int K>
|
||||||
|
std::shared_ptr<Node> Network<K>::addConv(const std::string& name,
|
||||||
|
const std::shared_ptr<memory>& src,
|
||||||
|
const std::shared_ptr<memory>& userDst,
|
||||||
|
bool relu)
|
||||||
|
{
|
||||||
|
const memory::dims strides = {1, 1};
|
||||||
|
const memory::dims padding = {1, 1};
|
||||||
|
|
||||||
|
memory::dims srcDims = getTensorDims(src);
|
||||||
|
|
||||||
|
// Get the weights
|
||||||
|
const auto& W = weightMap[name + "/W"];
|
||||||
|
if (W.ndims() != 4 || W.format != "oihw")
|
||||||
|
throw Exception(Error::InvalidOperation, "invalid convolution weights");
|
||||||
|
memory::dims weightsDims = W.dims;
|
||||||
|
auto userWeights = allocTensor(weightsDims, memory::format_tag::oihw, W.data);
|
||||||
|
|
||||||
|
// Pad the weights
|
||||||
|
memory::dims weightsPadDims = weightsDims;
|
||||||
|
weightsPadDims[1] = getPadded<K>(weightsDims[1]); // IC
|
||||||
|
weightsPadDims[0] = getPadded<K>(weightsDims[0]); // OC
|
||||||
|
assert(srcDims[1] == weightsPadDims[1]); // srcDims[C] == weightsPadDims[IC]
|
||||||
|
auto weightsPad = allocTensor(weightsPadDims, memory::format_tag::oihw);
|
||||||
|
WeightsReorderNode<K>(userWeights, weightsPad).execute(sm);
|
||||||
|
|
||||||
|
// Get the biases
|
||||||
|
const auto& b = weightMap[name + "/b"];
|
||||||
|
if (b.ndims() != 1)
|
||||||
|
throw Exception(Error::InvalidOperation, "invalid convolution biases");
|
||||||
|
memory::dims biasDims = b.dims;
|
||||||
|
|
||||||
|
// Copy/pad the biases
|
||||||
|
memory::dims biasPadDims = {getPadded<K>(biasDims[0])};
|
||||||
|
auto bias = allocTensor(biasPadDims);
|
||||||
|
if (biasDims[0] != biasPadDims[0])
|
||||||
|
memset(bias->get_data_handle(), 0, biasPadDims[0]*sizeof(float));
|
||||||
|
memcpy(bias->get_data_handle(), b.data, biasDims[0]*sizeof(float));
|
||||||
|
|
||||||
|
// Allocate memory for destination
|
||||||
|
memory::dims dstDims = srcDims;
|
||||||
|
dstDims[1] = weightsPadDims[0]; // dstDims[C] = weightsPadDims[OC]
|
||||||
|
|
||||||
|
std::shared_ptr<memory> dst;
|
||||||
|
if (!userDst)
|
||||||
|
dst = allocTensor(dstDims);
|
||||||
|
else if (getTensorDims(userDst) == dstDims)
|
||||||
|
dst = userDst;
|
||||||
|
else
|
||||||
|
dst = castTensor(dstDims, userDst);
|
||||||
|
|
||||||
|
// Create a convolution
|
||||||
|
// Let the convolution primitive choose the weights format
|
||||||
|
auto weightsDesc = memory::desc({ weightsPadDims }, memory::data_type::f32, memory::format_tag::any);
|
||||||
|
|
||||||
|
auto convAlgo = (K == 16) ? convolution_winograd : convolution_direct;
|
||||||
|
auto convDesc = convolution_forward::desc(
|
||||||
|
prop_kind::forward_inference, convAlgo,
|
||||||
|
src->get_desc(),
|
||||||
|
weightsDesc,
|
||||||
|
bias->get_desc(),
|
||||||
|
dst->get_desc(),
|
||||||
|
strides, padding, padding, padding_kind::zero);
|
||||||
|
|
||||||
|
// Incorporate relu
|
||||||
|
mkldnn::primitive_attr convAttr;
|
||||||
|
if (relu)
|
||||||
|
{
|
||||||
|
mkldnn::post_ops ops;
|
||||||
|
ops.append_eltwise(
|
||||||
|
1.f, // scale factor, not used
|
||||||
|
algorithm::eltwise_relu,
|
||||||
|
0.f, // max with
|
||||||
|
0.f // unused
|
||||||
|
);
|
||||||
|
convAttr.set_post_ops(ops);
|
||||||
|
}
|
||||||
|
convAttr.set_scratchpad_mode(scratchpad_mode_user);
|
||||||
|
|
||||||
|
auto convPrimDesc = convolution_forward::primitive_desc(convDesc, convAttr, eng);
|
||||||
|
|
||||||
|
// Reorder the weights to the final format, if necessary
|
||||||
|
auto weights = weightsPad;
|
||||||
|
if (convPrimDesc.weights_desc() != weightsPad->get_desc())
|
||||||
|
{
|
||||||
|
weights = std::make_shared<memory>(convPrimDesc.weights_desc(), eng);
|
||||||
|
ReorderNode(weightsPad, weights).execute(sm);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create convolution node and add it to the net
|
||||||
|
auto node = std::make_shared<ConvNode>(convPrimDesc, src, weights, bias, dst);
|
||||||
|
nodes.push_back(node);
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int K>
|
||||||
|
memory::dims Network<K>::getPoolDims(const memory::dims& srcDims)
|
||||||
|
{
|
||||||
|
memory::dims dstDims = srcDims;
|
||||||
|
dstDims[2] /= 2; // H/2
|
||||||
|
dstDims[3] /= 2; // W/2
|
||||||
|
return dstDims;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int K>
|
||||||
|
std::shared_ptr<Node> Network<K>::addPool(const std::shared_ptr<memory>& src,
|
||||||
|
const std::shared_ptr<memory>& userDst)
|
||||||
|
{
|
||||||
|
const memory::dims kernel = {2, 2};
|
||||||
|
const memory::dims strides = {2, 2};
|
||||||
|
const memory::dims padding = {0, 0};
|
||||||
|
|
||||||
|
memory::dims srcDims = getTensorDims(src);
|
||||||
|
memory::dims dstDims = getPoolDims(srcDims);
|
||||||
|
|
||||||
|
std::shared_ptr<memory> dst;
|
||||||
|
if (!userDst)
|
||||||
|
dst = allocTensor(dstDims);
|
||||||
|
else if (getTensorDims(userDst) == dstDims)
|
||||||
|
dst = userDst;
|
||||||
|
else
|
||||||
|
dst = castTensor(dstDims, userDst);
|
||||||
|
|
||||||
|
auto poolDesc = pooling_forward::desc(
|
||||||
|
prop_kind::forward_inference, pooling_max,
|
||||||
|
src->get_desc(),
|
||||||
|
dst->get_desc(),
|
||||||
|
strides, kernel, padding, padding, padding_kind::zero);
|
||||||
|
|
||||||
|
mkldnn::primitive_attr poolAttr;
|
||||||
|
poolAttr.set_scratchpad_mode(scratchpad_mode_user);
|
||||||
|
|
||||||
|
auto poolPrimDesc = pooling_forward::primitive_desc(poolDesc, poolAttr, eng);
|
||||||
|
|
||||||
|
auto node = std::make_shared<PoolNode>(poolPrimDesc, src, dst);
|
||||||
|
nodes.push_back(node);
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int K>
|
||||||
|
memory::dims Network<K>::getUpsampleDims(const memory::dims& srcDims)
|
||||||
|
{
|
||||||
|
memory::dims dstDims = srcDims;
|
||||||
|
dstDims[2] *= 2; // H*2
|
||||||
|
dstDims[3] *= 2; // W*2
|
||||||
|
return dstDims;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int K>
|
||||||
|
std::shared_ptr<Node> Network<K>::addUpsample(const std::shared_ptr<memory>& src,
|
||||||
|
const std::shared_ptr<memory>& userDst)
|
||||||
|
{
|
||||||
|
memory::dims srcDims = getTensorDims(src);
|
||||||
|
memory::dims dstDims = getUpsampleDims(srcDims);
|
||||||
|
|
||||||
|
std::shared_ptr<memory> dst;
|
||||||
|
if (!userDst)
|
||||||
|
dst = allocTensor(dstDims);
|
||||||
|
else if (getTensorDims(userDst) == dstDims)
|
||||||
|
dst = userDst;
|
||||||
|
else
|
||||||
|
dst = castTensor(dstDims, userDst);
|
||||||
|
|
||||||
|
// Create upsampling node and add it to net
|
||||||
|
auto node = std::make_shared<UpsampleNode<K>>(src, dst);
|
||||||
|
nodes.push_back(node);
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int K>
|
||||||
|
memory::dims Network<K>::getConcatDims(const memory::dims& src1Dims, const memory::dims& src2Dims)
|
||||||
|
{
|
||||||
|
assert(src1Dims[0] == src2Dims[0]); // N
|
||||||
|
assert(src1Dims[2] == src2Dims[2]); // H
|
||||||
|
assert(src1Dims[3] == src2Dims[3]); // W
|
||||||
|
|
||||||
|
memory::dims dstDims = src1Dims;
|
||||||
|
dstDims[1] += src2Dims[1]; // C
|
||||||
|
return dstDims;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int K>
|
||||||
|
std::shared_ptr<Node> Network<K>::addAutoexposure(const Image& color,
|
||||||
|
const std::shared_ptr<HDRTransferFunction>& transferFunc)
|
||||||
|
{
|
||||||
|
auto node = std::make_shared<AutoexposureNode>(color, transferFunc);
|
||||||
|
nodes.push_back(node);
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int K>
|
||||||
|
void Network<K>::finalize()
|
||||||
|
{
|
||||||
|
// Compute the size of the scratchpad
|
||||||
|
size_t scratchpadSize = 0;
|
||||||
|
for (const auto& node : nodes)
|
||||||
|
scratchpadSize = max(scratchpadSize, node->getScratchpadSize());
|
||||||
|
|
||||||
|
// Allocate the scratchpad
|
||||||
|
memory::dims scratchpadDims = { memory::dim(scratchpadSize) };
|
||||||
|
memory::desc scratchpadDesc(scratchpadDims, memory::data_type::u8, memory::format_tag::x);
|
||||||
|
auto scratchpad = std::make_shared<memory>(scratchpadDesc, eng);
|
||||||
|
activationAllocBytes += scratchpadSize;
|
||||||
|
totalAllocBytes += scratchpadSize;
|
||||||
|
|
||||||
|
// Set the scratchpad for the nodes
|
||||||
|
for (auto& node : nodes)
|
||||||
|
node->setScratchpad(scratchpad);
|
||||||
|
|
||||||
|
// Free the weights
|
||||||
|
weightMap.clear();
|
||||||
|
|
||||||
|
// Print statistics
|
||||||
|
if (device->isVerbose(2))
|
||||||
|
{
|
||||||
|
std::cout << "Activation bytes: " << activationAllocBytes << std::endl;
|
||||||
|
std::cout << "Scratchpad bytes: " << scratchpadSize << std::endl;
|
||||||
|
std::cout << "Total bytes : " << totalAllocBytes << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template class Network<8>;
|
||||||
|
template class Network<16>;
|
||||||
|
|
||||||
|
} // namespace oidn
|
112
thirdparty/oidn/core/network.h
vendored
Normal file
112
thirdparty/oidn/core/network.h
vendored
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
// ======================================================================== //
|
||||||
|
// Copyright 2009-2019 Intel Corporation //
|
||||||
|
// //
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||||||
|
// you may not use this file except in compliance with the License. //
|
||||||
|
// You may obtain a copy of the License at //
|
||||||
|
// //
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||||||
|
// //
|
||||||
|
// Unless required by applicable law or agreed to in writing, software //
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||||||
|
// See the License for the specific language governing permissions and //
|
||||||
|
// limitations under the License. //
|
||||||
|
// ======================================================================== //
|
||||||
|
|
||||||
|
#include "common/tensor.h"
|
||||||
|
#include "image.h"
|
||||||
|
#include "node.h"
|
||||||
|
#include "input_reorder.h"
|
||||||
|
#include "output_reorder.h"
|
||||||
|
#include "transfer_function.h"
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace oidn {
|
||||||
|
|
||||||
|
// Progress state
|
||||||
|
struct Progress
|
||||||
|
{
|
||||||
|
ProgressMonitorFunction func;
|
||||||
|
void* userPtr;
|
||||||
|
int taskCount;
|
||||||
|
};
|
||||||
|
|
||||||
|
class Executable
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
virtual ~Executable() {}
|
||||||
|
virtual void execute(const Progress& progress, int taskIndex) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<int K>
|
||||||
|
class Network : public Executable
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
Network(const Ref<Device>& device, const std::map<std::string, Tensor>& weightMap);
|
||||||
|
|
||||||
|
void execute(const Progress& progress, int taskIndex) override;
|
||||||
|
|
||||||
|
std::shared_ptr<memory> allocTensor(const memory::dims& dims,
|
||||||
|
memory::format_tag format = memory::format_tag::any,
|
||||||
|
void* data = nullptr);
|
||||||
|
|
||||||
|
std::shared_ptr<memory> castTensor(const memory::dims& dims,
|
||||||
|
const std::shared_ptr<memory>& src,
|
||||||
|
size_t srcOffset = 0,
|
||||||
|
memory::format_tag format = memory::format_tag::any);
|
||||||
|
|
||||||
|
std::shared_ptr<memory> castTensor(const memory::dims& dims,
|
||||||
|
const std::shared_ptr<memory>& src,
|
||||||
|
const memory::dims& srcOffset);
|
||||||
|
|
||||||
|
void zeroTensor(const std::shared_ptr<memory>& dst);
|
||||||
|
|
||||||
|
memory::dims getInputReorderDims(const memory::dims& srcDims, int alignment);
|
||||||
|
|
||||||
|
std::shared_ptr<Node> addInputReorder(const Image& color,
|
||||||
|
const Image& albedo,
|
||||||
|
const Image& normal,
|
||||||
|
const std::shared_ptr<TransferFunction>& transferFunc,
|
||||||
|
int alignment,
|
||||||
|
const std::shared_ptr<memory>& userDst = nullptr);
|
||||||
|
|
||||||
|
std::shared_ptr<Node> addOutputReorder(const std::shared_ptr<memory>& src,
|
||||||
|
const std::shared_ptr<TransferFunction>& transferFunc,
|
||||||
|
const Image& output);
|
||||||
|
|
||||||
|
memory::dims getConvDims(const std::string& name, const memory::dims& srcDims);
|
||||||
|
std::shared_ptr<Node> addConv(const std::string& name,
|
||||||
|
const std::shared_ptr<memory>& src,
|
||||||
|
const std::shared_ptr<memory>& userDst = nullptr,
|
||||||
|
bool relu = true);
|
||||||
|
|
||||||
|
memory::dims getPoolDims(const memory::dims& srcDims);
|
||||||
|
std::shared_ptr<Node> addPool(const std::shared_ptr<memory>& src,
|
||||||
|
const std::shared_ptr<memory>& userDst = nullptr);
|
||||||
|
|
||||||
|
memory::dims getUpsampleDims(const memory::dims& srcDims);
|
||||||
|
std::shared_ptr<Node> addUpsample(const std::shared_ptr<memory>& src,
|
||||||
|
const std::shared_ptr<memory>& userDst = nullptr);
|
||||||
|
|
||||||
|
memory::dims getConcatDims(const memory::dims& src1Dims, const memory::dims& src2Dims);
|
||||||
|
|
||||||
|
std::shared_ptr<Node> addAutoexposure(const Image& color,
|
||||||
|
const std::shared_ptr<HDRTransferFunction>& transferFunc);
|
||||||
|
|
||||||
|
void finalize();
|
||||||
|
|
||||||
|
private:
|
||||||
|
Ref<Device> device;
|
||||||
|
engine eng;
|
||||||
|
stream sm;
|
||||||
|
std::vector<std::shared_ptr<Node>> nodes;
|
||||||
|
std::map<std::string, Tensor> weightMap;
|
||||||
|
|
||||||
|
// Memory allocation statistics
|
||||||
|
size_t activationAllocBytes = 0; // number of allocated activation bytes
|
||||||
|
size_t totalAllocBytes = 0; // total number of allocated bytes
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace oidn
|
142
thirdparty/oidn/core/node.h
vendored
Normal file
142
thirdparty/oidn/core/node.h
vendored
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
// ======================================================================== //
|
||||||
|
// Copyright 2009-2019 Intel Corporation //
|
||||||
|
// //
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||||||
|
// you may not use this file except in compliance with the License. //
|
||||||
|
// You may obtain a copy of the License at //
|
||||||
|
// //
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||||||
|
// //
|
||||||
|
// Unless required by applicable law or agreed to in writing, software //
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||||||
|
// See the License for the specific language governing permissions and //
|
||||||
|
// limitations under the License. //
|
||||||
|
// ======================================================================== //
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "common.h"
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace oidn {
|
||||||
|
|
||||||
|
class Node
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
virtual ~Node() = default;
|
||||||
|
|
||||||
|
virtual void execute(stream& sm) = 0;
|
||||||
|
|
||||||
|
virtual std::shared_ptr<memory> getDst() const { return nullptr; }
|
||||||
|
|
||||||
|
virtual size_t getScratchpadSize() const { return 0; }
|
||||||
|
virtual void setScratchpad(const std::shared_ptr<memory>& mem) {}
|
||||||
|
|
||||||
|
virtual void setTile(int h1, int w1, int h2, int w2, int H, int W)
|
||||||
|
{
|
||||||
|
assert(0); // not supported
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Node wrapping an MKL-DNN primitive
|
||||||
|
class MklNode : public Node
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
primitive prim;
|
||||||
|
std::unordered_map<int, memory> args;
|
||||||
|
std::shared_ptr<memory> scratchpad;
|
||||||
|
|
||||||
|
public:
|
||||||
|
MklNode(const primitive& prim, const std::unordered_map<int, memory>& args)
|
||||||
|
: prim(prim),
|
||||||
|
args(args)
|
||||||
|
{}
|
||||||
|
|
||||||
|
size_t getScratchpadSize() const override
|
||||||
|
{
|
||||||
|
const auto primDesc = prim.get_primitive_desc();
|
||||||
|
const mkldnn_memory_desc_t* scratchpadDesc = mkldnn_primitive_desc_query_md(primDesc, mkldnn_query_scratchpad_md, 0);
|
||||||
|
if (scratchpadDesc == nullptr)
|
||||||
|
return 0;
|
||||||
|
return mkldnn_memory_desc_get_size(scratchpadDesc);
|
||||||
|
}
|
||||||
|
|
||||||
|
void setScratchpad(const std::shared_ptr<memory>& mem) override
|
||||||
|
{
|
||||||
|
scratchpad = mem;
|
||||||
|
args.insert(std::make_pair(MKLDNN_ARG_SCRATCHPAD, *scratchpad));
|
||||||
|
}
|
||||||
|
|
||||||
|
void execute(stream& sm) override
|
||||||
|
{
|
||||||
|
prim.execute(sm, args);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Convolution node
|
||||||
|
class ConvNode : public MklNode
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
std::shared_ptr<memory> src;
|
||||||
|
std::shared_ptr<memory> weights;
|
||||||
|
std::shared_ptr<memory> bias;
|
||||||
|
std::shared_ptr<memory> dst;
|
||||||
|
|
||||||
|
public:
|
||||||
|
ConvNode(const convolution_forward::primitive_desc& desc,
|
||||||
|
const std::shared_ptr<memory>& src,
|
||||||
|
const std::shared_ptr<memory>& weights,
|
||||||
|
const std::shared_ptr<memory>& bias,
|
||||||
|
const std::shared_ptr<memory>& dst)
|
||||||
|
: MklNode(convolution_forward(desc),
|
||||||
|
{ { MKLDNN_ARG_SRC, *src },
|
||||||
|
{ MKLDNN_ARG_WEIGHTS, *weights },
|
||||||
|
{ MKLDNN_ARG_BIAS, *bias },
|
||||||
|
{ MKLDNN_ARG_DST, *dst } }),
|
||||||
|
src(src), weights(weights), bias(bias), dst(dst)
|
||||||
|
{}
|
||||||
|
|
||||||
|
std::shared_ptr<memory> getDst() const override { return dst; }
|
||||||
|
};
|
||||||
|
|
||||||
|
// Pooling node
|
||||||
|
class PoolNode : public MklNode
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
std::shared_ptr<memory> src;
|
||||||
|
std::shared_ptr<memory> dst;
|
||||||
|
|
||||||
|
public:
|
||||||
|
PoolNode(const pooling_forward::primitive_desc& desc,
|
||||||
|
const std::shared_ptr<memory>& src,
|
||||||
|
const std::shared_ptr<memory>& dst)
|
||||||
|
: MklNode(pooling_forward(desc),
|
||||||
|
{ { MKLDNN_ARG_SRC, *src },
|
||||||
|
{ MKLDNN_ARG_DST, *dst } }),
|
||||||
|
src(src), dst(dst)
|
||||||
|
{}
|
||||||
|
|
||||||
|
std::shared_ptr<memory> getDst() const override { return dst; }
|
||||||
|
};
|
||||||
|
|
||||||
|
// Reorder node
|
||||||
|
class ReorderNode : public MklNode
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
std::shared_ptr<memory> src;
|
||||||
|
std::shared_ptr<memory> dst;
|
||||||
|
|
||||||
|
public:
|
||||||
|
ReorderNode(const std::shared_ptr<memory>& src,
|
||||||
|
const std::shared_ptr<memory>& dst)
|
||||||
|
: MklNode(reorder(reorder::primitive_desc(*src, *dst)),
|
||||||
|
{ { MKLDNN_ARG_SRC, *src },
|
||||||
|
{ MKLDNN_ARG_DST, *dst } }),
|
||||||
|
src(src), dst(dst)
|
||||||
|
{}
|
||||||
|
|
||||||
|
std::shared_ptr<memory> getDst() const override { return dst; }
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace oidn
|
126
thirdparty/oidn/core/output_reorder.h
vendored
Normal file
126
thirdparty/oidn/core/output_reorder.h
vendored
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
// ======================================================================== //
|
||||||
|
// Copyright 2009-2019 Intel Corporation //
|
||||||
|
// //
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||||||
|
// you may not use this file except in compliance with the License. //
|
||||||
|
// You may obtain a copy of the License at //
|
||||||
|
// //
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||||||
|
// //
|
||||||
|
// Unless required by applicable law or agreed to in writing, software //
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||||||
|
// See the License for the specific language governing permissions and //
|
||||||
|
// limitations under the License. //
|
||||||
|
// ======================================================================== //
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "node.h"
|
||||||
|
#include "image.h"
|
||||||
|
|
||||||
|
namespace oidn {
|
||||||
|
|
||||||
|
// Output reorder node
|
||||||
|
template<int K, class TransferFunction>
|
||||||
|
class OutputReorderNode : public Node
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
// Source
|
||||||
|
std::shared_ptr<memory> src;
|
||||||
|
const float* srcPtr;
|
||||||
|
int H1;
|
||||||
|
int W1;
|
||||||
|
|
||||||
|
// Destination
|
||||||
|
Image output;
|
||||||
|
|
||||||
|
// Tile
|
||||||
|
int h1Begin;
|
||||||
|
int w1Begin;
|
||||||
|
int h2Begin;
|
||||||
|
int w2Begin;
|
||||||
|
int H;
|
||||||
|
int W;
|
||||||
|
|
||||||
|
std::shared_ptr<TransferFunction> transferFunc;
|
||||||
|
|
||||||
|
public:
|
||||||
|
OutputReorderNode(const std::shared_ptr<memory>& src,
|
||||||
|
const Image& output,
|
||||||
|
const std::shared_ptr<TransferFunction>& transferFunc)
|
||||||
|
: src(src),
|
||||||
|
output(output),
|
||||||
|
h1Begin(0), w1Begin(0),
|
||||||
|
h2Begin(0), w2Begin(0),
|
||||||
|
H(output.height), W(output.width),
|
||||||
|
transferFunc(transferFunc)
|
||||||
|
{
|
||||||
|
const mkldnn_memory_desc_t& srcDesc = src->get_desc().data;
|
||||||
|
MAYBE_UNUSED(srcDesc);
|
||||||
|
assert(memory_desc_matches_tag(srcDesc, mkldnn_format_tag_t(BlockedFormat<K>::nChwKc)));
|
||||||
|
assert(srcDesc.ndims == 4);
|
||||||
|
assert(srcDesc.data_type == memory::data_type::f32);
|
||||||
|
assert(srcDesc.dims[0] == 1);
|
||||||
|
// We assume output data is <= K OC
|
||||||
|
assert(srcDesc.dims[1] == K);
|
||||||
|
|
||||||
|
srcPtr = (float*)src->get_data_handle();
|
||||||
|
H1 = srcDesc.dims[2];
|
||||||
|
W1 = srcDesc.dims[3];
|
||||||
|
}
|
||||||
|
|
||||||
|
void setTile(int h1, int w1, int h2, int w2, int H, int W) override
|
||||||
|
{
|
||||||
|
h1Begin = h1;
|
||||||
|
w1Begin = w1;
|
||||||
|
h2Begin = h2;
|
||||||
|
w2Begin = w2;
|
||||||
|
this->H = H;
|
||||||
|
this->W = W;
|
||||||
|
}
|
||||||
|
|
||||||
|
void execute(stream& sm) override
|
||||||
|
{
|
||||||
|
assert(h1Begin + H <= H1);
|
||||||
|
assert(w1Begin + W <= W1);
|
||||||
|
assert(h2Begin + H <= output.height);
|
||||||
|
assert(w2Begin + W <= output.width);
|
||||||
|
|
||||||
|
const int C1 = K;
|
||||||
|
|
||||||
|
parallel_nd(H, [&](int h)
|
||||||
|
{
|
||||||
|
const int h1 = h + h1Begin;
|
||||||
|
const int h2 = h + h2Begin;
|
||||||
|
|
||||||
|
for (int w = 0; w < W; ++w)
|
||||||
|
{
|
||||||
|
const int w1 = w + w1Begin;
|
||||||
|
const int w2 = w + w2Begin;
|
||||||
|
float* dstPtr_C = (float*)output.get(h2, w2);
|
||||||
|
|
||||||
|
// Source is in nChwKc format. In this case C is 1 so this is really nhwc
|
||||||
|
const float* srcPtr_C = srcPtr + h1*W1*C1 + w1*C1;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 3; ++i)
|
||||||
|
{
|
||||||
|
// Load the value
|
||||||
|
float x = srcPtr_C[i];
|
||||||
|
|
||||||
|
// The CNN output may contain negative values or even NaNs, so it must be sanitized
|
||||||
|
x = maxSafe(x, 0.f);
|
||||||
|
|
||||||
|
// Apply the inverse transfer function
|
||||||
|
x = transferFunc->inverse(x);
|
||||||
|
|
||||||
|
// Sanitize and store the final value
|
||||||
|
dstPtr_C[i] = max(x, 0.f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace oidn
|
103
thirdparty/oidn/core/transfer_function.cpp
vendored
Normal file
103
thirdparty/oidn/core/transfer_function.cpp
vendored
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
// ======================================================================== //
|
||||||
|
// Copyright 2009-2019 Intel Corporation //
|
||||||
|
// //
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||||||
|
// you may not use this file except in compliance with the License. //
|
||||||
|
// You may obtain a copy of the License at //
|
||||||
|
// //
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||||||
|
// //
|
||||||
|
// Unless required by applicable law or agreed to in writing, software //
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||||||
|
// See the License for the specific language governing permissions and //
|
||||||
|
// limitations under the License. //
|
||||||
|
// ======================================================================== //
|
||||||
|
|
||||||
|
#include "transfer_function.h"
|
||||||
|
|
||||||
|
namespace oidn {
|
||||||
|
|
||||||
|
const float LogTransferFunction::xScale = 1.f / log(LogTransferFunction::yMax + 1.f);
|
||||||
|
const float PQXTransferFunction::xScale = 1.f / PQXTransferFunction::pqxForward(PQXTransferFunction::yMax * PQXTransferFunction::yScale);
|
||||||
|
|
||||||
|
float AutoexposureNode::autoexposure(const Image& color)
|
||||||
|
{
|
||||||
|
assert(color.format == Format::Float3);
|
||||||
|
|
||||||
|
constexpr float key = 0.18f;
|
||||||
|
constexpr float eps = 1e-8f;
|
||||||
|
constexpr int K = 16; // downsampling amount
|
||||||
|
|
||||||
|
// Downsample the image to minimize sensitivity to noise
|
||||||
|
const int H = color.height; // original height
|
||||||
|
const int W = color.width; // original width
|
||||||
|
const int HK = (H + K/2) / K; // downsampled height
|
||||||
|
const int WK = (W + K/2) / K; // downsampled width
|
||||||
|
|
||||||
|
// Compute the average log luminance of the downsampled image
|
||||||
|
using Sum = std::pair<float, int>;
|
||||||
|
|
||||||
|
// -- GODOT start --
|
||||||
|
// Sum sum =
|
||||||
|
// tbb::parallel_reduce(
|
||||||
|
// tbb::blocked_range2d<int>(0, HK, 0, WK),
|
||||||
|
// Sum(0.f, 0),
|
||||||
|
// [&](const tbb::blocked_range2d<int>& r, Sum sum) -> Sum
|
||||||
|
// {
|
||||||
|
// // Iterate over blocks
|
||||||
|
// for (int i = r.rows().begin(); i != r.rows().end(); ++i)
|
||||||
|
// {
|
||||||
|
// for (int j = r.cols().begin(); j != r.cols().end(); ++j)
|
||||||
|
// {
|
||||||
|
|
||||||
|
Sum sum = Sum(0.0f, 0);
|
||||||
|
|
||||||
|
for (int i = 0; i != HK; ++i)
|
||||||
|
{
|
||||||
|
for (int j = 0; j != WK; ++j)
|
||||||
|
{
|
||||||
|
// Compute the average luminance in the current block
|
||||||
|
const int beginH = int(ptrdiff_t(i) * H / HK);
|
||||||
|
const int beginW = int(ptrdiff_t(j) * W / WK);
|
||||||
|
const int endH = int(ptrdiff_t(i+1) * H / HK);
|
||||||
|
const int endW = int(ptrdiff_t(j+1) * W / WK);
|
||||||
|
|
||||||
|
float L = 0.f;
|
||||||
|
|
||||||
|
for (int h = beginH; h < endH; ++h)
|
||||||
|
{
|
||||||
|
for (int w = beginW; w < endW; ++w)
|
||||||
|
{
|
||||||
|
const float* rgb = (const float*)color.get(h, w);
|
||||||
|
|
||||||
|
const float r = maxSafe(rgb[0], 0.f);
|
||||||
|
const float g = maxSafe(rgb[1], 0.f);
|
||||||
|
const float b = maxSafe(rgb[2], 0.f);
|
||||||
|
|
||||||
|
L += luminance(r, g, b);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
L /= (endH - beginH) * (endW - beginW);
|
||||||
|
|
||||||
|
// Accumulate the log luminance
|
||||||
|
if (L > eps)
|
||||||
|
{
|
||||||
|
sum.first += log2(L);
|
||||||
|
sum.second++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// return sum;
|
||||||
|
// },
|
||||||
|
// [](Sum a, Sum b) -> Sum { return Sum(a.first+b.first, a.second+b.second); },
|
||||||
|
// tbb::static_partitioner()
|
||||||
|
// );
|
||||||
|
// -- GODOT end --
|
||||||
|
|
||||||
|
return (sum.second > 0) ? (key / exp2(sum.first / float(sum.second))) : 1.f;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace oidn
|
201
thirdparty/oidn/core/transfer_function.h
vendored
Normal file
201
thirdparty/oidn/core/transfer_function.h
vendored
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
// ======================================================================== //
|
||||||
|
// Copyright 2009-2019 Intel Corporation //
|
||||||
|
// //
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||||||
|
// you may not use this file except in compliance with the License. //
|
||||||
|
// You may obtain a copy of the License at //
|
||||||
|
// //
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||||||
|
// //
|
||||||
|
// Unless required by applicable law or agreed to in writing, software //
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||||||
|
// See the License for the specific language governing permissions and //
|
||||||
|
// limitations under the License. //
|
||||||
|
// ======================================================================== //
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "image.h"
|
||||||
|
#include "node.h"
|
||||||
|
|
||||||
|
namespace oidn {
|
||||||
|
|
||||||
|
__forceinline float luminance(float r, float g, float b)
|
||||||
|
{
|
||||||
|
return 0.212671f * r + 0.715160f * g + 0.072169f * b;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Color transfer function base class
|
||||||
|
class TransferFunction
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
virtual ~TransferFunction() = default;
|
||||||
|
|
||||||
|
virtual float forward(float y) const = 0;
|
||||||
|
virtual float inverse(float x) const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
// HDR transfer function base class
|
||||||
|
class HDRTransferFunction : public TransferFunction
|
||||||
|
{
|
||||||
|
protected:
|
||||||
|
static constexpr float yMax = 65504.f;
|
||||||
|
|
||||||
|
float exposure;
|
||||||
|
float rcpExposure;
|
||||||
|
|
||||||
|
public:
|
||||||
|
HDRTransferFunction(float exposure = 1.f)
|
||||||
|
{
|
||||||
|
setExposure(exposure);
|
||||||
|
}
|
||||||
|
|
||||||
|
void setExposure(float exposure)
|
||||||
|
{
|
||||||
|
this->exposure = exposure;
|
||||||
|
this->rcpExposure = (exposure != 0.f) ? (1.f / exposure) : 0.f;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Linear transfer function (LDR)
|
||||||
|
class LinearTransferFunction : public TransferFunction
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
__forceinline float forward(float y) const override
|
||||||
|
{
|
||||||
|
return min(y, 1.f);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline float inverse(float x) const override
|
||||||
|
{
|
||||||
|
return min(x, 1.f);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// 2.2 gamma transfer function (LDR)
|
||||||
|
class GammaTransferFunction : public TransferFunction
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
__forceinline float forward(float y) const override
|
||||||
|
{
|
||||||
|
return min(pow(y, 1.f/2.2f), 1.f);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline float inverse(float x) const override
|
||||||
|
{
|
||||||
|
return min(pow(x, 2.2f), 1.f);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Logarithmic transfer function (HDR)
|
||||||
|
// Compresses [0..65504] to [0..1]
|
||||||
|
class LogTransferFunction : public HDRTransferFunction
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
static const float xScale;
|
||||||
|
|
||||||
|
public:
|
||||||
|
LogTransferFunction(float exposure = 1.f)
|
||||||
|
: HDRTransferFunction(exposure)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline float forward(float y) const override
|
||||||
|
{
|
||||||
|
return log(y * exposure + 1.f) * xScale;
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline float inverse(float x) const override
|
||||||
|
{
|
||||||
|
return (exp(x * (1.f/xScale)) - 1.f) * rcpExposure;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// PQX transfer function (HDR)
|
||||||
|
// Compresses [0..65504] to [0..1]
|
||||||
|
class PQXTransferFunction : public HDRTransferFunction
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
static constexpr float m1 = 2610.f / 4096.f / 4.f;
|
||||||
|
static constexpr float m2 = 2523.f / 4096.f * 128.f;
|
||||||
|
static constexpr float c1 = 3424.f / 4096.f;
|
||||||
|
static constexpr float c2 = 2413.f / 4096.f * 32.f;
|
||||||
|
static constexpr float c3 = 2392.f / 4096.f * 32.f;
|
||||||
|
static constexpr float a = 3711.f / 4096.f / 8.f;
|
||||||
|
|
||||||
|
static constexpr float yScale = 100.f / 10000.f;
|
||||||
|
static const float xScale;
|
||||||
|
|
||||||
|
public:
|
||||||
|
PQXTransferFunction(float exposure = 1.f)
|
||||||
|
: HDRTransferFunction(exposure)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline float forward(float y) const override
|
||||||
|
{
|
||||||
|
return pqxForward(y * exposure * yScale) * xScale;
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline float inverse(float x) const override
|
||||||
|
{
|
||||||
|
return pqxInverse(x * (1.f/xScale)) * (1.f/yScale) * rcpExposure;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
static __forceinline float pqForward(float y)
|
||||||
|
{
|
||||||
|
const float yp = pow(y, m1);
|
||||||
|
return pow((c1 + c2 * yp) * rcp(1.f + c3 * yp), m2);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __forceinline float pqxForward(float y)
|
||||||
|
{
|
||||||
|
if (y <= 1.f)
|
||||||
|
return pqForward(y);
|
||||||
|
else
|
||||||
|
return a * log(y) + 1.f;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __forceinline float pqInverse(float x)
|
||||||
|
{
|
||||||
|
const float xp = pow(x, 1.f/m2);
|
||||||
|
return pow(max((xp - c1) * rcp(c2 - c3 * xp), 0.f), 1.f/m1);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __forceinline float pqxInverse(float x)
|
||||||
|
{
|
||||||
|
if (x <= 1.f)
|
||||||
|
return pqInverse(x);
|
||||||
|
else
|
||||||
|
return exp((x - 1.f) * (1.f/a));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Autoexposure node
|
||||||
|
class AutoexposureNode : public Node
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
Image color;
|
||||||
|
std::shared_ptr<HDRTransferFunction> transferFunc;
|
||||||
|
|
||||||
|
public:
|
||||||
|
AutoexposureNode(const Image& color,
|
||||||
|
const std::shared_ptr<HDRTransferFunction>& transferFunc)
|
||||||
|
: color(color),
|
||||||
|
transferFunc(transferFunc)
|
||||||
|
{}
|
||||||
|
|
||||||
|
void execute(stream& sm) override
|
||||||
|
{
|
||||||
|
const float exposure = autoexposure(color);
|
||||||
|
//printf("exposure = %f\n", exposure);
|
||||||
|
transferFunc->setExposure(exposure);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
static float autoexposure(const Image& color);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace oidn
|
92
thirdparty/oidn/core/upsample.h
vendored
Normal file
92
thirdparty/oidn/core/upsample.h
vendored
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
// ======================================================================== //
|
||||||
|
// Copyright 2009-2019 Intel Corporation //
|
||||||
|
// //
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||||||
|
// you may not use this file except in compliance with the License. //
|
||||||
|
// You may obtain a copy of the License at //
|
||||||
|
// //
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||||||
|
// //
|
||||||
|
// Unless required by applicable law or agreed to in writing, software //
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||||||
|
// See the License for the specific language governing permissions and //
|
||||||
|
// limitations under the License. //
|
||||||
|
// ======================================================================== //
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "node.h"
|
||||||
|
|
||||||
|
namespace oidn {
|
||||||
|
|
||||||
|
// 2x2 nearest-neighbor upsampling node
|
||||||
|
template<int K>
|
||||||
|
class UpsampleNode : public Node
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
std::shared_ptr<memory> src;
|
||||||
|
std::shared_ptr<memory> dst;
|
||||||
|
|
||||||
|
public:
|
||||||
|
UpsampleNode(const std::shared_ptr<memory>& src,
|
||||||
|
const std::shared_ptr<memory>& dst)
|
||||||
|
: src(src),
|
||||||
|
dst(dst)
|
||||||
|
{
|
||||||
|
const mkldnn_memory_desc_t& srcDesc = src->get_desc().data;
|
||||||
|
const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data;
|
||||||
|
MAYBE_UNUSED(srcDesc);
|
||||||
|
MAYBE_UNUSED(dstDesc);
|
||||||
|
assert(memory_desc_matches_tag(srcDesc, mkldnn_format_tag_t(BlockedFormat<K>::nChwKc)));
|
||||||
|
assert(memory_desc_matches_tag(dstDesc, mkldnn_format_tag_t(BlockedFormat<K>::nChwKc)));
|
||||||
|
assert(srcDesc.ndims == 4);
|
||||||
|
assert(dstDesc.ndims == 4);
|
||||||
|
assert(srcDesc.data_type == memory::data_type::f32);
|
||||||
|
assert(dstDesc.data_type == memory::data_type::f32);
|
||||||
|
assert(srcDesc.dims[0] == 1);
|
||||||
|
assert(dstDesc.dims[0] == 1);
|
||||||
|
// 2x2 upsampling
|
||||||
|
assert(dstDesc.dims[2] == srcDesc.dims[2] * 2);
|
||||||
|
assert(dstDesc.dims[3] == srcDesc.dims[3] * 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
void execute(stream& sm) override
|
||||||
|
{
|
||||||
|
const mkldnn_memory_desc_t& srcDesc = src->get_desc().data;
|
||||||
|
|
||||||
|
const float* srcPtr = (float*)src->get_data_handle();
|
||||||
|
float* dstPtr = (float*)dst->get_data_handle();
|
||||||
|
|
||||||
|
const int C = srcDesc.dims[1];
|
||||||
|
const int H = srcDesc.dims[2];
|
||||||
|
const int W = srcDesc.dims[3];
|
||||||
|
const int CK = C / K;
|
||||||
|
|
||||||
|
parallel_nd(CK, H, [&](int ck, int h)
|
||||||
|
{
|
||||||
|
const size_t offset = ck*H*W*K + h*W*K;
|
||||||
|
const float* srcPtr_line = srcPtr + offset;
|
||||||
|
float* dstPtr_line0 = dstPtr + offset * 4;
|
||||||
|
float* dstPtr_line1 = dstPtr_line0 + W*2*K; // next line
|
||||||
|
|
||||||
|
for (int w = 0; w < W; ++w)
|
||||||
|
{
|
||||||
|
#pragma unroll
|
||||||
|
for (int k = 0; k < K; k += 4)
|
||||||
|
{
|
||||||
|
const __m128 m = _mm_load_ps(&srcPtr_line[w*K + k]);
|
||||||
|
|
||||||
|
_mm_stream_ps(&dstPtr_line0[w*2*K + k], m);
|
||||||
|
_mm_stream_ps(&dstPtr_line0[w*2*K+K + k], m);
|
||||||
|
_mm_stream_ps(&dstPtr_line1[w*2*K + k], m);
|
||||||
|
_mm_stream_ps(&dstPtr_line1[w*2*K+K + k], m);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<memory> getDst() const override { return dst; }
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace oidn
|
99
thirdparty/oidn/core/weights_reorder.h
vendored
Normal file
99
thirdparty/oidn/core/weights_reorder.h
vendored
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
// ======================================================================== //
|
||||||
|
// Copyright 2009-2019 Intel Corporation //
|
||||||
|
// //
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||||||
|
// you may not use this file except in compliance with the License. //
|
||||||
|
// You may obtain a copy of the License at //
|
||||||
|
// //
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||||||
|
// //
|
||||||
|
// Unless required by applicable law or agreed to in writing, software //
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||||||
|
// See the License for the specific language governing permissions and //
|
||||||
|
// limitations under the License. //
|
||||||
|
// ======================================================================== //
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "node.h"
|
||||||
|
|
||||||
|
namespace oidn {
|
||||||
|
|
||||||
|
// Reorders weights from oihw to padded oihw format
|
||||||
|
template<int K>
|
||||||
|
class WeightsReorderNode : public Node
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
std::shared_ptr<memory> src;
|
||||||
|
std::shared_ptr<memory> dst;
|
||||||
|
|
||||||
|
public:
|
||||||
|
WeightsReorderNode(const std::shared_ptr<memory>& src,
|
||||||
|
const std::shared_ptr<memory>& dst)
|
||||||
|
: src(src),
|
||||||
|
dst(dst)
|
||||||
|
{
|
||||||
|
const mkldnn_memory_desc_t& srcDesc = src->get_desc().data;
|
||||||
|
const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data;
|
||||||
|
MAYBE_UNUSED(srcDesc);
|
||||||
|
MAYBE_UNUSED(dstDesc);
|
||||||
|
assert(memory_desc_matches_tag(srcDesc, mkldnn_format_tag_t(memory::format_tag::oihw)));
|
||||||
|
assert(memory_desc_matches_tag(dstDesc, mkldnn_format_tag_t(memory::format_tag::oihw)));
|
||||||
|
assert(srcDesc.ndims == 4);
|
||||||
|
assert(dstDesc.ndims == 4);
|
||||||
|
assert(srcDesc.data_type == memory::data_type::f32);
|
||||||
|
assert(dstDesc.data_type == memory::data_type::f32);
|
||||||
|
assert(getPadded<K>(srcDesc.dims[0]) == dstDesc.dims[0]); // OC
|
||||||
|
assert(getPadded<K>(srcDesc.dims[1]) == dstDesc.dims[1]); // IC
|
||||||
|
assert(srcDesc.dims[2] == dstDesc.dims[2]);
|
||||||
|
assert(srcDesc.dims[3] == dstDesc.dims[3]);
|
||||||
|
}
|
||||||
|
|
||||||
|
void execute(stream& sm) override
|
||||||
|
{
|
||||||
|
const mkldnn_memory_desc_t& srcDesc = src->get_desc().data;
|
||||||
|
const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data;
|
||||||
|
|
||||||
|
const float* srcPtr = (float*)src->get_data_handle();
|
||||||
|
float* dstPtr = (float*)dst->get_data_handle();
|
||||||
|
|
||||||
|
const int OC1 = srcDesc.dims[0];
|
||||||
|
const int OC2 = dstDesc.dims[0];
|
||||||
|
const int IC1 = srcDesc.dims[1];
|
||||||
|
const int IC2 = dstDesc.dims[1];
|
||||||
|
const int H = dstDesc.dims[2];
|
||||||
|
const int W = dstDesc.dims[3];
|
||||||
|
|
||||||
|
for (int oc = 0; oc < OC2; ++oc)
|
||||||
|
{
|
||||||
|
for (int ic = 0; ic < IC2; ++ic)
|
||||||
|
{
|
||||||
|
for (int h = 0; h < H; ++h)
|
||||||
|
{
|
||||||
|
for (int w = 0; w < W; ++w)
|
||||||
|
{
|
||||||
|
// Output is in oihw format
|
||||||
|
float* dstPtr_c = dstPtr + oc*IC2*H*W + ic*H*W + h*W + w;
|
||||||
|
|
||||||
|
if (oc < OC1 && ic < IC1)
|
||||||
|
{
|
||||||
|
// Input is in oihw format
|
||||||
|
const float* srcPtr_c = srcPtr + oc*IC1*H*W + ic*H*W + h*W + w;
|
||||||
|
*dstPtr_c = *srcPtr_c;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// padding
|
||||||
|
*dstPtr_c = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<memory> getDst() const override { return dst; }
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace oidn
|
214
thirdparty/oidn/include/OpenImageDenoise/oidn.h
vendored
Normal file
214
thirdparty/oidn/include/OpenImageDenoise/oidn.h
vendored
Normal file
@ -0,0 +1,214 @@
|
|||||||
|
// ======================================================================== //
|
||||||
|
// Copyright 2009-2019 Intel Corporation //
|
||||||
|
// //
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||||||
|
// you may not use this file except in compliance with the License. //
|
||||||
|
// You may obtain a copy of the License at //
|
||||||
|
// //
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||||||
|
// //
|
||||||
|
// Unless required by applicable law or agreed to in writing, software //
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||||||
|
// See the License for the specific language governing permissions and //
|
||||||
|
// limitations under the License. //
|
||||||
|
// ======================================================================== //
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdbool.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include "version.h"
|
||||||
|
|
||||||
|
#if defined(__cplusplus)
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef OIDN_API
|
||||||
|
#if defined(_WIN32) && !defined(OIDN_STATIC_LIB)
|
||||||
|
# define OIDN_API __declspec(dllimport)
|
||||||
|
#else
|
||||||
|
# define OIDN_API
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// Device
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// Device types
|
||||||
|
typedef enum
|
||||||
|
{
|
||||||
|
OIDN_DEVICE_TYPE_DEFAULT = 0, // select device automatically
|
||||||
|
|
||||||
|
OIDN_DEVICE_TYPE_CPU = 1, // CPU device
|
||||||
|
} OIDNDeviceType;
|
||||||
|
|
||||||
|
// Error codes
|
||||||
|
typedef enum
|
||||||
|
{
|
||||||
|
OIDN_ERROR_NONE = 0, // no error occurred
|
||||||
|
OIDN_ERROR_UNKNOWN = 1, // an unknown error occurred
|
||||||
|
OIDN_ERROR_INVALID_ARGUMENT = 2, // an invalid argument was specified
|
||||||
|
OIDN_ERROR_INVALID_OPERATION = 3, // the operation is not allowed
|
||||||
|
OIDN_ERROR_OUT_OF_MEMORY = 4, // not enough memory to execute the operation
|
||||||
|
OIDN_ERROR_UNSUPPORTED_HARDWARE = 5, // the hardware (e.g. CPU) is not supported
|
||||||
|
OIDN_ERROR_CANCELLED = 6, // the operation was cancelled by the user
|
||||||
|
} OIDNError;
|
||||||
|
|
||||||
|
// Error callback function
|
||||||
|
typedef void (*OIDNErrorFunction)(void* userPtr, OIDNError code, const char* message);
|
||||||
|
|
||||||
|
// Device handle
|
||||||
|
typedef struct OIDNDeviceImpl* OIDNDevice;
|
||||||
|
|
||||||
|
// Creates a new device.
|
||||||
|
OIDN_API OIDNDevice oidnNewDevice(OIDNDeviceType type);
|
||||||
|
|
||||||
|
// Retains the device (increments the reference count).
|
||||||
|
OIDN_API void oidnRetainDevice(OIDNDevice device);
|
||||||
|
|
||||||
|
// Releases the device (decrements the reference count).
|
||||||
|
OIDN_API void oidnReleaseDevice(OIDNDevice device);
|
||||||
|
|
||||||
|
// Sets a boolean parameter of the device.
|
||||||
|
OIDN_API void oidnSetDevice1b(OIDNDevice device, const char* name, bool value);
|
||||||
|
|
||||||
|
// Sets an integer parameter of the device.
|
||||||
|
OIDN_API void oidnSetDevice1i(OIDNDevice device, const char* name, int value);
|
||||||
|
|
||||||
|
// Gets a boolean parameter of the device.
|
||||||
|
OIDN_API bool oidnGetDevice1b(OIDNDevice device, const char* name);
|
||||||
|
|
||||||
|
// Gets an integer parameter of the device (e.g. "version").
|
||||||
|
OIDN_API int oidnGetDevice1i(OIDNDevice device, const char* name);
|
||||||
|
|
||||||
|
// Sets the error callback function of the device.
|
||||||
|
OIDN_API void oidnSetDeviceErrorFunction(OIDNDevice device, OIDNErrorFunction func, void* userPtr);
|
||||||
|
|
||||||
|
// Returns the first unqueried error code stored in the device for the current
|
||||||
|
// thread, optionally also returning a string message (if not NULL), and clears
|
||||||
|
// the stored error. Can be called with a NULL device as well to check why a
|
||||||
|
// device creation failed.
|
||||||
|
OIDN_API OIDNError oidnGetDeviceError(OIDNDevice device, const char** outMessage);
|
||||||
|
|
||||||
|
// Commits all previous changes to the device.
|
||||||
|
// Must be called before first using the device (e.g. creating filters).
|
||||||
|
OIDN_API void oidnCommitDevice(OIDNDevice device);
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// Buffer
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// Formats for images and other data stored in buffers
|
||||||
|
typedef enum
|
||||||
|
{
|
||||||
|
OIDN_FORMAT_UNDEFINED = 0,
|
||||||
|
|
||||||
|
// 32-bit single-precision floating point scalar and vector formats
|
||||||
|
OIDN_FORMAT_FLOAT = 1,
|
||||||
|
OIDN_FORMAT_FLOAT2 = 2,
|
||||||
|
OIDN_FORMAT_FLOAT3 = 3,
|
||||||
|
OIDN_FORMAT_FLOAT4 = 4,
|
||||||
|
} OIDNFormat;
|
||||||
|
|
||||||
|
// Access modes for mapping buffers
|
||||||
|
typedef enum
|
||||||
|
{
|
||||||
|
OIDN_ACCESS_READ = 0, // read-only access
|
||||||
|
OIDN_ACCESS_WRITE = 1, // write-only access
|
||||||
|
OIDN_ACCESS_READ_WRITE = 2, // read and write access
|
||||||
|
OIDN_ACCESS_WRITE_DISCARD = 3, // write-only access, previous contents discarded
|
||||||
|
} OIDNAccess;
|
||||||
|
|
||||||
|
// Buffer handle
|
||||||
|
typedef struct OIDNBufferImpl* OIDNBuffer;
|
||||||
|
|
||||||
|
// Creates a new buffer (data allocated and owned by the device).
|
||||||
|
OIDN_API OIDNBuffer oidnNewBuffer(OIDNDevice device, size_t byteSize);
|
||||||
|
|
||||||
|
// Creates a new shared buffer (data allocated and owned by the user).
|
||||||
|
OIDN_API OIDNBuffer oidnNewSharedBuffer(OIDNDevice device, void* ptr, size_t byteSize);
|
||||||
|
|
||||||
|
// Maps a region of the buffer to host memory.
|
||||||
|
// If byteSize is 0, the maximum available amount of memory will be mapped.
|
||||||
|
OIDN_API void* oidnMapBuffer(OIDNBuffer buffer, OIDNAccess access, size_t byteOffset, size_t byteSize);
|
||||||
|
|
||||||
|
// Unmaps a region of the buffer.
|
||||||
|
// mappedPtr must be a pointer returned by a previous call to oidnMapBuffer.
|
||||||
|
OIDN_API void oidnUnmapBuffer(OIDNBuffer buffer, void* mappedPtr);
|
||||||
|
|
||||||
|
// Retains the buffer (increments the reference count).
|
||||||
|
OIDN_API void oidnRetainBuffer(OIDNBuffer buffer);
|
||||||
|
|
||||||
|
// Releases the buffer (decrements the reference count).
|
||||||
|
OIDN_API void oidnReleaseBuffer(OIDNBuffer buffer);
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// Filter
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// Progress monitor callback function
|
||||||
|
typedef bool (*OIDNProgressMonitorFunction)(void* userPtr, double n);
|
||||||
|
|
||||||
|
// Filter handle
|
||||||
|
typedef struct OIDNFilterImpl* OIDNFilter;
|
||||||
|
|
||||||
|
// Creates a new filter of the specified type (e.g. "RT").
|
||||||
|
OIDN_API OIDNFilter oidnNewFilter(OIDNDevice device, const char* type);
|
||||||
|
|
||||||
|
// Retains the filter (increments the reference count).
|
||||||
|
OIDN_API void oidnRetainFilter(OIDNFilter filter);
|
||||||
|
|
||||||
|
// Releases the filter (decrements the reference count).
|
||||||
|
OIDN_API void oidnReleaseFilter(OIDNFilter filter);
|
||||||
|
|
||||||
|
// Sets an image parameter of the filter (stored in a buffer).
|
||||||
|
// If bytePixelStride and/or byteRowStride are zero, these will be computed automatically.
|
||||||
|
OIDN_API void oidnSetFilterImage(OIDNFilter filter, const char* name,
|
||||||
|
OIDNBuffer buffer, OIDNFormat format,
|
||||||
|
size_t width, size_t height,
|
||||||
|
size_t byteOffset,
|
||||||
|
size_t bytePixelStride, size_t byteRowStride);
|
||||||
|
|
||||||
|
// Sets an image parameter of the filter (owned by the user).
|
||||||
|
// If bytePixelStride and/or byteRowStride are zero, these will be computed automatically.
|
||||||
|
OIDN_API void oidnSetSharedFilterImage(OIDNFilter filter, const char* name,
|
||||||
|
void* ptr, OIDNFormat format,
|
||||||
|
size_t width, size_t height,
|
||||||
|
size_t byteOffset,
|
||||||
|
size_t bytePixelStride, size_t byteRowStride);
|
||||||
|
|
||||||
|
// Sets a boolean parameter of the filter.
|
||||||
|
OIDN_API void oidnSetFilter1b(OIDNFilter filter, const char* name, bool value);
|
||||||
|
|
||||||
|
// Gets a boolean parameter of the filter.
|
||||||
|
OIDN_API bool oidnGetFilter1b(OIDNFilter filter, const char* name);
|
||||||
|
|
||||||
|
// Sets an integer parameter of the filter.
|
||||||
|
OIDN_API void oidnSetFilter1i(OIDNFilter filter, const char* name, int value);
|
||||||
|
|
||||||
|
// Gets an integer parameter of the filter.
|
||||||
|
OIDN_API int oidnGetFilter1i(OIDNFilter filter, const char* name);
|
||||||
|
|
||||||
|
// Sets a float parameter of the filter.
|
||||||
|
OIDN_API void oidnSetFilter1f(OIDNFilter filter, const char* name, float value);
|
||||||
|
|
||||||
|
// Gets a float parameter of the filter.
|
||||||
|
OIDN_API float oidnGetFilter1f(OIDNFilter filter, const char* name);
|
||||||
|
|
||||||
|
// Sets the progress monitor callback function of the filter.
|
||||||
|
OIDN_API void oidnSetFilterProgressMonitorFunction(OIDNFilter filter, OIDNProgressMonitorFunction func, void* userPtr);
|
||||||
|
|
||||||
|
// Commits all previous changes to the filter.
|
||||||
|
// Must be called before first executing the filter.
|
||||||
|
OIDN_API void oidnCommitFilter(OIDNFilter filter);
|
||||||
|
|
||||||
|
// Executes the filter.
|
||||||
|
OIDN_API void oidnExecuteFilter(OIDNFilter filter);
|
||||||
|
|
||||||
|
#if defined(__cplusplus)
|
||||||
|
}
|
||||||
|
#endif
|
468
thirdparty/oidn/include/OpenImageDenoise/oidn.hpp
vendored
Normal file
468
thirdparty/oidn/include/OpenImageDenoise/oidn.hpp
vendored
Normal file
@ -0,0 +1,468 @@
|
|||||||
|
// ======================================================================== //
|
||||||
|
// Copyright 2009-2019 Intel Corporation //
|
||||||
|
// //
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||||||
|
// you may not use this file except in compliance with the License. //
|
||||||
|
// You may obtain a copy of the License at //
|
||||||
|
// //
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||||||
|
// //
|
||||||
|
// Unless required by applicable law or agreed to in writing, software //
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||||||
|
// See the License for the specific language governing permissions and //
|
||||||
|
// limitations under the License. //
|
||||||
|
// ======================================================================== //
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include "oidn.h"
|
||||||
|
|
||||||
|
namespace oidn {
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Buffer
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// Formats for images and other data stored in buffers
|
||||||
|
enum class Format
|
||||||
|
{
|
||||||
|
Undefined = OIDN_FORMAT_UNDEFINED,
|
||||||
|
|
||||||
|
// 32-bit single-precision floating point scalar and vector formats
|
||||||
|
Float = OIDN_FORMAT_FLOAT,
|
||||||
|
Float2 = OIDN_FORMAT_FLOAT2,
|
||||||
|
Float3 = OIDN_FORMAT_FLOAT3,
|
||||||
|
Float4 = OIDN_FORMAT_FLOAT4,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Access modes for mapping buffers
|
||||||
|
enum class Access
|
||||||
|
{
|
||||||
|
Read = OIDN_ACCESS_READ, // read-only access
|
||||||
|
Write = OIDN_ACCESS_WRITE, // write-only access
|
||||||
|
ReadWrite = OIDN_ACCESS_READ_WRITE, // read and write access
|
||||||
|
WriteDiscard = OIDN_ACCESS_WRITE_DISCARD, // write-only access, previous contents discarded
|
||||||
|
};
|
||||||
|
|
||||||
|
// Buffer object with automatic reference counting
|
||||||
|
class BufferRef
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
OIDNBuffer handle;
|
||||||
|
|
||||||
|
public:
|
||||||
|
BufferRef() : handle(nullptr) {}
|
||||||
|
BufferRef(OIDNBuffer handle) : handle(handle) {}
|
||||||
|
|
||||||
|
BufferRef(const BufferRef& other) : handle(other.handle)
|
||||||
|
{
|
||||||
|
if (handle)
|
||||||
|
oidnRetainBuffer(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
BufferRef(BufferRef&& other) : handle(other.handle)
|
||||||
|
{
|
||||||
|
other.handle = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
BufferRef& operator =(const BufferRef& other)
|
||||||
|
{
|
||||||
|
if (&other != this)
|
||||||
|
{
|
||||||
|
if (other.handle)
|
||||||
|
oidnRetainBuffer(other.handle);
|
||||||
|
if (handle)
|
||||||
|
oidnReleaseBuffer(handle);
|
||||||
|
handle = other.handle;
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
BufferRef& operator =(BufferRef&& other)
|
||||||
|
{
|
||||||
|
std::swap(handle, other.handle);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
BufferRef& operator =(OIDNBuffer other)
|
||||||
|
{
|
||||||
|
if (other)
|
||||||
|
oidnRetainBuffer(other);
|
||||||
|
if (handle)
|
||||||
|
oidnReleaseBuffer(handle);
|
||||||
|
handle = other;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
~BufferRef()
|
||||||
|
{
|
||||||
|
if (handle)
|
||||||
|
oidnReleaseBuffer(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
OIDNBuffer getHandle() const
|
||||||
|
{
|
||||||
|
return handle;
|
||||||
|
}
|
||||||
|
|
||||||
|
operator bool() const
|
||||||
|
{
|
||||||
|
return handle != nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Maps a region of the buffer to host memory.
|
||||||
|
// If byteSize is 0, the maximum available amount of memory will be mapped.
|
||||||
|
void* map(Access access = Access::ReadWrite, size_t byteOffset = 0, size_t byteSize = 0)
|
||||||
|
{
|
||||||
|
return oidnMapBuffer(handle, (OIDNAccess)access, byteOffset, byteSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmaps a region of the buffer.
|
||||||
|
// mappedPtr must be a pointer returned by a previous call to map.
|
||||||
|
void unmap(void* mappedPtr)
|
||||||
|
{
|
||||||
|
oidnUnmapBuffer(handle, mappedPtr);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Filter
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// Progress monitor callback function
|
||||||
|
typedef bool (*ProgressMonitorFunction)(void* userPtr, double n);
|
||||||
|
|
||||||
|
// Filter object with automatic reference counting
|
||||||
|
class FilterRef
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
OIDNFilter handle;
|
||||||
|
|
||||||
|
public:
|
||||||
|
FilterRef() : handle(nullptr) {}
|
||||||
|
FilterRef(OIDNFilter handle) : handle(handle) {}
|
||||||
|
|
||||||
|
FilterRef(const FilterRef& other) : handle(other.handle)
|
||||||
|
{
|
||||||
|
if (handle)
|
||||||
|
oidnRetainFilter(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
FilterRef(FilterRef&& other) : handle(other.handle)
|
||||||
|
{
|
||||||
|
other.handle = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
FilterRef& operator =(const FilterRef& other)
|
||||||
|
{
|
||||||
|
if (&other != this)
|
||||||
|
{
|
||||||
|
if (other.handle)
|
||||||
|
oidnRetainFilter(other.handle);
|
||||||
|
if (handle)
|
||||||
|
oidnReleaseFilter(handle);
|
||||||
|
handle = other.handle;
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
FilterRef& operator =(FilterRef&& other)
|
||||||
|
{
|
||||||
|
std::swap(handle, other.handle);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
FilterRef& operator =(OIDNFilter other)
|
||||||
|
{
|
||||||
|
if (other)
|
||||||
|
oidnRetainFilter(other);
|
||||||
|
if (handle)
|
||||||
|
oidnReleaseFilter(handle);
|
||||||
|
handle = other;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
~FilterRef()
|
||||||
|
{
|
||||||
|
if (handle)
|
||||||
|
oidnReleaseFilter(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
OIDNFilter getHandle() const
|
||||||
|
{
|
||||||
|
return handle;
|
||||||
|
}
|
||||||
|
|
||||||
|
operator bool() const
|
||||||
|
{
|
||||||
|
return handle != nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sets an image parameter of the filter (stored in a buffer).
|
||||||
|
void setImage(const char* name,
|
||||||
|
const BufferRef& buffer, Format format,
|
||||||
|
size_t width, size_t height,
|
||||||
|
size_t byteOffset = 0,
|
||||||
|
size_t bytePixelStride = 0, size_t byteRowStride = 0)
|
||||||
|
{
|
||||||
|
oidnSetFilterImage(handle, name,
|
||||||
|
buffer.getHandle(), (OIDNFormat)format,
|
||||||
|
width, height,
|
||||||
|
byteOffset,
|
||||||
|
bytePixelStride, byteRowStride);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sets an image parameter of the filter (owned by the user).
|
||||||
|
void setImage(const char* name,
|
||||||
|
void* ptr, Format format,
|
||||||
|
size_t width, size_t height,
|
||||||
|
size_t byteOffset = 0,
|
||||||
|
size_t bytePixelStride = 0, size_t byteRowStride = 0)
|
||||||
|
{
|
||||||
|
oidnSetSharedFilterImage(handle, name,
|
||||||
|
ptr, (OIDNFormat)format,
|
||||||
|
width, height,
|
||||||
|
byteOffset,
|
||||||
|
bytePixelStride, byteRowStride);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sets a boolean parameter of the filter.
|
||||||
|
void set(const char* name, bool value)
|
||||||
|
{
|
||||||
|
oidnSetFilter1b(handle, name, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sets an integer parameter of the filter.
|
||||||
|
void set(const char* name, int value)
|
||||||
|
{
|
||||||
|
oidnSetFilter1i(handle, name, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sets a float parameter of the filter.
|
||||||
|
void set(const char* name, float value)
|
||||||
|
{
|
||||||
|
oidnSetFilter1f(handle, name, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Gets a parameter of the filter.
|
||||||
|
template<typename T>
|
||||||
|
T get(const char* name);
|
||||||
|
|
||||||
|
// Sets the progress monitor callback function of the filter.
|
||||||
|
void setProgressMonitorFunction(ProgressMonitorFunction func, void* userPtr = nullptr)
|
||||||
|
{
|
||||||
|
oidnSetFilterProgressMonitorFunction(handle, (OIDNProgressMonitorFunction)func, userPtr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Commits all previous changes to the filter.
|
||||||
|
void commit()
|
||||||
|
{
|
||||||
|
oidnCommitFilter(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Executes the filter.
|
||||||
|
void execute()
|
||||||
|
{
|
||||||
|
oidnExecuteFilter(handle);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Gets a boolean parameter of the filter.
|
||||||
|
template<>
|
||||||
|
inline bool FilterRef::get(const char* name)
|
||||||
|
{
|
||||||
|
return oidnGetFilter1b(handle, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Gets an integer parameter of the filter.
|
||||||
|
template<>
|
||||||
|
inline int FilterRef::get(const char* name)
|
||||||
|
{
|
||||||
|
return oidnGetFilter1i(handle, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Gets a float parameter of the filter.
|
||||||
|
template<>
|
||||||
|
inline float FilterRef::get(const char* name)
|
||||||
|
{
|
||||||
|
return oidnGetFilter1f(handle, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Device
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// Device types
|
||||||
|
enum class DeviceType
|
||||||
|
{
|
||||||
|
Default = OIDN_DEVICE_TYPE_DEFAULT, // select device automatically
|
||||||
|
|
||||||
|
CPU = OIDN_DEVICE_TYPE_CPU, // CPU device
|
||||||
|
};
|
||||||
|
|
||||||
|
// Error codes
|
||||||
|
enum class Error
|
||||||
|
{
|
||||||
|
None = OIDN_ERROR_NONE, // no error occurred
|
||||||
|
Unknown = OIDN_ERROR_UNKNOWN, // an unknown error occurred
|
||||||
|
InvalidArgument = OIDN_ERROR_INVALID_ARGUMENT, // an invalid argument was specified
|
||||||
|
InvalidOperation = OIDN_ERROR_INVALID_OPERATION, // the operation is not allowed
|
||||||
|
OutOfMemory = OIDN_ERROR_OUT_OF_MEMORY, // not enough memory to execute the operation
|
||||||
|
UnsupportedHardware = OIDN_ERROR_UNSUPPORTED_HARDWARE, // the hardware (e.g. CPU) is not supported
|
||||||
|
Cancelled = OIDN_ERROR_CANCELLED, // the operation was cancelled by the user
|
||||||
|
};
|
||||||
|
|
||||||
|
// Error callback function
|
||||||
|
typedef void (*ErrorFunction)(void* userPtr, Error code, const char* message);
|
||||||
|
|
||||||
|
// Device object with automatic reference counting
|
||||||
|
class DeviceRef
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
OIDNDevice handle;
|
||||||
|
|
||||||
|
public:
|
||||||
|
DeviceRef() : handle(nullptr) {}
|
||||||
|
DeviceRef(OIDNDevice handle) : handle(handle) {}
|
||||||
|
|
||||||
|
DeviceRef(const DeviceRef& other) : handle(other.handle)
|
||||||
|
{
|
||||||
|
if (handle)
|
||||||
|
oidnRetainDevice(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
DeviceRef(DeviceRef&& other) : handle(other.handle)
|
||||||
|
{
|
||||||
|
other.handle = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
DeviceRef& operator =(const DeviceRef& other)
|
||||||
|
{
|
||||||
|
if (&other != this)
|
||||||
|
{
|
||||||
|
if (other.handle)
|
||||||
|
oidnRetainDevice(other.handle);
|
||||||
|
if (handle)
|
||||||
|
oidnReleaseDevice(handle);
|
||||||
|
handle = other.handle;
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
DeviceRef& operator =(DeviceRef&& other)
|
||||||
|
{
|
||||||
|
std::swap(handle, other.handle);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
DeviceRef& operator =(OIDNDevice other)
|
||||||
|
{
|
||||||
|
if (other)
|
||||||
|
oidnRetainDevice(other);
|
||||||
|
if (handle)
|
||||||
|
oidnReleaseDevice(handle);
|
||||||
|
handle = other;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
~DeviceRef()
|
||||||
|
{
|
||||||
|
if (handle)
|
||||||
|
oidnReleaseDevice(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
OIDNDevice getHandle() const
|
||||||
|
{
|
||||||
|
return handle;
|
||||||
|
}
|
||||||
|
|
||||||
|
operator bool() const
|
||||||
|
{
|
||||||
|
return handle != nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sets a boolean parameter of the device.
|
||||||
|
void set(const char* name, bool value)
|
||||||
|
{
|
||||||
|
oidnSetDevice1b(handle, name, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sets an integer parameter of the device.
|
||||||
|
void set(const char* name, int value)
|
||||||
|
{
|
||||||
|
oidnSetDevice1i(handle, name, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Gets a parameter of the device.
|
||||||
|
template<typename T>
|
||||||
|
T get(const char* name);
|
||||||
|
|
||||||
|
// Sets the error callback function of the device.
|
||||||
|
void setErrorFunction(ErrorFunction func, void* userPtr = nullptr)
|
||||||
|
{
|
||||||
|
oidnSetDeviceErrorFunction(handle, (OIDNErrorFunction)func, userPtr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the first unqueried error code and clears the stored error.
|
||||||
|
// Can be called for a null device as well to check why a device creation failed.
|
||||||
|
Error getError()
|
||||||
|
{
|
||||||
|
return (Error)oidnGetDeviceError(handle, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the first unqueried error code and string message, and clears the stored error.
|
||||||
|
// Can be called for a null device as well to check why a device creation failed.
|
||||||
|
Error getError(const char*& outMessage)
|
||||||
|
{
|
||||||
|
return (Error)oidnGetDeviceError(handle, &outMessage);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Commits all previous changes to the device.
|
||||||
|
// Must be called before first using the device (e.g. creating filters).
|
||||||
|
void commit()
|
||||||
|
{
|
||||||
|
oidnCommitDevice(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Creates a new buffer (data allocated and owned by the device).
|
||||||
|
BufferRef newBuffer(size_t byteSize)
|
||||||
|
{
|
||||||
|
return oidnNewBuffer(handle, byteSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Creates a new shared buffer (data allocated and owned by the user).
|
||||||
|
BufferRef newBuffer(void* ptr, size_t byteSize)
|
||||||
|
{
|
||||||
|
return oidnNewSharedBuffer(handle, ptr, byteSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Creates a new filter of the specified type (e.g. "RT").
|
||||||
|
FilterRef newFilter(const char* type)
|
||||||
|
{
|
||||||
|
return oidnNewFilter(handle, type);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Gets a boolean parameter of the device.
|
||||||
|
template<>
|
||||||
|
inline bool DeviceRef::get(const char* name)
|
||||||
|
{
|
||||||
|
return oidnGetDevice1b(handle, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Gets an integer parameter of the device (e.g. "version").
|
||||||
|
template<>
|
||||||
|
inline int DeviceRef::get(const char* name)
|
||||||
|
{
|
||||||
|
return oidnGetDevice1i(handle, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Creates a new device.
|
||||||
|
inline DeviceRef newDevice(DeviceType type = DeviceType::Default)
|
||||||
|
{
|
||||||
|
return DeviceRef(oidnNewDevice((OIDNDeviceType)type));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace oidn
|
23
thirdparty/oidn/include/OpenImageDenoise/version.h
vendored
Normal file
23
thirdparty/oidn/include/OpenImageDenoise/version.h
vendored
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
// ======================================================================== //
|
||||||
|
// Copyright 2009-2019 Intel Corporation //
|
||||||
|
// //
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||||||
|
// you may not use this file except in compliance with the License. //
|
||||||
|
// You may obtain a copy of the License at //
|
||||||
|
// //
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||||||
|
// //
|
||||||
|
// Unless required by applicable law or agreed to in writing, software //
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||||||
|
// See the License for the specific language governing permissions and //
|
||||||
|
// limitations under the License. //
|
||||||
|
// ======================================================================== //
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#define OIDN_VERSION_MAJOR 1
|
||||||
|
#define OIDN_VERSION_MINOR 1
|
||||||
|
#define OIDN_VERSION_PATCH 0
|
||||||
|
#define OIDN_VERSION 10100
|
||||||
|
#define OIDN_VERSION_STRING "1.1.0"
|
214
thirdparty/oidn/mkl-dnn/LICENSE
vendored
Normal file
214
thirdparty/oidn/mkl-dnn/LICENSE
vendored
Normal file
@ -0,0 +1,214 @@
|
|||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "{}"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright {yyyy} {name of copyright owner}
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
|
||||||
|
============================================================================
|
||||||
|
|
||||||
|
Intel MKL-DNN includes components with separate copyright
|
||||||
|
notices and license terms.
|
||||||
|
|
||||||
|
XByak, 3-clause BSD license
|
||||||
|
Copyright (c) 2007 MITSUNARI Shigeo
|
||||||
|
See full copyright notice and license text in src/cpu/xbyak/COPYRIGHT
|
||||||
|
|
||||||
|
gtest, 3-clause BSD license
|
||||||
|
Copyright 2008, Google Inc.
|
||||||
|
See full copyright notice and license text in tests/gtests/gtest/LICENSE
|
1771
thirdparty/oidn/mkl-dnn/include/mkldnn.h
vendored
Normal file
1771
thirdparty/oidn/mkl-dnn/include/mkldnn.h
vendored
Normal file
File diff suppressed because it is too large
Load Diff
2615
thirdparty/oidn/mkl-dnn/include/mkldnn.hpp
vendored
Normal file
2615
thirdparty/oidn/mkl-dnn/include/mkldnn.hpp
vendored
Normal file
File diff suppressed because it is too large
Load Diff
98
thirdparty/oidn/mkl-dnn/include/mkldnn_debug.h
vendored
Normal file
98
thirdparty/oidn/mkl-dnn/include/mkldnn_debug.h
vendored
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2018-2019 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
/* DO NOT EDIT, AUTO-GENERATED */
|
||||||
|
|
||||||
|
#ifndef MKLDNN_DEBUG_H
|
||||||
|
#define MKLDNN_DEBUG_H
|
||||||
|
|
||||||
|
#ifndef DOXYGEN_SHOULD_SKIP_THIS
|
||||||
|
|
||||||
|
/* All symbols shall be internal unless marked as MKLDNN_API */
|
||||||
|
#if defined _WIN32 || defined __CYGWIN__
|
||||||
|
# define MKLDNN_HELPER_DLL_IMPORT __declspec(dllimport)
|
||||||
|
# define MKLDNN_HELPER_DLL_EXPORT __declspec(dllexport)
|
||||||
|
#else
|
||||||
|
# if __GNUC__ >= 4
|
||||||
|
# define MKLDNN_HELPER_DLL_IMPORT __attribute__ ((visibility ("default")))
|
||||||
|
# define MKLDNN_HELPER_DLL_EXPORT __attribute__ ((visibility ("default")))
|
||||||
|
# else
|
||||||
|
# define MKLDNN_HELPER_DLL_IMPORT
|
||||||
|
# define MKLDNN_HELPER_DLL_EXPORT
|
||||||
|
# endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef MKLDNN_DLL
|
||||||
|
# ifdef MKLDNN_DLL_EXPORTS
|
||||||
|
# define MKLDNN_API MKLDNN_HELPER_DLL_EXPORT
|
||||||
|
# else
|
||||||
|
# define MKLDNN_API MKLDNN_HELPER_DLL_IMPORT
|
||||||
|
# endif
|
||||||
|
#else
|
||||||
|
# define MKLDNN_API
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined (__GNUC__)
|
||||||
|
# define MKLDNN_DEPRECATED __attribute__((deprecated))
|
||||||
|
#elif defined(_MSC_VER)
|
||||||
|
# define MKLDNN_DEPRECATED __declspec(deprecated)
|
||||||
|
#else
|
||||||
|
# define MKLDNN_DEPRECATED
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "mkldnn_types.h"
|
||||||
|
#endif /* DOXYGEN_SHOULD_SKIP_THIS */
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
const char MKLDNN_API *mkldnn_status2str(mkldnn_status_t v);
|
||||||
|
const char MKLDNN_API *mkldnn_dt2str(mkldnn_data_type_t v);
|
||||||
|
const char MKLDNN_API *mkldnn_fmt_kind2str(mkldnn_format_kind_t v);
|
||||||
|
const char MKLDNN_API *mkldnn_fmt_tag2str(mkldnn_format_tag_t v);
|
||||||
|
const char MKLDNN_API *mkldnn_prop_kind2str(mkldnn_prop_kind_t v);
|
||||||
|
const char MKLDNN_API *mkldnn_prim_kind2str(mkldnn_primitive_kind_t v);
|
||||||
|
const char MKLDNN_API *mkldnn_alg_kind2str(mkldnn_alg_kind_t v);
|
||||||
|
const char MKLDNN_API *mkldnn_rnn_direction2str(mkldnn_rnn_direction_t v);
|
||||||
|
|
||||||
|
/** Forms a format string for a given memory descriptor.
|
||||||
|
*
|
||||||
|
* The format is defined as: 'dt:[p|o|0]:fmt_kind:fmt:extra'.
|
||||||
|
* Here:
|
||||||
|
* - dt -- data type
|
||||||
|
* - p -- indicates there is non-trivial padding
|
||||||
|
* - o -- indicates there is non-trivial padding offset
|
||||||
|
* - 0 -- indicates there is non-trivial offset0
|
||||||
|
* - fmt_kind -- format kind (blocked, wino, etc...)
|
||||||
|
* - fmt -- extended format string (format_kind specific)
|
||||||
|
* - extra -- shows extra fields (underspecified)
|
||||||
|
*/
|
||||||
|
int MKLDNN_API mkldnn_md2fmt_str(char *fmt_str, size_t fmt_str_len,
|
||||||
|
const mkldnn_memory_desc_t *md);
|
||||||
|
|
||||||
|
/** Forms a dimension string for a given memory descriptor.
|
||||||
|
*
|
||||||
|
* The format is defined as: 'dim0xdim1x...xdimN
|
||||||
|
*/
|
||||||
|
int MKLDNN_API mkldnn_md2dim_str(char *dim_str, size_t dim_str_len,
|
||||||
|
const mkldnn_memory_desc_t *md);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif
|
1415
thirdparty/oidn/mkl-dnn/include/mkldnn_types.h
vendored
Normal file
1415
thirdparty/oidn/mkl-dnn/include/mkldnn_types.h
vendored
Normal file
File diff suppressed because it is too large
Load Diff
32
thirdparty/oidn/mkl-dnn/include/mkldnn_version.h
vendored
Normal file
32
thirdparty/oidn/mkl-dnn/include/mkldnn_version.h
vendored
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2019 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#ifndef MKLDNN_VERSION_H
|
||||||
|
#define MKLDNN_VERSION_H
|
||||||
|
|
||||||
|
/* Major version of MKL-DNN */
|
||||||
|
#define MKLDNN_VERSION_MAJOR 0
|
||||||
|
|
||||||
|
/* Minor version of MKL-DNN */
|
||||||
|
#define MKLDNN_VERSION_MINOR 90
|
||||||
|
|
||||||
|
/* Patch version of MKL-DNN */
|
||||||
|
#define MKLDNN_VERSION_PATCH 0
|
||||||
|
|
||||||
|
/* Git Commit Hash of MKL-DNN */
|
||||||
|
#define MKLDNN_VERSION_HASH "096bda1ca23324879f2df5a129e610e4405f775c"
|
||||||
|
|
||||||
|
#endif
|
32
thirdparty/oidn/mkl-dnn/include/mkldnn_version.h.in
vendored
Normal file
32
thirdparty/oidn/mkl-dnn/include/mkldnn_version.h.in
vendored
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2019 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#ifndef MKLDNN_VERSION_H
|
||||||
|
#define MKLDNN_VERSION_H
|
||||||
|
|
||||||
|
/* Major version of MKL-DNN */
|
||||||
|
#define MKLDNN_VERSION_MAJOR @MKLDNN_VERSION_MAJOR@
|
||||||
|
|
||||||
|
/* Minor version of MKL-DNN */
|
||||||
|
#define MKLDNN_VERSION_MINOR @MKLDNN_VERSION_MINOR@
|
||||||
|
|
||||||
|
/* Patch version of MKL-DNN */
|
||||||
|
#define MKLDNN_VERSION_PATCH @MKLDNN_VERSION_PATCH@
|
||||||
|
|
||||||
|
/* Git Commit Hash of MKL-DNN */
|
||||||
|
#define MKLDNN_VERSION_HASH "@MKLDNN_VERSION_HASH@"
|
||||||
|
|
||||||
|
#endif
|
104
thirdparty/oidn/mkl-dnn/src/common/batch_normalization.cpp
vendored
Normal file
104
thirdparty/oidn/mkl-dnn/src/common/batch_normalization.cpp
vendored
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2016-2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
#include "mkldnn.h"
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "type_helpers.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
using namespace mkldnn::impl;
|
||||||
|
using namespace mkldnn::impl::utils;
|
||||||
|
using namespace mkldnn::impl::status;
|
||||||
|
using namespace mkldnn::impl::prop_kind;
|
||||||
|
using namespace mkldnn::impl::alg_kind;
|
||||||
|
using namespace mkldnn::impl::types;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
status_t bnrm_desc_init(batch_normalization_desc_t *bnrm_desc,
|
||||||
|
prop_kind_t prop_kind, const memory_desc_t *data_desc,
|
||||||
|
const memory_desc_t *diff_data_desc, float epsilon, unsigned flags) {
|
||||||
|
bool args_ok = true
|
||||||
|
&& !any_null(bnrm_desc, data_desc)
|
||||||
|
&& one_of(prop_kind, forward_training, forward_inference,
|
||||||
|
backward_data, backward)
|
||||||
|
&& IMPLICATION(prop_kind & backward, diff_data_desc != nullptr);
|
||||||
|
if (!args_ok) return invalid_arguments;
|
||||||
|
|
||||||
|
auto bd = batch_normalization_desc_t();
|
||||||
|
bd.primitive_kind = primitive_kind::batch_normalization;
|
||||||
|
bd.prop_kind = prop_kind;
|
||||||
|
|
||||||
|
bd.data_desc = *data_desc;
|
||||||
|
bd.diff_data_desc = zero_md();
|
||||||
|
if ( one_of(bd.prop_kind,backward_data, backward) )
|
||||||
|
bd.diff_data_desc = *diff_data_desc;
|
||||||
|
|
||||||
|
dims_t scaleshift_dims = { 2, data_desc->dims[1] };
|
||||||
|
mkldnn_memory_desc_init_by_tag(&bd.data_scaleshift_desc, 2,
|
||||||
|
scaleshift_dims, data_type::f32, mkldnn_nc);
|
||||||
|
bd.diff_data_scaleshift_desc = zero_md();
|
||||||
|
if (bd.prop_kind == backward) {
|
||||||
|
bd.diff_data_scaleshift_desc = bd.data_scaleshift_desc;
|
||||||
|
}
|
||||||
|
|
||||||
|
dims_t stats_dims = { data_desc->dims[1] };
|
||||||
|
mkldnn_memory_desc_init_by_tag(&bd.mean_desc, 1, stats_dims,
|
||||||
|
data_type::f32, mkldnn_x);
|
||||||
|
bd.variance_desc = bd.mean_desc;
|
||||||
|
bd.batch_norm_epsilon = epsilon;
|
||||||
|
|
||||||
|
unsigned bnorm_flags =
|
||||||
|
mkldnn_use_global_stats | mkldnn_use_scaleshift | mkldnn_fuse_bn_relu;
|
||||||
|
if ((~bnorm_flags & flags) != 0) return invalid_arguments;
|
||||||
|
|
||||||
|
bd.flags = flags;
|
||||||
|
|
||||||
|
bool consistency = true
|
||||||
|
&& utils::one_of(bd.data_desc.ndims, 2, 4, 5);
|
||||||
|
if (bd.prop_kind == backward_data)
|
||||||
|
consistency = consistency
|
||||||
|
&& utils::one_of(bd.diff_data_desc.ndims, 2, 4, 5)
|
||||||
|
&& array_cmp(bd.diff_data_desc.dims, bd.data_desc.dims,
|
||||||
|
bd.diff_data_desc.ndims);
|
||||||
|
if (!consistency) return invalid_arguments;
|
||||||
|
|
||||||
|
*bnrm_desc = bd;
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_batch_normalization_forward_desc_init(
|
||||||
|
batch_normalization_desc_t *bnrm_desc, prop_kind_t prop_kind,
|
||||||
|
const memory_desc_t *data_desc, float epsilon, unsigned flags) {
|
||||||
|
if (!one_of(prop_kind, forward_training, forward_inference))
|
||||||
|
return invalid_arguments;
|
||||||
|
return bnrm_desc_init(bnrm_desc, prop_kind, data_desc, nullptr,
|
||||||
|
epsilon, flags);
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_batch_normalization_backward_desc_init(
|
||||||
|
batch_normalization_desc_t *bnrm_desc, prop_kind_t prop_kind,
|
||||||
|
const memory_desc_t *diff_data_desc, const memory_desc_t *data_desc,
|
||||||
|
float epsilon, unsigned flags) {
|
||||||
|
if (!one_of(prop_kind, backward, backward_data))
|
||||||
|
return invalid_arguments;
|
||||||
|
return bnrm_desc_init(bnrm_desc, prop_kind, data_desc, diff_data_desc,
|
||||||
|
epsilon, flags);
|
||||||
|
}
|
||||||
|
|
||||||
|
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
240
thirdparty/oidn/mkl-dnn/src/common/batch_normalization_pd.hpp
vendored
Normal file
240
thirdparty/oidn/mkl-dnn/src/common/batch_normalization_pd.hpp
vendored
Normal file
@ -0,0 +1,240 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2016-2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#ifndef BATCH_NORMALIZATION_PD_HPP
|
||||||
|
#define BATCH_NORMALIZATION_PD_HPP
|
||||||
|
|
||||||
|
#include "mkldnn.h"
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "primitive_desc.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
namespace mkldnn {
|
||||||
|
namespace impl {
|
||||||
|
|
||||||
|
struct batch_normalization_fwd_pd_t;
|
||||||
|
|
||||||
|
struct batch_normalization_pd_t: public primitive_desc_t {
|
||||||
|
static constexpr auto base_pkind = primitive_kind::batch_normalization;
|
||||||
|
|
||||||
|
batch_normalization_pd_t(engine_t *engine,
|
||||||
|
const batch_normalization_desc_t *adesc,
|
||||||
|
const primitive_attr_t *attr,
|
||||||
|
const batch_normalization_fwd_pd_t *hint_fwd_pd)
|
||||||
|
: primitive_desc_t(engine, attr, base_pkind)
|
||||||
|
, desc_(*adesc)
|
||||||
|
, hint_fwd_pd_(hint_fwd_pd)
|
||||||
|
, data_md_(desc_.data_desc)
|
||||||
|
, stat_md_(desc_.mean_desc)
|
||||||
|
, scaleshift_md_(desc_.data_scaleshift_desc)
|
||||||
|
, ws_md_()
|
||||||
|
{}
|
||||||
|
|
||||||
|
const batch_normalization_desc_t *desc() const { return &desc_; }
|
||||||
|
virtual const op_desc_t *op_desc() const override
|
||||||
|
{ return reinterpret_cast<const op_desc_t *>(this->desc()); }
|
||||||
|
virtual void init_info() override { impl::init_info(this, this->info_); }
|
||||||
|
|
||||||
|
virtual status_t query(query_t what, int idx, void *result) const override {
|
||||||
|
switch (what) {
|
||||||
|
case query::batch_normalization_d:
|
||||||
|
*(const batch_normalization_desc_t**)result = desc(); break;
|
||||||
|
default: return primitive_desc_t::query(what, idx, result);
|
||||||
|
}
|
||||||
|
return status::success;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* common batch_normalization aux functions */
|
||||||
|
|
||||||
|
dim_t MB() const { return data_desc().dims[0]; }
|
||||||
|
dim_t C() const { return data_desc().dims[1]; }
|
||||||
|
dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; }
|
||||||
|
dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; }
|
||||||
|
dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; }
|
||||||
|
|
||||||
|
int ndims() const { return desc_.data_desc.ndims; }
|
||||||
|
|
||||||
|
bool stats_is_src() const { return desc_.flags & mkldnn_use_global_stats; }
|
||||||
|
bool use_scaleshift() const { return desc_.flags & mkldnn_use_scaleshift; }
|
||||||
|
bool use_global_stats() const
|
||||||
|
{ return desc_.flags & mkldnn_use_global_stats; }
|
||||||
|
bool fuse_bn_relu() const { return desc_.flags & mkldnn_fuse_bn_relu; }
|
||||||
|
bool with_relu_post_op() const {
|
||||||
|
const auto &p = this->attr()->post_ops_;
|
||||||
|
return p.len_ == 1 && p.entry_[0].is_relu(true, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_fwd() const {
|
||||||
|
return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
|
||||||
|
prop_kind::forward_inference);
|
||||||
|
}
|
||||||
|
bool is_bwd() const { return !this->is_fwd(); }
|
||||||
|
bool is_training() const
|
||||||
|
{ return desc_.prop_kind == prop_kind::forward_training; }
|
||||||
|
|
||||||
|
bool has_zero_dim_memory() const
|
||||||
|
{ return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
batch_normalization_desc_t desc_;
|
||||||
|
const batch_normalization_fwd_pd_t *hint_fwd_pd_;
|
||||||
|
|
||||||
|
memory_desc_t data_md_;
|
||||||
|
memory_desc_t stat_md_;
|
||||||
|
memory_desc_t scaleshift_md_;
|
||||||
|
|
||||||
|
memory_desc_t ws_md_;
|
||||||
|
|
||||||
|
void init_default_ws(size_t bits_per_element) {
|
||||||
|
const auto data_mdw = memory_desc_wrapper(data_md_);
|
||||||
|
|
||||||
|
const dim_t data_nelems = data_mdw.nelems(true);
|
||||||
|
const dim_t bits_per_byte = 8;
|
||||||
|
const dims_t ws_sz = { (dim_t)utils::div_up(
|
||||||
|
data_nelems * bits_per_element, bits_per_byte) };
|
||||||
|
mkldnn_memory_desc_init_by_tag(&ws_md_, 1, ws_sz, impl::data_type::u8,
|
||||||
|
format_tag::x);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const memory_desc_t &data_desc() const { return desc_.data_desc; }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct batch_normalization_fwd_pd_t: public batch_normalization_pd_t {
|
||||||
|
typedef batch_normalization_fwd_pd_t base_class;
|
||||||
|
typedef batch_normalization_fwd_pd_t hint_class;
|
||||||
|
|
||||||
|
batch_normalization_fwd_pd_t(engine_t *engine,
|
||||||
|
const batch_normalization_desc_t *adesc,
|
||||||
|
const primitive_attr_t *attr,
|
||||||
|
const batch_normalization_fwd_pd_t *hint_fwd_pd)
|
||||||
|
: batch_normalization_pd_t(engine, adesc, attr, hint_fwd_pd)
|
||||||
|
{}
|
||||||
|
|
||||||
|
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
||||||
|
if (arg == MKLDNN_ARG_SRC) return arg_usage_t::input;
|
||||||
|
if (arg == MKLDNN_ARG_DST) return arg_usage_t::output;
|
||||||
|
|
||||||
|
if (utils::one_of(arg, MKLDNN_ARG_MEAN, MKLDNN_ARG_VARIANCE)) {
|
||||||
|
if (stats_is_src()) return arg_usage_t::input;
|
||||||
|
if (!stats_is_src() && is_training()) return arg_usage_t::output;
|
||||||
|
return arg_usage_t::unused;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_SCALE_SHIFT && use_scaleshift())
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_WORKSPACE && is_training() && fuse_bn_relu())
|
||||||
|
return arg_usage_t::output;
|
||||||
|
|
||||||
|
return primitive_desc_t::arg_usage(arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual const memory_desc_t *src_md(int index = 0) const override {
|
||||||
|
if (index == 0) return &data_md_;
|
||||||
|
if (stats_is_src() && (index == 1 || index == 2)) return &stat_md_;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual const memory_desc_t *dst_md(int index = 0) const override {
|
||||||
|
if (index == 0) return &data_md_;
|
||||||
|
if (!stats_is_src() && is_training() && (index == 1 || index == 2))
|
||||||
|
return &stat_md_;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual const memory_desc_t *weights_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &scaleshift_md_ : nullptr; }
|
||||||
|
|
||||||
|
virtual const memory_desc_t *workspace_md(int index = 0) const override
|
||||||
|
{ return index == 0 && is_training() && fuse_bn_relu() ? &ws_md_ : nullptr; }
|
||||||
|
|
||||||
|
const memory_desc_t *stat_md() const
|
||||||
|
{ return stats_is_src() ? src_md(1) : dst_md(1); }
|
||||||
|
|
||||||
|
virtual int n_inputs() const override
|
||||||
|
{ return 1 + 2 * stats_is_src() + use_scaleshift(); }
|
||||||
|
virtual int n_outputs() const override
|
||||||
|
{ return 1 + (fuse_bn_relu() + 2 * (!stats_is_src())) * is_training(); }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct batch_normalization_bwd_pd_t: public batch_normalization_pd_t {
|
||||||
|
typedef batch_normalization_bwd_pd_t base_class;
|
||||||
|
typedef batch_normalization_fwd_pd_t hint_class;
|
||||||
|
|
||||||
|
batch_normalization_bwd_pd_t(engine_t *engine,
|
||||||
|
const batch_normalization_desc_t *adesc,
|
||||||
|
const primitive_attr_t *attr,
|
||||||
|
const batch_normalization_fwd_pd_t *hint_fwd_pd)
|
||||||
|
: batch_normalization_pd_t(engine, adesc, attr, hint_fwd_pd)
|
||||||
|
, diff_data_md_(desc_.diff_data_desc)
|
||||||
|
, diff_scaleshift_md_(desc_.diff_data_scaleshift_desc)
|
||||||
|
{}
|
||||||
|
|
||||||
|
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
||||||
|
if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_MEAN,
|
||||||
|
MKLDNN_ARG_VARIANCE, MKLDNN_ARG_DIFF_DST))
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_SCALE_SHIFT && use_scaleshift())
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_WORKSPACE && fuse_bn_relu())
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_DIFF_SRC)
|
||||||
|
return arg_usage_t::output;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_DIFF_SCALE_SHIFT && use_scaleshift())
|
||||||
|
return arg_usage_t::output;
|
||||||
|
|
||||||
|
return primitive_desc_t::arg_usage(arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual const memory_desc_t *src_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &data_md_ : index <= 2 ? &stat_md_ : nullptr; }
|
||||||
|
virtual const memory_desc_t *diff_dst_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &diff_data_md_ : nullptr; }
|
||||||
|
virtual const memory_desc_t *diff_src_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &diff_data_md_ : nullptr; }
|
||||||
|
|
||||||
|
virtual const memory_desc_t *weights_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &scaleshift_md_ : nullptr; }
|
||||||
|
virtual const memory_desc_t *diff_weights_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &diff_scaleshift_md_ : nullptr; }
|
||||||
|
|
||||||
|
virtual const memory_desc_t *workspace_md(int index = 0) const override
|
||||||
|
{ return index == 0 && fuse_bn_relu() ? &ws_md_ : nullptr; }
|
||||||
|
|
||||||
|
const memory_desc_t *stat_md() const { return src_md(1); }
|
||||||
|
|
||||||
|
virtual int n_inputs() const override
|
||||||
|
{ return 4 + use_scaleshift() + fuse_bn_relu(); }
|
||||||
|
virtual int n_outputs() const override
|
||||||
|
{ return 1 + (desc_.prop_kind == prop_kind::backward); }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
memory_desc_t diff_data_md_;
|
||||||
|
memory_desc_t diff_scaleshift_md_;
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
550
thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp
vendored
Normal file
550
thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp
vendored
Normal file
@ -0,0 +1,550 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2016-2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#ifndef TYPE_MAPPING_HPP
|
||||||
|
#define TYPE_MAPPING_HPP
|
||||||
|
|
||||||
|
#include "mkldnn_types.h"
|
||||||
|
|
||||||
|
namespace mkldnn {
|
||||||
|
namespace impl {
|
||||||
|
|
||||||
|
// TODO: autogenerate this
|
||||||
|
|
||||||
|
using dim_t = mkldnn_dim_t;
|
||||||
|
using dims_t = mkldnn_dims_t;
|
||||||
|
using stride_t = mkldnn_dim_t;
|
||||||
|
using strides_t = mkldnn_strides_t;
|
||||||
|
|
||||||
|
using status_t = mkldnn_status_t;
|
||||||
|
namespace status {
|
||||||
|
const status_t success = mkldnn_success;
|
||||||
|
const status_t out_of_memory = mkldnn_out_of_memory;
|
||||||
|
const status_t try_again = mkldnn_try_again;
|
||||||
|
const status_t invalid_arguments = mkldnn_invalid_arguments;
|
||||||
|
const status_t not_ready = mkldnn_not_ready;
|
||||||
|
const status_t unimplemented = mkldnn_unimplemented;
|
||||||
|
const status_t iterator_ends = mkldnn_iterator_ends;
|
||||||
|
const status_t runtime_error = mkldnn_runtime_error;
|
||||||
|
const status_t not_required = mkldnn_not_required;
|
||||||
|
}
|
||||||
|
|
||||||
|
using prop_kind_t = mkldnn_prop_kind_t;
|
||||||
|
namespace prop_kind {
|
||||||
|
const prop_kind_t undef = mkldnn_prop_kind_undef;
|
||||||
|
const prop_kind_t forward_training = mkldnn_forward_training;
|
||||||
|
const prop_kind_t forward_inference = mkldnn_forward_inference;
|
||||||
|
const prop_kind_t forward_scoring = mkldnn_forward_scoring;
|
||||||
|
const prop_kind_t forward = mkldnn_forward;
|
||||||
|
const prop_kind_t backward = mkldnn_backward;
|
||||||
|
const prop_kind_t backward_data = mkldnn_backward_data;
|
||||||
|
const prop_kind_t backward_weights = mkldnn_backward_weights;
|
||||||
|
const prop_kind_t backward_bias = mkldnn_backward_bias;
|
||||||
|
}
|
||||||
|
|
||||||
|
using alg_kind_t = mkldnn_alg_kind_t;
|
||||||
|
namespace alg_kind {
|
||||||
|
const alg_kind_t undef = mkldnn_alg_kind_undef;
|
||||||
|
const alg_kind_t convolution_auto = mkldnn_convolution_auto;
|
||||||
|
const alg_kind_t convolution_direct = mkldnn_convolution_direct;
|
||||||
|
const alg_kind_t convolution_winograd = mkldnn_convolution_winograd;
|
||||||
|
const alg_kind_t deconvolution_direct = mkldnn_deconvolution_direct;
|
||||||
|
const alg_kind_t deconvolution_winograd = mkldnn_deconvolution_winograd;
|
||||||
|
const alg_kind_t eltwise_relu = mkldnn_eltwise_relu;
|
||||||
|
const alg_kind_t eltwise_tanh = mkldnn_eltwise_tanh;
|
||||||
|
const alg_kind_t eltwise_elu = mkldnn_eltwise_elu;
|
||||||
|
const alg_kind_t eltwise_square = mkldnn_eltwise_square;
|
||||||
|
const alg_kind_t eltwise_abs = mkldnn_eltwise_abs;
|
||||||
|
const alg_kind_t eltwise_sqrt = mkldnn_eltwise_sqrt;
|
||||||
|
const alg_kind_t eltwise_linear = mkldnn_eltwise_linear;
|
||||||
|
const alg_kind_t eltwise_bounded_relu = mkldnn_eltwise_bounded_relu;
|
||||||
|
const alg_kind_t eltwise_soft_relu = mkldnn_eltwise_soft_relu;
|
||||||
|
const alg_kind_t eltwise_logistic = mkldnn_eltwise_logistic;
|
||||||
|
const alg_kind_t pooling_max = mkldnn_pooling_max;
|
||||||
|
const alg_kind_t pooling_avg = mkldnn_pooling_avg;
|
||||||
|
const alg_kind_t pooling_avg_include_padding = mkldnn_pooling_avg_include_padding;
|
||||||
|
const alg_kind_t pooling_avg_exclude_padding = mkldnn_pooling_avg_exclude_padding;
|
||||||
|
const alg_kind_t lrn_across_channels = mkldnn_lrn_across_channels;
|
||||||
|
const alg_kind_t lrn_within_channel = mkldnn_lrn_within_channel;
|
||||||
|
const alg_kind_t vanilla_rnn = mkldnn_vanilla_rnn;
|
||||||
|
const alg_kind_t vanilla_lstm = mkldnn_vanilla_lstm;
|
||||||
|
const alg_kind_t vanilla_gru = mkldnn_vanilla_gru;
|
||||||
|
const alg_kind_t gru_linear_before_reset = mkldnn_gru_linear_before_reset;
|
||||||
|
}
|
||||||
|
|
||||||
|
using data_type_t = mkldnn_data_type_t;
|
||||||
|
namespace data_type {
|
||||||
|
const data_type_t undef = mkldnn_data_type_undef;
|
||||||
|
const data_type_t f32 = mkldnn_f32;
|
||||||
|
const data_type_t s32 = mkldnn_s32;
|
||||||
|
const data_type_t s8 = mkldnn_s8;
|
||||||
|
const data_type_t u8 = mkldnn_u8;
|
||||||
|
}
|
||||||
|
|
||||||
|
using scratchpad_mode_t = mkldnn_scratchpad_mode_t;
|
||||||
|
namespace scratchpad_mode {
|
||||||
|
const scratchpad_mode_t library = mkldnn_scratchpad_mode_library;
|
||||||
|
const scratchpad_mode_t user = mkldnn_scratchpad_mode_user;
|
||||||
|
}
|
||||||
|
|
||||||
|
using rnn_packed_format_t = mkldnn_rnn_packed_memory_format_t;
|
||||||
|
namespace rnn_packed_format {
|
||||||
|
const rnn_packed_format_t undef = mkldnn_packed_format_undef;
|
||||||
|
const rnn_packed_format_t ldigo_p = mkldnn_ldigo_p;
|
||||||
|
const rnn_packed_format_t ldgoi_p = mkldnn_ldgoi_p;
|
||||||
|
}
|
||||||
|
|
||||||
|
using format_kind_t = mkldnn_format_kind_t;
|
||||||
|
namespace format_kind {
|
||||||
|
const format_kind_t undef = mkldnn_format_kind_undef;
|
||||||
|
const format_kind_t any = mkldnn_format_kind_any;
|
||||||
|
const format_kind_t blocked = mkldnn_blocked;
|
||||||
|
const format_kind_t wino = mkldnn_format_kind_wino;
|
||||||
|
const format_kind_t rnn_packed = mkldnn_format_kind_rnn_packed;
|
||||||
|
}
|
||||||
|
|
||||||
|
using format_tag_t = mkldnn_format_tag_t;
|
||||||
|
namespace format_tag {
|
||||||
|
const format_tag_t undef = mkldnn_format_tag_undef;
|
||||||
|
const format_tag_t any = mkldnn_format_tag_any;
|
||||||
|
const format_tag_t a = mkldnn_a;
|
||||||
|
const format_tag_t ab = mkldnn_ab;
|
||||||
|
const format_tag_t abc = mkldnn_abc;
|
||||||
|
const format_tag_t abcd = mkldnn_abcd;
|
||||||
|
const format_tag_t abcde = mkldnn_abcde;
|
||||||
|
const format_tag_t abcdef = mkldnn_abcdef;
|
||||||
|
const format_tag_t abdec = mkldnn_abdec;
|
||||||
|
const format_tag_t acb = mkldnn_acb;
|
||||||
|
const format_tag_t acbde = mkldnn_acbde;
|
||||||
|
const format_tag_t acdb = mkldnn_acdb;
|
||||||
|
const format_tag_t acdeb = mkldnn_acdeb;
|
||||||
|
const format_tag_t ba = mkldnn_ba;
|
||||||
|
const format_tag_t bac = mkldnn_bac;
|
||||||
|
const format_tag_t bacd = mkldnn_bacd;
|
||||||
|
const format_tag_t bcda = mkldnn_bcda;
|
||||||
|
const format_tag_t cba = mkldnn_cba;
|
||||||
|
const format_tag_t cdba = mkldnn_cdba;
|
||||||
|
const format_tag_t cdeba = mkldnn_cdeba;
|
||||||
|
const format_tag_t decab = mkldnn_decab;
|
||||||
|
const format_tag_t Abc16a = mkldnn_Abc16a;
|
||||||
|
const format_tag_t ABc16a16b = mkldnn_ABc16a16b;
|
||||||
|
const format_tag_t aBc16b = mkldnn_aBc16b;
|
||||||
|
const format_tag_t ABc16b16a = mkldnn_ABc16b16a;
|
||||||
|
const format_tag_t Abc4a = mkldnn_Abc4a;
|
||||||
|
const format_tag_t aBc4b = mkldnn_aBc4b;
|
||||||
|
const format_tag_t ABc4b16a4b = mkldnn_ABc4b16a4b;
|
||||||
|
const format_tag_t ABc4b4a = mkldnn_ABc4b4a;
|
||||||
|
const format_tag_t ABc8a16b2a = mkldnn_ABc8a16b2a;
|
||||||
|
const format_tag_t ABc8a8b = mkldnn_ABc8a8b;
|
||||||
|
const format_tag_t aBc8b = mkldnn_aBc8b;
|
||||||
|
const format_tag_t ABc8b16a2b = mkldnn_ABc8b16a2b;
|
||||||
|
const format_tag_t ABc8b8a = mkldnn_ABc8b8a;
|
||||||
|
const format_tag_t Abcd16a = mkldnn_Abcd16a;
|
||||||
|
const format_tag_t ABcd16a16b = mkldnn_ABcd16a16b;
|
||||||
|
const format_tag_t aBcd16b = mkldnn_aBcd16b;
|
||||||
|
const format_tag_t ABcd16b16a = mkldnn_ABcd16b16a;
|
||||||
|
const format_tag_t aBCd16b16c = mkldnn_aBCd16b16c;
|
||||||
|
const format_tag_t aBCd16c16b = mkldnn_aBCd16c16b;
|
||||||
|
const format_tag_t Abcd4a = mkldnn_Abcd4a;
|
||||||
|
const format_tag_t aBcd4b = mkldnn_aBcd4b;
|
||||||
|
const format_tag_t ABcd4b16a4b = mkldnn_ABcd4b16a4b;
|
||||||
|
const format_tag_t ABcd4b4a = mkldnn_ABcd4b4a;
|
||||||
|
const format_tag_t aBCd4c16b4c = mkldnn_aBCd4c16b4c;
|
||||||
|
const format_tag_t aBCd4c4b = mkldnn_aBCd4c4b;
|
||||||
|
const format_tag_t ABcd8a16b2a = mkldnn_ABcd8a16b2a;
|
||||||
|
const format_tag_t ABcd8a8b = mkldnn_ABcd8a8b;
|
||||||
|
const format_tag_t aBcd8b = mkldnn_aBcd8b;
|
||||||
|
const format_tag_t ABcd8b16a2b = mkldnn_ABcd8b16a2b;
|
||||||
|
const format_tag_t aBCd8b16c2b = mkldnn_aBCd8b16c2b;
|
||||||
|
const format_tag_t ABcd8b8a = mkldnn_ABcd8b8a;
|
||||||
|
const format_tag_t aBCd8b8c = mkldnn_aBCd8b8c;
|
||||||
|
const format_tag_t aBCd8c16b2c = mkldnn_aBCd8c16b2c;
|
||||||
|
const format_tag_t aBCd8c8b = mkldnn_aBCd8c8b;
|
||||||
|
const format_tag_t Abcde16a = mkldnn_Abcde16a;
|
||||||
|
const format_tag_t ABcde16a16b = mkldnn_ABcde16a16b;
|
||||||
|
const format_tag_t aBcde16b = mkldnn_aBcde16b;
|
||||||
|
const format_tag_t ABcde16b16a = mkldnn_ABcde16b16a;
|
||||||
|
const format_tag_t aBCde16b16c = mkldnn_aBCde16b16c;
|
||||||
|
const format_tag_t aBCde16c16b = mkldnn_aBCde16c16b;
|
||||||
|
const format_tag_t aBCde2c8b4c = mkldnn_aBCde2c8b4c;
|
||||||
|
const format_tag_t Abcde4a = mkldnn_Abcde4a;
|
||||||
|
const format_tag_t aBcde4b = mkldnn_aBcde4b;
|
||||||
|
const format_tag_t ABcde4b4a = mkldnn_ABcde4b4a;
|
||||||
|
const format_tag_t aBCde4b4c = mkldnn_aBCde4b4c;
|
||||||
|
const format_tag_t aBCde4c16b4c = mkldnn_aBCde4c16b4c;
|
||||||
|
const format_tag_t aBCde4c4b = mkldnn_aBCde4c4b;
|
||||||
|
const format_tag_t Abcde8a = mkldnn_Abcde8a;
|
||||||
|
const format_tag_t ABcde8a8b = mkldnn_ABcde8a8b;
|
||||||
|
const format_tag_t aBcde8b = mkldnn_aBcde8b;
|
||||||
|
const format_tag_t ABcde8b16a2b = mkldnn_ABcde8b16a2b;
|
||||||
|
const format_tag_t aBCde8b16c2b = mkldnn_aBCde8b16c2b;
|
||||||
|
const format_tag_t ABcde8b8a = mkldnn_ABcde8b8a;
|
||||||
|
const format_tag_t aBCde8b8c = mkldnn_aBCde8b8c;
|
||||||
|
const format_tag_t aBCde8c16b2c = mkldnn_aBCde8c16b2c;
|
||||||
|
const format_tag_t aBCde8c8b = mkldnn_aBCde8c8b;
|
||||||
|
const format_tag_t aBcdef16b = mkldnn_aBcdef16b;
|
||||||
|
const format_tag_t aBCdef16b16c = mkldnn_aBCdef16b16c;
|
||||||
|
const format_tag_t aBCdef16c16b = mkldnn_aBCdef16c16b;
|
||||||
|
const format_tag_t aBcdef4b = mkldnn_aBcdef4b;
|
||||||
|
const format_tag_t aBCdef4c4b = mkldnn_aBCdef4c4b;
|
||||||
|
const format_tag_t aBCdef8b8c = mkldnn_aBCdef8b8c;
|
||||||
|
const format_tag_t aBCdef8c16b2c = mkldnn_aBCdef8c16b2c;
|
||||||
|
const format_tag_t aBCdef8c8b = mkldnn_aBCdef8c8b;
|
||||||
|
const format_tag_t aBdc16b = mkldnn_aBdc16b;
|
||||||
|
const format_tag_t aBdc4b = mkldnn_aBdc4b;
|
||||||
|
const format_tag_t aBdc8b = mkldnn_aBdc8b;
|
||||||
|
const format_tag_t aBdec16b = mkldnn_aBdec16b;
|
||||||
|
const format_tag_t aBdec4b = mkldnn_aBdec4b;
|
||||||
|
const format_tag_t aBdec8b = mkldnn_aBdec8b;
|
||||||
|
const format_tag_t aBdefc16b = mkldnn_aBdefc16b;
|
||||||
|
const format_tag_t aBdefc4b = mkldnn_aBdefc4b;
|
||||||
|
const format_tag_t aBdefc8b = mkldnn_aBdefc8b;
|
||||||
|
const format_tag_t Acb16a = mkldnn_Acb16a;
|
||||||
|
const format_tag_t Acb4a = mkldnn_Acb4a;
|
||||||
|
const format_tag_t Acb8a = mkldnn_Acb8a;
|
||||||
|
const format_tag_t aCBd16b16c = mkldnn_aCBd16b16c;
|
||||||
|
const format_tag_t aCBde16b16c = mkldnn_aCBde16b16c;
|
||||||
|
const format_tag_t Acdb16a = mkldnn_Acdb16a;
|
||||||
|
const format_tag_t Acdb4a = mkldnn_Acdb4a;
|
||||||
|
const format_tag_t Acdb8a = mkldnn_Acdb8a;
|
||||||
|
const format_tag_t Acdeb16a = mkldnn_Acdeb16a;
|
||||||
|
const format_tag_t Acdeb4a = mkldnn_Acdeb4a;
|
||||||
|
const format_tag_t Acdeb8a = mkldnn_Acdeb8a;
|
||||||
|
const format_tag_t BAc16a16b = mkldnn_BAc16a16b;
|
||||||
|
const format_tag_t BAcd16a16b = mkldnn_BAcd16a16b;
|
||||||
|
const format_tag_t last = mkldnn_format_tag_last;
|
||||||
|
|
||||||
|
const format_tag_t x = mkldnn_x;
|
||||||
|
const format_tag_t nc = mkldnn_nc;
|
||||||
|
const format_tag_t cn = mkldnn_cn;
|
||||||
|
const format_tag_t ncw = mkldnn_ncw;
|
||||||
|
const format_tag_t nwc = mkldnn_nwc;
|
||||||
|
const format_tag_t nchw = mkldnn_nchw;
|
||||||
|
const format_tag_t nhwc = mkldnn_nhwc;
|
||||||
|
const format_tag_t chwn = mkldnn_chwn;
|
||||||
|
const format_tag_t ncdhw = mkldnn_ncdhw;
|
||||||
|
const format_tag_t ndhwc = mkldnn_ndhwc;
|
||||||
|
const format_tag_t oi = mkldnn_oi;
|
||||||
|
const format_tag_t io = mkldnn_io;
|
||||||
|
const format_tag_t oiw = mkldnn_oiw;
|
||||||
|
const format_tag_t wio = mkldnn_wio;
|
||||||
|
const format_tag_t oihw = mkldnn_oihw;
|
||||||
|
const format_tag_t hwio = mkldnn_hwio;
|
||||||
|
const format_tag_t ihwo = mkldnn_ihwo;
|
||||||
|
const format_tag_t iohw = mkldnn_iohw;
|
||||||
|
const format_tag_t oidhw = mkldnn_oidhw;
|
||||||
|
const format_tag_t dhwio = mkldnn_dhwio;
|
||||||
|
const format_tag_t goiw = mkldnn_goiw;
|
||||||
|
const format_tag_t goihw = mkldnn_goihw;
|
||||||
|
const format_tag_t hwigo = mkldnn_hwigo;
|
||||||
|
const format_tag_t giohw = mkldnn_giohw;
|
||||||
|
const format_tag_t goidhw = mkldnn_goidhw;
|
||||||
|
const format_tag_t tnc = mkldnn_tnc;
|
||||||
|
const format_tag_t ntc = mkldnn_ntc;
|
||||||
|
const format_tag_t ldsnc = mkldnn_ldsnc;
|
||||||
|
const format_tag_t ldigo = mkldnn_ldigo;
|
||||||
|
const format_tag_t ldgoi = mkldnn_ldgoi;
|
||||||
|
const format_tag_t ldgo = mkldnn_ldgo;
|
||||||
|
const format_tag_t nCdhw16c = mkldnn_nCdhw16c;
|
||||||
|
const format_tag_t nCdhw4c = mkldnn_nCdhw4c;
|
||||||
|
const format_tag_t nCdhw8c = mkldnn_nCdhw8c;
|
||||||
|
const format_tag_t nChw16c = mkldnn_nChw16c;
|
||||||
|
const format_tag_t nChw4c = mkldnn_nChw4c;
|
||||||
|
const format_tag_t nChw8c = mkldnn_nChw8c;
|
||||||
|
const format_tag_t nCw16c = mkldnn_nCw16c;
|
||||||
|
const format_tag_t nCw4c = mkldnn_nCw4c;
|
||||||
|
const format_tag_t nCw8c = mkldnn_nCw8c;
|
||||||
|
const format_tag_t IOw16o16i = mkldnn_IOw16o16i;
|
||||||
|
const format_tag_t OIw16i16o = mkldnn_OIw16i16o;
|
||||||
|
const format_tag_t OIw16o16i = mkldnn_OIw16o16i;
|
||||||
|
const format_tag_t Oiw16o = mkldnn_Oiw16o;
|
||||||
|
const format_tag_t OIw4i16o4i = mkldnn_OIw4i16o4i;
|
||||||
|
const format_tag_t OIw4i4o = mkldnn_OIw4i4o;
|
||||||
|
const format_tag_t Oiw4o = mkldnn_Oiw4o;
|
||||||
|
const format_tag_t OIw8i16o2i = mkldnn_OIw8i16o2i;
|
||||||
|
const format_tag_t OIw8i8o = mkldnn_OIw8i8o;
|
||||||
|
const format_tag_t OIw8o16i2o = mkldnn_OIw8o16i2o;
|
||||||
|
const format_tag_t OIw8o8i = mkldnn_OIw8o8i;
|
||||||
|
const format_tag_t Owi16o = mkldnn_Owi16o;
|
||||||
|
const format_tag_t Owi4o = mkldnn_Owi4o;
|
||||||
|
const format_tag_t Owi8o = mkldnn_Owi8o;
|
||||||
|
const format_tag_t IOhw16o16i = mkldnn_IOhw16o16i;
|
||||||
|
const format_tag_t Ohwi16o = mkldnn_Ohwi16o;
|
||||||
|
const format_tag_t Ohwi4o = mkldnn_Ohwi4o;
|
||||||
|
const format_tag_t Ohwi8o = mkldnn_Ohwi8o;
|
||||||
|
const format_tag_t OIhw16i16o = mkldnn_OIhw16i16o;
|
||||||
|
const format_tag_t OIhw16o16i = mkldnn_OIhw16o16i;
|
||||||
|
const format_tag_t Oihw16o = mkldnn_Oihw16o;
|
||||||
|
const format_tag_t OIhw4i16o4i = mkldnn_OIhw4i16o4i;
|
||||||
|
const format_tag_t OIhw4i4o = mkldnn_OIhw4i4o;
|
||||||
|
const format_tag_t Oihw4o = mkldnn_Oihw4o;
|
||||||
|
const format_tag_t OIhw8i16o2i = mkldnn_OIhw8i16o2i;
|
||||||
|
const format_tag_t OIhw8i8o = mkldnn_OIhw8i8o;
|
||||||
|
const format_tag_t OIhw8o16i2o = mkldnn_OIhw8o16i2o;
|
||||||
|
const format_tag_t OIhw8o8i = mkldnn_OIhw8o8i;
|
||||||
|
const format_tag_t Odhwi16o = mkldnn_Odhwi16o;
|
||||||
|
const format_tag_t Odhwi4o = mkldnn_Odhwi4o;
|
||||||
|
const format_tag_t Odhwi8o = mkldnn_Odhwi8o;
|
||||||
|
const format_tag_t OIdhw16i16o = mkldnn_OIdhw16i16o;
|
||||||
|
const format_tag_t OIdhw16o16i = mkldnn_OIdhw16o16i;
|
||||||
|
const format_tag_t Oidhw16o = mkldnn_Oidhw16o;
|
||||||
|
const format_tag_t OIdhw4i4o = mkldnn_OIdhw4i4o;
|
||||||
|
const format_tag_t Oidhw4o = mkldnn_Oidhw4o;
|
||||||
|
const format_tag_t OIdhw8i16o2i = mkldnn_OIdhw8i16o2i;
|
||||||
|
const format_tag_t OIdhw8i8o = mkldnn_OIdhw8i8o;
|
||||||
|
const format_tag_t OIdhw8o8i = mkldnn_OIdhw8o8i;
|
||||||
|
const format_tag_t gIOw16o16i = mkldnn_gIOw16o16i;
|
||||||
|
const format_tag_t Goiw16g = mkldnn_Goiw16g;
|
||||||
|
const format_tag_t gOIw16i16o = mkldnn_gOIw16i16o;
|
||||||
|
const format_tag_t gOIw16o16i = mkldnn_gOIw16o16i;
|
||||||
|
const format_tag_t gOiw16o = mkldnn_gOiw16o;
|
||||||
|
const format_tag_t gOIw4i16o4i = mkldnn_gOIw4i16o4i;
|
||||||
|
const format_tag_t gOIw4i4o = mkldnn_gOIw4i4o;
|
||||||
|
const format_tag_t gOiw4o = mkldnn_gOiw4o;
|
||||||
|
const format_tag_t gOIw8i16o2i = mkldnn_gOIw8i16o2i;
|
||||||
|
const format_tag_t gOIw8i8o = mkldnn_gOIw8i8o;
|
||||||
|
const format_tag_t gOIw8o16i2o = mkldnn_gOIw8o16i2o;
|
||||||
|
const format_tag_t gOIw8o8i = mkldnn_gOIw8o8i;
|
||||||
|
const format_tag_t gOwi16o = mkldnn_gOwi16o;
|
||||||
|
const format_tag_t gOwi4o = mkldnn_gOwi4o;
|
||||||
|
const format_tag_t gOwi8o = mkldnn_gOwi8o;
|
||||||
|
const format_tag_t gIOhw16o16i = mkldnn_gIOhw16o16i;
|
||||||
|
const format_tag_t gOhwi16o = mkldnn_gOhwi16o;
|
||||||
|
const format_tag_t gOhwi4o = mkldnn_gOhwi4o;
|
||||||
|
const format_tag_t gOhwi8o = mkldnn_gOhwi8o;
|
||||||
|
const format_tag_t Goihw16g = mkldnn_Goihw16g;
|
||||||
|
const format_tag_t gOIhw16i16o = mkldnn_gOIhw16i16o;
|
||||||
|
const format_tag_t gOIhw16o16i = mkldnn_gOIhw16o16i;
|
||||||
|
const format_tag_t gOihw16o = mkldnn_gOihw16o;
|
||||||
|
const format_tag_t gOIhw2i8o4i = mkldnn_gOIhw2i8o4i;
|
||||||
|
const format_tag_t gOIhw4i16o4i = mkldnn_gOIhw4i16o4i;
|
||||||
|
const format_tag_t gOIhw4i4o = mkldnn_gOIhw4i4o;
|
||||||
|
const format_tag_t gOIhw4o4i = mkldnn_gOIhw4o4i;
|
||||||
|
const format_tag_t gOihw4o = mkldnn_gOihw4o;
|
||||||
|
const format_tag_t Goihw8g = mkldnn_Goihw8g;
|
||||||
|
const format_tag_t gOIhw8i16o2i = mkldnn_gOIhw8i16o2i;
|
||||||
|
const format_tag_t gOIhw8i8o = mkldnn_gOIhw8i8o;
|
||||||
|
const format_tag_t gOIhw8o16i2o = mkldnn_gOIhw8o16i2o;
|
||||||
|
const format_tag_t gOIhw8o8i = mkldnn_gOIhw8o8i;
|
||||||
|
const format_tag_t gOdhwi16o = mkldnn_gOdhwi16o;
|
||||||
|
const format_tag_t gOdhwi4o = mkldnn_gOdhwi4o;
|
||||||
|
const format_tag_t gOdhwi8o = mkldnn_gOdhwi8o;
|
||||||
|
const format_tag_t gOIdhw16i16o = mkldnn_gOIdhw16i16o;
|
||||||
|
const format_tag_t gOIdhw16o16i = mkldnn_gOIdhw16o16i;
|
||||||
|
const format_tag_t gOidhw16o = mkldnn_gOidhw16o;
|
||||||
|
const format_tag_t gOIdhw4i4o = mkldnn_gOIdhw4i4o;
|
||||||
|
const format_tag_t gOidhw4o = mkldnn_gOidhw4o;
|
||||||
|
const format_tag_t gOIdhw8i16o2i = mkldnn_gOIdhw8i16o2i;
|
||||||
|
const format_tag_t gOIdhw8i8o = mkldnn_gOIdhw8i8o;
|
||||||
|
const format_tag_t gOIdhw8o8i = mkldnn_gOIdhw8o8i;
|
||||||
|
}
|
||||||
|
|
||||||
|
using memory_extra_flags_t = mkldnn_memory_extra_flags_t;
|
||||||
|
namespace memory_extra_flags {
|
||||||
|
const memory_extra_flags_t none = mkldnn_memory_extra_flag_none;
|
||||||
|
const memory_extra_flags_t compensation_conv_s8s8 = mkldnn_memory_extra_flag_compensation_conv_s8s8;
|
||||||
|
const memory_extra_flags_t scale_adjust = mkldnn_memory_extra_flag_scale_adjust;
|
||||||
|
}
|
||||||
|
|
||||||
|
using padding_kind_t = mkldnn_padding_kind_t;
|
||||||
|
namespace padding_kind {
|
||||||
|
const padding_kind_t padding_zero = mkldnn_padding_zero;
|
||||||
|
}
|
||||||
|
|
||||||
|
using engine_kind_t = mkldnn_engine_kind_t;
|
||||||
|
namespace engine_kind {
|
||||||
|
const engine_kind_t any_engine = mkldnn_any_engine;
|
||||||
|
const engine_kind_t cpu = mkldnn_cpu;
|
||||||
|
}
|
||||||
|
|
||||||
|
using primitive_kind_t = mkldnn_primitive_kind_t;
|
||||||
|
namespace primitive_kind {
|
||||||
|
const primitive_kind_t undefined = mkldnn_undefined_primitive;
|
||||||
|
const primitive_kind_t reorder = mkldnn_reorder;
|
||||||
|
const primitive_kind_t concat = mkldnn_concat;
|
||||||
|
const primitive_kind_t sum = mkldnn_sum;
|
||||||
|
const primitive_kind_t convolution = mkldnn_convolution;
|
||||||
|
const primitive_kind_t deconvolution = mkldnn_deconvolution;
|
||||||
|
const primitive_kind_t shuffle = mkldnn_shuffle;
|
||||||
|
const primitive_kind_t eltwise = mkldnn_eltwise;
|
||||||
|
const primitive_kind_t softmax = mkldnn_softmax;
|
||||||
|
const primitive_kind_t pooling = mkldnn_pooling;
|
||||||
|
const primitive_kind_t lrn = mkldnn_lrn;
|
||||||
|
const primitive_kind_t batch_normalization = mkldnn_batch_normalization;
|
||||||
|
const primitive_kind_t inner_product = mkldnn_inner_product;
|
||||||
|
const primitive_kind_t rnn = mkldnn_rnn;
|
||||||
|
}
|
||||||
|
|
||||||
|
using query_t = mkldnn_query_t;
|
||||||
|
namespace query {
|
||||||
|
const query_t undef = mkldnn_query_undef;
|
||||||
|
|
||||||
|
const query_t engine = mkldnn_query_engine;
|
||||||
|
const query_t primitive_kind = mkldnn_query_primitive_kind;
|
||||||
|
|
||||||
|
const query_t num_of_inputs_s32 = mkldnn_query_num_of_inputs_s32;
|
||||||
|
const query_t num_of_outputs_s32 = mkldnn_query_num_of_outputs_s32;
|
||||||
|
|
||||||
|
const query_t time_estimate_f64 = mkldnn_query_time_estimate_f64;
|
||||||
|
const query_t memory_consumption_s64 = mkldnn_query_memory_consumption_s64;
|
||||||
|
|
||||||
|
const query_t scratchpad_engine = mkldnn_query_scratchpad_engine;
|
||||||
|
|
||||||
|
const query_t impl_info_str = mkldnn_query_impl_info_str;
|
||||||
|
|
||||||
|
const query_t some_d = mkldnn_query_some_d;
|
||||||
|
const query_t op_d = mkldnn_query_op_d;
|
||||||
|
const query_t convolution_d = mkldnn_query_convolution_d;
|
||||||
|
const query_t deconvolution_d = mkldnn_query_deconvolution_d;
|
||||||
|
const query_t shuffle_d = mkldnn_query_shuffle_d;
|
||||||
|
const query_t eltwise_d = mkldnn_query_eltwise_d;
|
||||||
|
const query_t softmax_d = mkldnn_query_softmax_d;
|
||||||
|
const query_t pooling_d = mkldnn_query_pooling_d;
|
||||||
|
const query_t lrn_d = mkldnn_query_lrn_d;
|
||||||
|
const query_t batch_normalization_d = mkldnn_query_batch_normalization_d;
|
||||||
|
const query_t inner_product_d = mkldnn_query_inner_product_d;
|
||||||
|
const query_t rnn_d = mkldnn_query_rnn_d;
|
||||||
|
|
||||||
|
const query_t some_md = mkldnn_query_some_md;
|
||||||
|
const query_t src_md = mkldnn_query_src_md;
|
||||||
|
const query_t diff_src_md = mkldnn_query_diff_src_md;
|
||||||
|
const query_t weights_md = mkldnn_query_weights_md;
|
||||||
|
const query_t diff_weights_md = mkldnn_query_diff_weights_md;
|
||||||
|
const query_t dst_md = mkldnn_query_dst_md;
|
||||||
|
const query_t diff_dst_md = mkldnn_query_diff_dst_md;
|
||||||
|
|
||||||
|
const query_t workspace_md = mkldnn_query_workspace_md;
|
||||||
|
const query_t scratchpad_md = mkldnn_query_scratchpad_md;
|
||||||
|
}
|
||||||
|
|
||||||
|
using blocking_desc_t = mkldnn_blocking_desc_t;
|
||||||
|
using rnn_packed_desc_t = mkldnn_rnn_packed_desc_t;
|
||||||
|
using wino_desc_t = mkldnn_wino_desc_t;
|
||||||
|
using memory_extra_desc_t = mkldnn_memory_extra_desc_t;
|
||||||
|
using memory_desc_t = mkldnn_memory_desc_t;
|
||||||
|
using convolution_desc_t = mkldnn_convolution_desc_t;
|
||||||
|
using deconvolution_desc_t = mkldnn_deconvolution_desc_t;
|
||||||
|
using shuffle_desc_t = mkldnn_shuffle_desc_t;
|
||||||
|
using pooling_desc_t = mkldnn_pooling_desc_t;
|
||||||
|
using eltwise_desc_t = mkldnn_eltwise_desc_t;
|
||||||
|
using softmax_desc_t = mkldnn_softmax_desc_t;
|
||||||
|
using lrn_desc_t = mkldnn_lrn_desc_t;
|
||||||
|
using batch_normalization_desc_t = mkldnn_batch_normalization_desc_t;
|
||||||
|
using inner_product_desc_t = mkldnn_inner_product_desc_t;
|
||||||
|
|
||||||
|
using rnn_direction_t = mkldnn_rnn_direction_t;
|
||||||
|
using rnn_cell_desc_t = mkldnn_rnn_cell_desc_t;
|
||||||
|
using rnn_desc_t = mkldnn_rnn_desc_t;
|
||||||
|
|
||||||
|
/* C op_desc_t, which eventually are just (void*) */
|
||||||
|
using c_op_desc_t = mkldnn_op_desc_t;
|
||||||
|
using const_c_op_desc_t = const_mkldnn_op_desc_t;
|
||||||
|
|
||||||
|
struct op_desc_t {
|
||||||
|
union {
|
||||||
|
primitive_kind_t kind;
|
||||||
|
convolution_desc_t convolution;
|
||||||
|
deconvolution_desc_t deconvolution;
|
||||||
|
shuffle_desc_t shuffle;
|
||||||
|
pooling_desc_t pooling;
|
||||||
|
eltwise_desc_t eltwise;
|
||||||
|
softmax_desc_t softmax;
|
||||||
|
lrn_desc_t lrn;
|
||||||
|
batch_normalization_desc_t batch_normalization;
|
||||||
|
inner_product_desc_t inner_product;
|
||||||
|
rnn_desc_t rnn;
|
||||||
|
};
|
||||||
|
|
||||||
|
op_desc_t(const primitive_kind_t &_): kind(_) {}
|
||||||
|
|
||||||
|
# define DECL_CTOR_AND_CONVERTERS(c_type, name) \
|
||||||
|
op_desc_t(const c_type &_): name(_) {} \
|
||||||
|
static op_desc_t *convert_from_c(c_type *_) \
|
||||||
|
{ return reinterpret_cast<op_desc_t*>(_); } \
|
||||||
|
static const op_desc_t *convert_from_c(const c_type *_) \
|
||||||
|
{ return reinterpret_cast<const op_desc_t*>(_); }
|
||||||
|
|
||||||
|
DECL_CTOR_AND_CONVERTERS(convolution_desc_t, convolution);
|
||||||
|
DECL_CTOR_AND_CONVERTERS(shuffle_desc_t, shuffle);
|
||||||
|
DECL_CTOR_AND_CONVERTERS(pooling_desc_t, pooling);
|
||||||
|
DECL_CTOR_AND_CONVERTERS(eltwise_desc_t, eltwise);
|
||||||
|
DECL_CTOR_AND_CONVERTERS(softmax_desc_t, softmax);
|
||||||
|
DECL_CTOR_AND_CONVERTERS(lrn_desc_t, lrn);
|
||||||
|
DECL_CTOR_AND_CONVERTERS(batch_normalization_desc_t, batch_normalization);
|
||||||
|
DECL_CTOR_AND_CONVERTERS(inner_product_desc_t, inner_product);
|
||||||
|
DECL_CTOR_AND_CONVERTERS(rnn_desc_t, rnn);
|
||||||
|
|
||||||
|
# undef DECL_CTOR_AND_CONVERTERS
|
||||||
|
};
|
||||||
|
|
||||||
|
using engine_t = mkldnn_engine;
|
||||||
|
using primitive_desc_iterator_t = mkldnn_primitive_desc_iterator;
|
||||||
|
using primitive_desc_t = mkldnn_primitive_desc;
|
||||||
|
using primitive_attr_t = mkldnn_primitive_attr;
|
||||||
|
using post_ops_t = mkldnn_post_ops;
|
||||||
|
using memory_t = mkldnn_memory;
|
||||||
|
using primitive_t = mkldnn_primitive;
|
||||||
|
|
||||||
|
using primitive_arg_index_t = int;
|
||||||
|
|
||||||
|
using stream_flags_t = mkldnn_stream_flags_t;
|
||||||
|
namespace stream_flags {
|
||||||
|
const stream_flags_t default_flags = mkldnn_stream_default_flags;
|
||||||
|
}
|
||||||
|
using stream_t = mkldnn_stream;
|
||||||
|
|
||||||
|
/* forward declaration of the internal primitive_desc types */
|
||||||
|
struct batch_normalization_bwd_pd_t;
|
||||||
|
struct batch_normalization_fwd_pd_t;
|
||||||
|
struct batch_normalization_pd_t;
|
||||||
|
struct concat_pd_t;
|
||||||
|
struct convolution_bwd_data_pd_t;
|
||||||
|
struct convolution_bwd_weights_pd_t;
|
||||||
|
struct convolution_fwd_pd_t;
|
||||||
|
struct convolution_pd_t;
|
||||||
|
struct deconvolution_bwd_data_pd_t;
|
||||||
|
struct deconvolution_bwd_weights_pd_t;
|
||||||
|
struct deconvolution_fwd_pd_t;
|
||||||
|
struct deconvolution_pd_t;
|
||||||
|
struct eltwise_bwd_pd_t;
|
||||||
|
struct eltwise_fwd_pd_t;
|
||||||
|
struct eltwise_pd_t;
|
||||||
|
struct inner_product_bwd_data_pd_t;
|
||||||
|
struct inner_product_bwd_weights_pd_t;
|
||||||
|
struct inner_product_fwd_pd_t;
|
||||||
|
struct inner_product_pd_t;
|
||||||
|
struct lrn_bwd_pd_t;
|
||||||
|
struct lrn_fwd_pd_t;
|
||||||
|
struct lrn_pd_t;
|
||||||
|
struct pooling_bwd_pd_t;
|
||||||
|
struct pooling_fwd_pd_t;
|
||||||
|
struct pooling_pd_t;
|
||||||
|
struct reorder_pd_t;
|
||||||
|
struct rnn_bwd_pd_t;
|
||||||
|
struct rnn_fwd_pd_t;
|
||||||
|
struct rnn_pd_t;
|
||||||
|
struct shuffle_pd_t;
|
||||||
|
struct softmax_bwd_pd_t;
|
||||||
|
struct softmax_fwd_pd_t;
|
||||||
|
struct softmax_pd_t;
|
||||||
|
struct sum_pd_t;
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
86
thirdparty/oidn/mkl-dnn/src/common/concat.cpp
vendored
Normal file
86
thirdparty/oidn/mkl-dnn/src/common/concat.cpp
vendored
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
#include "mkldnn.h"
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "engine.hpp"
|
||||||
|
#include "type_helpers.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
#include "concat_pd.hpp"
|
||||||
|
|
||||||
|
using namespace mkldnn::impl;
|
||||||
|
using namespace mkldnn::impl::utils;
|
||||||
|
using namespace mkldnn::impl::status;
|
||||||
|
|
||||||
|
status_t mkldnn_concat_primitive_desc_create(primitive_desc_t **concat_pd,
|
||||||
|
const memory_desc_t *dst_md, int n, int concat_dim,
|
||||||
|
const memory_desc_t *src_mds,
|
||||||
|
const primitive_attr_t *attr,
|
||||||
|
engine_t *engine) {
|
||||||
|
bool args_ok = !any_null(concat_pd, src_mds) && n > 0;
|
||||||
|
if (!args_ok) return invalid_arguments;
|
||||||
|
|
||||||
|
const primitive_attr_t dummy_attr;
|
||||||
|
if (attr == NULL)
|
||||||
|
attr = &dummy_attr;
|
||||||
|
|
||||||
|
const int ndims = src_mds[0].ndims;
|
||||||
|
const dims_t &dims = src_mds[0].dims;
|
||||||
|
const data_type_t dt = src_mds[0].data_type;
|
||||||
|
|
||||||
|
int concat_dim_sz = dims[concat_dim];
|
||||||
|
for (int i = 1; i < n; ++i) {
|
||||||
|
if (src_mds[i].ndims != ndims) return invalid_arguments;
|
||||||
|
for (int d = 0; d < ndims; ++d) {
|
||||||
|
if (d == concat_dim) continue;
|
||||||
|
if (src_mds[i].dims[d] != dims[d])
|
||||||
|
return invalid_arguments;
|
||||||
|
}
|
||||||
|
if (src_mds[i].data_type != dt) return invalid_arguments;
|
||||||
|
concat_dim_sz += src_mds[i].dims[concat_dim];
|
||||||
|
}
|
||||||
|
|
||||||
|
memory_desc_t dummy_dst_md;
|
||||||
|
if (dst_md) {
|
||||||
|
if (dst_md->ndims != ndims) return invalid_arguments;
|
||||||
|
for (int d = 0; d < ndims; ++d) {
|
||||||
|
if (dst_md->dims[d] !=
|
||||||
|
(d == concat_dim ? concat_dim_sz : dims[d]))
|
||||||
|
return invalid_arguments;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
dummy_dst_md = src_mds[0];
|
||||||
|
dummy_dst_md.dims[concat_dim] = concat_dim_sz;
|
||||||
|
dummy_dst_md.format_kind = format_kind::any;
|
||||||
|
dst_md = &dummy_dst_md;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto c_pd = reinterpret_cast<concat_pd_t **>(concat_pd);
|
||||||
|
|
||||||
|
for (auto c = engine->get_concat_implementation_list(); *c; ++c) {
|
||||||
|
if ((*c)(c_pd, engine, attr, dst_md, n, concat_dim, src_mds)
|
||||||
|
== success) {
|
||||||
|
(*c_pd)->init_info();
|
||||||
|
(*c_pd)->init_scratchpad_md();
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return unimplemented;
|
||||||
|
}
|
211
thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp
vendored
Normal file
211
thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp
vendored
Normal file
@ -0,0 +1,211 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2019 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#ifndef CONCAT_PD_HPP
|
||||||
|
#define CONCAT_PD_HPP
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "nstl.hpp"
|
||||||
|
#include "primitive_desc.hpp"
|
||||||
|
#include "type_helpers.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
namespace mkldnn {
|
||||||
|
namespace impl {
|
||||||
|
|
||||||
|
struct concat_pd_t: public primitive_desc_t {
|
||||||
|
concat_pd_t(engine_t *engine, const primitive_attr_t *attr,
|
||||||
|
const memory_desc_t *dst_md, int n, int concat_dim,
|
||||||
|
const memory_desc_t *src_mds)
|
||||||
|
: primitive_desc_t(engine, attr, primitive_kind::concat)
|
||||||
|
, n_(n), concat_dim_(concat_dim), dst_md_(*dst_md)
|
||||||
|
{
|
||||||
|
src_mds_.reserve(n_);
|
||||||
|
for (int i = 0; i < n_; ++i) src_mds_.push_back(src_mds[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
concat_pd_t(const concat_pd_t &rhs) = default;
|
||||||
|
|
||||||
|
virtual void init_info() override { impl::init_info(this, this->info_); }
|
||||||
|
|
||||||
|
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
||||||
|
if (arg >= MKLDNN_ARG_MULTIPLE_SRC
|
||||||
|
&& arg < MKLDNN_ARG_MULTIPLE_SRC + n_inputs())
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_DST)
|
||||||
|
return arg_usage_t::output;
|
||||||
|
|
||||||
|
return primitive_desc_t::arg_usage(arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual const memory_desc_t *src_md(int index = 0) const override
|
||||||
|
{ return index < n_inputs() ? &src_mds_[index] : nullptr; }
|
||||||
|
virtual const memory_desc_t *dst_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &dst_md_ : nullptr; }
|
||||||
|
|
||||||
|
virtual int n_inputs() const override { return n_; }
|
||||||
|
virtual int n_outputs() const override { return 1; }
|
||||||
|
|
||||||
|
int concat_dim() const { return concat_dim_; }
|
||||||
|
|
||||||
|
const memory_desc_t *src_image_md(int index = 0) const
|
||||||
|
{ return index < n_inputs() ? &src_image_mds_[index] : nullptr; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
int n_, concat_dim_;
|
||||||
|
memory_desc_t dst_md_;
|
||||||
|
nstl::vector<memory_desc_t> src_mds_;
|
||||||
|
|
||||||
|
/* contains images of srcs in the dst memory (if possible)
|
||||||
|
* Lives here to simplify some implementations. An implementation might
|
||||||
|
* use this auxiliary array iff init() returned success */
|
||||||
|
nstl::vector<memory_desc_t> src_image_mds_;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
/* inits src_image_mds_ and dst_md_ in simple cases. The call may fail */
|
||||||
|
status_t init() {
|
||||||
|
bool ok = true
|
||||||
|
&& set_default_params() == status::success
|
||||||
|
&& attr()->has_default_values();
|
||||||
|
if (!ok) return status::unimplemented;
|
||||||
|
|
||||||
|
for (int i = 0; i < n_; ++i) {
|
||||||
|
const memory_desc_wrapper i_d(&src_mds_[i]);
|
||||||
|
if (!i_d.is_blocking_desc() || i_d.is_additional_buffer())
|
||||||
|
return status::unimplemented;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int ndims = dst_md_.ndims;
|
||||||
|
int current_concat_dim_offset = 0;
|
||||||
|
for (int i = 0; i < n_; ++i) {
|
||||||
|
const int dim = src_mds_[i].dims[concat_dim_];
|
||||||
|
dims_t dims, offsets = {};
|
||||||
|
utils::array_copy(dims, dst_md_.dims, ndims);
|
||||||
|
dims[concat_dim_] = dim;
|
||||||
|
offsets[concat_dim_] = current_concat_dim_offset;
|
||||||
|
|
||||||
|
memory_desc_t src_img_d;
|
||||||
|
status_t status = mkldnn_memory_desc_init_submemory(&src_img_d,
|
||||||
|
&dst_md_, dims, offsets);
|
||||||
|
if (status != status::success) return status;
|
||||||
|
src_image_mds_.push_back(src_img_d);
|
||||||
|
current_concat_dim_offset += dim;
|
||||||
|
}
|
||||||
|
|
||||||
|
return status::success;
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t set_default_params() {
|
||||||
|
if (dst_md_.format_kind != format_kind::any)
|
||||||
|
return status::success;
|
||||||
|
|
||||||
|
const int ndims = dst_md_.ndims;
|
||||||
|
|
||||||
|
/* The stupidest ever heuristics (but not the same as we had before):
|
||||||
|
* - Pick the first non-plain format;
|
||||||
|
* - If all formats are plain or it is not possible to create a
|
||||||
|
* blocked format for the output, pick the format of the plain input
|
||||||
|
* - If this fails as well, use plain layout (abcd...)
|
||||||
|
*/
|
||||||
|
status_t status = status::unimplemented;
|
||||||
|
for (int i = 0; i < n_; ++i) {
|
||||||
|
const memory_desc_wrapper src_d(src_mds_[i]);
|
||||||
|
if (src_d.is_blocking_desc() && !src_d.is_plain()) {
|
||||||
|
status = memory_desc_init_by_blocking_desc(dst_md_,
|
||||||
|
src_d.blocking_desc());
|
||||||
|
if (status == status::success) break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (status == status::success) {
|
||||||
|
/* check if we can create a sub-memory for the dst */
|
||||||
|
bool desired_format_ok = true;
|
||||||
|
int current_concat_dim_offset = 0;
|
||||||
|
for (int i = 0; i < n_; ++i) {
|
||||||
|
const int dim = src_mds_[i].dims[concat_dim_];
|
||||||
|
dims_t dims, offsets = {};
|
||||||
|
utils::array_copy(dims, dst_md_.dims, ndims);
|
||||||
|
dims[concat_dim_] = dim;
|
||||||
|
offsets[concat_dim_] = current_concat_dim_offset;
|
||||||
|
|
||||||
|
memory_desc_t src_img_d;
|
||||||
|
status_t status = mkldnn_memory_desc_init_submemory(&src_img_d,
|
||||||
|
&dst_md_, dims, offsets);
|
||||||
|
if (status != status::success) {
|
||||||
|
desired_format_ok = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
current_concat_dim_offset += dim;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!desired_format_ok)
|
||||||
|
status = status::unimplemented;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* if no success so far, try using the format of the first plain input */
|
||||||
|
if (status != status::success) {
|
||||||
|
for (int i = 0; i < n_; ++i) {
|
||||||
|
const memory_desc_wrapper src_d(src_mds_[i]);
|
||||||
|
if (src_d.is_blocking_desc() && src_d.is_plain()) {
|
||||||
|
status = memory_desc_init_by_blocking_desc(dst_md_,
|
||||||
|
memory_desc_wrapper(src_mds_[0]).blocking_desc());
|
||||||
|
if (status == status::success) return status;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* the last line of defense: use plain abcd... format */
|
||||||
|
if (status != status::success)
|
||||||
|
status = memory_desc_init_by_strides(dst_md_, nullptr);
|
||||||
|
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#define DECLARE_CONCAT_PD_t(impl_name, ...) \
|
||||||
|
static status_t create(concat_pd_t **concat_pd, \
|
||||||
|
engine_t *engine, const primitive_attr_t *attr, \
|
||||||
|
const memory_desc_t *dst_md, int n, int concat_dim, \
|
||||||
|
const memory_desc_t *src_mds) { \
|
||||||
|
using namespace status; \
|
||||||
|
auto _pd = new pd_t(engine, attr, dst_md, n, concat_dim, src_mds); \
|
||||||
|
if (_pd == nullptr) return out_of_memory; \
|
||||||
|
if (_pd->init() != success) { delete _pd; return unimplemented; } \
|
||||||
|
return safe_ptr_assign<concat_pd_t>(*concat_pd, _pd); \
|
||||||
|
} \
|
||||||
|
virtual status_t create_primitive(primitive_t **p) const override { \
|
||||||
|
double ms = get_msec(); \
|
||||||
|
auto ret = safe_ptr_assign<primitive_t>(*p, new (__VA_ARGS__)(this)); \
|
||||||
|
ms = get_msec() - ms; \
|
||||||
|
if (mkldnn_verbose()->level >= 2) { \
|
||||||
|
printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \
|
||||||
|
fflush(0); \
|
||||||
|
} \
|
||||||
|
return ret; \
|
||||||
|
} \
|
||||||
|
virtual pd_t *clone() const override { return new pd_t(*this); } \
|
||||||
|
virtual const char *name() const override { return impl_name; } \
|
||||||
|
|
||||||
|
#define DECLARE_CONCAT_PD_T(impl_name, ...) \
|
||||||
|
DECLARE_CONCAT_PD_t(impl_name, __VA_ARGS__)
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
200
thirdparty/oidn/mkl-dnn/src/common/convolution.cpp
vendored
Normal file
200
thirdparty/oidn/mkl-dnn/src/common/convolution.cpp
vendored
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2016-2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
#include "mkldnn.h"
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "type_helpers.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
using namespace mkldnn::impl;
|
||||||
|
using namespace mkldnn::impl::utils;
|
||||||
|
using namespace mkldnn::impl::status;
|
||||||
|
using namespace mkldnn::impl::prop_kind;
|
||||||
|
using namespace mkldnn::impl::alg_kind;
|
||||||
|
using namespace mkldnn::impl::types;
|
||||||
|
|
||||||
|
namespace mkldnn {
|
||||||
|
namespace impl {
|
||||||
|
status_t conv_desc_init(convolution_desc_t *conv_desc,
|
||||||
|
prop_kind_t prop_kind, alg_kind_t alg_kind,
|
||||||
|
const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
|
||||||
|
const memory_desc_t *bias_desc, const memory_desc_t *dst_desc,
|
||||||
|
const dims_t strides, const dims_t dilates,
|
||||||
|
const dims_t padding_l, const dims_t padding_r,
|
||||||
|
padding_kind_t padding_kind) {
|
||||||
|
bool args_ok = true
|
||||||
|
&& !any_null(conv_desc, src_desc, weights_desc, dst_desc, strides,
|
||||||
|
padding_l)
|
||||||
|
&& one_of(alg_kind, convolution_auto, convolution_direct, convolution_winograd)
|
||||||
|
&& one_of(padding_kind, padding_kind::padding_zero);
|
||||||
|
if (!args_ok) return invalid_arguments;
|
||||||
|
|
||||||
|
if (padding_r == nullptr) padding_r = padding_l;
|
||||||
|
|
||||||
|
auto cd = convolution_desc_t();
|
||||||
|
cd.primitive_kind = primitive_kind::convolution;
|
||||||
|
cd.prop_kind = prop_kind;
|
||||||
|
cd.alg_kind = alg_kind;
|
||||||
|
|
||||||
|
cd.diff_src_desc = cd.src_desc = zero_md();
|
||||||
|
cd.diff_dst_desc = cd.dst_desc = zero_md();
|
||||||
|
cd.diff_weights_desc = cd.weights_desc = zero_md();
|
||||||
|
cd.diff_bias_desc = cd.bias_desc = zero_md();
|
||||||
|
|
||||||
|
const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
|
||||||
|
const bool with_bias =
|
||||||
|
bias_desc && bias_desc->format_kind != format_kind::undef;
|
||||||
|
const bool with_groups = weights_desc->ndims == src_desc->ndims + 1;
|
||||||
|
|
||||||
|
(prop_kind == backward_data ? cd.diff_src_desc : cd.src_desc) = *src_desc;
|
||||||
|
(is_fwd ? cd.dst_desc : cd.diff_dst_desc) = *dst_desc;
|
||||||
|
(prop_kind == backward_weights ? cd.diff_weights_desc : cd.weights_desc) =
|
||||||
|
*weights_desc;
|
||||||
|
if (with_bias)
|
||||||
|
(prop_kind == backward_weights ? cd.diff_bias_desc : cd.bias_desc) =
|
||||||
|
*bias_desc;
|
||||||
|
|
||||||
|
int sp_dims = src_desc->ndims - 2;
|
||||||
|
utils::array_copy(cd.strides, strides, sp_dims);
|
||||||
|
utils::array_copy(cd.padding[0], padding_l, sp_dims);
|
||||||
|
utils::array_copy(cd.padding[1], padding_r, sp_dims);
|
||||||
|
if (dilates)
|
||||||
|
utils::array_copy(cd.dilates, dilates, sp_dims);
|
||||||
|
else
|
||||||
|
utils::array_set(cd.dilates, 0, sp_dims);
|
||||||
|
|
||||||
|
cd.padding_kind = padding_kind;
|
||||||
|
cd.accum_data_type = types::default_accum_data_type(src_desc->data_type,
|
||||||
|
weights_desc->data_type, dst_desc->data_type, prop_kind);
|
||||||
|
|
||||||
|
const int g = with_groups ? weights_desc->dims[0] : 1;
|
||||||
|
const int bias_dim = prop_kind == backward_data
|
||||||
|
? src_desc->dims[1]
|
||||||
|
: dst_desc->dims[1];
|
||||||
|
|
||||||
|
bool consistency = true
|
||||||
|
&& memory_desc_wrapper(weights_desc).nelems()
|
||||||
|
&& src_desc->ndims == dst_desc->ndims
|
||||||
|
&& utils::one_of(src_desc->ndims, 3, 4, 5)
|
||||||
|
&& utils::one_of(weights_desc->ndims, src_desc->ndims,
|
||||||
|
src_desc->ndims + 1)
|
||||||
|
&& (with_bias ? bias_desc->ndims == 1 : true)
|
||||||
|
&& (with_bias ? bias_desc->dims[0] == bias_dim : true)
|
||||||
|
&& src_desc->dims[0] == dst_desc->dims[0]
|
||||||
|
&& src_desc->dims[1] == g * weights_desc->dims[with_groups + 1]
|
||||||
|
&& dst_desc->dims[1] == g * weights_desc->dims[with_groups + 0];
|
||||||
|
for (int i = 2; i < src_desc->ndims; ++i)
|
||||||
|
{
|
||||||
|
int src = src_desc->dims[i];
|
||||||
|
int ker = weights_desc->dims[with_groups + i];
|
||||||
|
int dil = cd.dilates[i - 2];
|
||||||
|
int pad_l = padding_l[i - 2];
|
||||||
|
int pad_r = padding_r[i - 2];
|
||||||
|
int str = strides[i - 2];
|
||||||
|
int dst = dst_desc->dims[i];
|
||||||
|
int ker_range = 1 + (ker - 1) * (dil + 1);
|
||||||
|
|
||||||
|
if (str < 1) return invalid_arguments;
|
||||||
|
consistency = consistency
|
||||||
|
&& dil >= 0
|
||||||
|
&& pad_l >= 0
|
||||||
|
&& pad_r + str > 0
|
||||||
|
&& (src - ker_range + pad_l + pad_r) / str + 1 == dst;
|
||||||
|
}
|
||||||
|
if (!consistency) return invalid_arguments;
|
||||||
|
|
||||||
|
*conv_desc = cd;
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_convolution_forward_desc_init(convolution_desc_t *conv_desc,
|
||||||
|
prop_kind_t prop_kind, alg_kind_t alg_kind,
|
||||||
|
const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
|
||||||
|
const memory_desc_t *bias_desc, const memory_desc_t *dst_desc,
|
||||||
|
const dims_t strides, const dims_t padding_l, const dims_t padding_r,
|
||||||
|
padding_kind_t padding_kind) {
|
||||||
|
if (!one_of(prop_kind, forward_training, forward_inference))
|
||||||
|
return invalid_arguments;
|
||||||
|
return mkldnn::impl::conv_desc_init(conv_desc, prop_kind, alg_kind, src_desc,
|
||||||
|
weights_desc, bias_desc, dst_desc, strides, nullptr,
|
||||||
|
padding_l, padding_r, padding_kind);
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_dilated_convolution_forward_desc_init(
|
||||||
|
convolution_desc_t *conv_desc, prop_kind_t prop_kind,
|
||||||
|
alg_kind_t alg_kind, const memory_desc_t *src_desc,
|
||||||
|
const memory_desc_t *weights_desc, const memory_desc_t *bias_desc,
|
||||||
|
const memory_desc_t *dst_desc, const dims_t strides,
|
||||||
|
const dims_t dilates, const dims_t padding_l,
|
||||||
|
const dims_t padding_r, padding_kind_t padding_kind) {
|
||||||
|
if (!one_of(prop_kind, forward_training, forward_inference))
|
||||||
|
return invalid_arguments;
|
||||||
|
return mkldnn::impl::conv_desc_init(conv_desc, prop_kind, alg_kind, src_desc,
|
||||||
|
weights_desc, bias_desc, dst_desc, strides, dilates,
|
||||||
|
padding_l, padding_r, padding_kind);
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_convolution_backward_data_desc_init(
|
||||||
|
convolution_desc_t *conv_desc, alg_kind_t alg_kind,
|
||||||
|
const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc,
|
||||||
|
const memory_desc_t *diff_dst_desc, const dims_t strides,
|
||||||
|
const dims_t padding_l, const dims_t padding_r,
|
||||||
|
padding_kind_t padding_kind) {
|
||||||
|
return mkldnn::impl::conv_desc_init(conv_desc, backward_data, alg_kind, diff_src_desc,
|
||||||
|
weights_desc, nullptr, diff_dst_desc, strides, nullptr,
|
||||||
|
padding_l, padding_r, padding_kind);
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_dilated_convolution_backward_data_desc_init(
|
||||||
|
convolution_desc_t *conv_desc, alg_kind_t alg_kind,
|
||||||
|
const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc,
|
||||||
|
const memory_desc_t *diff_dst_desc, const dims_t strides,
|
||||||
|
const dims_t dilates, const dims_t padding_l, const dims_t padding_r,
|
||||||
|
padding_kind_t padding_kind) {
|
||||||
|
return mkldnn::impl::conv_desc_init(conv_desc, backward_data, alg_kind, diff_src_desc,
|
||||||
|
weights_desc, nullptr, diff_dst_desc, strides, dilates,
|
||||||
|
padding_l, padding_r, padding_kind);
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_convolution_backward_weights_desc_init(
|
||||||
|
convolution_desc_t *conv_desc, alg_kind_t alg_kind,
|
||||||
|
const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc,
|
||||||
|
const memory_desc_t *diff_bias_desc,
|
||||||
|
const memory_desc_t *diff_dst_desc, const dims_t strides,
|
||||||
|
const dims_t padding_l, const dims_t padding_r,
|
||||||
|
padding_kind_t padding_kind) {
|
||||||
|
return mkldnn::impl::conv_desc_init(conv_desc, backward_weights, alg_kind, src_desc,
|
||||||
|
diff_weights_desc, diff_bias_desc, diff_dst_desc, strides,
|
||||||
|
nullptr, padding_l, padding_r, padding_kind);
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_dilated_convolution_backward_weights_desc_init(
|
||||||
|
convolution_desc_t *conv_desc, alg_kind_t alg_kind,
|
||||||
|
const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc,
|
||||||
|
const memory_desc_t *diff_bias_desc,
|
||||||
|
const memory_desc_t *diff_dst_desc, const dims_t strides,
|
||||||
|
const dims_t dilates, const dims_t padding_l, const dims_t padding_r,
|
||||||
|
padding_kind_t padding_kind) {
|
||||||
|
return mkldnn::impl::conv_desc_init(conv_desc, backward_weights, alg_kind, src_desc,
|
||||||
|
diff_weights_desc, diff_bias_desc, diff_dst_desc, strides,
|
||||||
|
dilates, padding_l, padding_r, padding_kind);
|
||||||
|
}
|
||||||
|
|
||||||
|
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
56
thirdparty/oidn/mkl-dnn/src/common/convolution_pd.cpp
vendored
Normal file
56
thirdparty/oidn/mkl-dnn/src/common/convolution_pd.cpp
vendored
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
#include "convolution_pd.hpp"
|
||||||
|
|
||||||
|
namespace mkldnn {
|
||||||
|
namespace impl {
|
||||||
|
|
||||||
|
using namespace prop_kind;
|
||||||
|
|
||||||
|
memory_desc_t *conv_prop_invariant_src_d(convolution_desc_t *desc) {
|
||||||
|
return desc->prop_kind == backward_data
|
||||||
|
? &desc->diff_src_desc : &desc->src_desc;
|
||||||
|
}
|
||||||
|
|
||||||
|
memory_desc_t *conv_prop_invariant_wei_d(convolution_desc_t *desc) {
|
||||||
|
return desc->prop_kind == backward_weights
|
||||||
|
? &desc->diff_weights_desc : &desc->weights_desc;
|
||||||
|
}
|
||||||
|
|
||||||
|
memory_desc_t *conv_prop_invariant_bia_d(convolution_desc_t *desc) {
|
||||||
|
return desc->prop_kind == backward_weights
|
||||||
|
? &desc->diff_bias_desc : &desc->bias_desc;
|
||||||
|
}
|
||||||
|
|
||||||
|
memory_desc_t *conv_prop_invariant_dst_d(convolution_desc_t *desc) {
|
||||||
|
return utils::one_of(desc->prop_kind, forward_inference, forward_training)
|
||||||
|
? &desc->dst_desc : &desc->diff_dst_desc;
|
||||||
|
}
|
||||||
|
|
||||||
|
const memory_desc_t *conv_prop_invariant_src_d(const convolution_desc_t *desc)
|
||||||
|
{ return conv_prop_invariant_src_d(const_cast<convolution_desc_t *>(desc)); }
|
||||||
|
const memory_desc_t *conv_prop_invariant_wei_d(const convolution_desc_t *desc)
|
||||||
|
{ return conv_prop_invariant_wei_d(const_cast<convolution_desc_t *>(desc)); }
|
||||||
|
const memory_desc_t *conv_prop_invariant_bia_d(const convolution_desc_t *desc)
|
||||||
|
{ return conv_prop_invariant_bia_d(const_cast<convolution_desc_t *>(desc)); }
|
||||||
|
const memory_desc_t *conv_prop_invariant_dst_d(const convolution_desc_t *desc)
|
||||||
|
{ return conv_prop_invariant_dst_d(const_cast<convolution_desc_t *>(desc)); }
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
348
thirdparty/oidn/mkl-dnn/src/common/convolution_pd.hpp
vendored
Normal file
348
thirdparty/oidn/mkl-dnn/src/common/convolution_pd.hpp
vendored
Normal file
@ -0,0 +1,348 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2016-2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#ifndef CONVOLUTION_PD_HPP
|
||||||
|
#define CONVOLUTION_PD_HPP
|
||||||
|
|
||||||
|
#include "mkldnn.h"
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "primitive_desc.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
namespace mkldnn {
|
||||||
|
namespace impl {
|
||||||
|
|
||||||
|
status_t conv_desc_init(convolution_desc_t *conv_desc,
|
||||||
|
prop_kind_t prop_kind, alg_kind_t alg_kind,
|
||||||
|
const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
|
||||||
|
const memory_desc_t *bias_desc, const memory_desc_t *dst_desc,
|
||||||
|
const dims_t strides, const dims_t dilates,
|
||||||
|
const dims_t padding_l, const dims_t padding_r,
|
||||||
|
padding_kind_t padding_kind);
|
||||||
|
|
||||||
|
memory_desc_t *conv_prop_invariant_src_d(convolution_desc_t *desc);
|
||||||
|
memory_desc_t *conv_prop_invariant_wei_d(convolution_desc_t *desc);
|
||||||
|
memory_desc_t *conv_prop_invariant_bia_d(convolution_desc_t *desc);
|
||||||
|
memory_desc_t *conv_prop_invariant_dst_d(convolution_desc_t *desc);
|
||||||
|
const memory_desc_t *conv_prop_invariant_src_d(const convolution_desc_t *desc);
|
||||||
|
const memory_desc_t *conv_prop_invariant_wei_d(const convolution_desc_t *desc);
|
||||||
|
const memory_desc_t *conv_prop_invariant_bia_d(const convolution_desc_t *desc);
|
||||||
|
const memory_desc_t *conv_prop_invariant_dst_d(const convolution_desc_t *desc);
|
||||||
|
|
||||||
|
struct convolution_fwd_pd_t;
|
||||||
|
|
||||||
|
struct convolution_pd_t: public primitive_desc_t {
|
||||||
|
static constexpr auto base_pkind = primitive_kind::convolution;
|
||||||
|
|
||||||
|
convolution_pd_t(engine_t *engine,
|
||||||
|
const convolution_desc_t *adesc,
|
||||||
|
const primitive_attr_t *attr,
|
||||||
|
const convolution_fwd_pd_t *hint_fwd_pd)
|
||||||
|
: primitive_desc_t(engine, attr, base_pkind)
|
||||||
|
, desc_(*adesc)
|
||||||
|
, hint_fwd_pd_(hint_fwd_pd)
|
||||||
|
{}
|
||||||
|
|
||||||
|
const convolution_desc_t *desc() const { return &desc_; }
|
||||||
|
virtual const op_desc_t *op_desc() const override
|
||||||
|
{ return reinterpret_cast<const op_desc_t *>(this->desc()); }
|
||||||
|
virtual void init_info() override { impl::init_info(this, this->info_); }
|
||||||
|
|
||||||
|
virtual status_t query(query_t what, int idx, void *result) const override {
|
||||||
|
switch (what) {
|
||||||
|
case pkind_traits<base_pkind>::query_d:
|
||||||
|
*(const convolution_desc_t**)result = desc(); break;
|
||||||
|
default: return primitive_desc_t::query(what, idx, result);
|
||||||
|
}
|
||||||
|
return status::success;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* common conv aux functions */
|
||||||
|
|
||||||
|
dim_t MB() const { return _src_md()->dims[0]; }
|
||||||
|
|
||||||
|
dim_t IC() const { return _src_md()->dims[1]; }
|
||||||
|
dim_t OC() const { return _dst_md()->dims[1]; }
|
||||||
|
dim_t G() const { return with_groups() ? _wei_md()->dims[0] : 1; }
|
||||||
|
|
||||||
|
dim_t ID() const { return ndims() >= 5 ? _src_md()->dims[ndims() - 3] : 1; }
|
||||||
|
dim_t IH() const { return ndims() >= 4 ? _src_md()->dims[ndims() - 2] : 1; }
|
||||||
|
dim_t IW() const { return _src_md()->dims[ndims() - 1]; }
|
||||||
|
|
||||||
|
dim_t OD() const { return ndims() >= 5 ? _dst_md()->dims[ndims() - 3] : 1; }
|
||||||
|
dim_t OH() const { return ndims() >= 4 ? _dst_md()->dims[ndims() - 2] : 1; }
|
||||||
|
dim_t OW() const { return _dst_md()->dims[ndims() - 1]; }
|
||||||
|
|
||||||
|
dim_t KD() const { return ndims() >= 5 ? _wei_md()->dims[ndims() + with_groups() - 3] : 1; }
|
||||||
|
dim_t KH() const { return ndims() >= 4 ? _wei_md()->dims[ndims() + with_groups() - 2] : 1; }
|
||||||
|
dim_t KW() const { return _wei_md()->dims[ndims() + with_groups() - 1]; }
|
||||||
|
|
||||||
|
dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; }
|
||||||
|
dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; }
|
||||||
|
dim_t KSW() const { return desc_.strides[ndims() - 3]; }
|
||||||
|
|
||||||
|
dim_t KDD() const { return ndims() >= 5 ? desc_.dilates[ndims() - 5] : 0; }
|
||||||
|
dim_t KDH() const { return ndims() >= 4 ? desc_.dilates[ndims() - 4] : 1; }
|
||||||
|
dim_t KDW() const { return desc_.dilates[ndims() - 3]; }
|
||||||
|
|
||||||
|
dim_t padFront() const { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; }
|
||||||
|
dim_t padBack() const { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; }
|
||||||
|
dim_t padT() const { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; }
|
||||||
|
dim_t padB() const { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; }
|
||||||
|
dim_t padL() const { return desc_.padding[0][ndims() - 3]; }
|
||||||
|
dim_t padR() const { return desc_.padding[1][ndims() - 3]; }
|
||||||
|
|
||||||
|
int ndims() const { return _src_md()->ndims; }
|
||||||
|
|
||||||
|
bool with_bias() const { return !memory_desc_wrapper(*_bia_md()).is_zero(); }
|
||||||
|
bool with_groups() const { return _wei_md()->ndims == ndims() + 1; }
|
||||||
|
|
||||||
|
bool is_fwd() const {
|
||||||
|
return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
|
||||||
|
prop_kind::forward_inference);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool has_zero_dim_memory() const {
|
||||||
|
const auto s_d = memory_desc_wrapper(*_src_md());
|
||||||
|
const auto d_d = memory_desc_wrapper(*_dst_md());
|
||||||
|
return s_d.has_zero_dim() || d_d.has_zero_dim();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
convolution_desc_t desc_;
|
||||||
|
const convolution_fwd_pd_t *hint_fwd_pd_;
|
||||||
|
|
||||||
|
bool set_default_formats_common_template(
|
||||||
|
memory_desc_t &src_md, format_tag_t src_tag,
|
||||||
|
memory_desc_t &wei_md, format_tag_t wei_tag,
|
||||||
|
memory_desc_t &dst_md, format_tag_t dst_tag,
|
||||||
|
memory_desc_t &bia_md) {
|
||||||
|
using namespace format_tag;
|
||||||
|
|
||||||
|
# define IS_OK(f) \
|
||||||
|
do { if ((f) != status::success) return false; } while(0)
|
||||||
|
if (src_md.format_kind == format_kind::any
|
||||||
|
&& !utils::one_of(src_tag, any, undef))
|
||||||
|
IS_OK(memory_desc_init_by_tag(src_md, src_tag));
|
||||||
|
if (dst_md.format_kind == format_kind::any
|
||||||
|
&& !utils::one_of(dst_tag, any, undef))
|
||||||
|
IS_OK(memory_desc_init_by_tag(dst_md, dst_tag));
|
||||||
|
if (wei_md.format_kind == format_kind::any
|
||||||
|
&& !utils::one_of(wei_tag, any, undef))
|
||||||
|
IS_OK(memory_desc_init_by_tag(wei_md, wei_tag));
|
||||||
|
if (with_bias() && bia_md.format_kind == format_kind::any)
|
||||||
|
IS_OK(memory_desc_init_by_tag(bia_md, x));
|
||||||
|
# undef IS_OK
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool set_default_alg_kind(alg_kind_t alg_kind) {
|
||||||
|
assert(utils::one_of(alg_kind, alg_kind::convolution_direct,
|
||||||
|
alg_kind::convolution_winograd));
|
||||||
|
if (desc_.alg_kind == alg_kind::convolution_auto)
|
||||||
|
desc_.alg_kind = alg_kind;
|
||||||
|
return desc_.alg_kind == alg_kind;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool expect_data_types(data_type_t src_dt, data_type_t wei_dt,
|
||||||
|
data_type_t bia_dt, data_type_t dst_dt, data_type_t acc_dt) const {
|
||||||
|
bool ok = true
|
||||||
|
&& (src_dt == data_type::undef || _src_md()->data_type == src_dt)
|
||||||
|
&& (wei_dt == data_type::undef || _wei_md()->data_type == wei_dt)
|
||||||
|
&& (dst_dt == data_type::undef || _dst_md()->data_type == dst_dt)
|
||||||
|
&& (acc_dt == data_type::undef || desc_.accum_data_type == acc_dt);
|
||||||
|
if (with_bias() && bia_dt != data_type::undef)
|
||||||
|
ok = ok && _bia_md()->data_type == bia_dt;
|
||||||
|
return ok;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const memory_desc_t *_src_md() const { return conv_prop_invariant_src_d(&desc_); }
|
||||||
|
const memory_desc_t *_wei_md() const { return conv_prop_invariant_wei_d(&desc_); }
|
||||||
|
const memory_desc_t *_bia_md() const { return conv_prop_invariant_bia_d(&desc_); }
|
||||||
|
const memory_desc_t *_dst_md() const { return conv_prop_invariant_dst_d(&desc_); }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct convolution_fwd_pd_t: public convolution_pd_t {
|
||||||
|
typedef convolution_fwd_pd_t base_class;
|
||||||
|
typedef convolution_fwd_pd_t hint_class;
|
||||||
|
|
||||||
|
convolution_fwd_pd_t(engine_t *engine,
|
||||||
|
const convolution_desc_t *adesc,
|
||||||
|
const primitive_attr_t *attr,
|
||||||
|
const convolution_fwd_pd_t *hint_fwd_pd)
|
||||||
|
: convolution_pd_t(engine, adesc, attr, hint_fwd_pd)
|
||||||
|
, src_md_(desc_.src_desc)
|
||||||
|
, weights_md_(desc_.weights_desc)
|
||||||
|
, bias_md_(desc_.bias_desc)
|
||||||
|
, dst_md_(desc_.dst_desc)
|
||||||
|
{}
|
||||||
|
|
||||||
|
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
||||||
|
if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS))
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_BIAS && with_bias())
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_DST)
|
||||||
|
return arg_usage_t::output;
|
||||||
|
|
||||||
|
return primitive_desc_t::arg_usage(arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual const memory_desc_t *src_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &src_md_ : nullptr; }
|
||||||
|
virtual const memory_desc_t *dst_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &dst_md_ : nullptr; }
|
||||||
|
virtual const memory_desc_t *weights_md(int index = 0) const override {
|
||||||
|
if (index == 0) return &weights_md_;
|
||||||
|
if (index == 1 && with_bias()) return &bias_md_;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual int n_inputs() const override { return 2 + with_bias(); }
|
||||||
|
virtual int n_outputs() const override { return 1; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
memory_desc_t src_md_;
|
||||||
|
memory_desc_t weights_md_;
|
||||||
|
memory_desc_t bias_md_;
|
||||||
|
memory_desc_t dst_md_;
|
||||||
|
|
||||||
|
bool set_default_formats_common(format_tag_t src_tag,
|
||||||
|
format_tag_t wei_tag, format_tag_t dst_tag) {
|
||||||
|
return set_default_formats_common_template(src_md_, src_tag,
|
||||||
|
weights_md_, wei_tag, dst_md_, dst_tag, bias_md_);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct convolution_bwd_data_pd_t: public convolution_pd_t {
|
||||||
|
typedef convolution_bwd_data_pd_t base_class;
|
||||||
|
typedef convolution_fwd_pd_t hint_class;
|
||||||
|
|
||||||
|
convolution_bwd_data_pd_t(engine_t *engine,
|
||||||
|
const convolution_desc_t *adesc,
|
||||||
|
const primitive_attr_t *attr,
|
||||||
|
const convolution_fwd_pd_t *hint_fwd_pd)
|
||||||
|
: convolution_pd_t(engine, adesc, attr, hint_fwd_pd)
|
||||||
|
, diff_src_md_(desc_.diff_src_desc)
|
||||||
|
, weights_md_(desc_.weights_desc)
|
||||||
|
, bias_md_(desc_.bias_desc)
|
||||||
|
, diff_dst_md_(desc_.diff_dst_desc)
|
||||||
|
{}
|
||||||
|
|
||||||
|
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
||||||
|
if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST))
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_DIFF_SRC)
|
||||||
|
return arg_usage_t::output;
|
||||||
|
|
||||||
|
return primitive_desc_t::arg_usage(arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual const memory_desc_t *diff_src_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &diff_src_md_ : nullptr; }
|
||||||
|
virtual const memory_desc_t *diff_dst_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &diff_dst_md_ : nullptr; }
|
||||||
|
virtual const memory_desc_t *weights_md(int index = 0) const override {
|
||||||
|
if (index == 0) return &weights_md_;
|
||||||
|
if (index == 1 && with_bias()) return &bias_md_;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual int n_inputs() const override { return 2 + with_bias(); }
|
||||||
|
virtual int n_outputs() const override { return 1; }
|
||||||
|
|
||||||
|
virtual bool support_bias() const { return false; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
memory_desc_t diff_src_md_;
|
||||||
|
memory_desc_t weights_md_;
|
||||||
|
memory_desc_t bias_md_;
|
||||||
|
memory_desc_t diff_dst_md_;
|
||||||
|
|
||||||
|
bool set_default_formats_common(format_tag_t diff_src_tag,
|
||||||
|
format_tag_t wei_tag, format_tag_t diff_dst_tag) {
|
||||||
|
return set_default_formats_common_template(diff_src_md_, diff_src_tag,
|
||||||
|
weights_md_, wei_tag, diff_dst_md_, diff_dst_tag, bias_md_);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct convolution_bwd_weights_pd_t: public convolution_pd_t {
|
||||||
|
typedef convolution_bwd_weights_pd_t base_class;
|
||||||
|
typedef convolution_fwd_pd_t hint_class;
|
||||||
|
|
||||||
|
convolution_bwd_weights_pd_t(engine_t *engine,
|
||||||
|
const convolution_desc_t *adesc,
|
||||||
|
const primitive_attr_t *attr,
|
||||||
|
const convolution_fwd_pd_t *hint_fwd_pd)
|
||||||
|
: convolution_pd_t(engine, adesc, attr, hint_fwd_pd)
|
||||||
|
, src_md_(desc_.src_desc)
|
||||||
|
, diff_weights_md_(desc_.diff_weights_desc)
|
||||||
|
, diff_bias_md_(desc_.diff_bias_desc)
|
||||||
|
, diff_dst_md_(desc_.diff_dst_desc)
|
||||||
|
{}
|
||||||
|
|
||||||
|
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
||||||
|
if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_DIFF_WEIGHTS)
|
||||||
|
return arg_usage_t::output;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias())
|
||||||
|
return arg_usage_t::output;
|
||||||
|
|
||||||
|
return primitive_desc_t::arg_usage(arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual const memory_desc_t *src_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &src_md_ : nullptr; }
|
||||||
|
virtual const memory_desc_t *diff_dst_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &diff_dst_md_ : nullptr; }
|
||||||
|
virtual const memory_desc_t *diff_weights_md(int index = 0) const override {
|
||||||
|
if (index == 0) return &diff_weights_md_;
|
||||||
|
if (index == 1 && with_bias()) return &diff_bias_md_;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual int n_inputs() const override { return 2; }
|
||||||
|
virtual int n_outputs() const override { return 1 + with_bias(); }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
memory_desc_t src_md_;
|
||||||
|
memory_desc_t diff_weights_md_;
|
||||||
|
memory_desc_t diff_bias_md_;
|
||||||
|
memory_desc_t diff_dst_md_;
|
||||||
|
|
||||||
|
bool set_default_formats_common(format_tag_t src_tag,
|
||||||
|
format_tag_t diff_wei_tag, format_tag_t diff_dst_tag) {
|
||||||
|
return set_default_formats_common_template(src_md_, src_tag,
|
||||||
|
diff_weights_md_, diff_wei_tag, diff_dst_md_, diff_dst_tag,
|
||||||
|
diff_bias_md_);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
188
thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp
vendored
Normal file
188
thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp
vendored
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#include "mkldnn.h"
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "type_helpers.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
using namespace mkldnn::impl;
|
||||||
|
using namespace mkldnn::impl::utils;
|
||||||
|
using namespace mkldnn::impl::status;
|
||||||
|
using namespace mkldnn::impl::prop_kind;
|
||||||
|
using namespace mkldnn::impl::alg_kind;
|
||||||
|
using namespace mkldnn::impl::types;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
status_t deconv_desc_init(deconvolution_desc_t *deconv_desc,
|
||||||
|
prop_kind_t prop_kind, alg_kind_t alg_kind,
|
||||||
|
const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
|
||||||
|
const memory_desc_t *bias_desc, const memory_desc_t *dst_desc,
|
||||||
|
const dims_t strides, const dims_t dilates, const dims_t padding_l,
|
||||||
|
const dims_t padding_r, padding_kind_t padding_kind) {
|
||||||
|
bool args_ok = true
|
||||||
|
&& !any_null(deconv_desc, src_desc, weights_desc, dst_desc, strides,
|
||||||
|
padding_l)
|
||||||
|
&& one_of(alg_kind, deconvolution_direct, deconvolution_winograd)
|
||||||
|
&& one_of(padding_kind, padding_kind::padding_zero);
|
||||||
|
if (!args_ok)
|
||||||
|
return invalid_arguments;
|
||||||
|
|
||||||
|
if (padding_r == nullptr)
|
||||||
|
padding_r = padding_l;
|
||||||
|
|
||||||
|
auto dd = deconvolution_desc_t();
|
||||||
|
dd.primitive_kind = primitive_kind::deconvolution;
|
||||||
|
dd.prop_kind = prop_kind;
|
||||||
|
dd.alg_kind = alg_kind;
|
||||||
|
|
||||||
|
dd.diff_src_desc = dd.src_desc = zero_md();
|
||||||
|
dd.diff_dst_desc = dd.dst_desc = zero_md();
|
||||||
|
dd.diff_weights_desc = dd.weights_desc = zero_md();
|
||||||
|
dd.diff_bias_desc = dd.bias_desc = zero_md();
|
||||||
|
|
||||||
|
const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
|
||||||
|
const bool with_bias
|
||||||
|
= bias_desc && bias_desc->format_kind != format_kind::undef;
|
||||||
|
const bool with_groups = weights_desc->ndims == src_desc->ndims + 1;
|
||||||
|
|
||||||
|
(prop_kind == backward_data ? dd.diff_src_desc : dd.src_desc) = *src_desc;
|
||||||
|
(is_fwd ? dd.dst_desc : dd.diff_dst_desc) = *dst_desc;
|
||||||
|
(prop_kind == backward_weights ? dd.diff_weights_desc : dd.weights_desc)
|
||||||
|
= *weights_desc;
|
||||||
|
if (with_bias)
|
||||||
|
(prop_kind == backward_weights ? dd.diff_bias_desc : dd.bias_desc)
|
||||||
|
= *bias_desc;
|
||||||
|
|
||||||
|
int sp_dims = src_desc->ndims - 2;
|
||||||
|
utils::array_copy(dd.strides, strides, sp_dims);
|
||||||
|
utils::array_copy(dd.padding[0], padding_l, sp_dims);
|
||||||
|
utils::array_copy(dd.padding[1], padding_r, sp_dims);
|
||||||
|
if (dilates)
|
||||||
|
utils::array_copy(dd.dilates, dilates, sp_dims);
|
||||||
|
else
|
||||||
|
utils::array_set(dd.dilates, 0, sp_dims);
|
||||||
|
|
||||||
|
dd.padding_kind = padding_kind;
|
||||||
|
dd.accum_data_type = types::default_accum_data_type(src_desc->data_type,
|
||||||
|
weights_desc->data_type, dst_desc->data_type, prop_kind);
|
||||||
|
|
||||||
|
const int g = with_groups ? weights_desc->dims[0] : 1;
|
||||||
|
bool consistency = true
|
||||||
|
&& src_desc->ndims == dst_desc->ndims
|
||||||
|
&& utils::one_of(src_desc->ndims, 3, 4, 5)
|
||||||
|
&& utils::one_of(weights_desc->ndims, src_desc->ndims,
|
||||||
|
src_desc->ndims + 1)
|
||||||
|
&& (with_bias ? bias_desc->ndims == 1 : true)
|
||||||
|
&& (with_bias ? bias_desc->dims[0] == dst_desc->dims[1] : true)
|
||||||
|
&& src_desc->dims[0] == dst_desc->dims[0]
|
||||||
|
&& src_desc->dims[1] == g * weights_desc->dims[with_groups + 1]
|
||||||
|
&& dst_desc->dims[1] == g * weights_desc->dims[with_groups + 0];
|
||||||
|
for (int i = 2; i < src_desc->ndims; ++i) {
|
||||||
|
int src = src_desc->dims[i];
|
||||||
|
int ker = weights_desc->dims[with_groups + i];
|
||||||
|
int dil = dd.dilates[i - 2];
|
||||||
|
int pad = padding_l[i - 2] + padding_r[i - 2];
|
||||||
|
int str = strides[i - 2];
|
||||||
|
int dst = dst_desc->dims[i];
|
||||||
|
int ker_range = 1 + (ker - 1) * (dil + 1);
|
||||||
|
|
||||||
|
consistency
|
||||||
|
= consistency && (dst - ker_range + pad) / str + 1 == src;
|
||||||
|
}
|
||||||
|
if (!consistency)
|
||||||
|
return invalid_arguments;
|
||||||
|
|
||||||
|
*deconv_desc = dd;
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_deconvolution_forward_desc_init(
|
||||||
|
deconvolution_desc_t *deconv_desc, prop_kind_t prop_kind,
|
||||||
|
alg_kind_t alg_kind, const memory_desc_t *src_desc,
|
||||||
|
const memory_desc_t *weights_desc, const memory_desc_t *bias_desc,
|
||||||
|
const memory_desc_t *dst_desc, const dims_t strides,
|
||||||
|
const dims_t padding_l, const dims_t padding_r,
|
||||||
|
padding_kind_t padding_kind) {
|
||||||
|
if (!one_of(prop_kind, forward_training, forward_inference))
|
||||||
|
return invalid_arguments;
|
||||||
|
return deconv_desc_init(deconv_desc, prop_kind, alg_kind, src_desc,
|
||||||
|
weights_desc, bias_desc, dst_desc, strides, nullptr, padding_l,
|
||||||
|
padding_r, padding_kind);
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_dilated_deconvolution_forward_desc_init(
|
||||||
|
deconvolution_desc_t *deconv_desc, prop_kind_t prop_kind,
|
||||||
|
alg_kind_t alg_kind, const memory_desc_t *src_desc,
|
||||||
|
const memory_desc_t *weights_desc, const memory_desc_t *bias_desc,
|
||||||
|
const memory_desc_t *dst_desc, const dims_t strides,
|
||||||
|
const dims_t dilates, const dims_t padding_l, const dims_t padding_r,
|
||||||
|
padding_kind_t padding_kind) {
|
||||||
|
if (!one_of(prop_kind, forward_training, forward_inference))
|
||||||
|
return invalid_arguments;
|
||||||
|
return deconv_desc_init(deconv_desc, prop_kind, alg_kind, src_desc,
|
||||||
|
weights_desc, bias_desc, dst_desc, strides, dilates, padding_l,
|
||||||
|
padding_r, padding_kind);
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_deconvolution_backward_data_desc_init(
|
||||||
|
deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind,
|
||||||
|
const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc,
|
||||||
|
const memory_desc_t *diff_dst_desc, const dims_t strides,
|
||||||
|
const dims_t padding_l, const dims_t padding_r,
|
||||||
|
padding_kind_t padding_kind) {
|
||||||
|
return deconv_desc_init(deconv_desc, backward_data, alg_kind, diff_src_desc,
|
||||||
|
weights_desc, nullptr, diff_dst_desc, strides, nullptr, padding_l,
|
||||||
|
padding_r, padding_kind);
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_dilated_deconvolution_backward_data_desc_init(
|
||||||
|
deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind,
|
||||||
|
const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc,
|
||||||
|
const memory_desc_t *diff_dst_desc, const dims_t strides,
|
||||||
|
const dims_t dilates, const dims_t padding_l, const dims_t padding_r,
|
||||||
|
padding_kind_t padding_kind) {
|
||||||
|
return deconv_desc_init(deconv_desc, backward_data, alg_kind, diff_src_desc,
|
||||||
|
weights_desc, nullptr, diff_dst_desc, strides,dilates, padding_l,
|
||||||
|
padding_r, padding_kind);
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_deconvolution_backward_weights_desc_init(
|
||||||
|
deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind,
|
||||||
|
const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc,
|
||||||
|
const memory_desc_t *diff_bias_desc, const memory_desc_t *diff_dst_desc,
|
||||||
|
const dims_t strides, const dims_t padding_l, const dims_t padding_r,
|
||||||
|
padding_kind_t padding_kind) {
|
||||||
|
return deconv_desc_init(deconv_desc, backward_weights, alg_kind, src_desc,
|
||||||
|
diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, nullptr,
|
||||||
|
padding_l, padding_r, padding_kind);
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_dilated_deconvolution_backward_weights_desc_init(
|
||||||
|
deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind,
|
||||||
|
const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc,
|
||||||
|
const memory_desc_t *diff_bias_desc, const memory_desc_t *diff_dst_desc,
|
||||||
|
const dims_t strides, const dims_t dilates, const dims_t padding_l,
|
||||||
|
const dims_t padding_r, padding_kind_t padding_kind) {
|
||||||
|
return deconv_desc_init(deconv_desc, backward_weights, alg_kind, src_desc,
|
||||||
|
diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, dilates,
|
||||||
|
padding_l, padding_r, padding_kind);
|
||||||
|
}
|
||||||
|
|
||||||
|
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
293
thirdparty/oidn/mkl-dnn/src/common/deconvolution_pd.hpp
vendored
Normal file
293
thirdparty/oidn/mkl-dnn/src/common/deconvolution_pd.hpp
vendored
Normal file
@ -0,0 +1,293 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#ifndef DECONVOLUTION_PD_HPP
|
||||||
|
#define DECONVOLUTION_PD_HPP
|
||||||
|
|
||||||
|
#include "mkldnn.h"
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "convolution_pd.hpp"
|
||||||
|
#include "primitive_desc.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
namespace mkldnn {
|
||||||
|
namespace impl {
|
||||||
|
|
||||||
|
struct deconvolution_fwd_pd_t;
|
||||||
|
|
||||||
|
struct deconvolution_pd_t: public primitive_desc_t {
|
||||||
|
static constexpr auto base_pkind = primitive_kind::deconvolution;
|
||||||
|
|
||||||
|
deconvolution_pd_t(engine_t *engine,
|
||||||
|
const deconvolution_desc_t *adesc,
|
||||||
|
const primitive_attr_t *attr,
|
||||||
|
const deconvolution_fwd_pd_t *hint_fwd_pd)
|
||||||
|
: primitive_desc_t(engine, attr, base_pkind)
|
||||||
|
, desc_(*adesc)
|
||||||
|
, hint_fwd_pd_(hint_fwd_pd)
|
||||||
|
{}
|
||||||
|
|
||||||
|
const deconvolution_desc_t *desc() const { return &desc_; }
|
||||||
|
virtual const op_desc_t *op_desc() const override
|
||||||
|
{ return reinterpret_cast<const op_desc_t *>(this->desc()); }
|
||||||
|
virtual void init_info() override { impl::init_info(this, this->info_); }
|
||||||
|
|
||||||
|
virtual status_t query(query_t what, int idx, void *result) const override {
|
||||||
|
switch (what) {
|
||||||
|
case pkind_traits<base_pkind>::query_d:
|
||||||
|
*(const deconvolution_desc_t **)result = desc();
|
||||||
|
break;
|
||||||
|
default: return primitive_desc_t::query(what, idx, result);
|
||||||
|
}
|
||||||
|
return status::success;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* common deconv aux functions (note that conv_desc_t == deconv_desc_t) */
|
||||||
|
|
||||||
|
dim_t MB() const { return conv_prop_invariant_src_d(&desc_)->dims[0]; }
|
||||||
|
|
||||||
|
dim_t IC() const { return conv_prop_invariant_src_d(&desc_)->dims[1]; }
|
||||||
|
dim_t OC() const { return conv_prop_invariant_dst_d(&desc_)->dims[1]; }
|
||||||
|
dim_t G() const
|
||||||
|
{ return with_groups() ? conv_prop_invariant_wei_d(&desc_)->dims[0] : 1; }
|
||||||
|
|
||||||
|
dim_t ID() const {
|
||||||
|
return ndims() >= 5
|
||||||
|
? conv_prop_invariant_src_d(&desc_)->dims[ndims() - 3] : 1;
|
||||||
|
}
|
||||||
|
dim_t IH() const {
|
||||||
|
return ndims() >= 4
|
||||||
|
? conv_prop_invariant_src_d(&desc_)->dims[ndims() - 2] : 1;
|
||||||
|
}
|
||||||
|
dim_t IW() const {
|
||||||
|
return conv_prop_invariant_src_d(&desc_)->dims[ndims() - 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
dim_t OD() const {
|
||||||
|
return ndims() >= 5
|
||||||
|
? conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 3] : 1;
|
||||||
|
}
|
||||||
|
dim_t OH() const {
|
||||||
|
return ndims() >= 4
|
||||||
|
? conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 2] : 1;
|
||||||
|
}
|
||||||
|
dim_t OW() const {
|
||||||
|
return conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
dim_t KD() const {
|
||||||
|
const int w_ndims = ndims() + with_groups();
|
||||||
|
return ndims() >= 5
|
||||||
|
? conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 3] : 1;
|
||||||
|
}
|
||||||
|
dim_t KH() const {
|
||||||
|
const int w_ndims = ndims() + with_groups();
|
||||||
|
return ndims() >= 4
|
||||||
|
? conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 2] : 1;
|
||||||
|
}
|
||||||
|
dim_t KW() const {
|
||||||
|
const int w_ndims = ndims() + with_groups();
|
||||||
|
return conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; }
|
||||||
|
dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; }
|
||||||
|
dim_t KSW() const { return desc_.strides[ndims() - 3]; }
|
||||||
|
|
||||||
|
dim_t KDD() const { return ndims() >= 5 ? desc_.dilates[ndims() - 5] : 0; }
|
||||||
|
dim_t KDH() const { return ndims() >= 4 ? desc_.dilates[ndims() - 4] : 1; }
|
||||||
|
dim_t KDW() const { return desc_.dilates[ndims() - 3]; }
|
||||||
|
|
||||||
|
dim_t padFront() const
|
||||||
|
{ return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; }
|
||||||
|
dim_t padBack() const
|
||||||
|
{ return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; }
|
||||||
|
dim_t padT() const
|
||||||
|
{ return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; }
|
||||||
|
dim_t padB() const
|
||||||
|
{ return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; }
|
||||||
|
dim_t padL() const { return desc_.padding[0][ndims() - 3]; }
|
||||||
|
dim_t padR() const { return desc_.padding[1][ndims() - 3]; }
|
||||||
|
|
||||||
|
bool with_bias() const {
|
||||||
|
return
|
||||||
|
!memory_desc_wrapper(*conv_prop_invariant_bia_d(&desc_)).is_zero();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool with_groups() const
|
||||||
|
{ return conv_prop_invariant_wei_d(&desc_)->ndims == ndims() + 1; }
|
||||||
|
|
||||||
|
int ndims() const { return conv_prop_invariant_src_d(&desc_)->ndims; }
|
||||||
|
|
||||||
|
bool is_fwd() const {
|
||||||
|
return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
|
||||||
|
prop_kind::forward_inference);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool has_zero_dim_memory() const {
|
||||||
|
const auto s_d = memory_desc_wrapper(*conv_prop_invariant_src_d(&desc_));
|
||||||
|
const auto d_d = memory_desc_wrapper(*conv_prop_invariant_dst_d(&desc_));
|
||||||
|
return s_d.has_zero_dim() || d_d.has_zero_dim();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
deconvolution_desc_t desc_;
|
||||||
|
const deconvolution_fwd_pd_t *hint_fwd_pd_;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct deconvolution_fwd_pd_t: public deconvolution_pd_t {
|
||||||
|
typedef deconvolution_fwd_pd_t base_class;
|
||||||
|
typedef deconvolution_fwd_pd_t hint_class;
|
||||||
|
|
||||||
|
deconvolution_fwd_pd_t(engine_t *engine,
|
||||||
|
const deconvolution_desc_t *adesc,
|
||||||
|
const primitive_attr_t *attr,
|
||||||
|
const deconvolution_fwd_pd_t *hint_fwd_pd)
|
||||||
|
: deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd)
|
||||||
|
, src_md_(desc_.src_desc)
|
||||||
|
, weights_md_(desc_.weights_desc)
|
||||||
|
, bias_md_(desc_.bias_desc)
|
||||||
|
, dst_md_(desc_.dst_desc)
|
||||||
|
{}
|
||||||
|
|
||||||
|
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
||||||
|
if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS))
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_BIAS && with_bias())
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_DST)
|
||||||
|
return arg_usage_t::output;
|
||||||
|
|
||||||
|
return primitive_desc_t::arg_usage(arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual const memory_desc_t *src_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &src_md_ : nullptr; }
|
||||||
|
virtual const memory_desc_t *dst_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &dst_md_ : nullptr; }
|
||||||
|
virtual const memory_desc_t *weights_md(int index = 0) const override {
|
||||||
|
if (index == 0) return &weights_md_;
|
||||||
|
if (index == 1 && with_bias()) return &bias_md_;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual int n_inputs() const override { return 2 + with_bias(); }
|
||||||
|
virtual int n_outputs() const override { return 1; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
memory_desc_t src_md_;
|
||||||
|
memory_desc_t weights_md_;
|
||||||
|
memory_desc_t bias_md_;
|
||||||
|
memory_desc_t dst_md_;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct deconvolution_bwd_data_pd_t: public deconvolution_pd_t {
|
||||||
|
typedef deconvolution_bwd_data_pd_t base_class;
|
||||||
|
typedef deconvolution_fwd_pd_t hint_class;
|
||||||
|
|
||||||
|
deconvolution_bwd_data_pd_t(engine_t *engine,
|
||||||
|
const deconvolution_desc_t *adesc,
|
||||||
|
const primitive_attr_t *attr,
|
||||||
|
const deconvolution_fwd_pd_t *hint_fwd_pd)
|
||||||
|
: deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd)
|
||||||
|
, diff_src_md_(desc_.diff_src_desc)
|
||||||
|
, weights_md_(desc_.weights_desc)
|
||||||
|
, diff_dst_md_(desc_.diff_dst_desc)
|
||||||
|
{}
|
||||||
|
|
||||||
|
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
||||||
|
if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST))
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_DIFF_SRC)
|
||||||
|
return arg_usage_t::output;
|
||||||
|
|
||||||
|
return primitive_desc_t::arg_usage(arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual const memory_desc_t *diff_src_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &diff_src_md_ : nullptr; }
|
||||||
|
virtual const memory_desc_t *diff_dst_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &diff_dst_md_ : nullptr; }
|
||||||
|
virtual const memory_desc_t *weights_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &weights_md_ : nullptr; }
|
||||||
|
|
||||||
|
virtual int n_inputs() const override { return 2; }
|
||||||
|
virtual int n_outputs() const override { return 1; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
memory_desc_t diff_src_md_;
|
||||||
|
memory_desc_t weights_md_;
|
||||||
|
memory_desc_t diff_dst_md_;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct deconvolution_bwd_weights_pd_t: public deconvolution_pd_t {
|
||||||
|
typedef deconvolution_bwd_weights_pd_t base_class;
|
||||||
|
typedef deconvolution_fwd_pd_t hint_class;
|
||||||
|
|
||||||
|
deconvolution_bwd_weights_pd_t(engine_t *engine,
|
||||||
|
const deconvolution_desc_t *adesc,
|
||||||
|
const primitive_attr_t *attr,
|
||||||
|
const deconvolution_fwd_pd_t *hint_fwd_pd)
|
||||||
|
: deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd)
|
||||||
|
, src_md_(desc_.src_desc)
|
||||||
|
, diff_weights_md_(desc_.diff_weights_desc)
|
||||||
|
, diff_bias_md_(desc_.diff_bias_desc)
|
||||||
|
, diff_dst_md_(desc_.diff_dst_desc)
|
||||||
|
{}
|
||||||
|
|
||||||
|
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
||||||
|
if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_DIFF_WEIGHTS)
|
||||||
|
return arg_usage_t::output;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias())
|
||||||
|
return arg_usage_t::output;
|
||||||
|
|
||||||
|
return primitive_desc_t::arg_usage(arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual const memory_desc_t *src_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &src_md_ : nullptr; }
|
||||||
|
virtual const memory_desc_t *diff_dst_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &diff_dst_md_ : nullptr; }
|
||||||
|
virtual const memory_desc_t *diff_weights_md(int index = 0) const override {
|
||||||
|
if (index == 0) return &diff_weights_md_;
|
||||||
|
if (index == 1 && with_bias()) return &diff_bias_md_;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual int n_inputs() const override { return 2; }
|
||||||
|
virtual int n_outputs() const override { return 1 + with_bias(); }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
memory_desc_t src_md_;
|
||||||
|
memory_desc_t diff_weights_md_;
|
||||||
|
memory_desc_t diff_bias_md_;
|
||||||
|
memory_desc_t diff_dst_md_;
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
84
thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp
vendored
Normal file
84
thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp
vendored
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2016-2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
#include "mkldnn.h"
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "type_helpers.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
using namespace mkldnn::impl;
|
||||||
|
using namespace mkldnn::impl::utils;
|
||||||
|
using namespace mkldnn::impl::status;
|
||||||
|
using namespace mkldnn::impl::prop_kind;
|
||||||
|
using namespace mkldnn::impl::alg_kind;
|
||||||
|
using namespace mkldnn::impl::types;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
status_t eltwise_desc_init(eltwise_desc_t *eltwise_desc, prop_kind_t prop_kind,
|
||||||
|
alg_kind_t alg_kind, const memory_desc_t *data_desc,
|
||||||
|
const memory_desc_t *diff_data_desc, float alpha, float beta) {
|
||||||
|
bool args_ok = true
|
||||||
|
&& !any_null(eltwise_desc, data_desc)
|
||||||
|
&& one_of(prop_kind, forward_training, forward_inference,
|
||||||
|
backward_data)
|
||||||
|
&& one_of(alg_kind, eltwise_relu, eltwise_tanh, eltwise_elu,
|
||||||
|
eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
|
||||||
|
eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic)
|
||||||
|
&& IMPLICATION(prop_kind == backward_data, diff_data_desc != nullptr);
|
||||||
|
if (!args_ok) return invalid_arguments;
|
||||||
|
|
||||||
|
auto ed = eltwise_desc_t();
|
||||||
|
ed.primitive_kind = primitive_kind::eltwise;
|
||||||
|
ed.prop_kind = prop_kind;
|
||||||
|
ed.alg_kind = alg_kind;
|
||||||
|
|
||||||
|
ed.data_desc = *data_desc;
|
||||||
|
ed.diff_data_desc =
|
||||||
|
(ed.prop_kind == backward_data) ? *diff_data_desc : zero_md();
|
||||||
|
|
||||||
|
ed.alpha = alpha;
|
||||||
|
ed.beta = beta;
|
||||||
|
|
||||||
|
bool consistency = true
|
||||||
|
&& IMPLICATION(ed.prop_kind == backward_data,
|
||||||
|
array_cmp(ed.diff_data_desc.dims, ed.data_desc.dims,
|
||||||
|
ed.diff_data_desc.ndims));
|
||||||
|
if (!consistency) return invalid_arguments;
|
||||||
|
|
||||||
|
*eltwise_desc = ed;
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_eltwise_forward_desc_init(eltwise_desc_t *eltwise_desc,
|
||||||
|
prop_kind_t prop_kind, alg_kind_t alg_kind,
|
||||||
|
const memory_desc_t *data_desc, float alpha, float beta) {
|
||||||
|
if (!one_of(prop_kind, forward_training, forward_inference))
|
||||||
|
return invalid_arguments;
|
||||||
|
return eltwise_desc_init(eltwise_desc, prop_kind, alg_kind, data_desc,
|
||||||
|
nullptr, alpha, beta);
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_eltwise_backward_desc_init(eltwise_desc_t *eltwise_desc,
|
||||||
|
alg_kind_t alg_kind, const memory_desc_t *diff_data_desc,
|
||||||
|
const memory_desc_t *data_desc, float alpha, float beta) {
|
||||||
|
return eltwise_desc_init(eltwise_desc, backward_data, alg_kind, data_desc,
|
||||||
|
diff_data_desc, alpha, beta);
|
||||||
|
}
|
||||||
|
|
||||||
|
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
161
thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp
vendored
Normal file
161
thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp
vendored
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2016-2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#ifndef ELTWISE_PD_HPP
|
||||||
|
#define ELTWISE_PD_HPP
|
||||||
|
|
||||||
|
#include "mkldnn.h"
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "primitive_desc.hpp"
|
||||||
|
|
||||||
|
namespace mkldnn {
|
||||||
|
namespace impl {
|
||||||
|
|
||||||
|
struct eltwise_fwd_pd_t;
|
||||||
|
|
||||||
|
struct eltwise_pd_t: public primitive_desc_t {
|
||||||
|
static constexpr auto base_pkind = primitive_kind::eltwise;
|
||||||
|
|
||||||
|
eltwise_pd_t(mkldnn::impl::engine_t *engine,
|
||||||
|
const eltwise_desc_t *adesc,
|
||||||
|
const primitive_attr_t *attr,
|
||||||
|
const eltwise_fwd_pd_t *hint_fwd_pd)
|
||||||
|
: primitive_desc_t(engine, attr, base_pkind)
|
||||||
|
, desc_(*adesc)
|
||||||
|
, hint_fwd_pd_(hint_fwd_pd)
|
||||||
|
, data_md_(desc_.data_desc)
|
||||||
|
{}
|
||||||
|
|
||||||
|
const eltwise_desc_t *desc() const { return &desc_; }
|
||||||
|
virtual const op_desc_t *op_desc() const override
|
||||||
|
{ return reinterpret_cast<const op_desc_t *>(this->desc()); }
|
||||||
|
virtual void init_info() override { impl::init_info(this, this->info_); }
|
||||||
|
|
||||||
|
virtual status_t query(query_t what, int idx, void *result) const override {
|
||||||
|
switch (what) {
|
||||||
|
case query::eltwise_d:
|
||||||
|
*(const eltwise_desc_t**)result = desc(); break;
|
||||||
|
default: return primitive_desc_t::query(what, idx, result);
|
||||||
|
}
|
||||||
|
return status::success;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* common eltwise aux functions */
|
||||||
|
|
||||||
|
dim_t MB() const { return data_desc().dims[0]; }
|
||||||
|
dim_t C() const { return data_desc().dims[1]; }
|
||||||
|
dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; }
|
||||||
|
dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; }
|
||||||
|
dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; }
|
||||||
|
|
||||||
|
int ndims() const { return data_desc().ndims; }
|
||||||
|
|
||||||
|
bool is_fwd() const {
|
||||||
|
return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
|
||||||
|
prop_kind::forward_inference);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool has_zero_dim_memory() const
|
||||||
|
{ return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
eltwise_desc_t desc_;
|
||||||
|
const eltwise_fwd_pd_t *hint_fwd_pd_;
|
||||||
|
|
||||||
|
memory_desc_t data_md_;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const memory_desc_t &data_desc() const { return desc_.data_desc; }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct eltwise_fwd_pd_t: public eltwise_pd_t {
|
||||||
|
typedef eltwise_fwd_pd_t base_class;
|
||||||
|
typedef eltwise_fwd_pd_t hint_class;
|
||||||
|
|
||||||
|
eltwise_fwd_pd_t(mkldnn::impl::engine_t *engine,
|
||||||
|
const eltwise_desc_t *adesc,
|
||||||
|
const primitive_attr_t *attr,
|
||||||
|
const eltwise_fwd_pd_t *hint_fwd_pd)
|
||||||
|
: eltwise_pd_t(engine, adesc, attr, hint_fwd_pd)
|
||||||
|
{}
|
||||||
|
|
||||||
|
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
||||||
|
if (arg == MKLDNN_ARG_SRC)
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_DST)
|
||||||
|
return arg_usage_t::output;
|
||||||
|
|
||||||
|
return primitive_desc_t::arg_usage(arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual const memory_desc_t *src_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &data_md_ : nullptr; }
|
||||||
|
virtual const memory_desc_t *dst_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &data_md_ : nullptr; }
|
||||||
|
|
||||||
|
virtual int n_inputs() const override { return 1; }
|
||||||
|
virtual int n_outputs() const override { return 1; }
|
||||||
|
|
||||||
|
bool is_zero_preserved() const
|
||||||
|
{ return math::eltwise_fwd_preserves_zero(desc_.alg_kind); }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct eltwise_bwd_pd_t: public eltwise_pd_t {
|
||||||
|
typedef eltwise_bwd_pd_t base_class;
|
||||||
|
typedef eltwise_fwd_pd_t hint_class;
|
||||||
|
|
||||||
|
eltwise_bwd_pd_t(engine_t *engine,
|
||||||
|
const eltwise_desc_t *adesc,
|
||||||
|
const primitive_attr_t *attr,
|
||||||
|
const eltwise_fwd_pd_t *hint_fwd_pd)
|
||||||
|
: eltwise_pd_t(engine, adesc, attr, hint_fwd_pd)
|
||||||
|
, diff_data_md_(desc_.diff_data_desc)
|
||||||
|
{}
|
||||||
|
|
||||||
|
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
||||||
|
if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_DIFF_SRC)
|
||||||
|
return arg_usage_t::output;
|
||||||
|
|
||||||
|
return primitive_desc_t::arg_usage(arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual const memory_desc_t *src_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &data_md_ : nullptr; }
|
||||||
|
virtual const memory_desc_t *diff_dst_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &diff_data_md_ : nullptr; }
|
||||||
|
virtual const memory_desc_t *diff_src_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &diff_data_md_ : nullptr; }
|
||||||
|
|
||||||
|
virtual int n_inputs() const override { return 2; }
|
||||||
|
virtual int n_outputs() const override { return 1; }
|
||||||
|
|
||||||
|
bool is_zero_preserved() const { return true; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
memory_desc_t diff_data_md_;
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
75
thirdparty/oidn/mkl-dnn/src/common/engine.cpp
vendored
Normal file
75
thirdparty/oidn/mkl-dnn/src/common/engine.cpp
vendored
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2016-2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#include "mkldnn.h"
|
||||||
|
#include "engine.hpp"
|
||||||
|
#include "nstl.hpp"
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "../cpu/cpu_engine.hpp"
|
||||||
|
|
||||||
|
namespace mkldnn {
|
||||||
|
namespace impl {
|
||||||
|
|
||||||
|
engine_factory_t *engine_factories[] = {
|
||||||
|
&cpu::engine_factory,
|
||||||
|
nullptr,
|
||||||
|
};
|
||||||
|
|
||||||
|
static inline engine_factory_t *get_engine_factory(engine_kind_t kind) {
|
||||||
|
for (engine_factory_t **ef = engine_factories; *ef; ef++)
|
||||||
|
if ((*ef)->kind() == kind)
|
||||||
|
return *ef;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
using namespace mkldnn::impl;
|
||||||
|
using namespace mkldnn::impl::status;
|
||||||
|
|
||||||
|
size_t mkldnn_engine_get_count(engine_kind_t kind) {
|
||||||
|
engine_factory_t *ef = get_engine_factory(kind);
|
||||||
|
return ef != nullptr ? ef->count() : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_engine_create(engine_t **engine,
|
||||||
|
engine_kind_t kind, size_t index) {
|
||||||
|
if (engine == nullptr)
|
||||||
|
return invalid_arguments;
|
||||||
|
|
||||||
|
engine_factory_t *ef = get_engine_factory(kind);
|
||||||
|
if (ef == nullptr || index >= ef->count())
|
||||||
|
return invalid_arguments;
|
||||||
|
|
||||||
|
return ef->engine_create(engine, index);
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_engine_get_kind(engine_t *engine, engine_kind_t *kind) {
|
||||||
|
if (engine == nullptr)
|
||||||
|
return invalid_arguments;
|
||||||
|
*kind = engine->kind();
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_engine_destroy(engine_t *engine) {
|
||||||
|
/* TODO: engine->dec_ref_count(); */
|
||||||
|
delete engine;
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
119
thirdparty/oidn/mkl-dnn/src/common/engine.hpp
vendored
Normal file
119
thirdparty/oidn/mkl-dnn/src/common/engine.hpp
vendored
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2016-2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#ifndef ENGINE_HPP
|
||||||
|
#define ENGINE_HPP
|
||||||
|
|
||||||
|
#include "mkldnn.h"
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "primitive.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
/** \brief An abstraction of an execution unit with shared resources
|
||||||
|
*
|
||||||
|
* Responsibilities:
|
||||||
|
* - Provide engine specific memory allocation
|
||||||
|
* - Provide engine specific primitive_desc_t creators
|
||||||
|
*/
|
||||||
|
struct mkldnn_engine: public mkldnn::impl::c_compatible {
|
||||||
|
mkldnn_engine(mkldnn::impl::engine_kind_t kind)
|
||||||
|
: kind_(kind)
|
||||||
|
{}
|
||||||
|
virtual ~mkldnn_engine() {}
|
||||||
|
|
||||||
|
/** get kind of the current engine */
|
||||||
|
virtual mkldnn::impl::engine_kind_t kind() const { return kind_; }
|
||||||
|
|
||||||
|
/** allocate memory */
|
||||||
|
virtual mkldnn::impl::status_t memory_create(
|
||||||
|
mkldnn::impl::memory_t **memory,
|
||||||
|
const mkldnn::impl::memory_desc_t *md,
|
||||||
|
void *handle) = 0;
|
||||||
|
|
||||||
|
/** implementation section (typedefs) */
|
||||||
|
|
||||||
|
// TODO: remove engine?
|
||||||
|
typedef mkldnn::impl::status_t (*reorder_primitive_desc_create_f)(
|
||||||
|
mkldnn::impl::reorder_pd_t **reorder_pd,
|
||||||
|
mkldnn::impl::engine_t *engine,
|
||||||
|
const mkldnn::impl::primitive_attr_t *attr,
|
||||||
|
mkldnn::impl::engine_t *src_engine,
|
||||||
|
const mkldnn::impl::memory_desc_t *src_md,
|
||||||
|
mkldnn::impl::engine_t *dst_engine,
|
||||||
|
const mkldnn::impl::memory_desc_t *dst_md);
|
||||||
|
|
||||||
|
typedef mkldnn::impl::status_t (*concat_primitive_desc_create_f)(
|
||||||
|
mkldnn::impl::concat_pd_t **concat_pd,
|
||||||
|
mkldnn::impl::engine_t *engine,
|
||||||
|
const mkldnn::impl::primitive_attr_t *attr,
|
||||||
|
const mkldnn::impl::memory_desc_t *dst_md,
|
||||||
|
int n, int concat_dim,
|
||||||
|
const mkldnn::impl::memory_desc_t *src_mds);
|
||||||
|
|
||||||
|
typedef mkldnn::impl::status_t (*sum_primitive_desc_create_f)(
|
||||||
|
mkldnn::impl::sum_pd_t **sum_pd,
|
||||||
|
mkldnn::impl::engine_t *engine,
|
||||||
|
const mkldnn::impl::primitive_attr_t *attr,
|
||||||
|
const mkldnn::impl::memory_desc_t *dst_md,
|
||||||
|
int n, const float *scales,
|
||||||
|
const mkldnn::impl::memory_desc_t *src_mds);
|
||||||
|
|
||||||
|
typedef mkldnn::impl::status_t (*primitive_desc_create_f)(
|
||||||
|
mkldnn::impl::primitive_desc_t **, const mkldnn::impl::op_desc_t *,
|
||||||
|
const mkldnn::impl::primitive_attr_t *attr,
|
||||||
|
mkldnn::impl::engine_t *, const mkldnn::impl::primitive_desc_t *);
|
||||||
|
|
||||||
|
/* implementation section */
|
||||||
|
|
||||||
|
/** return the list of reorder implementations. engine guarantees to return
|
||||||
|
* a NULL-terminated list */
|
||||||
|
virtual const reorder_primitive_desc_create_f*
|
||||||
|
get_reorder_implementation_list() const = 0;
|
||||||
|
|
||||||
|
/** return the list of concat implementations. engine guarantees to return
|
||||||
|
* a NULL-terminated list */
|
||||||
|
virtual const concat_primitive_desc_create_f*
|
||||||
|
get_concat_implementation_list() const = 0;
|
||||||
|
|
||||||
|
/** return the list of sum implementations. engine guarantees to return
|
||||||
|
* a NULL-terminated list */
|
||||||
|
virtual const sum_primitive_desc_create_f*
|
||||||
|
get_sum_implementation_list() const = 0;
|
||||||
|
|
||||||
|
/** return the list of implementations. engine guarantees to return a
|
||||||
|
* NULL-terminated list */
|
||||||
|
virtual const primitive_desc_create_f* get_implementation_list() const = 0;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
mkldnn::impl::engine_kind_t kind_;
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace mkldnn {
|
||||||
|
namespace impl {
|
||||||
|
|
||||||
|
struct engine_factory_t: public c_compatible {
|
||||||
|
virtual size_t count() const = 0;
|
||||||
|
virtual engine_kind_t kind() const = 0;
|
||||||
|
virtual status_t engine_create(engine_t **engine, size_t index) const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
106
thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp
vendored
Normal file
106
thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp
vendored
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2016-2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
#include "mkldnn.h"
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "type_helpers.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
using namespace mkldnn::impl;
|
||||||
|
using namespace mkldnn::impl::utils;
|
||||||
|
using namespace mkldnn::impl::status;
|
||||||
|
using namespace mkldnn::impl::prop_kind;
|
||||||
|
using namespace mkldnn::impl::types;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
status_t ip_desc_init(inner_product_desc_t *ip_desc, prop_kind_t prop_kind,
|
||||||
|
const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
|
||||||
|
const memory_desc_t *bias_desc, const memory_desc_t *dst_desc) {
|
||||||
|
bool args_ok = !any_null(ip_desc, src_desc, weights_desc, dst_desc);
|
||||||
|
if (!args_ok) return invalid_arguments;
|
||||||
|
|
||||||
|
auto id = inner_product_desc_t();
|
||||||
|
id.primitive_kind = primitive_kind::inner_product;
|
||||||
|
id.prop_kind = prop_kind;
|
||||||
|
|
||||||
|
id.diff_src_desc = id.src_desc = zero_md();
|
||||||
|
id.diff_dst_desc = id.dst_desc = zero_md();
|
||||||
|
id.diff_weights_desc = id.weights_desc = zero_md();
|
||||||
|
id.diff_bias_desc = id.bias_desc = zero_md();
|
||||||
|
|
||||||
|
const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
|
||||||
|
const bool with_bias =
|
||||||
|
bias_desc && bias_desc->format_kind != format_kind::undef;
|
||||||
|
|
||||||
|
(prop_kind == backward_data ? id.diff_src_desc : id.src_desc) = *src_desc;
|
||||||
|
(is_fwd ? id.dst_desc : id.diff_dst_desc) = *dst_desc;
|
||||||
|
(prop_kind == backward_weights ? id.diff_weights_desc : id.weights_desc) =
|
||||||
|
*weights_desc;
|
||||||
|
if (with_bias)
|
||||||
|
(prop_kind == backward_weights ? id.diff_bias_desc : id.bias_desc) =
|
||||||
|
*bias_desc;
|
||||||
|
|
||||||
|
id.accum_data_type = types::default_accum_data_type(src_desc->data_type,
|
||||||
|
weights_desc->data_type, dst_desc->data_type, prop_kind);
|
||||||
|
|
||||||
|
bool consistency = true
|
||||||
|
&& memory_desc_wrapper(weights_desc).nelems()
|
||||||
|
&& one_of(src_desc->ndims, 2, 3, 4, 5)
|
||||||
|
&& dst_desc->ndims == 2
|
||||||
|
&& weights_desc->ndims == src_desc->ndims
|
||||||
|
&& (with_bias ? bias_desc->ndims == 1 : true)
|
||||||
|
&& (with_bias ? bias_desc->dims[0] == dst_desc->dims[1] : true)
|
||||||
|
&& src_desc->dims[0] == dst_desc->dims[0]
|
||||||
|
&& array_cmp(&src_desc->dims[1], &weights_desc->dims[1],
|
||||||
|
src_desc->ndims - 1)
|
||||||
|
&& dst_desc->dims[1] == weights_desc->dims[0];
|
||||||
|
if (!consistency) return invalid_arguments;
|
||||||
|
|
||||||
|
*ip_desc = id;
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_inner_product_forward_desc_init(inner_product_desc_t *ip_desc,
|
||||||
|
prop_kind_t prop_kind, const memory_desc_t *src_desc,
|
||||||
|
const memory_desc_t *weights_desc, const memory_desc_t *bias_desc,
|
||||||
|
const memory_desc_t *dst_desc) {
|
||||||
|
if (!one_of(prop_kind, forward_training, forward_inference))
|
||||||
|
return invalid_arguments;
|
||||||
|
return ip_desc_init(ip_desc, prop_kind, src_desc, weights_desc, bias_desc,
|
||||||
|
dst_desc);
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_inner_product_backward_data_desc_init(
|
||||||
|
inner_product_desc_t *ip_desc, const memory_desc_t *diff_src_desc,
|
||||||
|
const memory_desc_t *weights_desc, const memory_desc_t *diff_dst_desc)
|
||||||
|
{
|
||||||
|
return ip_desc_init(ip_desc, backward_data, diff_src_desc, weights_desc,
|
||||||
|
nullptr, diff_dst_desc);
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_inner_product_backward_weights_desc_init(
|
||||||
|
inner_product_desc_t *ip_desc, const memory_desc_t *src_desc,
|
||||||
|
const memory_desc_t *diff_weights_desc,
|
||||||
|
const memory_desc_t *diff_bias_desc,
|
||||||
|
const memory_desc_t *diff_dst_desc) {
|
||||||
|
return ip_desc_init(ip_desc, backward_weights, src_desc, diff_weights_desc,
|
||||||
|
diff_bias_desc, diff_dst_desc);
|
||||||
|
}
|
||||||
|
|
||||||
|
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
56
thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.cpp
vendored
Normal file
56
thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.cpp
vendored
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
#include "inner_product_pd.hpp"
|
||||||
|
|
||||||
|
namespace mkldnn {
|
||||||
|
namespace impl {
|
||||||
|
|
||||||
|
using namespace prop_kind;
|
||||||
|
|
||||||
|
memory_desc_t *ip_prop_invariant_src_d(inner_product_desc_t *desc) {
|
||||||
|
return desc->prop_kind == backward_data
|
||||||
|
? &desc->diff_src_desc : &desc->src_desc;
|
||||||
|
}
|
||||||
|
|
||||||
|
memory_desc_t *ip_prop_invariant_wei_d(inner_product_desc_t *desc) {
|
||||||
|
return desc->prop_kind == backward_weights
|
||||||
|
? &desc->diff_weights_desc : &desc->weights_desc;
|
||||||
|
}
|
||||||
|
|
||||||
|
memory_desc_t *ip_prop_invariant_bia_d(inner_product_desc_t *desc) {
|
||||||
|
return desc->prop_kind == backward_weights
|
||||||
|
? &desc->diff_bias_desc : &desc->bias_desc;
|
||||||
|
}
|
||||||
|
|
||||||
|
memory_desc_t *ip_prop_invariant_dst_d(inner_product_desc_t *desc) {
|
||||||
|
return utils::one_of(desc->prop_kind, forward_inference, forward_training)
|
||||||
|
? &desc->dst_desc : &desc->diff_dst_desc;
|
||||||
|
}
|
||||||
|
|
||||||
|
const memory_desc_t *ip_prop_invariant_src_d(const inner_product_desc_t *desc)
|
||||||
|
{ return ip_prop_invariant_src_d(const_cast<inner_product_desc_t *>(desc)); }
|
||||||
|
const memory_desc_t *ip_prop_invariant_wei_d(const inner_product_desc_t *desc)
|
||||||
|
{ return ip_prop_invariant_wei_d(const_cast<inner_product_desc_t *>(desc)); }
|
||||||
|
const memory_desc_t *ip_prop_invariant_bia_d(const inner_product_desc_t *desc)
|
||||||
|
{ return ip_prop_invariant_bia_d(const_cast<inner_product_desc_t *>(desc)); }
|
||||||
|
const memory_desc_t *ip_prop_invariant_dst_d(const inner_product_desc_t *desc)
|
||||||
|
{ return ip_prop_invariant_dst_d(const_cast<inner_product_desc_t *>(desc)); }
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
321
thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.hpp
vendored
Normal file
321
thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.hpp
vendored
Normal file
@ -0,0 +1,321 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2016-2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#ifndef INNER_PRODUCT_PD_HPP
|
||||||
|
#define INNER_PRODUCT_PD_HPP
|
||||||
|
|
||||||
|
#include "mkldnn.h"
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "primitive_desc.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
namespace mkldnn {
|
||||||
|
namespace impl {
|
||||||
|
|
||||||
|
memory_desc_t *ip_prop_invariant_src_d(inner_product_desc_t *desc);
|
||||||
|
memory_desc_t *ip_prop_invariant_wei_d(inner_product_desc_t *desc);
|
||||||
|
memory_desc_t *ip_prop_invariant_bia_d(inner_product_desc_t *desc);
|
||||||
|
memory_desc_t *ip_prop_invariant_dst_d(inner_product_desc_t *desc);
|
||||||
|
const memory_desc_t *ip_prop_invariant_src_d(const inner_product_desc_t *desc);
|
||||||
|
const memory_desc_t *ip_prop_invariant_wei_d(const inner_product_desc_t *desc);
|
||||||
|
const memory_desc_t *ip_prop_invariant_bia_d(const inner_product_desc_t *desc);
|
||||||
|
const memory_desc_t *ip_prop_invariant_dst_d(const inner_product_desc_t *desc);
|
||||||
|
|
||||||
|
struct inner_product_fwd_pd_t;
|
||||||
|
|
||||||
|
struct inner_product_pd_t: public primitive_desc_t {
|
||||||
|
static constexpr auto base_pkind = primitive_kind::inner_product;
|
||||||
|
|
||||||
|
inner_product_pd_t(engine_t *engine,
|
||||||
|
const inner_product_desc_t *adesc,
|
||||||
|
const primitive_attr_t *attr,
|
||||||
|
const inner_product_fwd_pd_t *hint_fwd_pd)
|
||||||
|
: primitive_desc_t(engine, attr, base_pkind)
|
||||||
|
, desc_(*adesc)
|
||||||
|
, hint_fwd_pd_(hint_fwd_pd)
|
||||||
|
{}
|
||||||
|
|
||||||
|
const inner_product_desc_t *desc() const { return &desc_; }
|
||||||
|
virtual const op_desc_t *op_desc() const override
|
||||||
|
{ return reinterpret_cast<const op_desc_t *>(this->desc()); }
|
||||||
|
virtual void init_info() override { impl::init_info(this, this->info_); }
|
||||||
|
|
||||||
|
virtual status_t query(query_t what, int idx, void *result) const override {
|
||||||
|
switch (what) {
|
||||||
|
case query::inner_product_d:
|
||||||
|
*(const inner_product_desc_t**)result = desc(); break;
|
||||||
|
default: return primitive_desc_t::query(what, idx, result);
|
||||||
|
}
|
||||||
|
return status::success;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* common inner_product aux functions */
|
||||||
|
|
||||||
|
dim_t MB() const { return ip_prop_invariant_src_d(&desc_)->dims[0]; }
|
||||||
|
dim_t IC() const { return ip_prop_invariant_src_d(&desc_)->dims[1]; }
|
||||||
|
dim_t OC() const { return ip_prop_invariant_dst_d(&desc_)->dims[1]; }
|
||||||
|
|
||||||
|
dim_t ID() const {
|
||||||
|
return ndims() >= 5
|
||||||
|
? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 3] : 1;
|
||||||
|
}
|
||||||
|
dim_t IH() const {
|
||||||
|
return ndims() >= 4
|
||||||
|
? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 2] : 1;
|
||||||
|
}
|
||||||
|
dim_t IW() const {
|
||||||
|
return ndims() >= 3
|
||||||
|
? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 1] : 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
dim_t OD() const {
|
||||||
|
return ndims() >= 5
|
||||||
|
? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 3] : 1;
|
||||||
|
}
|
||||||
|
dim_t OH() const {
|
||||||
|
return ndims() >= 4
|
||||||
|
? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 2] : 1;
|
||||||
|
}
|
||||||
|
dim_t OW() const {
|
||||||
|
return ndims() >= 3
|
||||||
|
? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 1] : 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
dim_t KD() const {
|
||||||
|
return ndims() >= 5
|
||||||
|
? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 3] : 1;
|
||||||
|
}
|
||||||
|
dim_t KH() const {
|
||||||
|
return ndims() >= 4
|
||||||
|
? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 2] : 1;
|
||||||
|
}
|
||||||
|
dim_t KW() const {
|
||||||
|
return ndims() >= 3
|
||||||
|
? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 1] : 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
dim_t IC_total() const {
|
||||||
|
return utils::array_product(&ip_prop_invariant_src_d(&desc_)->dims[1],
|
||||||
|
ndims() - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
dim_t IC_total_padded() const {
|
||||||
|
auto src_d = desc()->prop_kind == prop_kind::backward_data
|
||||||
|
? memory_desc_wrapper(diff_src_md())
|
||||||
|
: memory_desc_wrapper(src_md());
|
||||||
|
assert(src_d.is_blocking_desc());
|
||||||
|
if (!src_d.is_blocking_desc()) return -1;
|
||||||
|
return utils::array_product(src_d.padded_dims() + 1, ndims() - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
int ndims() const { return ip_prop_invariant_src_d(&desc_)->ndims; }
|
||||||
|
|
||||||
|
bool with_bias() const
|
||||||
|
{ return !memory_desc_wrapper(*ip_prop_invariant_bia_d(&desc_)).is_zero(); }
|
||||||
|
|
||||||
|
bool has_zero_dim_memory() const {
|
||||||
|
const auto s_d = memory_desc_wrapper(*ip_prop_invariant_src_d(&desc_));
|
||||||
|
const auto d_d = memory_desc_wrapper(*ip_prop_invariant_dst_d(&desc_));
|
||||||
|
return s_d.has_zero_dim() || d_d.has_zero_dim();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_fwd() const {
|
||||||
|
return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
|
||||||
|
prop_kind::forward_inference);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
inner_product_desc_t desc_;
|
||||||
|
const inner_product_fwd_pd_t *hint_fwd_pd_;
|
||||||
|
|
||||||
|
status_t template_set_default_params(memory_desc_t &src_md,
|
||||||
|
memory_desc_t &weights_md, memory_desc_t &dst_md,
|
||||||
|
memory_desc_t *bias_md) {
|
||||||
|
using namespace format_tag;
|
||||||
|
if (src_md.format_kind == format_kind::any) {
|
||||||
|
CHECK(memory_desc_init_by_tag(src_md,
|
||||||
|
utils::pick(ndims() - 2, nc, ncw, nchw, ncdhw)));
|
||||||
|
}
|
||||||
|
if (dst_md.format_kind == format_kind::any)
|
||||||
|
CHECK(memory_desc_init_by_tag(dst_md, nc));
|
||||||
|
if (weights_md.format_kind == format_kind::any) {
|
||||||
|
CHECK(memory_desc_init_by_tag(weights_md,
|
||||||
|
utils::pick(ndims() - 2, oi, oiw, oihw, oidhw)));
|
||||||
|
}
|
||||||
|
if (bias_md && bias_md->format_kind == format_kind::any)
|
||||||
|
CHECK(memory_desc_init_by_tag(*bias_md, x));
|
||||||
|
return status::success;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct inner_product_fwd_pd_t: public inner_product_pd_t {
|
||||||
|
typedef inner_product_fwd_pd_t base_class;
|
||||||
|
typedef inner_product_fwd_pd_t hint_class;
|
||||||
|
|
||||||
|
inner_product_fwd_pd_t(engine_t *engine,
|
||||||
|
const inner_product_desc_t *adesc,
|
||||||
|
const primitive_attr_t *attr,
|
||||||
|
const inner_product_fwd_pd_t *hint_fwd_pd)
|
||||||
|
: inner_product_pd_t(engine, adesc, attr, hint_fwd_pd)
|
||||||
|
, src_md_(desc_.src_desc)
|
||||||
|
, weights_md_(desc_.weights_desc)
|
||||||
|
, bias_md_(desc_.bias_desc)
|
||||||
|
, dst_md_(desc_.dst_desc)
|
||||||
|
{}
|
||||||
|
|
||||||
|
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
||||||
|
if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS))
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_BIAS && with_bias())
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_DST)
|
||||||
|
return arg_usage_t::output;
|
||||||
|
|
||||||
|
return primitive_desc_t::arg_usage(arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual const memory_desc_t *src_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &src_md_ : nullptr; }
|
||||||
|
virtual const memory_desc_t *dst_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &dst_md_ : nullptr; }
|
||||||
|
virtual const memory_desc_t *weights_md(int index = 0) const override {
|
||||||
|
if (index == 0) return &weights_md_;
|
||||||
|
if (index == 1 && with_bias()) return &bias_md_;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual int n_inputs() const override { return 2 + with_bias(); }
|
||||||
|
virtual int n_outputs() const override { return 1; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
memory_desc_t src_md_;
|
||||||
|
memory_desc_t weights_md_;
|
||||||
|
memory_desc_t bias_md_;
|
||||||
|
memory_desc_t dst_md_;
|
||||||
|
|
||||||
|
status_t set_default_params() {
|
||||||
|
return template_set_default_params(src_md_, weights_md_, dst_md_,
|
||||||
|
&bias_md_);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct inner_product_bwd_data_pd_t: public inner_product_pd_t {
|
||||||
|
typedef inner_product_bwd_data_pd_t base_class;
|
||||||
|
typedef inner_product_fwd_pd_t hint_class;
|
||||||
|
|
||||||
|
inner_product_bwd_data_pd_t(engine_t *engine,
|
||||||
|
const inner_product_desc_t *adesc,
|
||||||
|
const primitive_attr_t *attr,
|
||||||
|
const inner_product_fwd_pd_t *hint_fwd_pd)
|
||||||
|
: inner_product_pd_t(engine, adesc, attr, hint_fwd_pd)
|
||||||
|
, diff_src_md_(desc_.diff_src_desc)
|
||||||
|
, weights_md_(desc_.weights_desc)
|
||||||
|
, diff_dst_md_(desc_.diff_dst_desc)
|
||||||
|
{}
|
||||||
|
|
||||||
|
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
||||||
|
if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST))
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_DIFF_SRC)
|
||||||
|
return arg_usage_t::output;
|
||||||
|
|
||||||
|
return primitive_desc_t::arg_usage(arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual const memory_desc_t *diff_src_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &diff_src_md_ : nullptr; }
|
||||||
|
virtual const memory_desc_t *diff_dst_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &diff_dst_md_ : nullptr; }
|
||||||
|
virtual const memory_desc_t *weights_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &weights_md_ : nullptr; }
|
||||||
|
|
||||||
|
virtual int n_inputs() const override { return 2; }
|
||||||
|
virtual int n_outputs() const override { return 1; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
memory_desc_t diff_src_md_;
|
||||||
|
memory_desc_t weights_md_;
|
||||||
|
memory_desc_t diff_dst_md_;
|
||||||
|
|
||||||
|
status_t set_default_params() {
|
||||||
|
return template_set_default_params(diff_src_md_, weights_md_,
|
||||||
|
diff_dst_md_, nullptr);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct inner_product_bwd_weights_pd_t: public inner_product_pd_t {
|
||||||
|
typedef inner_product_bwd_weights_pd_t base_class;
|
||||||
|
typedef inner_product_fwd_pd_t hint_class;
|
||||||
|
|
||||||
|
inner_product_bwd_weights_pd_t(engine_t *engine,
|
||||||
|
const inner_product_desc_t *adesc,
|
||||||
|
const primitive_attr_t *attr,
|
||||||
|
const inner_product_fwd_pd_t *hint_fwd_pd)
|
||||||
|
: inner_product_pd_t(engine, adesc, attr, hint_fwd_pd)
|
||||||
|
, src_md_(desc_.src_desc)
|
||||||
|
, diff_weights_md_(desc_.diff_weights_desc)
|
||||||
|
, diff_bias_md_(desc_.diff_bias_desc)
|
||||||
|
, diff_dst_md_(desc_.diff_dst_desc)
|
||||||
|
{}
|
||||||
|
|
||||||
|
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
||||||
|
if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_DIFF_WEIGHTS)
|
||||||
|
return arg_usage_t::output;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias())
|
||||||
|
return arg_usage_t::output;
|
||||||
|
|
||||||
|
return primitive_desc_t::arg_usage(arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual const memory_desc_t *src_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &src_md_ : nullptr; }
|
||||||
|
virtual const memory_desc_t *diff_dst_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &diff_dst_md_ : nullptr; }
|
||||||
|
virtual const memory_desc_t *diff_weights_md(int index = 0) const override {
|
||||||
|
if (index == 0) return &diff_weights_md_;
|
||||||
|
if (index == 1 && with_bias()) return &diff_bias_md_;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual int n_inputs() const override { return 2; }
|
||||||
|
virtual int n_outputs() const override { return 1 + with_bias(); }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
memory_desc_t src_md_;
|
||||||
|
memory_desc_t diff_weights_md_;
|
||||||
|
memory_desc_t diff_bias_md_;
|
||||||
|
memory_desc_t diff_dst_md_;
|
||||||
|
|
||||||
|
status_t set_default_params() {
|
||||||
|
return template_set_default_params(src_md_, diff_weights_md_,
|
||||||
|
diff_dst_md_, &diff_bias_md_);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
91
thirdparty/oidn/mkl-dnn/src/common/lrn.cpp
vendored
Normal file
91
thirdparty/oidn/mkl-dnn/src/common/lrn.cpp
vendored
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2016-2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
#include "mkldnn.h"
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "type_helpers.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
using namespace mkldnn::impl;
|
||||||
|
using namespace mkldnn::impl::utils;
|
||||||
|
using namespace mkldnn::impl::status;
|
||||||
|
using namespace mkldnn::impl::prop_kind;
|
||||||
|
using namespace mkldnn::impl::alg_kind;
|
||||||
|
using namespace mkldnn::impl::types;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
status_t lrn_desc_init(lrn_desc_t *lrn_desc,
|
||||||
|
prop_kind_t prop_kind, alg_kind_t alg_kind,
|
||||||
|
const memory_desc_t *data_desc, const memory_desc_t *diff_data_desc,
|
||||||
|
dim_t local_size, float alpha, float beta, float k) {
|
||||||
|
bool args_ok = true
|
||||||
|
&& !any_null(lrn_desc, data_desc)
|
||||||
|
&& one_of(alg_kind, lrn_within_channel, lrn_across_channels)
|
||||||
|
&& one_of(prop_kind, forward_training, forward_inference, backward_data)
|
||||||
|
&& IMPLICATION(prop_kind == backward_data, diff_data_desc != nullptr);
|
||||||
|
if (!args_ok) return invalid_arguments;
|
||||||
|
|
||||||
|
auto ld = lrn_desc_t();
|
||||||
|
ld.primitive_kind = primitive_kind::lrn;
|
||||||
|
ld.prop_kind = prop_kind;
|
||||||
|
ld.alg_kind = alg_kind;
|
||||||
|
|
||||||
|
const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
|
||||||
|
|
||||||
|
ld.data_desc = *data_desc;
|
||||||
|
if (!is_fwd)
|
||||||
|
ld.diff_data_desc = *diff_data_desc;
|
||||||
|
else
|
||||||
|
ld.diff_data_desc = zero_md();
|
||||||
|
ld.local_size = local_size;
|
||||||
|
ld.lrn_alpha = alpha;
|
||||||
|
ld.lrn_beta = beta;
|
||||||
|
ld.lrn_k = k;
|
||||||
|
|
||||||
|
bool consistency = true
|
||||||
|
&& ld.data_desc.ndims == 4;
|
||||||
|
if (ld.prop_kind == backward_data)
|
||||||
|
consistency = consistency
|
||||||
|
&& ld.diff_data_desc.ndims == 4
|
||||||
|
&& array_cmp(ld.diff_data_desc.dims, ld.data_desc.dims, 4);
|
||||||
|
if (!consistency) return invalid_arguments;
|
||||||
|
|
||||||
|
*lrn_desc = ld;
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_lrn_forward_desc_init(lrn_desc_t *lrn_desc,
|
||||||
|
prop_kind_t prop_kind, alg_kind_t alg_kind,
|
||||||
|
const memory_desc_t *data_desc, dim_t local_size, float alpha,
|
||||||
|
float beta, float k) {
|
||||||
|
if (!one_of(prop_kind, forward_training, forward_inference))
|
||||||
|
return invalid_arguments;
|
||||||
|
return lrn_desc_init(lrn_desc, prop_kind, alg_kind, data_desc, nullptr,
|
||||||
|
local_size, alpha, beta, k);
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_lrn_backward_desc_init(lrn_desc_t *lrn_desc,
|
||||||
|
alg_kind_t alg_kind, const memory_desc_t *data_desc,
|
||||||
|
const memory_desc_t *diff_data_desc, dim_t local_size, float alpha,
|
||||||
|
float beta, float k) {
|
||||||
|
return lrn_desc_init(lrn_desc, backward_data, alg_kind, data_desc,
|
||||||
|
diff_data_desc, local_size, alpha, beta, k);
|
||||||
|
}
|
||||||
|
|
||||||
|
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
170
thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp
vendored
Normal file
170
thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp
vendored
Normal file
@ -0,0 +1,170 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2016-2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#ifndef LRN_PD_HPP
|
||||||
|
#define LRN_PD_HPP
|
||||||
|
|
||||||
|
#include "mkldnn.h"
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "primitive_desc.hpp"
|
||||||
|
|
||||||
|
namespace mkldnn {
|
||||||
|
namespace impl {
|
||||||
|
|
||||||
|
struct lrn_fwd_pd_t;
|
||||||
|
|
||||||
|
struct lrn_pd_t: public primitive_desc_t {
|
||||||
|
static constexpr auto base_pkind = primitive_kind::lrn;
|
||||||
|
|
||||||
|
lrn_pd_t(engine_t *engine,
|
||||||
|
const lrn_desc_t *adesc,
|
||||||
|
const primitive_attr_t *attr,
|
||||||
|
const lrn_fwd_pd_t *hint_fwd_pd)
|
||||||
|
: primitive_desc_t(engine, attr, base_pkind)
|
||||||
|
, desc_(*adesc)
|
||||||
|
, hint_fwd_pd_(hint_fwd_pd)
|
||||||
|
, data_md_(desc_.data_desc)
|
||||||
|
, ws_md_()
|
||||||
|
{}
|
||||||
|
|
||||||
|
const lrn_desc_t *desc() const { return &desc_; }
|
||||||
|
virtual const op_desc_t *op_desc() const override
|
||||||
|
{ return reinterpret_cast<const op_desc_t *>(this->desc()); }
|
||||||
|
virtual void init_info() override { impl::init_info(this, this->info_); }
|
||||||
|
|
||||||
|
virtual status_t query(query_t what, int idx, void *result) const override {
|
||||||
|
switch (what) {
|
||||||
|
case query::lrn_d:
|
||||||
|
*(const lrn_desc_t**)result = desc(); break;
|
||||||
|
default: return primitive_desc_t::query(what, idx, result);
|
||||||
|
}
|
||||||
|
return status::success;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* common lrn aux functions */
|
||||||
|
|
||||||
|
dim_t MB() const { return data_desc().dims[0]; }
|
||||||
|
dim_t C() const { return data_desc().dims[1]; }
|
||||||
|
dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; }
|
||||||
|
dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; }
|
||||||
|
dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; }
|
||||||
|
|
||||||
|
int ndims() const { return data_desc().ndims; }
|
||||||
|
|
||||||
|
bool has_zero_dim_memory() const
|
||||||
|
{ return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); }
|
||||||
|
|
||||||
|
bool is_fwd() const {
|
||||||
|
return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
|
||||||
|
prop_kind::forward_inference);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
lrn_desc_t desc_;
|
||||||
|
const lrn_fwd_pd_t *hint_fwd_pd_;
|
||||||
|
|
||||||
|
memory_desc_t data_md_;
|
||||||
|
memory_desc_t ws_md_;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const memory_desc_t &data_desc() const { return desc_.data_desc; }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct lrn_fwd_pd_t: public lrn_pd_t {
|
||||||
|
typedef lrn_fwd_pd_t base_class;
|
||||||
|
typedef lrn_fwd_pd_t hint_class;
|
||||||
|
|
||||||
|
lrn_fwd_pd_t(engine_t *engine,
|
||||||
|
const lrn_desc_t *adesc,
|
||||||
|
const primitive_attr_t *attr,
|
||||||
|
const lrn_fwd_pd_t *hint_fwd_pd)
|
||||||
|
: lrn_pd_t(engine, adesc, attr, hint_fwd_pd)
|
||||||
|
{}
|
||||||
|
|
||||||
|
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
||||||
|
if (arg == MKLDNN_ARG_SRC)
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_DST)
|
||||||
|
return arg_usage_t::output;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr))
|
||||||
|
return arg_usage_t::output;
|
||||||
|
|
||||||
|
return primitive_desc_t::arg_usage(arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual const memory_desc_t *src_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &data_md_ : nullptr; }
|
||||||
|
virtual const memory_desc_t *dst_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &data_md_ : nullptr; }
|
||||||
|
virtual const memory_desc_t *workspace_md(int index = 0) const override
|
||||||
|
{ return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; }
|
||||||
|
|
||||||
|
virtual int n_inputs() const override { return 1; }
|
||||||
|
virtual int n_outputs() const override
|
||||||
|
{ return 1 + (workspace_md() != nullptr); }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct lrn_bwd_pd_t: public lrn_pd_t {
|
||||||
|
typedef lrn_bwd_pd_t base_class;
|
||||||
|
typedef lrn_fwd_pd_t hint_class;
|
||||||
|
|
||||||
|
lrn_bwd_pd_t(engine_t *engine,
|
||||||
|
const lrn_desc_t *adesc,
|
||||||
|
const primitive_attr_t *attr,
|
||||||
|
const lrn_fwd_pd_t *hint_fwd_pd)
|
||||||
|
: lrn_pd_t(engine, adesc, attr, hint_fwd_pd)
|
||||||
|
, diff_data_md_(desc_.diff_data_desc)
|
||||||
|
{}
|
||||||
|
|
||||||
|
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
||||||
|
if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_DIFF_SRC)
|
||||||
|
return arg_usage_t::output;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr))
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
return primitive_desc_t::arg_usage(arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual const memory_desc_t *src_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &data_md_ : nullptr; }
|
||||||
|
virtual const memory_desc_t *diff_dst_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &diff_data_md_ : nullptr; }
|
||||||
|
virtual const memory_desc_t *diff_src_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &diff_data_md_ : nullptr; }
|
||||||
|
virtual const memory_desc_t *workspace_md(int index = 0) const override
|
||||||
|
{ return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; }
|
||||||
|
|
||||||
|
virtual int n_inputs() const override
|
||||||
|
{ return 2 + (workspace_md() != nullptr); }
|
||||||
|
virtual int n_outputs() const override { return 1; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
memory_desc_t diff_data_md_;
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
280
thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp
vendored
Normal file
280
thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp
vendored
Normal file
@ -0,0 +1,280 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2017-2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#ifndef MATH_UTILS_HPP
|
||||||
|
#define MATH_UTILS_HPP
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <math.h>
|
||||||
|
|
||||||
|
#include "utils.hpp"
|
||||||
|
#include "nstl.hpp"
|
||||||
|
#include "mkldnn_traits.hpp"
|
||||||
|
|
||||||
|
#if defined(MKLDNN_X86_64)
|
||||||
|
#include "immintrin.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace mkldnn {
|
||||||
|
namespace impl {
|
||||||
|
namespace math {
|
||||||
|
|
||||||
|
/** rounds @p f to an integer according to the mxcsr register */
|
||||||
|
inline int mxcsr_round(float f) {
|
||||||
|
#if defined(MKLDNN_X86_64)
|
||||||
|
return _mm_cvtss_si32(_mm_load_ss(&f));
|
||||||
|
#else
|
||||||
|
return (int)nearbyintf(f); // optimism
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename data_t, typename acc_t>
|
||||||
|
inline typename utils::enable_if<!nstl::is_integral<data_t>::value,
|
||||||
|
typename utils::remove_reference<data_t>::type>::type
|
||||||
|
saturate(const acc_t &x) {
|
||||||
|
return (typename utils::remove_reference<data_t>::type)x;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename data_t, typename acc_t>
|
||||||
|
inline typename utils::enable_if<nstl::is_integral<data_t>::value,
|
||||||
|
typename utils::remove_reference<data_t>::type>::type
|
||||||
|
saturate(const acc_t &x) {
|
||||||
|
acc_t v = x;
|
||||||
|
if (v < (acc_t)nstl::numeric_limits<data_t>::lowest())
|
||||||
|
v = (acc_t)nstl::numeric_limits<data_t>::lowest();
|
||||||
|
if (v > (acc_t)nstl::numeric_limits<data_t>::max())
|
||||||
|
v = (acc_t)nstl::numeric_limits<data_t>::max();
|
||||||
|
return (typename utils::remove_reference<data_t>::type)v;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename data_t>
|
||||||
|
double saturate(const double &x) {
|
||||||
|
double v = x;
|
||||||
|
if (v < (double)nstl::numeric_limits<data_t>::lowest())
|
||||||
|
v = (double)nstl::numeric_limits<data_t>::lowest();
|
||||||
|
if (v > (double)nstl::numeric_limits<data_t>::max())
|
||||||
|
v = (double)nstl::numeric_limits<data_t>::max();
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <> inline int8_t saturate<int8_t, uint8_t>(const uint8_t &x) {
|
||||||
|
return x <= 127u ? x : 127;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <> inline uint8_t saturate<uint8_t, int8_t>(const int8_t &x) {
|
||||||
|
return x >= 0 ? x : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename out_t>
|
||||||
|
typename utils::enable_if<nstl::is_integral<out_t>::value, out_t>::type
|
||||||
|
out_round(float v) { return (out_t)mxcsr_round(v); }
|
||||||
|
|
||||||
|
template <typename out_t>
|
||||||
|
typename utils::enable_if<nstl::is_integral<out_t>::value, out_t>::type
|
||||||
|
out_round(double v) { return (out_t)mxcsr_round((float)v); }
|
||||||
|
|
||||||
|
template <typename out_t>
|
||||||
|
typename utils::enable_if<!nstl::is_integral<out_t>::value, out_t>::type
|
||||||
|
out_round(float v) { return v; }
|
||||||
|
|
||||||
|
inline int gcd(int a, int b) {
|
||||||
|
a = impl::nstl::abs(a);
|
||||||
|
b = impl::nstl::abs(b);
|
||||||
|
if (a < b) { int x = a; a = b; b = x; }
|
||||||
|
|
||||||
|
if (b == 0) return a;
|
||||||
|
|
||||||
|
int r;
|
||||||
|
while ((r = a % b) != 0) { a = b; b = r; }
|
||||||
|
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline bool is_pow2(const T& v) { return (v & (v - 1)) == 0; }
|
||||||
|
|
||||||
|
/** returns floor(log2(v)), aka the position of the leftmost non-0 bit */
|
||||||
|
inline int ilog2q(size_t v) {
|
||||||
|
if (v == 0)
|
||||||
|
return -1;
|
||||||
|
|
||||||
|
int p = 0;
|
||||||
|
# define CP(pw) do { if (v >= (1ull << pw)) { v >>= pw; p += pw; } } while(0)
|
||||||
|
CP(32); CP(16); CP(8); CP(4); CP(2); CP(1);
|
||||||
|
# undef CP
|
||||||
|
return p;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U = typename utils::remove_reference<T>::type>
|
||||||
|
inline U one_m_square(T x) {
|
||||||
|
return (U)(1 - x) * (1 + x);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U = typename utils::remove_reference<T>::type>
|
||||||
|
inline U x_m_square(T x) {
|
||||||
|
return (U)(1 - x) * x;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* activation */
|
||||||
|
template <typename T, typename A,
|
||||||
|
typename U = typename utils::remove_reference<T>::type>
|
||||||
|
inline U relu_fwd(T s, A alpha) {
|
||||||
|
return s > 0 ? s : (U)(s * alpha);
|
||||||
|
}
|
||||||
|
template <typename T, typename A,
|
||||||
|
typename U = typename utils::remove_reference<T>::type>
|
||||||
|
inline U relu_bwd(T dd, T s, A alpha) {
|
||||||
|
return s > 0 ? dd : (U)(dd * alpha);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U = typename utils::remove_reference<T>::type>
|
||||||
|
inline U tanh_fwd(T s) {
|
||||||
|
const float e = tanhf((float) s);
|
||||||
|
return (U)e;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U = typename utils::remove_reference<T>::type>
|
||||||
|
inline U tanh_bwd(T dd, T s) {
|
||||||
|
const float e = tanh_fwd<float>((float) s);
|
||||||
|
return (U)(dd * (1 - e) * (1 + e));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename A,
|
||||||
|
typename U = typename utils::remove_reference<T>::type>
|
||||||
|
inline U elu_fwd(T s, A alpha) {
|
||||||
|
return s > 0 ? s : (U)(alpha * (::expm1f((float)s)));
|
||||||
|
}
|
||||||
|
template <typename T, typename A,
|
||||||
|
typename U = typename utils::remove_reference<T>::type>
|
||||||
|
inline U elu_bwd(T dd, T s, A alpha) {
|
||||||
|
return (U)(dd * (s > 0 ? 1 : alpha * ::expf((float)s)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U = typename utils::remove_reference<T>::type>
|
||||||
|
inline U square_fwd(T s) {
|
||||||
|
return s * s;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U = typename utils::remove_reference<T>::type>
|
||||||
|
inline U square_bwd(T dd, T s) {
|
||||||
|
return dd * 2 * s;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U = typename utils::remove_reference<T>::type>
|
||||||
|
inline U abs_fwd(T s) {
|
||||||
|
return s > 0 ? s : -s;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U = typename utils::remove_reference<T>::type>
|
||||||
|
inline U abs_bwd(T dd, T s) {
|
||||||
|
return s > 0 ? dd : s < 0 ? -dd : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U = typename utils::remove_reference<T>::type>
|
||||||
|
inline U sqrt_fwd(T s) {
|
||||||
|
return s > 0 ? (U)(::sqrtf((float)(s))) : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U = typename utils::remove_reference<T>::type>
|
||||||
|
inline U sqrt_bwd(T dd, T s) {
|
||||||
|
return s > 0
|
||||||
|
? (U)(dd / (2 * ::sqrtf((float)(s))))
|
||||||
|
: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename A,
|
||||||
|
typename U = typename utils::remove_reference<T>::type>
|
||||||
|
inline U linear_fwd(T s, A alpha, A beta) {
|
||||||
|
return (U)(alpha * s + beta);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename A,
|
||||||
|
typename U = typename utils::remove_reference<T>::type>
|
||||||
|
inline U linear_bwd(T dd, T s, A alpha, A beta) {
|
||||||
|
(void) s;
|
||||||
|
(void) beta;
|
||||||
|
return (U)(dd * alpha);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename A,
|
||||||
|
typename U = typename utils::remove_reference<T>::type>
|
||||||
|
inline U bounded_relu_fwd(T s, A alpha) {
|
||||||
|
s = s > 0 ? s : 0;
|
||||||
|
return s > alpha ? (U)(alpha) : s;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename A,
|
||||||
|
typename U = typename utils::remove_reference<T>::type>
|
||||||
|
inline U bounded_relu_bwd(T dd, T s, A alpha) {
|
||||||
|
return dd * (0 < s && s < alpha ? 1 : 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U = typename utils::remove_reference<T>::type>
|
||||||
|
inline U soft_relu_fwd(T s) {
|
||||||
|
float max_logf = 8.872284e+01; //::logf(FLT_MAX)
|
||||||
|
return s < max_logf ? (U)(::log1pf(::expf((float)s))) : s;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U = typename utils::remove_reference<T>::type>
|
||||||
|
inline U soft_relu_bwd(T dd, T s) {
|
||||||
|
return (U)(dd / (1 + ::expf((float)(-s))));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U = typename utils::remove_reference<T>::type>
|
||||||
|
inline U logistic_fwd(T s) {
|
||||||
|
U v = (U)(::expf((float) -s));
|
||||||
|
return 1 / (1 + v);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U = typename utils::remove_reference<T>::type>
|
||||||
|
inline U logistic_bwd(T dd, T s) {
|
||||||
|
U v = logistic_fwd<T, U>(s);
|
||||||
|
return dd * v * (1 - v);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool eltwise_fwd_preserves_zero(alg_kind_t alg, bool jit_impl = false) {
|
||||||
|
using namespace alg_kind;
|
||||||
|
using namespace utils;
|
||||||
|
const bool preserves_zero = true
|
||||||
|
&& !one_of(alg, eltwise_linear, eltwise_soft_relu, eltwise_logistic)
|
||||||
|
&& IMPLICATION(jit_impl, !one_of(alg, eltwise_elu, eltwise_tanh));
|
||||||
|
return preserves_zero;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline float get_bias(const char *bias, size_t offset, data_type_t data_type)
|
||||||
|
{
|
||||||
|
if (!bias)
|
||||||
|
return 0.0f;
|
||||||
|
|
||||||
|
#define CASE(dt) \
|
||||||
|
case dt: return (float)((const prec_traits<dt>::type *)bias)[offset]
|
||||||
|
|
||||||
|
switch (data_type) {
|
||||||
|
CASE(data_type::s8);
|
||||||
|
CASE(data_type::u8);
|
||||||
|
CASE(data_type::s32);
|
||||||
|
CASE(data_type::f32);
|
||||||
|
default: assert(!"unimplemented");
|
||||||
|
}
|
||||||
|
return 0; // never happens (should probably be a NaN)
|
||||||
|
#undef CASE
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
238
thirdparty/oidn/mkl-dnn/src/common/memory.cpp
vendored
Normal file
238
thirdparty/oidn/mkl-dnn/src/common/memory.cpp
vendored
Normal file
@ -0,0 +1,238 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2016-2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include "mkldnn.h"
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "engine.hpp"
|
||||||
|
#include "type_helpers.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
using namespace mkldnn::impl;
|
||||||
|
using namespace mkldnn::impl::utils;
|
||||||
|
using namespace mkldnn::impl::status;
|
||||||
|
using namespace mkldnn::impl::data_type;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
bool memory_desc_sanity_check(int ndims,const dims_t dims,
|
||||||
|
data_type_t data_type, format_kind_t format_kind) {
|
||||||
|
if (ndims == 0) return true;
|
||||||
|
|
||||||
|
bool ok = true
|
||||||
|
&& dims != nullptr
|
||||||
|
&& 0 < ndims && ndims <= MKLDNN_MAX_NDIMS
|
||||||
|
&& one_of(data_type, f32, s32, s8, u8)
|
||||||
|
&& format_kind != format_kind::undef;
|
||||||
|
if (!ok) return false;
|
||||||
|
for (int d = 0; d < ndims; ++d)
|
||||||
|
if (dims[d] < 0) return false;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool memory_desc_sanity_check(const memory_desc_t *md) {
|
||||||
|
if (md == nullptr) return false;
|
||||||
|
return memory_desc_sanity_check(md->ndims, md->dims, md->data_type,
|
||||||
|
format_kind::any);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_memory_desc_init_by_tag(memory_desc_t *memory_desc, int ndims,
|
||||||
|
const dims_t dims, data_type_t data_type, format_tag_t tag) {
|
||||||
|
if (any_null(memory_desc)) return invalid_arguments;
|
||||||
|
if (ndims == 0 || tag == format_tag::undef) {
|
||||||
|
*memory_desc = types::zero_md();
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
format_kind_t format_kind = types::format_tag_to_kind(tag);
|
||||||
|
|
||||||
|
/* memory_desc != 0 */
|
||||||
|
bool args_ok = !any_null(memory_desc)
|
||||||
|
&& memory_desc_sanity_check(ndims, dims, data_type, format_kind);
|
||||||
|
if (!args_ok) return invalid_arguments;
|
||||||
|
|
||||||
|
auto md = memory_desc_t();
|
||||||
|
md.ndims = ndims;
|
||||||
|
array_copy(md.dims, dims, ndims);
|
||||||
|
md.data_type = data_type;
|
||||||
|
array_copy(md.padded_dims, dims, ndims);
|
||||||
|
md.format_kind = format_kind;
|
||||||
|
|
||||||
|
status_t status = success;
|
||||||
|
if (tag == format_tag::undef) {
|
||||||
|
status = invalid_arguments;
|
||||||
|
} else if (tag == format_tag::any) {
|
||||||
|
// nop
|
||||||
|
} else if (format_kind == format_kind::blocked) {
|
||||||
|
status = memory_desc_wrapper::compute_blocking(md, tag);
|
||||||
|
} else {
|
||||||
|
assert(!"unreachable");
|
||||||
|
status = invalid_arguments;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (status == success)
|
||||||
|
*memory_desc = md;
|
||||||
|
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_memory_desc_init_by_strides(memory_desc_t *memory_desc,
|
||||||
|
int ndims, const dims_t dims, data_type_t data_type,
|
||||||
|
const dims_t strides) {
|
||||||
|
if (any_null(memory_desc)) return invalid_arguments;
|
||||||
|
if (ndims == 0) {
|
||||||
|
*memory_desc = types::zero_md();
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* memory_desc != 0 */
|
||||||
|
bool args_ok = !any_null(memory_desc)
|
||||||
|
&& memory_desc_sanity_check(ndims, dims, data_type, format_kind::any);
|
||||||
|
if (!args_ok) return invalid_arguments;
|
||||||
|
|
||||||
|
auto md = memory_desc_t();
|
||||||
|
md.ndims = ndims;
|
||||||
|
array_copy(md.dims, dims, ndims);
|
||||||
|
md.data_type = data_type;
|
||||||
|
array_copy(md.padded_dims, dims, ndims);
|
||||||
|
md.format_kind = format_kind::blocked;
|
||||||
|
|
||||||
|
dims_t default_strides = {0};
|
||||||
|
if (strides == nullptr) {
|
||||||
|
default_strides[md.ndims - 1] = 1;
|
||||||
|
for (int d = md.ndims - 2; d >= 0; --d)
|
||||||
|
default_strides[d] = default_strides[d + 1] * md.padded_dims[d + 1];
|
||||||
|
strides = default_strides;
|
||||||
|
} else {
|
||||||
|
/* TODO: add sanity check for the provided strides */
|
||||||
|
}
|
||||||
|
|
||||||
|
array_copy(md.format_desc.blocking.strides, strides, md.ndims);
|
||||||
|
|
||||||
|
*memory_desc = md;
|
||||||
|
|
||||||
|
return status::success;
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_memory_desc_init_submemory(memory_desc_t *md,
|
||||||
|
const memory_desc_t *parent_md, const dims_t dims,
|
||||||
|
const dims_t offsets) {
|
||||||
|
if (any_null(md, parent_md) || !memory_desc_sanity_check(parent_md))
|
||||||
|
return invalid_arguments;
|
||||||
|
|
||||||
|
const memory_desc_wrapper src_d(parent_md);
|
||||||
|
|
||||||
|
for (int d = 0; d < src_d.ndims(); ++d) {
|
||||||
|
if (dims[d] < 0 || offsets[d] < 0
|
||||||
|
|| (offsets[d] + dims[d] > src_d.dims()[d]))
|
||||||
|
return invalid_arguments;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (src_d.format_kind() != format_kind::blocked)
|
||||||
|
return unimplemented;
|
||||||
|
|
||||||
|
dims_t blocks;
|
||||||
|
src_d.compute_blocks(blocks);
|
||||||
|
|
||||||
|
memory_desc_t dst_d = *parent_md;
|
||||||
|
auto &dst_d_blk = dst_d.format_desc.blocking;
|
||||||
|
|
||||||
|
/* TODO: put this into memory_desc_wrapper */
|
||||||
|
for (int d = 0; d < src_d.ndims(); ++d) {
|
||||||
|
/* very limited functionality for now */
|
||||||
|
const bool ok = true
|
||||||
|
&& offsets[d] % blocks[d] == 0 /* [r1] */
|
||||||
|
&& src_d.padded_offsets()[d] == 0
|
||||||
|
&& (false
|
||||||
|
|| dims[d] % blocks[d] == 0
|
||||||
|
|| dims[d] < blocks[d]);
|
||||||
|
if (!ok)
|
||||||
|
return unimplemented;
|
||||||
|
|
||||||
|
const bool is_right_border = offsets[d] + dims[d] == src_d.dims()[d];
|
||||||
|
|
||||||
|
dst_d.dims[d] = dims[d];
|
||||||
|
dst_d.padded_dims[d] = is_right_border
|
||||||
|
? src_d.padded_dims()[d] - offsets[d] : dst_d.dims[d];
|
||||||
|
dst_d.padded_offsets[d] = src_d.padded_offsets()[d];
|
||||||
|
dst_d.offset0 += /* [r1] */
|
||||||
|
offsets[d] / blocks[d] * dst_d_blk.strides[d];
|
||||||
|
}
|
||||||
|
|
||||||
|
*md = dst_d;
|
||||||
|
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
int mkldnn_memory_desc_equal(const memory_desc_t *lhs,
|
||||||
|
const memory_desc_t *rhs) {
|
||||||
|
if (lhs == rhs) return 1;
|
||||||
|
if (any_null(lhs, rhs)) return 0;
|
||||||
|
return memory_desc_wrapper(*lhs) == memory_desc_wrapper(*rhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t mkldnn_memory_desc_get_size(const memory_desc_t *md) {
|
||||||
|
if (md == nullptr) return 0;
|
||||||
|
return memory_desc_wrapper(*md).size();
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_memory_create(memory_t **memory, const memory_desc_t *md,
|
||||||
|
engine_t *engine, void *handle) {
|
||||||
|
if (any_null(memory, engine)) return invalid_arguments;
|
||||||
|
memory_desc_t z_md = types::zero_md();
|
||||||
|
return engine->memory_create(memory, md ? md : &z_md, handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_memory_get_memory_desc(const memory_t *memory,
|
||||||
|
const memory_desc_t **md) {
|
||||||
|
if (any_null(memory, md)) return invalid_arguments;
|
||||||
|
*md = memory->md();
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_memory_get_engine(const memory_t *memory, engine_t **engine) {
|
||||||
|
if (any_null(memory, engine)) return invalid_arguments;
|
||||||
|
*engine = memory->engine();
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_memory_get_data_handle(const memory_t *memory,
|
||||||
|
void **handle) {
|
||||||
|
if (any_null(handle))
|
||||||
|
return invalid_arguments;
|
||||||
|
if (memory == nullptr) {
|
||||||
|
*handle = nullptr;
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
return memory->get_data_handle(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_memory_set_data_handle(memory_t *memory, void *handle) {
|
||||||
|
if (any_null(memory)) return invalid_arguments;
|
||||||
|
return memory->set_data_handle(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_memory_destroy(memory_t *memory) {
|
||||||
|
delete memory;
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
63
thirdparty/oidn/mkl-dnn/src/common/memory.hpp
vendored
Normal file
63
thirdparty/oidn/mkl-dnn/src/common/memory.hpp
vendored
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#ifndef MEMORY_HPP
|
||||||
|
#define MEMORY_HPP
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
#include "mkldnn.h"
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "nstl.hpp"
|
||||||
|
|
||||||
|
struct mkldnn_memory: public mkldnn::impl::c_compatible {
|
||||||
|
mkldnn_memory(mkldnn::impl::engine_t *engine,
|
||||||
|
const mkldnn::impl::memory_desc_t *md)
|
||||||
|
: engine_(engine), md_(*md) {}
|
||||||
|
virtual ~mkldnn_memory() {}
|
||||||
|
|
||||||
|
/** allocates/initializes memory */
|
||||||
|
virtual mkldnn::impl::status_t init() = 0;
|
||||||
|
|
||||||
|
/** returns memory's engine */
|
||||||
|
mkldnn::impl::engine_t *engine() const { return engine_; }
|
||||||
|
/** returns memory's description */
|
||||||
|
const mkldnn::impl::memory_desc_t *md() const { return &md_; }
|
||||||
|
|
||||||
|
/** returns data handle */
|
||||||
|
virtual mkldnn::impl::status_t get_data_handle(void **handle) const = 0;
|
||||||
|
|
||||||
|
/** sets data handle */
|
||||||
|
virtual mkldnn::impl::status_t set_data_handle(void *handle) = 0;
|
||||||
|
|
||||||
|
/** zeros padding */
|
||||||
|
virtual mkldnn::impl::status_t zero_pad() const
|
||||||
|
{ return mkldnn::impl::status::success; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
mkldnn::impl::engine_t *engine_;
|
||||||
|
const mkldnn::impl::memory_desc_t md_;
|
||||||
|
|
||||||
|
private:
|
||||||
|
mkldnn_memory() = delete;
|
||||||
|
mkldnn_memory(const mkldnn_memory &) = delete;
|
||||||
|
mkldnn_memory(mkldnn_memory &&) = delete;
|
||||||
|
mkldnn_memory &operator=(const mkldnn_memory &) = delete;
|
||||||
|
mkldnn_memory &operator=(mkldnn_memory &&) = delete;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif
|
212
thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.cpp
vendored
Normal file
212
thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.cpp
vendored
Normal file
@ -0,0 +1,212 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2016-2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
#include <initializer_list>
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "memory_desc_wrapper.hpp"
|
||||||
|
#include "type_helpers.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
namespace mkldnn {
|
||||||
|
namespace impl {
|
||||||
|
|
||||||
|
status_t fill_blocked(memory_desc_t &md,
|
||||||
|
std::initializer_list<int> perm,
|
||||||
|
std::initializer_list<int> inner_blks,
|
||||||
|
std::initializer_list<int> inner_idxs) {
|
||||||
|
const bool ok = true
|
||||||
|
&& perm.size() == (size_t)md.ndims
|
||||||
|
&& inner_blks.size() == inner_idxs.size();
|
||||||
|
if (!ok) return status::invalid_arguments;
|
||||||
|
|
||||||
|
md.offset0 = 0;
|
||||||
|
|
||||||
|
blocking_desc_t &blk = md.format_desc.blocking;
|
||||||
|
|
||||||
|
dim_t block_size = 1;
|
||||||
|
dims_t blocks = {0};
|
||||||
|
utils::array_set(blocks, 1, md.ndims);
|
||||||
|
|
||||||
|
blk.inner_nblks = (int)inner_blks.size();
|
||||||
|
|
||||||
|
int iblk = 0;
|
||||||
|
for (const auto &b: inner_idxs)
|
||||||
|
blk.inner_idxs[iblk++] = b;
|
||||||
|
|
||||||
|
iblk = 0;
|
||||||
|
for (const auto &b: inner_blks) {
|
||||||
|
int dim = blk.inner_idxs[iblk];
|
||||||
|
block_size *= b;
|
||||||
|
blocks[dim] *= b;
|
||||||
|
blk.inner_blks[iblk++] = b;
|
||||||
|
}
|
||||||
|
|
||||||
|
utils::array_set(md.padded_offsets, 0, md.ndims);
|
||||||
|
for (int d = 0; d < md.ndims; ++d)
|
||||||
|
md.padded_dims[d] = utils::rnd_up(md.dims[d], blocks[d]);
|
||||||
|
|
||||||
|
dim_t stride = block_size;
|
||||||
|
// if only we use C++14, the initializer_list would have rbegin()/rend()...
|
||||||
|
for (int d = 0; d < md.ndims; ++d)
|
||||||
|
stride *= md.padded_dims[d] == 0 ? 1 : md.padded_dims[d] / blocks[d];
|
||||||
|
|
||||||
|
for (const auto &d: perm) {
|
||||||
|
if (md.padded_dims[d] == 0) {
|
||||||
|
blk.strides[d] = 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
stride /= md.padded_dims[d] / blocks[d];
|
||||||
|
blk.strides[d] = stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(stride == block_size);
|
||||||
|
|
||||||
|
return status::success;
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t memory_desc_wrapper::compute_blocking(memory_desc_t &memory_desc,
|
||||||
|
format_tag_t tag)
|
||||||
|
{
|
||||||
|
using namespace format_tag;
|
||||||
|
|
||||||
|
if (memory_desc.ndims == 0) return status::invalid_arguments;
|
||||||
|
|
||||||
|
# define C(tag, ... /* perm, inner_blks, inner_idxs */) \
|
||||||
|
case tag: return fill_blocked(memory_desc, __VA_ARGS__)
|
||||||
|
|
||||||
|
switch (tag) {
|
||||||
|
C(a, {0}, {}, {});
|
||||||
|
C(ab, {0, 1}, {}, {});
|
||||||
|
C(abc, {0, 1, 2}, {}, {});
|
||||||
|
C(abcd, {0, 1, 2, 3}, {}, {});
|
||||||
|
C(abcde, {0, 1, 2, 3, 4}, {}, {});
|
||||||
|
C(abcdef, {0, 1, 2, 3, 4, 5}, {}, {});
|
||||||
|
C(abdec, {0, 1, 3, 4, 2}, {}, {});
|
||||||
|
C(acb, {0, 2, 1}, {}, {});
|
||||||
|
C(acbde, {0, 2, 1, 3, 4}, {}, {});
|
||||||
|
C(acdb, {0, 2, 3, 1}, {}, {});
|
||||||
|
C(acdeb, {0, 2, 3, 4, 1}, {}, {});
|
||||||
|
C(ba, {1, 0}, {}, {});
|
||||||
|
C(bac, {1, 0, 2}, {}, {});
|
||||||
|
C(bacd, {1, 0, 2, 3}, {}, {});
|
||||||
|
C(bcda, {1, 2, 3, 0}, {}, {});
|
||||||
|
C(cba, {2, 1, 0}, {}, {});
|
||||||
|
C(cdba, {2, 3, 1, 0}, {}, {});
|
||||||
|
C(cdeba, {2, 3, 4, 1, 0}, {}, {});
|
||||||
|
C(decab, {3, 4, 2, 0, 1}, {}, {});
|
||||||
|
|
||||||
|
C(Abc4a, {0, 1, 2}, {4}, {0});
|
||||||
|
C(aBc4b, {0, 1, 2}, {4}, {1});
|
||||||
|
C(ABc4b16a4b, {0, 1, 2}, {4, 16, 4}, {1, 0, 1});
|
||||||
|
C(ABc4b4a, {0, 1, 2}, {4, 4}, {1, 0});
|
||||||
|
C(Abcd4a, {0, 1, 2, 3}, {4}, {0});
|
||||||
|
C(aBcd4b, {0, 1, 2, 3}, {4}, {1});
|
||||||
|
C(ABcd4b4a, {0, 1, 2, 3}, {4, 4}, {1, 0});
|
||||||
|
C(aBCd4c16b4c, {0, 1, 2, 3}, {4, 16, 4}, {2, 1, 2});
|
||||||
|
C(aBCd4c4b, {0, 1, 2, 3, 4}, {4, 4}, {2, 1});
|
||||||
|
C(Abcde4a, {0, 1, 2, 3, 4}, {4}, {0});
|
||||||
|
C(aBcde4b, {0, 1, 2, 3, 4}, {4}, {1});
|
||||||
|
C(ABcde4b4a, {0, 1, 2, 3, 4}, {4, 4}, {1, 0});
|
||||||
|
C(aBCde4c4b, {0, 1, 2, 3, 4}, {4, 4}, {2, 1});
|
||||||
|
C(aBcdef4b, {0, 1, 2, 3, 4, 5}, {4}, {1});
|
||||||
|
C(aBCdef4c4b, {0, 1, 2, 3, 4, 5}, {4, 4}, {2, 1});
|
||||||
|
C(aBdc4b, {0, 1, 3, 2}, {4}, {1});
|
||||||
|
C(aBdec4b, {0, 1, 3, 4, 2}, {4}, {1});
|
||||||
|
C(aBdefc4b, {0, 1, 3, 4, 5, 2}, {4}, {1});
|
||||||
|
C(Acb4a, {0, 2, 1}, {4}, {0});
|
||||||
|
C(Acdb4a, {0, 2, 3, 1}, {4}, {0});
|
||||||
|
C(Acdeb4a, {0, 2, 3, 4, 1}, {4}, {0});
|
||||||
|
|
||||||
|
C(Abc16a, {0, 1, 2}, {16}, {0});
|
||||||
|
C(ABc16a16b, {0, 1, 2}, {16, 16}, {0, 1});
|
||||||
|
C(aBc16b, {0, 1, 2}, {16}, {1});
|
||||||
|
C(ABc16b16a, {0, 1, 2}, {16, 16}, {1, 0});
|
||||||
|
C(ABc8a16b2a, {0, 1, 2}, {8, 16, 2}, {0, 1, 0});
|
||||||
|
C(ABc8a8b, {0, 1, 2}, {8, 8}, {0, 1});
|
||||||
|
C(aBc8b, {0, 1, 2}, {8}, {1});
|
||||||
|
C(ABc8b16a2b, {0, 1, 2}, {8, 16, 2}, {1, 0, 1});
|
||||||
|
C(ABc8b8a, {0, 1, 2}, {8, 8}, {1, 0});
|
||||||
|
C(Abcd16a, {0, 1, 2, 3}, {16}, {0});
|
||||||
|
C(ABcd16a16b, {0, 1, 2, 3}, {16, 16}, {0, 1});
|
||||||
|
C(aBcd16b, {0, 1, 2, 3}, {16}, {1});
|
||||||
|
C(ABcd16b16a, {0, 1, 2, 3}, {16, 16}, {1, 0});
|
||||||
|
C(aBCd16b16c, {0, 1, 2, 3}, {16, 16}, {1, 2});
|
||||||
|
C(aBCd16c16b, {0, 1, 2, 3}, {16, 16}, {2, 1});
|
||||||
|
C(ABcd4b16a4b, {0, 1, 2, 3}, {4, 16, 4}, {1, 0, 1});
|
||||||
|
C(ABcd8a16b2a, {0, 1, 2, 3}, {8, 16, 2}, {0, 1, 0});
|
||||||
|
C(ABcd8a8b, {0, 1, 2, 3}, {8, 8}, {0, 1});
|
||||||
|
C(aBcd8b, {0, 1, 2, 3}, {8}, {1});
|
||||||
|
C(ABcd8b16a2b, {0, 1, 2, 3}, {8, 16, 2}, {1, 0, 1});
|
||||||
|
C(aBCd8b16c2b, {0, 1, 2, 3}, {8, 16, 2}, {1, 2, 1});
|
||||||
|
C(ABcd8b8a, {0, 1, 2, 3}, {8, 8}, {1, 0});
|
||||||
|
C(aBCd8b8c, {0, 1, 2, 3}, {8, 8}, {1, 2});
|
||||||
|
C(aBCd8c16b2c, {0, 1, 2, 3}, {8, 16, 2}, {2, 1, 2});
|
||||||
|
C(aBCd8c8b, {0, 1, 2, 3}, {8, 8}, {2, 1});
|
||||||
|
C(Abcde16a, {0, 1, 2, 3, 4}, {16}, {0});
|
||||||
|
C(ABcde16a16b, {0, 1, 2, 3, 4}, {16, 16}, {0, 1});
|
||||||
|
C(aBcde16b, {0, 1, 2, 3, 4}, {16}, {1});
|
||||||
|
C(ABcde16b16a, {0, 1, 2, 3, 4}, {16, 16}, {1, 0});
|
||||||
|
C(aBCde16b16c, {0, 1, 2, 3, 4}, {16, 16}, {1, 2});
|
||||||
|
C(aBCde16c16b, {0, 1, 2, 3, 4}, {16, 16}, {2, 1});
|
||||||
|
C(aBCde2c8b4c, {0, 1, 2, 3, 4}, {2, 8, 4}, {2, 1, 2});
|
||||||
|
C(aBCde4b4c, {0, 1, 2, 3, 4}, {4, 4}, {1, 2});
|
||||||
|
C(aBCde4c16b4c, {0, 1, 2, 3, 4}, {4, 16, 4}, {2, 1, 2});
|
||||||
|
C(Abcde8a, {0, 1, 2, 3, 4}, {8}, {0});
|
||||||
|
C(ABcde8a8b, {0, 1, 2, 3, 4}, {8, 8}, {0, 1});
|
||||||
|
C(aBcde8b, {0, 1, 2, 3, 4}, {8}, {1});
|
||||||
|
C(ABcde8b16a2b, {0, 1, 2, 3, 4}, {8, 16, 2}, {1, 0, 1});
|
||||||
|
C(aBCde8b16c2b, {0, 1, 2, 3, 4}, {8, 16, 2}, {1, 2, 1});
|
||||||
|
C(ABcde8b8a, {0, 1, 2, 3, 4}, {8, 8}, {1, 0});
|
||||||
|
C(aBCde8b8c, {0, 1, 2, 3, 4}, {8, 8}, {1, 2});
|
||||||
|
C(aBCde8c16b2c, {0, 1, 2, 3, 4}, {8, 16, 2}, {2, 1, 2});
|
||||||
|
C(aBCde8c8b, {0, 1, 2, 3, 4}, {8, 8}, {2, 1});
|
||||||
|
C(aBcdef16b, {0, 1, 2, 3, 4, 5}, {16}, {1});
|
||||||
|
C(aBCdef16b16c, {0, 1, 2, 3, 4, 5}, {16, 16}, {1, 2});
|
||||||
|
C(aBCdef16c16b, {0, 1, 2, 3, 4, 5}, {16, 16}, {2, 1});
|
||||||
|
C(aBCdef8b8c, {0, 1, 2, 3, 4, 5}, {8, 8}, {1, 2});
|
||||||
|
C(aBCdef8c16b2c, {0, 1, 2, 3, 4, 5}, {8, 16, 2}, {2, 1, 2});
|
||||||
|
C(aBCdef8c8b, {0, 1, 2, 3, 4, 5}, {8, 8}, {2, 1});
|
||||||
|
C(aBdc16b, {0, 1, 3, 2}, {16}, {1});
|
||||||
|
C(aBdc8b, {0, 1, 3, 2}, {8}, {1});
|
||||||
|
C(aBdec16b, {0, 1, 3, 4, 2}, {16}, {1});
|
||||||
|
C(aBdec8b, {0, 1, 3, 4, 2}, {8}, {1});
|
||||||
|
C(aBdefc16b, {0, 1, 3, 4, 5, 2}, {16}, {1});
|
||||||
|
C(aBdefc8b, {0, 1, 3, 4, 5, 2}, {8}, {1});
|
||||||
|
C(Acb16a, {0, 2, 1}, {16}, {0});
|
||||||
|
C(Acb8a, {0, 2, 1}, {8}, {0});
|
||||||
|
C(aCBd16b16c, {0, 2, 1, 3}, {16, 16}, {1, 2});
|
||||||
|
C(aCBde16b16c, {0, 2, 1, 3, 4}, {16, 16}, {1, 2});
|
||||||
|
C(Acdb16a, {0, 2, 3, 1}, {16}, {0});
|
||||||
|
C(Acdb8a, {0, 2, 3, 1}, {8}, {0});
|
||||||
|
C(Acdeb16a, {0, 2, 3, 4, 1}, {16}, {0});
|
||||||
|
C(Acdeb8a, {0, 2, 3, 4, 1}, {8}, {0});
|
||||||
|
C(BAc16a16b, {1, 0, 2}, {16, 16}, {0, 1});
|
||||||
|
C(BAcd16a16b, {1, 0, 2, 3}, {16, 16}, {0, 1});
|
||||||
|
default: break;
|
||||||
|
}
|
||||||
|
|
||||||
|
#undef C
|
||||||
|
|
||||||
|
return status::invalid_arguments;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
400
thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.hpp
vendored
Normal file
400
thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.hpp
vendored
Normal file
@ -0,0 +1,400 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2016-2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#ifndef MEMORY_DESC_WRAPPER_HPP
|
||||||
|
#define MEMORY_DESC_WRAPPER_HPP
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "nstl.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
#include "type_helpers.hpp"
|
||||||
|
|
||||||
|
namespace mkldnn {
|
||||||
|
namespace impl {
|
||||||
|
|
||||||
|
/** thin wrapper class over \struct memory_desc_t which allows easy
|
||||||
|
* manipulations with underlying C structure, which is taken by reference */
|
||||||
|
struct memory_desc_wrapper: public c_compatible {
|
||||||
|
const memory_desc_t *md_;
|
||||||
|
|
||||||
|
/** constructor which takes a reference to a constant underlying C memory
|
||||||
|
* descriptor \param md */
|
||||||
|
memory_desc_wrapper(const memory_desc_t *md): md_(md) {}
|
||||||
|
memory_desc_wrapper(const memory_desc_t &md): memory_desc_wrapper(&md) {}
|
||||||
|
|
||||||
|
/* implementing attributes */
|
||||||
|
int ndims() const { return md_->ndims; }
|
||||||
|
const dims_t &dims() const { return md_->dims; }
|
||||||
|
data_type_t data_type() const { return md_->data_type; }
|
||||||
|
|
||||||
|
const dims_t &padded_dims() const { return md_->padded_dims; }
|
||||||
|
const dims_t &padded_offsets() const { return md_->padded_offsets; }
|
||||||
|
dim_t offset0() const { return md_->offset0; }
|
||||||
|
|
||||||
|
format_kind_t format_kind() const { return md_->format_kind; }
|
||||||
|
|
||||||
|
bool is_blocking_desc() const
|
||||||
|
{ return format_kind() == format_kind::blocked; }
|
||||||
|
bool is_wino_desc() const
|
||||||
|
{ return format_kind() == format_kind::wino; }
|
||||||
|
bool is_rnn_packed_desc() const
|
||||||
|
{ return format_kind() == format_kind::rnn_packed; }
|
||||||
|
|
||||||
|
const blocking_desc_t &blocking_desc() const {
|
||||||
|
assert(is_blocking_desc());
|
||||||
|
return md_->format_desc.blocking;
|
||||||
|
}
|
||||||
|
const wino_desc_t &wino_desc() const {
|
||||||
|
assert(is_wino_desc());
|
||||||
|
return md_->format_desc.wino_desc;
|
||||||
|
}
|
||||||
|
const rnn_packed_desc_t &rnn_packed_desc() const {
|
||||||
|
assert(is_rnn_packed_desc());
|
||||||
|
return md_->format_desc.rnn_packed_desc;
|
||||||
|
}
|
||||||
|
|
||||||
|
const memory_extra_desc_t &extra() const { return md_->extra; }
|
||||||
|
|
||||||
|
/* some useful function */
|
||||||
|
|
||||||
|
/** returns the number of elements including padding if \param with_padding
|
||||||
|
* is true, and the number of data elements otherwise */
|
||||||
|
dim_t nelems(bool with_padding = false) const {
|
||||||
|
if (is_zero()) return 0;
|
||||||
|
return utils::array_product(
|
||||||
|
with_padding ? padded_dims() : dims(), ndims());
|
||||||
|
}
|
||||||
|
|
||||||
|
/** returns true if memory descriptor is zero */
|
||||||
|
bool is_zero() const { return ndims() == 0; }
|
||||||
|
|
||||||
|
/** returns true if memory descriptor contains zero as one of its dim */
|
||||||
|
bool has_zero_dim() const { return nelems() == 0; }
|
||||||
|
|
||||||
|
/** return the size of data type (a shortcut) */
|
||||||
|
size_t data_type_size() const
|
||||||
|
{ return types::data_type_size(data_type()); }
|
||||||
|
|
||||||
|
/** return the size of data type of additional buffer */
|
||||||
|
size_t additional_buffer_data_size() const {
|
||||||
|
if (extra().flags & memory_extra_flags::compensation_conv_s8s8)
|
||||||
|
return sizeof(int32_t);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** return true if memory format has additional buffer */
|
||||||
|
bool is_additional_buffer() const {
|
||||||
|
return (extra().flags & memory_extra_flags::compensation_conv_s8s8);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** returns the size of additional buffer */
|
||||||
|
size_t additional_buffer_size() const {
|
||||||
|
if (extra().flags & memory_extra_flags::compensation_conv_s8s8) {
|
||||||
|
int cmask = extra().compensation_mask;
|
||||||
|
assert(cmask == 1 || cmask == 3);
|
||||||
|
dim_t prod = 1;
|
||||||
|
for (int d = 0; d < ndims(); ++d)
|
||||||
|
if (cmask & (1<<d)) prod *= padded_dims()[d];
|
||||||
|
return prod * additional_buffer_data_size();
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** returns the size required to store described memory
|
||||||
|
* note: if offset0 != 0 returns 0 (need to specify the behavior) */
|
||||||
|
size_t size() const {
|
||||||
|
if (is_zero() || has_zero_dim() || format_kind() == format_kind::any)
|
||||||
|
return 0;
|
||||||
|
|
||||||
|
if (format_kind() == format_kind::wino) {
|
||||||
|
return wino_desc().size;
|
||||||
|
} else if (format_kind() == format_kind::rnn_packed) {
|
||||||
|
return rnn_packed_desc().size;
|
||||||
|
} else {
|
||||||
|
if (offset0() != 0) return 0;
|
||||||
|
|
||||||
|
dims_t blocks = {0};
|
||||||
|
compute_blocks(blocks);
|
||||||
|
|
||||||
|
const auto &bd = blocking_desc();
|
||||||
|
|
||||||
|
size_t max_size = 0;
|
||||||
|
for (int d = 0; d < ndims(); ++d)
|
||||||
|
max_size = nstl::max<size_t>(max_size,
|
||||||
|
padded_dims()[d] / blocks[d] * bd.strides[d]);
|
||||||
|
|
||||||
|
if (max_size == 1 && bd.inner_nblks != 0) {
|
||||||
|
max_size = utils::array_product(bd.inner_blks, bd.inner_nblks);
|
||||||
|
}
|
||||||
|
|
||||||
|
return max_size * data_type_size() + additional_buffer_size();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** returns true if data is dense in memory */
|
||||||
|
bool is_dense(bool with_padding = false) const {
|
||||||
|
if (utils::one_of(format_kind(), format_kind::undef, format_kind::any))
|
||||||
|
return false;
|
||||||
|
return nelems(with_padding) * data_type_size() == size();
|
||||||
|
}
|
||||||
|
|
||||||
|
/** returns true if memory desc is fully defined */
|
||||||
|
bool is_defined() const { return format_kind() != format_kind::any; }
|
||||||
|
|
||||||
|
/** returns true if the only (potentially) padded dim is \param dim */
|
||||||
|
bool only_padded_dim(int dim) const {
|
||||||
|
for (int d = 0; d < ndims(); ++d)
|
||||||
|
if (d != dim && dims()[d] != padded_dims()[d])
|
||||||
|
return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** returns true if memory desc has blocked layout and block dims are 1s */
|
||||||
|
bool is_plain() const {
|
||||||
|
if (!is_blocking_desc()) return false;
|
||||||
|
return blocking_desc().inner_nblks == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** returns overall block sizes */
|
||||||
|
void compute_blocks(dims_t blocks) const {
|
||||||
|
if (!is_blocking_desc()) {
|
||||||
|
utils::array_set(blocks, 0, ndims());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
utils::array_set(blocks, 1, ndims());
|
||||||
|
|
||||||
|
const auto &bd = blocking_desc();
|
||||||
|
for (int iblk = 0; iblk < bd.inner_nblks; ++iblk)
|
||||||
|
blocks[bd.inner_idxs[iblk]] *= bd.inner_blks[iblk];
|
||||||
|
}
|
||||||
|
|
||||||
|
/* comparison section */
|
||||||
|
|
||||||
|
bool operator==(const memory_desc_wrapper &rhs) const
|
||||||
|
{ return *this->md_ == *rhs.md_; }
|
||||||
|
bool operator!=(const memory_desc_wrapper &rhs) const
|
||||||
|
{ return !operator==(rhs); }
|
||||||
|
bool operator==(const memory_desc_t &rhs) const
|
||||||
|
{ return operator==(memory_desc_wrapper(rhs)); }
|
||||||
|
bool operator!=(const memory_desc_t &rhs) const
|
||||||
|
{ return !operator==(rhs); }
|
||||||
|
|
||||||
|
/** returns true if data (w/o padding if with_padding == false and w/
|
||||||
|
* padding otherwise) have the same physical structure, i.e. dimensions,
|
||||||
|
* strides, and blocked structure. Depending on with_data_type flag
|
||||||
|
* data_type is taken or not taken into account. dim_start allows to check
|
||||||
|
* similarity for the logical part of data [dim_start .. ndims()].
|
||||||
|
* CAUTION: format kind any and undef are not similar to whatever, hence the
|
||||||
|
* following statement might be true: lhs == rhs && !lhs.similar_to(rhs) */
|
||||||
|
/* TODO: revise */
|
||||||
|
bool similar_to(const memory_desc_wrapper &rhs,
|
||||||
|
bool with_padding = true, bool with_data_type = true,
|
||||||
|
int dim_start = 0) const;
|
||||||
|
|
||||||
|
/** returns true if one memory can be reordered to another */
|
||||||
|
bool consistent_with(const memory_desc_wrapper &rhs) const;
|
||||||
|
|
||||||
|
/** returns true if the memory desc corresponds to the given format tag and
|
||||||
|
* strides.
|
||||||
|
* @sa memory_desc_matches_tag */
|
||||||
|
bool matches_tag(format_tag_t tag, const dims_t strides = nullptr) const {
|
||||||
|
return memory_desc_matches_tag(*md_, tag, strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** returns matching tag (or undef if match is not found)
|
||||||
|
* XXX: This is a workaround that eventually should go away! */
|
||||||
|
template <typename... Tags>
|
||||||
|
format_tag_t matches_one_of_tag(Tags ...tags) const {
|
||||||
|
for (const auto tag: {tags...}) {
|
||||||
|
if (memory_desc_matches_tag(*md_, tag))
|
||||||
|
return tag;
|
||||||
|
}
|
||||||
|
return format_tag::undef;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* offset section */
|
||||||
|
|
||||||
|
/** returns physical offset by logical one. logical offset is represented by
|
||||||
|
* an array \param pos. if \param is_pos_padded is true \param pos
|
||||||
|
* represents the position in already padded area */
|
||||||
|
dim_t off_v(const dims_t pos, bool is_pos_padded = false) const {
|
||||||
|
assert(is_blocking_desc());
|
||||||
|
const blocking_desc_t &blk = blocking_desc();
|
||||||
|
|
||||||
|
dims_t pos_copy = {0};
|
||||||
|
for (int d = 0; d < ndims(); ++d)
|
||||||
|
pos_copy[d] = pos[d] + (is_pos_padded ? 0 : padded_offsets()[d]);
|
||||||
|
|
||||||
|
dim_t phys_offset = offset0();
|
||||||
|
|
||||||
|
if (blk.inner_nblks > 0) {
|
||||||
|
dim_t blk_stride = 1;
|
||||||
|
for (int iblk = blk.inner_nblks - 1; iblk >= 0; --iblk) {
|
||||||
|
const int d = blk.inner_idxs[iblk];
|
||||||
|
const dim_t p = pos_copy[d] % blk.inner_blks[iblk];
|
||||||
|
|
||||||
|
phys_offset += p * blk_stride;
|
||||||
|
|
||||||
|
pos_copy[d] /= blk.inner_blks[iblk];
|
||||||
|
|
||||||
|
blk_stride *= blk.inner_blks[iblk];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int d = 0; d < ndims(); ++d) {
|
||||||
|
const dim_t p = pos_copy[d];
|
||||||
|
phys_offset += p * blk.strides[d];
|
||||||
|
}
|
||||||
|
|
||||||
|
return phys_offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** returns physical offset by logical one. logical offset is represented by
|
||||||
|
* a scalar \param l_offset. if \param is_pos_padded is true, \param
|
||||||
|
* l_offset represents logical offset in already padded area */
|
||||||
|
dim_t off_l(dim_t l_offset, bool is_pos_padded = false) const {
|
||||||
|
assert(is_blocking_desc());
|
||||||
|
dims_t pos;
|
||||||
|
for (int rd = 0; rd < ndims(); ++rd) {
|
||||||
|
const int d = ndims() - 1 - rd;
|
||||||
|
const dim_t cur_dim = is_pos_padded ? padded_dims()[d] : dims()[d];
|
||||||
|
pos[d] = l_offset % cur_dim;
|
||||||
|
l_offset /= cur_dim;
|
||||||
|
}
|
||||||
|
return off_v(pos, is_pos_padded);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** returns physical offset by logical one. logical offset is represented by
|
||||||
|
* a tuple of indices (\param xn, ..., \param x1, \param x0) */
|
||||||
|
template<typename... Args>
|
||||||
|
dim_t off(Args... args) const {
|
||||||
|
assert(sizeof...(args) == ndims());
|
||||||
|
dims_t pos = { args... };
|
||||||
|
return off_v(pos, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** returns physical offset by logical one. logical offset is represented by
|
||||||
|
* a tuple of indices (\param xn, ..., \param x1, \param x0) in already
|
||||||
|
* padded area */
|
||||||
|
template<typename... Args>
|
||||||
|
dim_t off_padding(Args... args) const {
|
||||||
|
assert(sizeof...(args) == ndims());
|
||||||
|
dims_t pos = { args... };
|
||||||
|
return off_v(pos, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** returns physical offset by logical one. Logical offset is represented by
|
||||||
|
* a tuple of block indices (\param bn, ..., \param b1, \param b0). It is a
|
||||||
|
* user responsibility to adjust the result to get offset within blocks */
|
||||||
|
template<typename ...Args>
|
||||||
|
dim_t blk_off(Args... args) const {
|
||||||
|
return _blk_off<sizeof...(args), Args...>(args...);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<bool skip_first, typename T, typename ...Args>
|
||||||
|
dim_t blk_off(T xn, Args... args) const {
|
||||||
|
return skip_first
|
||||||
|
? blk_off<Args...>(args...)
|
||||||
|
: blk_off<T, Args...>(xn, args...);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* static functions section */
|
||||||
|
/* TODO: replace with non-static, once md_ becomes non-const ref */
|
||||||
|
|
||||||
|
static status_t compute_blocking(memory_desc_t &memory_desc,
|
||||||
|
format_tag_t tag);
|
||||||
|
|
||||||
|
private:
|
||||||
|
/* TODO: put logical_offset in utils */
|
||||||
|
template<typename T>
|
||||||
|
dim_t logical_offset(T x0) const { return x0; }
|
||||||
|
|
||||||
|
template<typename T, typename... Args>
|
||||||
|
dim_t logical_offset(T xn, Args... args) const {
|
||||||
|
const size_t n_args = sizeof...(args);
|
||||||
|
return xn * utils::array_product<n_args>(
|
||||||
|
&dims()[ndims() - n_args]) + logical_offset(args...);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int ORIG_LEN, typename ...Void>
|
||||||
|
dim_t _blk_off() const { return offset0(); }
|
||||||
|
|
||||||
|
template<int ORIG_LEN, typename T, typename ...Args>
|
||||||
|
dim_t _blk_off(T xc, Args ...args) const {
|
||||||
|
assert(is_blocking_desc());
|
||||||
|
constexpr int dc = ORIG_LEN - sizeof...(args) - 1;
|
||||||
|
return xc * blocking_desc().strides[dc]
|
||||||
|
+ _blk_off<ORIG_LEN, Args...>(args...);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs,
|
||||||
|
bool with_padding, bool with_data_type, int dim_start) const {
|
||||||
|
using namespace utils;
|
||||||
|
|
||||||
|
if (one_of(format_kind(), format_kind::undef, format_kind::any))
|
||||||
|
return false;
|
||||||
|
if (is_wino_desc() || is_rnn_packed_desc())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
const int ds = dim_start;
|
||||||
|
const auto &blk = blocking_desc();
|
||||||
|
const auto &r_blk = rhs.blocking_desc();
|
||||||
|
|
||||||
|
return ndims() == rhs.ndims()
|
||||||
|
&& dim_start <= ndims() /* guard */
|
||||||
|
&& format_kind() == rhs.format_kind()
|
||||||
|
&& IMPLICATION(with_data_type, data_type() == rhs.data_type())
|
||||||
|
&& array_cmp(dims() + ds, rhs.dims() + ds, ndims() - ds)
|
||||||
|
&& array_cmp(blk.strides + ds, r_blk.strides + ds, ndims() - ds)
|
||||||
|
&& blk.inner_nblks == r_blk.inner_nblks
|
||||||
|
&& array_cmp(blk.inner_blks, r_blk.inner_blks, blk.inner_nblks)
|
||||||
|
&& array_cmp(blk.inner_idxs, r_blk.inner_idxs, blk.inner_nblks)
|
||||||
|
&& IMPLICATION(with_padding, true
|
||||||
|
&& array_cmp(padded_dims() + ds, rhs.padded_dims() + ds,
|
||||||
|
ndims() - ds)
|
||||||
|
&& array_cmp(padded_offsets() + ds, rhs.padded_offsets() + ds,
|
||||||
|
ndims() - ds));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool memory_desc_wrapper::consistent_with(
|
||||||
|
const memory_desc_wrapper &rhs) const {
|
||||||
|
if (ndims() == rhs.ndims()) {
|
||||||
|
for (int d = 0; d < ndims(); ++d) {
|
||||||
|
if (dims()[d] != rhs.dims()[d]) return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
/* TODO: revise.
|
||||||
|
* is the following possible?
|
||||||
|
* [1, a, b] <--reorder--> [a, b]
|
||||||
|
* [a, 1, b] <--reorder--> [a, b]
|
||||||
|
* not, at least for now */
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
295
thirdparty/oidn/mkl-dnn/src/common/memory_tracking.hpp
vendored
Normal file
295
thirdparty/oidn/mkl-dnn/src/common/memory_tracking.hpp
vendored
Normal file
@ -0,0 +1,295 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#ifndef MEMORY_TRACKING_HPP
|
||||||
|
#define MEMORY_TRACKING_HPP
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include "nstl.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
namespace mkldnn {
|
||||||
|
namespace impl {
|
||||||
|
namespace memory_tracking {
|
||||||
|
|
||||||
|
/* Memory tracking capabilities
|
||||||
|
*
|
||||||
|
* The main purpose of this header file is to provide uniform way to register
|
||||||
|
* required memory for a scratchpad at a primitive descriptor creation time
|
||||||
|
* and then easily access it having only the base address of the scratchpad.
|
||||||
|
*
|
||||||
|
* Primitives might contain multiple disjoint parts that require temporary
|
||||||
|
* buffers (known as scratchpad) during their execution. A primitive descriptor
|
||||||
|
* should summarize all the needs into one single number -- the buffer size
|
||||||
|
* that would be requested from a user. At execution time, the corresponding
|
||||||
|
* primitive will receive a base pointer to a scratchpad. It then needs to
|
||||||
|
* provide each part of algorithm the corresponding piece of memory. Three main
|
||||||
|
* challenges here are:
|
||||||
|
* 1. Track correct offset (from the base scratchpad address) for each piece
|
||||||
|
* 2. Algorithm might require that different memory pieces to be aligned, so
|
||||||
|
* the scratchpad size is no more just a sum of size of the corresponding
|
||||||
|
* subparts.
|
||||||
|
* 3. While a primitive is responsible for its scratchpad, the implementation
|
||||||
|
* might use some other basic blocks (e.g. cpu_reducer) that also require
|
||||||
|
* scratchpad memory. So there should be a simple way of passing the
|
||||||
|
* information back and force between the main algorithm (a primitive) and
|
||||||
|
* auxiliary stuff that lives completely separately from it (e.g. reducer).
|
||||||
|
*
|
||||||
|
* To address these challenges this header file provides 3 structures:
|
||||||
|
* 1. registry_t -- the class the stores the information about requested
|
||||||
|
* memory. The information includes required size and desired
|
||||||
|
* alignment for each piece. This class is also responsible
|
||||||
|
* for computing the right offset to a given piece using the
|
||||||
|
* base pointer.
|
||||||
|
* This class is basically a ledger with all entries.
|
||||||
|
* Lives in primitive descriptors.
|
||||||
|
*
|
||||||
|
* 2. registrar_t -- the interface to a registry_t to book memory. Used at
|
||||||
|
* primitive descriptor creation time only. Contains a
|
||||||
|
* reference to the corresponding *mutable* registry.
|
||||||
|
* Always modifiable.
|
||||||
|
* Allows chaining (using prefixes).
|
||||||
|
*
|
||||||
|
* 3. grantor_t -- the interface to a registry_t to access memory. Used at
|
||||||
|
* primitive execution time only. Contains a reference to
|
||||||
|
* the corresponding *constant* registry and base pointer.
|
||||||
|
* Always constant.
|
||||||
|
* Allows chaining (using prefixes).
|
||||||
|
*
|
||||||
|
* Both registrar_t and grantor_t allow chaining with extra prefix provided.
|
||||||
|
* The feature is useful when a primitive offload a part of computations to
|
||||||
|
* some other primitives which require their own scratchpad space
|
||||||
|
* (e.g. reducer). Prefixes are used to avoid key collision in cases when
|
||||||
|
* multiple sub-primitive (e.g. multiple reducers) are used.
|
||||||
|
*
|
||||||
|
* A short example below demonstrates how to use aforementioned classes. In it
|
||||||
|
* the main primitive is convolution that uses scratchpad for keeping padded
|
||||||
|
* bias. It also needs a reducer, that needs its own space as well.
|
||||||
|
*
|
||||||
|
* ``` c++
|
||||||
|
* struct reducer_t {
|
||||||
|
* static void init(registrar_t &scratchpad) {
|
||||||
|
* // preserve space for the reduction (one page aligned)
|
||||||
|
* scratchpad.book(key_space, sizeof(float) * 980 * 1024, 4096);
|
||||||
|
* }
|
||||||
|
*
|
||||||
|
* void exec(const grantor_t &scratchpad) {
|
||||||
|
* // get the pointer to preserved space. scratchpad came from
|
||||||
|
* // upper primitive (convolution in this example)
|
||||||
|
* auto space = scratchpad.get<float>(key_reducer_space);
|
||||||
|
*
|
||||||
|
* space[:] += ...;
|
||||||
|
* }
|
||||||
|
* };
|
||||||
|
*
|
||||||
|
* struct conv_t {
|
||||||
|
* struct pd_t {
|
||||||
|
* void init() {
|
||||||
|
* registrar_t scratchpad(scratchpad_registry_);
|
||||||
|
*
|
||||||
|
* // preserve a space for padded bias (using default alignment)
|
||||||
|
* scratchpad.book(key_conv_padded_bias, 128);
|
||||||
|
*
|
||||||
|
* // create a proxy registrar for the reducer All entries made
|
||||||
|
* // by reducer would live in convolution's registry, but would
|
||||||
|
* // have their own `prefix`, so no interference with conv's
|
||||||
|
* // buffers.
|
||||||
|
* registrar_t reducer_scratchpad(scratchpad, prefix_reducer);
|
||||||
|
*
|
||||||
|
* reducer_t::init(reducer_scratchpad);
|
||||||
|
* }
|
||||||
|
*
|
||||||
|
* registry_t scratchpad_registry_;
|
||||||
|
* }
|
||||||
|
*
|
||||||
|
* void exec() {
|
||||||
|
* // get the base pointer to a scratchpad memory from a user
|
||||||
|
* void *scratchpad_ptr = this->input(MKLDNN_MEM_SCRATCHPAD);
|
||||||
|
*
|
||||||
|
* // create a grantor to the scratchpad (and provide the base
|
||||||
|
* // pointer).
|
||||||
|
* grantor_t scratchpad(pd()->scratchpad_registry_, scratchpad_ptr);
|
||||||
|
*
|
||||||
|
* // access the padded_bias (need only key name and the grantor)
|
||||||
|
* auto padded_bias = scratchpad.get<float>(key_conv_padded_bias);
|
||||||
|
*
|
||||||
|
* // to give the `right` grantor to reducer we need to add the
|
||||||
|
* // corresponding prefix, so that reducer would be able to access
|
||||||
|
* // its keys. The call is very similar to the one in pd_t::init
|
||||||
|
* // with only difference in types: grantor_t vs registrar_t.
|
||||||
|
* grantor_t reducer_scratchpad(scratchpad, prefix_reducer);
|
||||||
|
* reducer->exec(reducer_scratchpad);
|
||||||
|
* }
|
||||||
|
* };
|
||||||
|
* ```
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
/* namespace with common keys and prefixes */
|
||||||
|
namespace names {
|
||||||
|
enum {
|
||||||
|
key_none = 0,
|
||||||
|
key_bnorm_tmp_mean,
|
||||||
|
key_bnorm_tmp_var,
|
||||||
|
key_bnorm_tmp_diff_ss,
|
||||||
|
key_bnorm_tmp_stats,
|
||||||
|
key_bnorm_reduction,
|
||||||
|
key_concat_iptrs,
|
||||||
|
key_concat_istrides,
|
||||||
|
key_concat_nelems,
|
||||||
|
key_concat_optrs,
|
||||||
|
key_conv_adjusted_scales,
|
||||||
|
key_conv_bia_reduction,
|
||||||
|
key_conv_gemm_col,
|
||||||
|
key_conv_gemm_imtr,
|
||||||
|
key_conv_int_dat_in_acc_dt,
|
||||||
|
key_conv_padded_bias,
|
||||||
|
key_conv_rtus_space,
|
||||||
|
key_conv_tr_diff_dst,
|
||||||
|
key_conv_tr_diff_dst_bctx,
|
||||||
|
key_conv_tr_src,
|
||||||
|
key_conv_tr_src_bctx,
|
||||||
|
key_conv_wei_reduction,
|
||||||
|
key_conv_wei_bia_reduction,
|
||||||
|
key_conv_wei_bia_reduction_bctx,
|
||||||
|
key_iprod_int_dat_in_acc_dt,
|
||||||
|
key_reducer_space,
|
||||||
|
key_reducer_space_bctx,
|
||||||
|
key_reorder_wino_plain,
|
||||||
|
key_reorder_wino_transform_space,
|
||||||
|
key_reorder_rnn_weights_quantization,
|
||||||
|
key_reorder_rnn_weights_reduction,
|
||||||
|
key_rnn_space,
|
||||||
|
key_rnn_ptrs_bia,
|
||||||
|
key_rnn_ptrs_wei_layer,
|
||||||
|
key_rnn_ptrs_wei_iter,
|
||||||
|
key_softmax_reduction,
|
||||||
|
key_wino_U,
|
||||||
|
key_wino_V,
|
||||||
|
key_wino_M,
|
||||||
|
key_barrier,
|
||||||
|
};
|
||||||
|
|
||||||
|
enum {
|
||||||
|
prefix_none = 0,
|
||||||
|
prefix_reducer_bia,
|
||||||
|
prefix_reducer_wei,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// level 0: 00 00 00 xxx
|
||||||
|
// level 1: 00 00 aa xxx
|
||||||
|
// level 2: 00 aa bb xxx
|
||||||
|
// level 3: aa bb cc xxx
|
||||||
|
// max # of levels: 3 + 1 (base_level)
|
||||||
|
// here:
|
||||||
|
// xxx : [1 .. MAX_KEY) : key
|
||||||
|
// aa, bb, cc : [1 .. MAX_PREFIX) : prefixes for levels 1, 2, and 3
|
||||||
|
|
||||||
|
using key_t = uint32_t;
|
||||||
|
enum { MAX_KEY = (1u << 10), MAX_PREFIX = (1u << 7), };
|
||||||
|
|
||||||
|
/// generates global key based on a prefix and a local key
|
||||||
|
inline key_t make_key(key_t prefix, key_t key) { return prefix + key; }
|
||||||
|
|
||||||
|
/// generates global prefix based on the global parent and the local ones
|
||||||
|
inline key_t make_prefix(key_t parent_prefix, key_t prefix)
|
||||||
|
{ return MAX_PREFIX * parent_prefix + MAX_KEY * prefix; }
|
||||||
|
|
||||||
|
struct registrar_t;
|
||||||
|
struct grantor_t;
|
||||||
|
|
||||||
|
struct registry_t {
|
||||||
|
void book(const key_t &key, size_t size, size_t alignment) {
|
||||||
|
if (size == 0) return;
|
||||||
|
assert(offset_map_.count(key) == 0);
|
||||||
|
|
||||||
|
size = utils::rnd_up(size, minimal_alignment);
|
||||||
|
alignment = nstl::max<size_t>(alignment, minimal_alignment);
|
||||||
|
offset_map_[key] = entry_t{size_, size, alignment};
|
||||||
|
|
||||||
|
size_ += size + alignment - minimal_alignment;
|
||||||
|
}
|
||||||
|
|
||||||
|
void *get(const key_t &key, void *base_ptr) const {
|
||||||
|
if (base_ptr == nullptr) { assert(size() == 0); return nullptr; }
|
||||||
|
if (offset_map_.count(key) != 1) return nullptr;
|
||||||
|
|
||||||
|
const auto &e = offset_map_.at(key);
|
||||||
|
base_ptr = utils::align_ptr<void>(base_ptr, minimal_alignment);
|
||||||
|
char *ptr = (char *)base_ptr + e.offset;
|
||||||
|
return utils::align_ptr<void>(ptr, e.alignment);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t size() const
|
||||||
|
{ return size_ > 0 ? size_ + minimal_alignment - 1 : 0; }
|
||||||
|
|
||||||
|
registrar_t registrar();
|
||||||
|
grantor_t grantor(void *base_ptr) const;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
enum { minimal_alignment = 64 };
|
||||||
|
struct entry_t { size_t offset, size, alignment; };
|
||||||
|
|
||||||
|
std::unordered_map<key_t, entry_t> offset_map_;
|
||||||
|
size_t size_ = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct registrar_t {
|
||||||
|
enum { default_alignment = 64 };
|
||||||
|
|
||||||
|
registrar_t(registry_t ®istry): registry_(registry), prefix_(0) {}
|
||||||
|
registrar_t(registrar_t &parent, const key_t &prefix)
|
||||||
|
: registry_(parent.registry_)
|
||||||
|
, prefix_(make_prefix(parent.prefix_, prefix)) {}
|
||||||
|
|
||||||
|
void book(const key_t &key, size_t size,
|
||||||
|
size_t alignment = default_alignment)
|
||||||
|
{ registry_.book(make_key(prefix_, key), size, alignment); }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
registry_t ®istry_;
|
||||||
|
const key_t prefix_;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct grantor_t {
|
||||||
|
grantor_t(const registry_t ®istry, void *base_ptr)
|
||||||
|
: registry_(registry), prefix_(0), base_ptr_(base_ptr) {}
|
||||||
|
grantor_t(const grantor_t &parent, const key_t &prefix)
|
||||||
|
: registry_(parent.registry_)
|
||||||
|
, prefix_(make_prefix(parent.prefix_, prefix))
|
||||||
|
, base_ptr_(parent.base_ptr_) {}
|
||||||
|
|
||||||
|
template <typename T = void> T *get(const key_t &key) const
|
||||||
|
{ return (T *)registry_.get(make_key(prefix_, key), base_ptr_); }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
const registry_t ®istry_;
|
||||||
|
const key_t prefix_;
|
||||||
|
void *base_ptr_;
|
||||||
|
};
|
||||||
|
|
||||||
|
inline registrar_t registry_t::registrar() { return registrar_t(*this); }
|
||||||
|
inline grantor_t registry_t::grantor(void *base_ptr) const
|
||||||
|
{ return grantor_t(*this, base_ptr); }
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
131
thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp
vendored
Normal file
131
thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp
vendored
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2019 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <cinttypes>
|
||||||
|
|
||||||
|
#include "mkldnn_debug.h"
|
||||||
|
#include "mkldnn_types.h"
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "type_helpers.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
#define DPRINT(...) do { \
|
||||||
|
int l = snprintf(str + written_len, str_len, __VA_ARGS__); \
|
||||||
|
if (l < 0) return l; \
|
||||||
|
if ((size_t)l >= str_len) return -1; \
|
||||||
|
written_len += l; str_len -= l; \
|
||||||
|
} while(0)
|
||||||
|
|
||||||
|
int mkldnn_md2fmt_str(char *str, size_t str_len,
|
||||||
|
const mkldnn_memory_desc_t *mdesc) {
|
||||||
|
using namespace mkldnn::impl;
|
||||||
|
|
||||||
|
if (str == nullptr || str_len <= 1u)
|
||||||
|
return -1;
|
||||||
|
|
||||||
|
int written_len = 0;
|
||||||
|
|
||||||
|
if (mdesc == nullptr) {
|
||||||
|
DPRINT("%s::%s::",
|
||||||
|
mkldnn_dt2str(data_type::undef),
|
||||||
|
mkldnn_fmt_kind2str(format_kind::undef));
|
||||||
|
return written_len;
|
||||||
|
}
|
||||||
|
|
||||||
|
memory_desc_wrapper md(mdesc);
|
||||||
|
|
||||||
|
DPRINT("%s:", mkldnn_dt2str(md.data_type()));
|
||||||
|
|
||||||
|
bool padded_dims = false, padded_offsets = false;
|
||||||
|
for (int d = 0; d < md.ndims(); ++d) {
|
||||||
|
if (md.dims()[d] != md.padded_dims()[d]) padded_dims = true;
|
||||||
|
if (md.padded_offsets()[d] != 0) padded_offsets = true;
|
||||||
|
}
|
||||||
|
bool offset0 = md.offset0();
|
||||||
|
DPRINT("%s%s%s:",
|
||||||
|
padded_dims ? "p" : "",
|
||||||
|
padded_offsets ? "o" : "",
|
||||||
|
offset0 ? "0" : "");
|
||||||
|
|
||||||
|
DPRINT("%s:", mkldnn_fmt_kind2str(md.format_kind()));
|
||||||
|
|
||||||
|
if (!md.is_blocking_desc()) {
|
||||||
|
/* TODO: extend */
|
||||||
|
DPRINT("%s:", "");
|
||||||
|
} else {
|
||||||
|
const auto &blk = md.blocking_desc();
|
||||||
|
|
||||||
|
dims_t blocks;
|
||||||
|
md.compute_blocks(blocks);
|
||||||
|
|
||||||
|
char dim_chars[MKLDNN_MAX_NDIMS + 1];
|
||||||
|
|
||||||
|
bool plain = true;
|
||||||
|
for (int d = 0; d < md.ndims(); ++d) {
|
||||||
|
dim_chars[d] = (blocks[d] == 1 ? 'a' : 'A') + (char)d;
|
||||||
|
if (blocks[d] != 1) plain = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
dims_t strides;
|
||||||
|
utils::array_copy(strides, blk.strides, md.ndims());
|
||||||
|
utils::simultaneous_sort(strides, dim_chars, md.ndims(),
|
||||||
|
[](dim_t a, dim_t b) { return b - a; });
|
||||||
|
|
||||||
|
dim_chars[md.ndims()] = '\0';
|
||||||
|
DPRINT("%s", dim_chars);
|
||||||
|
|
||||||
|
if (!plain) {
|
||||||
|
for (int iblk = 0; iblk < blk.inner_nblks; ++iblk) {
|
||||||
|
DPRINT("%d%c", (int)blk.inner_blks[iblk],
|
||||||
|
'a' + (char)blk.inner_idxs[iblk]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
DPRINT("%s", ":");
|
||||||
|
}
|
||||||
|
|
||||||
|
DPRINT("f%lx", (long)md.extra().flags);
|
||||||
|
|
||||||
|
return written_len;
|
||||||
|
}
|
||||||
|
|
||||||
|
int mkldnn_md2dim_str(char *str, size_t str_len,
|
||||||
|
const mkldnn_memory_desc_t *mdesc) {
|
||||||
|
using namespace mkldnn::impl;
|
||||||
|
|
||||||
|
if (str == nullptr || str_len <= 1)
|
||||||
|
return -1;
|
||||||
|
|
||||||
|
int written_len = 0;
|
||||||
|
|
||||||
|
if (mdesc == nullptr || mdesc->ndims == 0) {
|
||||||
|
DPRINT("%s", "");
|
||||||
|
return written_len;
|
||||||
|
}
|
||||||
|
|
||||||
|
memory_desc_wrapper md(mdesc);
|
||||||
|
|
||||||
|
for (int d = 0; d < md.ndims() - 1; ++d)
|
||||||
|
DPRINT("%" PRId64 "x", md.dims()[d]);
|
||||||
|
DPRINT("%" PRId64, md.dims()[md.ndims() - 1]);
|
||||||
|
|
||||||
|
return written_len;
|
||||||
|
}
|
||||||
|
|
||||||
|
#undef DPRINT
|
365
thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp
vendored
Normal file
365
thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp
vendored
Normal file
@ -0,0 +1,365 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2018-2019 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
/* DO NOT EDIT, AUTO-GENERATED */
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
#include "mkldnn_debug.h"
|
||||||
|
#include "mkldnn_types.h"
|
||||||
|
|
||||||
|
const char *mkldnn_status2str(mkldnn_status_t v) {
|
||||||
|
if (v == mkldnn_success) return "success";
|
||||||
|
if (v == mkldnn_out_of_memory) return "out_of_memory";
|
||||||
|
if (v == mkldnn_try_again) return "try_again";
|
||||||
|
if (v == mkldnn_invalid_arguments) return "invalid_arguments";
|
||||||
|
if (v == mkldnn_not_ready) return "not_ready";
|
||||||
|
if (v == mkldnn_unimplemented) return "unimplemented";
|
||||||
|
if (v == mkldnn_iterator_ends) return "iterator_ends";
|
||||||
|
if (v == mkldnn_runtime_error) return "runtime_error";
|
||||||
|
if (v == mkldnn_not_required) return "not_required";
|
||||||
|
assert(!"unknown status");
|
||||||
|
return "unknown status";
|
||||||
|
}
|
||||||
|
|
||||||
|
const char *mkldnn_dt2str(mkldnn_data_type_t v) {
|
||||||
|
if (v == mkldnn_data_type_undef) return "undef";
|
||||||
|
if (v == mkldnn_f32) return "f32";
|
||||||
|
if (v == mkldnn_s32) return "s32";
|
||||||
|
if (v == mkldnn_s8) return "s8";
|
||||||
|
if (v == mkldnn_u8) return "u8";
|
||||||
|
assert(!"unknown dt");
|
||||||
|
return "unknown dt";
|
||||||
|
}
|
||||||
|
|
||||||
|
const char *mkldnn_fmt_kind2str(mkldnn_format_kind_t v) {
|
||||||
|
if (v == mkldnn_format_kind_undef) return "undef";
|
||||||
|
if (v == mkldnn_format_kind_any) return "any";
|
||||||
|
if (v == mkldnn_blocked) return "blocked";
|
||||||
|
if (v == mkldnn_format_kind_wino) return "wino";
|
||||||
|
if (v == mkldnn_format_kind_rnn_packed) return "rnn_packed";
|
||||||
|
assert(!"unknown fmt_kind");
|
||||||
|
return "unknown fmt_kind";
|
||||||
|
}
|
||||||
|
|
||||||
|
const char *mkldnn_fmt_tag2str(mkldnn_format_tag_t v) {
|
||||||
|
if (v == mkldnn_format_tag_undef) return "undef";
|
||||||
|
if (v == mkldnn_format_tag_any) return "format_tag_any";
|
||||||
|
if (v == mkldnn_a) return "a";
|
||||||
|
if (v == mkldnn_ab) return "ab";
|
||||||
|
if (v == mkldnn_abc) return "abc";
|
||||||
|
if (v == mkldnn_abcd) return "abcd";
|
||||||
|
if (v == mkldnn_abcde) return "abcde";
|
||||||
|
if (v == mkldnn_abcdef) return "abcdef";
|
||||||
|
if (v == mkldnn_abdec) return "abdec";
|
||||||
|
if (v == mkldnn_acb) return "acb";
|
||||||
|
if (v == mkldnn_acbde) return "acbde";
|
||||||
|
if (v == mkldnn_acdb) return "acdb";
|
||||||
|
if (v == mkldnn_acdeb) return "acdeb";
|
||||||
|
if (v == mkldnn_ba) return "ba";
|
||||||
|
if (v == mkldnn_bac) return "bac";
|
||||||
|
if (v == mkldnn_bacd) return "bacd";
|
||||||
|
if (v == mkldnn_bcda) return "bcda";
|
||||||
|
if (v == mkldnn_cba) return "cba";
|
||||||
|
if (v == mkldnn_cdba) return "cdba";
|
||||||
|
if (v == mkldnn_cdeba) return "cdeba";
|
||||||
|
if (v == mkldnn_decab) return "decab";
|
||||||
|
if (v == mkldnn_Abc16a) return "Abc16a";
|
||||||
|
if (v == mkldnn_ABc16a16b) return "ABc16a16b";
|
||||||
|
if (v == mkldnn_aBc16b) return "aBc16b";
|
||||||
|
if (v == mkldnn_ABc16b16a) return "ABc16b16a";
|
||||||
|
if (v == mkldnn_Abc4a) return "Abc4a";
|
||||||
|
if (v == mkldnn_aBc4b) return "aBc4b";
|
||||||
|
if (v == mkldnn_ABc4b16a4b) return "ABc4b16a4b";
|
||||||
|
if (v == mkldnn_ABc4b4a) return "ABc4b4a";
|
||||||
|
if (v == mkldnn_ABc8a16b2a) return "ABc8a16b2a";
|
||||||
|
if (v == mkldnn_ABc8a8b) return "ABc8a8b";
|
||||||
|
if (v == mkldnn_aBc8b) return "aBc8b";
|
||||||
|
if (v == mkldnn_ABc8b16a2b) return "ABc8b16a2b";
|
||||||
|
if (v == mkldnn_ABc8b8a) return "ABc8b8a";
|
||||||
|
if (v == mkldnn_Abcd16a) return "Abcd16a";
|
||||||
|
if (v == mkldnn_ABcd16a16b) return "ABcd16a16b";
|
||||||
|
if (v == mkldnn_aBcd16b) return "aBcd16b";
|
||||||
|
if (v == mkldnn_ABcd16b16a) return "ABcd16b16a";
|
||||||
|
if (v == mkldnn_aBCd16b16c) return "aBCd16b16c";
|
||||||
|
if (v == mkldnn_aBCd16c16b) return "aBCd16c16b";
|
||||||
|
if (v == mkldnn_Abcd4a) return "Abcd4a";
|
||||||
|
if (v == mkldnn_aBcd4b) return "aBcd4b";
|
||||||
|
if (v == mkldnn_ABcd4b16a4b) return "ABcd4b16a4b";
|
||||||
|
if (v == mkldnn_ABcd4b4a) return "ABcd4b4a";
|
||||||
|
if (v == mkldnn_aBCd4c16b4c) return "aBCd4c16b4c";
|
||||||
|
if (v == mkldnn_aBCd4c4b) return "aBCd4c4b";
|
||||||
|
if (v == mkldnn_ABcd8a16b2a) return "ABcd8a16b2a";
|
||||||
|
if (v == mkldnn_ABcd8a8b) return "ABcd8a8b";
|
||||||
|
if (v == mkldnn_aBcd8b) return "aBcd8b";
|
||||||
|
if (v == mkldnn_ABcd8b16a2b) return "ABcd8b16a2b";
|
||||||
|
if (v == mkldnn_aBCd8b16c2b) return "aBCd8b16c2b";
|
||||||
|
if (v == mkldnn_ABcd8b8a) return "ABcd8b8a";
|
||||||
|
if (v == mkldnn_aBCd8b8c) return "aBCd8b8c";
|
||||||
|
if (v == mkldnn_aBCd8c16b2c) return "aBCd8c16b2c";
|
||||||
|
if (v == mkldnn_aBCd8c8b) return "aBCd8c8b";
|
||||||
|
if (v == mkldnn_Abcde16a) return "Abcde16a";
|
||||||
|
if (v == mkldnn_ABcde16a16b) return "ABcde16a16b";
|
||||||
|
if (v == mkldnn_aBcde16b) return "aBcde16b";
|
||||||
|
if (v == mkldnn_ABcde16b16a) return "ABcde16b16a";
|
||||||
|
if (v == mkldnn_aBCde16b16c) return "aBCde16b16c";
|
||||||
|
if (v == mkldnn_aBCde16c16b) return "aBCde16c16b";
|
||||||
|
if (v == mkldnn_aBCde2c8b4c) return "aBCde2c8b4c";
|
||||||
|
if (v == mkldnn_Abcde4a) return "Abcde4a";
|
||||||
|
if (v == mkldnn_aBcde4b) return "aBcde4b";
|
||||||
|
if (v == mkldnn_ABcde4b4a) return "ABcde4b4a";
|
||||||
|
if (v == mkldnn_aBCde4b4c) return "aBCde4b4c";
|
||||||
|
if (v == mkldnn_aBCde4c16b4c) return "aBCde4c16b4c";
|
||||||
|
if (v == mkldnn_aBCde4c4b) return "aBCde4c4b";
|
||||||
|
if (v == mkldnn_Abcde8a) return "Abcde8a";
|
||||||
|
if (v == mkldnn_ABcde8a8b) return "ABcde8a8b";
|
||||||
|
if (v == mkldnn_ABcde8b16a2b) return "ABcde8b16a2b";
|
||||||
|
if (v == mkldnn_aBCde8b16c2b) return "aBCde8b16c2b";
|
||||||
|
if (v == mkldnn_ABcde8b8a) return "ABcde8b8a";
|
||||||
|
if (v == mkldnn_aBCde8b8c) return "aBCde8b8c";
|
||||||
|
if (v == mkldnn_aBCde8c16b2c) return "aBCde8c16b2c";
|
||||||
|
if (v == mkldnn_aBCde8c8b) return "aBCde8c8b";
|
||||||
|
if (v == mkldnn_aBcdef16b) return "aBcdef16b";
|
||||||
|
if (v == mkldnn_aBCdef16b16c) return "aBCdef16b16c";
|
||||||
|
if (v == mkldnn_aBCdef16c16b) return "aBCdef16c16b";
|
||||||
|
if (v == mkldnn_aBcdef4b) return "aBcdef4b";
|
||||||
|
if (v == mkldnn_aBCdef4c4b) return "aBCdef4c4b";
|
||||||
|
if (v == mkldnn_aBCdef8b8c) return "aBCdef8b8c";
|
||||||
|
if (v == mkldnn_aBCdef8c16b2c) return "aBCdef8c16b2c";
|
||||||
|
if (v == mkldnn_aBCdef8c8b) return "aBCdef8c8b";
|
||||||
|
if (v == mkldnn_aBdc16b) return "aBdc16b";
|
||||||
|
if (v == mkldnn_aBdc4b) return "aBdc4b";
|
||||||
|
if (v == mkldnn_aBdc8b) return "aBdc8b";
|
||||||
|
if (v == mkldnn_aBdec16b) return "aBdec16b";
|
||||||
|
if (v == mkldnn_aBdec4b) return "aBdec4b";
|
||||||
|
if (v == mkldnn_aBdec8b) return "aBdec8b";
|
||||||
|
if (v == mkldnn_aBdefc16b) return "aBdefc16b";
|
||||||
|
if (v == mkldnn_aBdefc4b) return "aBdefc4b";
|
||||||
|
if (v == mkldnn_aBdefc8b) return "aBdefc8b";
|
||||||
|
if (v == mkldnn_Acb16a) return "Acb16a";
|
||||||
|
if (v == mkldnn_Acb4a) return "Acb4a";
|
||||||
|
if (v == mkldnn_Acb8a) return "Acb8a";
|
||||||
|
if (v == mkldnn_aCBd16b16c) return "aCBd16b16c";
|
||||||
|
if (v == mkldnn_aCBde16b16c) return "aCBde16b16c";
|
||||||
|
if (v == mkldnn_Acdb16a) return "Acdb16a";
|
||||||
|
if (v == mkldnn_Acdb4a) return "Acdb4a";
|
||||||
|
if (v == mkldnn_Acdb8a) return "Acdb8a";
|
||||||
|
if (v == mkldnn_Acdeb16a) return "Acdeb16a";
|
||||||
|
if (v == mkldnn_Acdeb4a) return "Acdeb4a";
|
||||||
|
if (v == mkldnn_Acdeb8a) return "Acdeb8a";
|
||||||
|
if (v == mkldnn_BAc16a16b) return "BAc16a16b";
|
||||||
|
if (v == mkldnn_BAcd16a16b) return "BAcd16a16b";
|
||||||
|
if (v == mkldnn_format_tag_last) return "format_tag_last";
|
||||||
|
if (v == mkldnn_x) return "x";
|
||||||
|
if (v == mkldnn_nc) return "nc";
|
||||||
|
if (v == mkldnn_cn) return "cn";
|
||||||
|
if (v == mkldnn_ncw) return "ncw";
|
||||||
|
if (v == mkldnn_nwc) return "nwc";
|
||||||
|
if (v == mkldnn_nchw) return "nchw";
|
||||||
|
if (v == mkldnn_nhwc) return "nhwc";
|
||||||
|
if (v == mkldnn_chwn) return "chwn";
|
||||||
|
if (v == mkldnn_ncdhw) return "ncdhw";
|
||||||
|
if (v == mkldnn_ndhwc) return "ndhwc";
|
||||||
|
if (v == mkldnn_oi) return "oi";
|
||||||
|
if (v == mkldnn_io) return "io";
|
||||||
|
if (v == mkldnn_oiw) return "oiw";
|
||||||
|
if (v == mkldnn_wio) return "wio";
|
||||||
|
if (v == mkldnn_oihw) return "oihw";
|
||||||
|
if (v == mkldnn_hwio) return "hwio";
|
||||||
|
if (v == mkldnn_ihwo) return "ihwo";
|
||||||
|
if (v == mkldnn_iohw) return "iohw";
|
||||||
|
if (v == mkldnn_oidhw) return "oidhw";
|
||||||
|
if (v == mkldnn_dhwio) return "dhwio";
|
||||||
|
if (v == mkldnn_goiw) return "goiw";
|
||||||
|
if (v == mkldnn_goihw) return "goihw";
|
||||||
|
if (v == mkldnn_hwigo) return "hwigo";
|
||||||
|
if (v == mkldnn_giohw) return "giohw";
|
||||||
|
if (v == mkldnn_goidhw) return "goidhw";
|
||||||
|
if (v == mkldnn_tnc) return "tnc";
|
||||||
|
if (v == mkldnn_ntc) return "ntc";
|
||||||
|
if (v == mkldnn_ldsnc) return "ldsnc";
|
||||||
|
if (v == mkldnn_ldigo) return "ldigo";
|
||||||
|
if (v == mkldnn_ldgoi) return "ldgoi";
|
||||||
|
if (v == mkldnn_ldgo) return "ldgo";
|
||||||
|
if (v == mkldnn_nCdhw16c) return "nCdhw16c";
|
||||||
|
if (v == mkldnn_nCdhw4c) return "nCdhw4c";
|
||||||
|
if (v == mkldnn_nCdhw8c) return "nCdhw8c";
|
||||||
|
if (v == mkldnn_nChw16c) return "nChw16c";
|
||||||
|
if (v == mkldnn_nChw4c) return "nChw4c";
|
||||||
|
if (v == mkldnn_nChw8c) return "nChw8c";
|
||||||
|
if (v == mkldnn_nCw16c) return "nCw16c";
|
||||||
|
if (v == mkldnn_nCw4c) return "nCw4c";
|
||||||
|
if (v == mkldnn_nCw8c) return "nCw8c";
|
||||||
|
if (v == mkldnn_IOw16o16i) return "IOw16o16i";
|
||||||
|
if (v == mkldnn_OIw16i16o) return "OIw16i16o";
|
||||||
|
if (v == mkldnn_OIw16o16i) return "OIw16o16i";
|
||||||
|
if (v == mkldnn_Oiw16o) return "Oiw16o";
|
||||||
|
if (v == mkldnn_OIw4i16o4i) return "OIw4i16o4i";
|
||||||
|
if (v == mkldnn_OIw4i4o) return "OIw4i4o";
|
||||||
|
if (v == mkldnn_Oiw4o) return "Oiw4o";
|
||||||
|
if (v == mkldnn_OIw8i16o2i) return "OIw8i16o2i";
|
||||||
|
if (v == mkldnn_OIw8i8o) return "OIw8i8o";
|
||||||
|
if (v == mkldnn_OIw8o16i2o) return "OIw8o16i2o";
|
||||||
|
if (v == mkldnn_OIw8o8i) return "OIw8o8i";
|
||||||
|
if (v == mkldnn_Owi16o) return "Owi16o";
|
||||||
|
if (v == mkldnn_Owi4o) return "Owi4o";
|
||||||
|
if (v == mkldnn_Owi8o) return "Owi8o";
|
||||||
|
if (v == mkldnn_IOhw16o16i) return "IOhw16o16i";
|
||||||
|
if (v == mkldnn_Ohwi16o) return "Ohwi16o";
|
||||||
|
if (v == mkldnn_Ohwi4o) return "Ohwi4o";
|
||||||
|
if (v == mkldnn_Ohwi8o) return "Ohwi8o";
|
||||||
|
if (v == mkldnn_OIhw16i16o) return "OIhw16i16o";
|
||||||
|
if (v == mkldnn_OIhw16o16i) return "OIhw16o16i";
|
||||||
|
if (v == mkldnn_Oihw16o) return "Oihw16o";
|
||||||
|
if (v == mkldnn_OIhw4i16o4i) return "OIhw4i16o4i";
|
||||||
|
if (v == mkldnn_OIhw4i4o) return "OIhw4i4o";
|
||||||
|
if (v == mkldnn_Oihw4o) return "Oihw4o";
|
||||||
|
if (v == mkldnn_OIhw8i16o2i) return "OIhw8i16o2i";
|
||||||
|
if (v == mkldnn_OIhw8i8o) return "OIhw8i8o";
|
||||||
|
if (v == mkldnn_OIhw8o16i2o) return "OIhw8o16i2o";
|
||||||
|
if (v == mkldnn_OIhw8o8i) return "OIhw8o8i";
|
||||||
|
if (v == mkldnn_Odhwi16o) return "Odhwi16o";
|
||||||
|
if (v == mkldnn_Odhwi4o) return "Odhwi4o";
|
||||||
|
if (v == mkldnn_Odhwi8o) return "Odhwi8o";
|
||||||
|
if (v == mkldnn_OIdhw16i16o) return "OIdhw16i16o";
|
||||||
|
if (v == mkldnn_OIdhw16o16i) return "OIdhw16o16i";
|
||||||
|
if (v == mkldnn_Oidhw16o) return "Oidhw16o";
|
||||||
|
if (v == mkldnn_OIdhw4i4o) return "OIdhw4i4o";
|
||||||
|
if (v == mkldnn_Oidhw4o) return "Oidhw4o";
|
||||||
|
if (v == mkldnn_OIdhw8i16o2i) return "OIdhw8i16o2i";
|
||||||
|
if (v == mkldnn_OIdhw8i8o) return "OIdhw8i8o";
|
||||||
|
if (v == mkldnn_OIdhw8o8i) return "OIdhw8o8i";
|
||||||
|
if (v == mkldnn_Goiw16g) return "Goiw16g";
|
||||||
|
if (v == mkldnn_gIOw16o16i) return "gIOw16o16i";
|
||||||
|
if (v == mkldnn_gOIw16i16o) return "gOIw16i16o";
|
||||||
|
if (v == mkldnn_gOIw16o16i) return "gOIw16o16i";
|
||||||
|
if (v == mkldnn_gOiw16o) return "gOiw16o";
|
||||||
|
if (v == mkldnn_gOIw4i16o4i) return "gOIw4i16o4i";
|
||||||
|
if (v == mkldnn_gOIw4i4o) return "gOIw4i4o";
|
||||||
|
if (v == mkldnn_gOiw4o) return "gOiw4o";
|
||||||
|
if (v == mkldnn_gOIw8i16o2i) return "gOIw8i16o2i";
|
||||||
|
if (v == mkldnn_gOIw8i8o) return "gOIw8i8o";
|
||||||
|
if (v == mkldnn_gOIw8o16i2o) return "gOIw8o16i2o";
|
||||||
|
if (v == mkldnn_gOIw8o8i) return "gOIw8o8i";
|
||||||
|
if (v == mkldnn_gOwi16o) return "gOwi16o";
|
||||||
|
if (v == mkldnn_gOwi4o) return "gOwi4o";
|
||||||
|
if (v == mkldnn_gOwi8o) return "gOwi8o";
|
||||||
|
if (v == mkldnn_gIOhw16o16i) return "gIOhw16o16i";
|
||||||
|
if (v == mkldnn_gOhwi16o) return "gOhwi16o";
|
||||||
|
if (v == mkldnn_gOhwi4o) return "gOhwi4o";
|
||||||
|
if (v == mkldnn_gOhwi8o) return "gOhwi8o";
|
||||||
|
if (v == mkldnn_Goihw16g) return "Goihw16g";
|
||||||
|
if (v == mkldnn_gOIhw16i16o) return "gOIhw16i16o";
|
||||||
|
if (v == mkldnn_gOIhw16o16i) return "gOIhw16o16i";
|
||||||
|
if (v == mkldnn_gOihw16o) return "gOihw16o";
|
||||||
|
if (v == mkldnn_gOIhw2i8o4i) return "gOIhw2i8o4i";
|
||||||
|
if (v == mkldnn_gOIhw4i16o4i) return "gOIhw4i16o4i";
|
||||||
|
if (v == mkldnn_gOIhw4i4o) return "gOIhw4i4o";
|
||||||
|
if (v == mkldnn_gOIhw4o4i) return "gOIhw4o4i";
|
||||||
|
if (v == mkldnn_gOihw4o) return "gOihw4o";
|
||||||
|
if (v == mkldnn_Goihw8g) return "Goihw8g";
|
||||||
|
if (v == mkldnn_gOIhw8i16o2i) return "gOIhw8i16o2i";
|
||||||
|
if (v == mkldnn_gOIhw8i8o) return "gOIhw8i8o";
|
||||||
|
if (v == mkldnn_gOIhw8o16i2o) return "gOIhw8o16i2o";
|
||||||
|
if (v == mkldnn_gOIhw8o8i) return "gOIhw8o8i";
|
||||||
|
if (v == mkldnn_gOdhwi16o) return "gOdhwi16o";
|
||||||
|
if (v == mkldnn_gOdhwi4o) return "gOdhwi4o";
|
||||||
|
if (v == mkldnn_gOdhwi8o) return "gOdhwi8o";
|
||||||
|
if (v == mkldnn_gOIdhw16i16o) return "gOIdhw16i16o";
|
||||||
|
if (v == mkldnn_gOIdhw16o16i) return "gOIdhw16o16i";
|
||||||
|
if (v == mkldnn_gOidhw16o) return "gOidhw16o";
|
||||||
|
if (v == mkldnn_gOIdhw4i4o) return "gOIdhw4i4o";
|
||||||
|
if (v == mkldnn_gOidhw4o) return "gOidhw4o";
|
||||||
|
if (v == mkldnn_gOIdhw8i16o2i) return "gOIdhw8i16o2i";
|
||||||
|
if (v == mkldnn_gOIdhw8i8o) return "gOIdhw8i8o";
|
||||||
|
if (v == mkldnn_gOIdhw8o8i) return "gOIdhw8o8i";
|
||||||
|
assert(!"unknown fmt_tag");
|
||||||
|
return "unknown fmt_tag";
|
||||||
|
}
|
||||||
|
|
||||||
|
const char *mkldnn_prop_kind2str(mkldnn_prop_kind_t v) {
|
||||||
|
if (v == mkldnn_prop_kind_undef) return "undef";
|
||||||
|
if (v == mkldnn_forward_training) return "forward_training";
|
||||||
|
if (v == mkldnn_forward_inference) return "forward_inference";
|
||||||
|
if (v == mkldnn_forward_scoring) return "forward_scoring";
|
||||||
|
if (v == mkldnn_forward) return "forward";
|
||||||
|
if (v == mkldnn_backward) return "backward";
|
||||||
|
if (v == mkldnn_backward_data) return "backward_data";
|
||||||
|
if (v == mkldnn_backward_weights) return "backward_weights";
|
||||||
|
if (v == mkldnn_backward_bias) return "backward_bias";
|
||||||
|
assert(!"unknown prop_kind");
|
||||||
|
return "unknown prop_kind";
|
||||||
|
}
|
||||||
|
|
||||||
|
const char *mkldnn_prim_kind2str(mkldnn_primitive_kind_t v) {
|
||||||
|
if (v == mkldnn_undefined_primitive) return "undef";
|
||||||
|
if (v == mkldnn_reorder) return "reorder";
|
||||||
|
if (v == mkldnn_shuffle) return "shuffle";
|
||||||
|
if (v == mkldnn_concat) return "concat";
|
||||||
|
if (v == mkldnn_sum) return "sum";
|
||||||
|
if (v == mkldnn_convolution) return "convolution";
|
||||||
|
if (v == mkldnn_deconvolution) return "deconvolution";
|
||||||
|
if (v == mkldnn_eltwise) return "eltwise";
|
||||||
|
if (v == mkldnn_softmax) return "softmax";
|
||||||
|
if (v == mkldnn_pooling) return "pooling";
|
||||||
|
if (v == mkldnn_lrn) return "lrn";
|
||||||
|
if (v == mkldnn_batch_normalization) return "batch_normalization";
|
||||||
|
if (v == mkldnn_inner_product) return "inner_product";
|
||||||
|
if (v == mkldnn_rnn) return "rnn";
|
||||||
|
assert(!"unknown prim_kind");
|
||||||
|
return "unknown prim_kind";
|
||||||
|
}
|
||||||
|
|
||||||
|
const char *mkldnn_alg_kind2str(mkldnn_alg_kind_t v) {
|
||||||
|
if (v == mkldnn_alg_kind_undef) return "undef";
|
||||||
|
if (v == mkldnn_convolution_direct) return "convolution_direct";
|
||||||
|
if (v == mkldnn_convolution_winograd) return "convolution_winograd";
|
||||||
|
if (v == mkldnn_convolution_auto) return "convolution_auto";
|
||||||
|
if (v == mkldnn_deconvolution_direct) return "deconvolution_direct";
|
||||||
|
if (v == mkldnn_deconvolution_winograd) return "deconvolution_winograd";
|
||||||
|
if (v == mkldnn_eltwise_relu) return "eltwise_relu";
|
||||||
|
if (v == mkldnn_eltwise_tanh) return "eltwise_tanh";
|
||||||
|
if (v == mkldnn_eltwise_elu) return "eltwise_elu";
|
||||||
|
if (v == mkldnn_eltwise_square) return "eltwise_square";
|
||||||
|
if (v == mkldnn_eltwise_abs) return "eltwise_abs";
|
||||||
|
if (v == mkldnn_eltwise_sqrt) return "eltwise_sqrt";
|
||||||
|
if (v == mkldnn_eltwise_linear) return "eltwise_linear";
|
||||||
|
if (v == mkldnn_eltwise_bounded_relu) return "eltwise_bounded_relu";
|
||||||
|
if (v == mkldnn_eltwise_soft_relu) return "eltwise_soft_relu";
|
||||||
|
if (v == mkldnn_eltwise_logistic) return "eltwise_logistic";
|
||||||
|
if (v == mkldnn_pooling_max) return "pooling_max";
|
||||||
|
if (v == mkldnn_pooling_avg_include_padding) return "pooling_avg_include_padding";
|
||||||
|
if (v == mkldnn_pooling_avg_exclude_padding) return "pooling_avg_exclude_padding";
|
||||||
|
if (v == mkldnn_pooling_avg) return "pooling_avg";
|
||||||
|
if (v == mkldnn_lrn_across_channels) return "lrn_across_channels";
|
||||||
|
if (v == mkldnn_lrn_within_channel) return "lrn_within_channel";
|
||||||
|
if (v == mkldnn_vanilla_rnn) return "vanilla_rnn";
|
||||||
|
if (v == mkldnn_vanilla_lstm) return "vanilla_lstm";
|
||||||
|
if (v == mkldnn_vanilla_gru) return "vanilla_gru";
|
||||||
|
if (v == mkldnn_gru_linear_before_reset) return "gru_linear_before_reset";
|
||||||
|
assert(!"unknown alg_kind");
|
||||||
|
return "unknown alg_kind";
|
||||||
|
}
|
||||||
|
|
||||||
|
const char *mkldnn_rnn_direction2str(mkldnn_rnn_direction_t v) {
|
||||||
|
if (v == mkldnn_unidirectional_left2right) return "unidirectional_left2right";
|
||||||
|
if (v == mkldnn_unidirectional_right2left) return "unidirectional_right2left";
|
||||||
|
if (v == mkldnn_bidirectional_concat) return "bidirectional_concat";
|
||||||
|
if (v == mkldnn_bidirectional_sum) return "bidirectional_sum";
|
||||||
|
if (v == mkldnn_unidirectional) return "unidirectional";
|
||||||
|
assert(!"unknown rnn_direction");
|
||||||
|
return "unknown rnn_direction";
|
||||||
|
}
|
115
thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp
vendored
Normal file
115
thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp
vendored
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2017-2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#ifndef MKLDNN_THREAD_HPP
|
||||||
|
#define MKLDNN_THREAD_HPP
|
||||||
|
|
||||||
|
#include "utils.hpp"
|
||||||
|
#include "z_magic.hpp"
|
||||||
|
|
||||||
|
#define MKLDNN_THR_SEQ 0
|
||||||
|
#define MKLDNN_THR_OMP 1
|
||||||
|
#define MKLDNN_THR_TBB 2
|
||||||
|
|
||||||
|
/* Ideally this condition below should never happen (if the library is built
|
||||||
|
* using regular cmake). For the 3rd-party projects that build the library
|
||||||
|
* from the sources on their own try to guess the right threading... */
|
||||||
|
#if !defined(MKLDNN_THR)
|
||||||
|
# define MKLDNN_THR MKLDNN_THR_TBB
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if MKLDNN_THR == MKLDNN_THR_SEQ
|
||||||
|
#define MKLDNN_THR_SYNC 1
|
||||||
|
inline int mkldnn_get_max_threads() { return 1; }
|
||||||
|
inline int mkldnn_get_num_threads() { return 1; }
|
||||||
|
inline int mkldnn_get_thread_num() { return 0; }
|
||||||
|
inline int mkldnn_in_parallel() { return 0; }
|
||||||
|
inline void mkldnn_thr_barrier() {}
|
||||||
|
|
||||||
|
#define PRAGMA_OMP(...)
|
||||||
|
|
||||||
|
#elif MKLDNN_THR == MKLDNN_THR_OMP
|
||||||
|
#include <omp.h>
|
||||||
|
#define MKLDNN_THR_SYNC 1
|
||||||
|
|
||||||
|
inline int mkldnn_get_max_threads() { return omp_get_max_threads(); }
|
||||||
|
inline int mkldnn_get_num_threads() { return omp_get_num_threads(); }
|
||||||
|
inline int mkldnn_get_thread_num() { return omp_get_thread_num(); }
|
||||||
|
inline int mkldnn_in_parallel() { return omp_in_parallel(); }
|
||||||
|
inline void mkldnn_thr_barrier() {
|
||||||
|
# pragma omp barrier
|
||||||
|
}
|
||||||
|
|
||||||
|
#define PRAGMA_OMP(...) PRAGMA_MACRO(CHAIN2(omp, __VA_ARGS__))
|
||||||
|
|
||||||
|
#elif MKLDNN_THR == MKLDNN_THR_TBB
|
||||||
|
#include "tbb/task_arena.h"
|
||||||
|
#include "tbb/parallel_for.h"
|
||||||
|
#define MKLDNN_THR_SYNC 0
|
||||||
|
|
||||||
|
inline int mkldnn_get_max_threads()
|
||||||
|
{ return tbb::this_task_arena::max_concurrency(); }
|
||||||
|
inline int mkldnn_get_num_threads() { return mkldnn_get_max_threads(); }
|
||||||
|
inline int mkldnn_get_thread_num()
|
||||||
|
{ return tbb::this_task_arena::current_thread_index(); }
|
||||||
|
inline int mkldnn_in_parallel() { return 0; }
|
||||||
|
inline void mkldnn_thr_barrier() { assert(!"no barrier in TBB"); }
|
||||||
|
|
||||||
|
#define PRAGMA_OMP(...)
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/* MSVC still supports omp 2.0 only */
|
||||||
|
#if defined(_MSC_VER) && !defined(__clang__) && !defined(__INTEL_COMPILER)
|
||||||
|
# define collapse(x)
|
||||||
|
# define PRAGMA_OMP_SIMD(...)
|
||||||
|
#else
|
||||||
|
# define PRAGMA_OMP_SIMD(...) PRAGMA_MACRO(CHAIN2(omp, simd __VA_ARGS__))
|
||||||
|
#endif // defined(_MSC_VER) && !defined(__INTEL_COMPILER)
|
||||||
|
|
||||||
|
namespace mkldnn {
|
||||||
|
namespace impl {
|
||||||
|
|
||||||
|
inline bool mkldnn_thr_syncable() { return MKLDNN_THR_SYNC == 1; }
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
inline void balance211(T n, U team, U tid, T &n_start, T &n_end) {
|
||||||
|
T n_min = 1;
|
||||||
|
T &n_my = n_end;
|
||||||
|
if (team <= 1 || n == 0) {
|
||||||
|
n_start = 0;
|
||||||
|
n_my = n;
|
||||||
|
} else if (n_min == 1) {
|
||||||
|
// team = T1 + T2
|
||||||
|
// n = T1*n1 + T2*n2 (n1 - n2 = 1)
|
||||||
|
T n1 = utils::div_up(n, (T)team);
|
||||||
|
T n2 = n1 - 1;
|
||||||
|
T T1 = n - n2 * (T)team;
|
||||||
|
n_my = (T)tid < T1 ? n1 : n2;
|
||||||
|
n_start = (T)tid <= T1 ? tid * n1 : T1 * n1 + ((T)tid - T1) * n2;
|
||||||
|
}
|
||||||
|
|
||||||
|
n_end += n_start;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace impl
|
||||||
|
} // namespace mkldnn
|
||||||
|
|
||||||
|
#include "mkldnn_thread_parallel_nd.hpp"
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
277
thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread_parallel_nd.hpp
vendored
Normal file
277
thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread_parallel_nd.hpp
vendored
Normal file
@ -0,0 +1,277 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#ifndef MKLDNN_THREAD_PARALLEL_ND_HPP
|
||||||
|
#define MKLDNN_THREAD_PARALLEL_ND_HPP
|
||||||
|
|
||||||
|
/* This header must be included by mkldnn_thread.hpp only */
|
||||||
|
|
||||||
|
/* Functions:
|
||||||
|
* - parallel(nthr, f) - executes f in parallel using at most
|
||||||
|
* nthr threads. If nthr equals 0
|
||||||
|
* mkldnn_get_max_threads() threads is
|
||||||
|
* used
|
||||||
|
* - for_nd(ithr, nthr, dims..., f) - multidimensional for loop for already
|
||||||
|
* created threads
|
||||||
|
* - parallel_nd(dims..., f) - creates a parallel section and then
|
||||||
|
* calls for_nd
|
||||||
|
* - parallel_nd_in_omp(dims..., f) - queries current nthr and ithr and then
|
||||||
|
* calls for_nd (mostly for convenience)
|
||||||
|
*/
|
||||||
|
|
||||||
|
namespace mkldnn {
|
||||||
|
namespace impl {
|
||||||
|
|
||||||
|
/* general parallelization */
|
||||||
|
template <typename F>
|
||||||
|
void parallel(int nthr, F f) {
|
||||||
|
if (nthr == 0) nthr = mkldnn_get_max_threads();
|
||||||
|
#if MKLDNN_THR == MKLDNN_THR_SEQ
|
||||||
|
assert(nthr == 1);
|
||||||
|
f(0, 1);
|
||||||
|
#elif MKLDNN_THR == MKLDNN_THR_OMP
|
||||||
|
if (nthr == 1) { f(0, 1); return; }
|
||||||
|
# pragma omp parallel num_threads(nthr)
|
||||||
|
f(mkldnn_get_thread_num(), mkldnn_get_num_threads());
|
||||||
|
#elif MKLDNN_THR == MKLDNN_THR_TBB
|
||||||
|
if (nthr == 1) { f(0, 1); return; }
|
||||||
|
tbb::parallel_for(0, nthr, [&](int ithr) { f(ithr, nthr); }, tbb::static_partitioner());
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
/* for_nd section */
|
||||||
|
|
||||||
|
template <typename T0, typename F>
|
||||||
|
void for_nd(const int ithr, const int nthr, const T0 &D0, F f) {
|
||||||
|
T0 start{0}, end{0};
|
||||||
|
balance211(D0, nthr, ithr, start, end);
|
||||||
|
for (T0 d0 = start; d0 < end; ++d0) f(d0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T0, typename T1, typename F>
|
||||||
|
void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, F f) {
|
||||||
|
const size_t work_amount = (size_t)D0 * D1;
|
||||||
|
if (work_amount == 0) return;
|
||||||
|
size_t start{0}, end{0};
|
||||||
|
balance211(work_amount, nthr, ithr, start, end);
|
||||||
|
|
||||||
|
T0 d0{0}; T1 d1{0};
|
||||||
|
utils::nd_iterator_init(start, d0, D0, d1, D1);
|
||||||
|
for (size_t iwork = start; iwork < end; ++iwork) {
|
||||||
|
f(d0, d1);
|
||||||
|
utils::nd_iterator_step(d0, D0, d1, D1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T0, typename T1, typename T2, typename F>
|
||||||
|
void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
|
||||||
|
const T2 &D2, F f) {
|
||||||
|
const size_t work_amount = (size_t)D0 * D1 * D2;
|
||||||
|
if (work_amount == 0) return;
|
||||||
|
size_t start{0}, end{0};
|
||||||
|
balance211(work_amount, nthr, ithr, start, end);
|
||||||
|
|
||||||
|
T0 d0{0}; T1 d1{0}; T2 d2{0};
|
||||||
|
utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2);
|
||||||
|
for (size_t iwork = start; iwork < end; ++iwork) {
|
||||||
|
f(d0, d1, d2);
|
||||||
|
utils::nd_iterator_step(d0, D0, d1, D1, d2, D2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T0, typename T1, typename T2, typename T3, typename F>
|
||||||
|
void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
|
||||||
|
const T2 &D2, const T3 &D3, F f) {
|
||||||
|
const size_t work_amount = (size_t)D0 * D1 * D2 * D3;
|
||||||
|
if (work_amount == 0) return;
|
||||||
|
size_t start{0}, end{0};
|
||||||
|
balance211(work_amount, nthr, ithr, start, end);
|
||||||
|
|
||||||
|
T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0};
|
||||||
|
utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3);
|
||||||
|
for (size_t iwork = start; iwork < end; ++iwork) {
|
||||||
|
f(d0, d1, d2, d3);
|
||||||
|
utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T0, typename T1, typename T2, typename T3, typename T4,
|
||||||
|
typename F>
|
||||||
|
void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
|
||||||
|
const T2 &D2, const T3 &D3, const T4 &D4, F f) {
|
||||||
|
const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4;
|
||||||
|
if (work_amount == 0) return;
|
||||||
|
size_t start{0}, end{0};
|
||||||
|
balance211(work_amount, nthr, ithr, start, end);
|
||||||
|
|
||||||
|
T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0};
|
||||||
|
utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
|
||||||
|
for (size_t iwork = start; iwork < end; ++iwork) {
|
||||||
|
f(d0, d1, d2, d3, d4);
|
||||||
|
utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T0, typename T1, typename T2, typename T3, typename T4,
|
||||||
|
typename T5, typename F>
|
||||||
|
void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
|
||||||
|
const T2 &D2, const T3 &D3, const T4 &D4, const T5 &D5, F f) {
|
||||||
|
const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5;
|
||||||
|
if (work_amount == 0) return;
|
||||||
|
size_t start{0}, end{0};
|
||||||
|
balance211(work_amount, nthr, ithr, start, end);
|
||||||
|
|
||||||
|
T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; T5 d5{0};
|
||||||
|
utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4,
|
||||||
|
d5, D5);
|
||||||
|
for (size_t iwork = start; iwork < end; ++iwork) {
|
||||||
|
f(d0, d1, d2, d3, d4, d5);
|
||||||
|
utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip a lambda function in the parameter pack.
|
||||||
|
template <typename T>
|
||||||
|
constexpr size_t get_work_amount(const T &v) { return 1; }
|
||||||
|
template <typename T, typename ...Args>
|
||||||
|
constexpr size_t get_work_amount(const T &v, Args &&...args)
|
||||||
|
{ return (size_t)v * get_work_amount(utils::forward<Args>(args)...); }
|
||||||
|
|
||||||
|
/* parallel_nd and parallel_nd_in_omp section */
|
||||||
|
|
||||||
|
#if MKLDNN_THR != MKLDNN_THR_TBB
|
||||||
|
template <typename ...Args>
|
||||||
|
void parallel_nd(Args &&...args) {
|
||||||
|
#if MKLDNN_THR == MKLDNN_THR_SEQ
|
||||||
|
for_nd(0, 1, utils::forward<Args>(args)...);
|
||||||
|
#elif MKLDNN_THR == MKLDNN_THR_OMP
|
||||||
|
const bool do_parallel = get_work_amount(utils::forward<Args>(args)...) > 1;
|
||||||
|
# pragma omp parallel if (do_parallel)
|
||||||
|
{
|
||||||
|
const int nthr = !do_parallel ? 1 : mkldnn_get_num_threads();
|
||||||
|
const int ithr = !do_parallel ? 0 : mkldnn_get_thread_num();
|
||||||
|
for_nd(ithr, nthr, utils::forward<Args>(args)...);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
#else // MKLDNN_THR != MKLDNN_THR_TBB
|
||||||
|
|
||||||
|
// gcc 4.8 has a bug with passing parameter pack to lambdas.
|
||||||
|
// So have to explicitly instantiate all the cases.
|
||||||
|
|
||||||
|
template <typename T0, typename F>
|
||||||
|
void parallel_nd(const T0 &D0, F f) {
|
||||||
|
const size_t work_amount = (size_t)D0;
|
||||||
|
if (work_amount == 0) return;
|
||||||
|
tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
|
||||||
|
for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
|
||||||
|
f(T0(iwork));
|
||||||
|
}
|
||||||
|
}, tbb::static_partitioner());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T0, typename T1, typename F>
|
||||||
|
void parallel_nd(const T0 &D0, const T1 &D1, F f) {
|
||||||
|
const size_t work_amount = (size_t)D0 * D1;
|
||||||
|
if (work_amount == 0) return;
|
||||||
|
tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
|
||||||
|
T0 d0{0}; T1 d1{0};
|
||||||
|
utils::nd_iterator_init(r.begin(), d0, D0, d1, D1);
|
||||||
|
for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
|
||||||
|
f(d0, d1);
|
||||||
|
utils::nd_iterator_step(d0, D0, d1, D1);
|
||||||
|
}
|
||||||
|
}, tbb::static_partitioner());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T0, typename T1, typename T2, typename F>
|
||||||
|
void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, F f) {
|
||||||
|
const size_t work_amount = (size_t)D0 * D1 * D2;
|
||||||
|
if (work_amount == 0) return;
|
||||||
|
tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
|
||||||
|
T0 d0{0}; T1 d1{0}; T2 d2{0};
|
||||||
|
utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2);
|
||||||
|
for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
|
||||||
|
f(d0, d1, d2);
|
||||||
|
utils::nd_iterator_step(d0, D0, d1, D1, d2, D2);
|
||||||
|
}
|
||||||
|
}, tbb::static_partitioner());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T0, typename T1, typename T2, typename T3, typename F>
|
||||||
|
void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3, F f) {
|
||||||
|
const size_t work_amount = (size_t)D0 * D1 * D2 * D3;
|
||||||
|
if (work_amount == 0) return;
|
||||||
|
tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
|
||||||
|
T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0};
|
||||||
|
utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3);
|
||||||
|
for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
|
||||||
|
f(d0, d1, d2, d3);
|
||||||
|
utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3);
|
||||||
|
}
|
||||||
|
}, tbb::static_partitioner());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T0, typename T1, typename T2, typename T3, typename T4,
|
||||||
|
typename F>
|
||||||
|
void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3,
|
||||||
|
const T4 &D4, F f) {
|
||||||
|
const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4;
|
||||||
|
if (work_amount == 0) return;
|
||||||
|
tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
|
||||||
|
T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0};
|
||||||
|
utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
|
||||||
|
for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
|
||||||
|
f(d0, d1, d2, d3, d4);
|
||||||
|
utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
|
||||||
|
}
|
||||||
|
}, tbb::static_partitioner());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T0, typename T1, typename T2, typename T3, typename T4,
|
||||||
|
typename T5, typename F>
|
||||||
|
void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3,
|
||||||
|
const T4 &D4, const T5 &D5, F f) {
|
||||||
|
const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5;
|
||||||
|
if (work_amount == 0) return;
|
||||||
|
tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
|
||||||
|
T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; T5 d5{0};
|
||||||
|
utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3, d4, D4,
|
||||||
|
d5, D5);
|
||||||
|
for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
|
||||||
|
f(d0, d1, d2, d3, d4, d5);
|
||||||
|
utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5);
|
||||||
|
}
|
||||||
|
}, tbb::static_partitioner());
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
template <typename ...Args>
|
||||||
|
void parallel_nd_in_omp(Args &&...args) {
|
||||||
|
#if MKLDNN_THR == MKLDNN_THR_SEQ
|
||||||
|
for_nd(0, 1, utils::forward<Args>(args)...);
|
||||||
|
#elif MKLDNN_THR == MKLDNN_THR_OMP
|
||||||
|
for_nd(mkldnn_get_thread_num(), mkldnn_get_num_threads(),
|
||||||
|
utils::forward<Args>(args)...);
|
||||||
|
#elif MKLDNN_THR == MKLDNN_THR_TBB
|
||||||
|
assert(!"unsupported parallel_nd_in_omp()");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace impl
|
||||||
|
} // namespace mkldnn
|
||||||
|
|
||||||
|
#endif
|
77
thirdparty/oidn/mkl-dnn/src/common/mkldnn_traits.hpp
vendored
Normal file
77
thirdparty/oidn/mkl-dnn/src/common/mkldnn_traits.hpp
vendored
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2016-2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#ifndef MKLDNN_TRAITS_HPP
|
||||||
|
#define MKLDNN_TRAITS_HPP
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include "mkldnn.h"
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "nstl.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
#include "z_magic.hpp"
|
||||||
|
|
||||||
|
namespace mkldnn {
|
||||||
|
namespace impl {
|
||||||
|
|
||||||
|
template <data_type_t> struct prec_traits {}; /* ::type -> float */
|
||||||
|
template <typename> struct data_traits {}; /* ::data_type -> f32 */
|
||||||
|
template <int> struct typesize_traits {}; /* ::data_type_size -> f32 */
|
||||||
|
template <primitive_kind_t> struct pkind_traits {}; /* ::desc_type, ::query_d */
|
||||||
|
|
||||||
|
template <> struct prec_traits<data_type::f32> { typedef float type; };
|
||||||
|
template <> struct prec_traits<data_type::s32> { typedef int32_t type; };
|
||||||
|
template <> struct prec_traits<data_type::s8> { typedef int8_t type; };
|
||||||
|
template <> struct prec_traits<data_type::u8> { typedef uint8_t type; };
|
||||||
|
|
||||||
|
template <> struct data_traits<float>
|
||||||
|
{ static constexpr data_type_t data_type = data_type::f32; };
|
||||||
|
template <> struct data_traits<int32_t>
|
||||||
|
{ static constexpr data_type_t data_type = data_type::s32; };
|
||||||
|
template <> struct data_traits<int8_t>
|
||||||
|
{ static constexpr data_type_t data_type = data_type::s8; };
|
||||||
|
template <> struct data_traits<uint8_t>
|
||||||
|
{ static constexpr data_type_t data_type = data_type::u8; };
|
||||||
|
|
||||||
|
template <> struct typesize_traits<4> { typedef float type; };
|
||||||
|
template <> struct typesize_traits<2> { typedef int16_t type; };
|
||||||
|
template <> struct typesize_traits<1> { typedef uint8_t type; };
|
||||||
|
|
||||||
|
#define PKIND_TRAITS_INST(op) \
|
||||||
|
template <> struct pkind_traits<primitive_kind::op> { \
|
||||||
|
typedef CONCAT2(op, _desc_t) desc_type; \
|
||||||
|
static constexpr query_t query_d = query::CONCAT2(op, _d); \
|
||||||
|
}
|
||||||
|
PKIND_TRAITS_INST(convolution);
|
||||||
|
PKIND_TRAITS_INST(deconvolution);
|
||||||
|
PKIND_TRAITS_INST(shuffle);
|
||||||
|
PKIND_TRAITS_INST(eltwise);
|
||||||
|
PKIND_TRAITS_INST(softmax);
|
||||||
|
PKIND_TRAITS_INST(pooling);
|
||||||
|
PKIND_TRAITS_INST(lrn);
|
||||||
|
PKIND_TRAITS_INST(batch_normalization);
|
||||||
|
PKIND_TRAITS_INST(inner_product);
|
||||||
|
PKIND_TRAITS_INST(rnn);
|
||||||
|
#undef PKIND_TRAITS_INST
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
193
thirdparty/oidn/mkl-dnn/src/common/nstl.hpp
vendored
Normal file
193
thirdparty/oidn/mkl-dnn/src/common/nstl.hpp
vendored
Normal file
@ -0,0 +1,193 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2016-2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#ifndef NSTL_HPP
|
||||||
|
#define NSTL_HPP
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <limits.h>
|
||||||
|
#include <float.h>
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
#include "z_magic.hpp"
|
||||||
|
|
||||||
|
namespace mkldnn {
|
||||||
|
namespace impl {
|
||||||
|
|
||||||
|
void *malloc(size_t size, int alignment);
|
||||||
|
void free(void *p);
|
||||||
|
|
||||||
|
struct c_compatible {
|
||||||
|
enum { default_alignment = 64 };
|
||||||
|
static void *operator new(size_t sz) {
|
||||||
|
return malloc(sz, default_alignment);
|
||||||
|
}
|
||||||
|
static void *operator new(size_t sz, void *p) { UNUSED(sz); return p; }
|
||||||
|
static void *operator new[](size_t sz) {
|
||||||
|
return malloc(sz, default_alignment);
|
||||||
|
}
|
||||||
|
static void operator delete(void *p) { free(p); }
|
||||||
|
static void operator delete[](void *p) { free(p); }
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace nstl {
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
inline const T abs(const T& a) {
|
||||||
|
return a >= 0 ? a : -a;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
inline const T& max(const T& a, const T& b) {
|
||||||
|
return a > b ? a : b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
inline const T& min(const T& a, const T& b) {
|
||||||
|
return a < b ? a : b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T> void swap(T& t1, T& t2) {
|
||||||
|
T tmp(t1);
|
||||||
|
t1 = t2;
|
||||||
|
t2 = tmp;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rationale: MKL-DNN needs numeric limits implementation that does not
|
||||||
|
// generate dependencies on C++ run-time libraries.
|
||||||
|
|
||||||
|
template<typename T> struct numeric_limits;
|
||||||
|
|
||||||
|
template<> struct numeric_limits<float> {
|
||||||
|
static constexpr float lowest() { return -FLT_MAX; }
|
||||||
|
static constexpr float max() { return FLT_MAX; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template<> struct numeric_limits<int32_t> {
|
||||||
|
static constexpr int lowest() { return INT32_MIN; }
|
||||||
|
static constexpr int max() { return INT32_MAX; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template<> struct numeric_limits<int16_t> {
|
||||||
|
static constexpr int16_t lowest() { return INT16_MIN; }
|
||||||
|
static constexpr int16_t max() { return INT16_MAX; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template<> struct numeric_limits<int8_t> {
|
||||||
|
static constexpr int8_t lowest() { return INT8_MIN; }
|
||||||
|
static constexpr int8_t max() { return INT8_MAX; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template<> struct numeric_limits<uint8_t> {
|
||||||
|
static constexpr uint8_t lowest() { return 0; }
|
||||||
|
static constexpr uint8_t max() { return UINT8_MAX; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename T> struct is_integral
|
||||||
|
{ static constexpr bool value = false; };
|
||||||
|
template<> struct is_integral<int32_t> { static constexpr bool value = true; };
|
||||||
|
template<> struct is_integral<int16_t> { static constexpr bool value = true; };
|
||||||
|
template<> struct is_integral<int8_t> { static constexpr bool value = true; };
|
||||||
|
template<> struct is_integral<uint8_t> { static constexpr bool value = true; };
|
||||||
|
|
||||||
|
template <typename T, typename U> struct is_same
|
||||||
|
{ static constexpr bool value = false; };
|
||||||
|
template <typename T> struct is_same<T, T>
|
||||||
|
{ static constexpr bool value = true; };
|
||||||
|
|
||||||
|
// Rationale: MKL-DNN needs container implementations that do not generate
|
||||||
|
// dependencies on C++ run-time libraries.
|
||||||
|
//
|
||||||
|
// Implementation philosophy: caller is responsible to check if the operation
|
||||||
|
// is valid. The only functions that have to return status are those that
|
||||||
|
// depend on memory allocation or similar operations.
|
||||||
|
//
|
||||||
|
// This means that e.g. an operator [] does not have to check for boundaries.
|
||||||
|
// The caller should have checked the boundaries. If it did not we crash and
|
||||||
|
// burn: this is a bug in MKL-DNN and throwing an exception would not have been
|
||||||
|
// recoverable.
|
||||||
|
//
|
||||||
|
// On the other hand, insert() or resize() or a similar operation needs to
|
||||||
|
// return a status because the outcome depends on factors external to the
|
||||||
|
// caller. The situation is probably also not recoverable also, but MKL-DNN
|
||||||
|
// needs to be nice and report "out of memory" to the users.
|
||||||
|
|
||||||
|
enum nstl_status_t {
|
||||||
|
success = 0,
|
||||||
|
out_of_memory
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T> class vector: public c_compatible {
|
||||||
|
private:
|
||||||
|
std::vector<T> _impl;
|
||||||
|
public:
|
||||||
|
typedef typename std::vector<T>::iterator iterator;
|
||||||
|
typedef typename std::vector<T>::const_iterator const_iterator;
|
||||||
|
typedef typename std::vector<T>::size_type size_type;
|
||||||
|
vector() {}
|
||||||
|
vector(size_type n): _impl(n) {}
|
||||||
|
vector(size_type n, const T &value): _impl(n, value) {}
|
||||||
|
template <typename input_iterator>
|
||||||
|
vector(input_iterator first, input_iterator last): _impl(first, last) {}
|
||||||
|
~vector() {}
|
||||||
|
size_type size() const { return _impl.size(); }
|
||||||
|
T& operator[] (size_type i) { return _impl[i]; }
|
||||||
|
const T& operator[] (size_type i) const { return _impl[i]; }
|
||||||
|
iterator begin() { return _impl.begin(); }
|
||||||
|
const_iterator begin() const { return _impl.begin(); }
|
||||||
|
iterator end() { return _impl.end(); }
|
||||||
|
const_iterator end() const { return _impl.end(); }
|
||||||
|
template <typename input_iterator>
|
||||||
|
nstl_status_t insert(iterator pos, input_iterator begin, input_iterator end)
|
||||||
|
{
|
||||||
|
_impl.insert(pos, begin, end);
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
void clear() { _impl.clear(); }
|
||||||
|
void push_back(const T& t) { _impl.push_back(t); }
|
||||||
|
void resize(size_type count) { _impl.resize(count); }
|
||||||
|
void reserve(size_type count) { _impl.reserve(count); }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Key, typename T> class map: public c_compatible {
|
||||||
|
private:
|
||||||
|
std::map<Key, T> _impl;
|
||||||
|
public:
|
||||||
|
typedef typename std::map<Key, T>::iterator iterator;
|
||||||
|
typedef typename std::map<Key, T>::const_iterator const_iterator;
|
||||||
|
typedef typename std::map<Key, T>::size_type size_type;
|
||||||
|
map() {}
|
||||||
|
~map() {}
|
||||||
|
size_type size() const { return _impl.size(); }
|
||||||
|
T& operator[](const Key &k) { return _impl[k]; }
|
||||||
|
const T& operator[](const Key &k) const { return _impl[k]; }
|
||||||
|
iterator begin() { return _impl.begin(); }
|
||||||
|
const_iterator begin() const { return _impl.begin(); }
|
||||||
|
iterator end() { return _impl.end(); }
|
||||||
|
const_iterator end() const { return _impl.end(); }
|
||||||
|
template <typename input_iterator>
|
||||||
|
void clear() { _impl.clear(); }
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
114
thirdparty/oidn/mkl-dnn/src/common/pooling.cpp
vendored
Normal file
114
thirdparty/oidn/mkl-dnn/src/common/pooling.cpp
vendored
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2016-2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
#include "mkldnn.h"
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "type_helpers.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
using namespace mkldnn::impl;
|
||||||
|
using namespace mkldnn::impl::utils;
|
||||||
|
using namespace mkldnn::impl::status;
|
||||||
|
using namespace mkldnn::impl::prop_kind;
|
||||||
|
using namespace mkldnn::impl::alg_kind;
|
||||||
|
using namespace mkldnn::impl::types;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
status_t pooling_desc_init(pooling_desc_t *pool_desc,
|
||||||
|
prop_kind_t prop_kind, alg_kind_t alg_kind,
|
||||||
|
const memory_desc_t *src_desc, const memory_desc_t *dst_desc,
|
||||||
|
const dims_t strides, const dims_t kernel, const dims_t padding_l,
|
||||||
|
const dims_t padding_r, padding_kind_t padding_kind) {
|
||||||
|
bool args_ok = true
|
||||||
|
&& !any_null(pool_desc, src_desc, dst_desc, strides, kernel, padding_l)
|
||||||
|
&& one_of(alg_kind, pooling_max,
|
||||||
|
pooling_avg_include_padding,
|
||||||
|
pooling_avg_exclude_padding)
|
||||||
|
&& one_of(padding_kind, padding_kind::padding_zero);
|
||||||
|
if (!args_ok) return invalid_arguments;
|
||||||
|
|
||||||
|
if (padding_r == nullptr) padding_r = padding_l;
|
||||||
|
|
||||||
|
auto pd = pooling_desc_t();
|
||||||
|
pd.primitive_kind = primitive_kind::pooling;
|
||||||
|
pd.prop_kind = prop_kind;
|
||||||
|
pd.alg_kind = alg_kind;
|
||||||
|
pd.src_desc.ndims = src_desc->ndims;
|
||||||
|
|
||||||
|
const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
|
||||||
|
|
||||||
|
pd.diff_src_desc = pd.src_desc = zero_md();
|
||||||
|
pd.diff_dst_desc = pd.dst_desc = zero_md();
|
||||||
|
|
||||||
|
(is_fwd ? pd.src_desc : pd.diff_src_desc) = *src_desc;
|
||||||
|
(is_fwd ? pd.dst_desc : pd.diff_dst_desc) = *dst_desc;
|
||||||
|
|
||||||
|
int sp_dims = src_desc->ndims - 2;
|
||||||
|
utils::array_copy(pd.strides, strides, sp_dims);
|
||||||
|
utils::array_copy(pd.kernel, kernel, sp_dims);
|
||||||
|
utils::array_copy(pd.padding[0], padding_l, sp_dims);
|
||||||
|
utils::array_copy(pd.padding[1], padding_r, sp_dims);
|
||||||
|
|
||||||
|
pd.padding_kind = padding_kind;
|
||||||
|
if (one_of(alg_kind, pooling_max, pooling_avg_include_padding,
|
||||||
|
pooling_avg_exclude_padding)) {
|
||||||
|
pd.accum_data_type = types::default_accum_data_type(
|
||||||
|
src_desc->data_type, dst_desc->data_type);
|
||||||
|
} else {
|
||||||
|
pd.accum_data_type = dst_desc->data_type;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool consistency = true
|
||||||
|
&& utils::one_of(src_desc->ndims, 4, 5)
|
||||||
|
&& utils::one_of(dst_desc->ndims, 4, 5)
|
||||||
|
&& src_desc->dims[0] == dst_desc->dims[0]
|
||||||
|
&& src_desc->dims[1] == dst_desc->dims[1];
|
||||||
|
for (int i = 2; i < src_desc->ndims; ++i)
|
||||||
|
consistency = consistency && (
|
||||||
|
(src_desc->dims[i] - kernel[i - 2] + padding_l[i - 2]
|
||||||
|
+ padding_r[i - 2]) / strides[i - 2] + 1
|
||||||
|
== dst_desc->dims[i]);
|
||||||
|
if (!consistency) return invalid_arguments;
|
||||||
|
|
||||||
|
*pool_desc = pd;
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_pooling_forward_desc_init(pooling_desc_t *pool_desc,
|
||||||
|
prop_kind_t prop_kind, alg_kind_t alg_kind,
|
||||||
|
const memory_desc_t *src_desc, const memory_desc_t *dst_desc,
|
||||||
|
const dims_t strides, const dims_t kernel, const dims_t padding_l,
|
||||||
|
const dims_t padding_r, padding_kind_t padding_kind) {
|
||||||
|
if (!one_of(prop_kind, forward_training, forward_inference))
|
||||||
|
return invalid_arguments;
|
||||||
|
return pooling_desc_init(pool_desc, prop_kind, alg_kind, src_desc,
|
||||||
|
dst_desc, strides, kernel, padding_l, padding_r, padding_kind);
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_pooling_backward_desc_init(pooling_desc_t *pool_desc,
|
||||||
|
alg_kind_t alg_kind, const memory_desc_t *diff_src_desc,
|
||||||
|
const memory_desc_t *diff_dst_desc, const dims_t strides,
|
||||||
|
const dims_t kernel, const dims_t padding_l, const dims_t padding_r,
|
||||||
|
padding_kind_t padding_kind) {
|
||||||
|
return pooling_desc_init(pool_desc, prop_kind::backward_data, alg_kind,
|
||||||
|
diff_src_desc, diff_dst_desc, strides, kernel, padding_l,
|
||||||
|
padding_r, padding_kind);
|
||||||
|
}
|
||||||
|
|
||||||
|
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
238
thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp
vendored
Normal file
238
thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp
vendored
Normal file
@ -0,0 +1,238 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2016-2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#ifndef POOLING_PD_HPP
|
||||||
|
#define POOLING_PD_HPP
|
||||||
|
|
||||||
|
#include "mkldnn.h"
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "primitive_desc.hpp"
|
||||||
|
#include "type_helpers.hpp"
|
||||||
|
|
||||||
|
namespace mkldnn {
|
||||||
|
namespace impl {
|
||||||
|
|
||||||
|
struct pooling_fwd_pd_t;
|
||||||
|
|
||||||
|
struct pooling_pd_t: public primitive_desc_t {
|
||||||
|
static constexpr auto base_pkind = primitive_kind::pooling;
|
||||||
|
|
||||||
|
pooling_pd_t(engine_t *engine,
|
||||||
|
const pooling_desc_t *adesc,
|
||||||
|
const primitive_attr_t *attr,
|
||||||
|
const pooling_fwd_pd_t *hint_fwd_pd)
|
||||||
|
: primitive_desc_t(engine, attr, base_pkind)
|
||||||
|
, desc_(*adesc)
|
||||||
|
, hint_fwd_pd_(hint_fwd_pd)
|
||||||
|
, ws_md_()
|
||||||
|
{}
|
||||||
|
|
||||||
|
const pooling_desc_t *desc() const { return &desc_; }
|
||||||
|
virtual const op_desc_t *op_desc() const override
|
||||||
|
{ return reinterpret_cast<const op_desc_t *>(this->desc()); }
|
||||||
|
virtual void init_info() override { impl::init_info(this, this->info_); }
|
||||||
|
|
||||||
|
virtual status_t query(query_t what, int idx, void *result) const override {
|
||||||
|
switch (what) {
|
||||||
|
case query::pooling_d:
|
||||||
|
*(const pooling_desc_t**)result = desc(); break;
|
||||||
|
default: return primitive_desc_t::query(what, idx, result);
|
||||||
|
}
|
||||||
|
return status::success;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* common pooling aux functions */
|
||||||
|
|
||||||
|
dim_t MB() const { return src_desc().dims[0]; }
|
||||||
|
dim_t C() const { return src_desc().dims[1]; }
|
||||||
|
|
||||||
|
dim_t ID() const { return ndims() >= 5 ? src_desc().dims[ndims() - 3] : 1; }
|
||||||
|
dim_t IH() const { return ndims() >= 4 ? src_desc().dims[ndims() - 2] : 1; }
|
||||||
|
dim_t IW() const { return src_desc().dims[ndims() - 1]; }
|
||||||
|
|
||||||
|
dim_t OD() const { return ndims() >= 5 ? dst_desc().dims[ndims() - 3] : 1; }
|
||||||
|
dim_t OH() const { return ndims() >= 4 ? dst_desc().dims[ndims() - 2] : 1; }
|
||||||
|
dim_t OW() const { return dst_desc().dims[ndims() - 1]; }
|
||||||
|
|
||||||
|
dim_t KD() const { return ndims() >= 5 ? desc_.kernel[ndims() - 5] : 1; }
|
||||||
|
dim_t KH() const { return ndims() >= 4 ? desc_.kernel[ndims() - 4] : 1; }
|
||||||
|
dim_t KW() const { return desc_.kernel[ndims() - 3]; }
|
||||||
|
|
||||||
|
dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; }
|
||||||
|
dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; }
|
||||||
|
dim_t KSW() const { return desc_.strides[ndims() - 3]; }
|
||||||
|
|
||||||
|
dim_t padFront() const
|
||||||
|
{ return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; }
|
||||||
|
dim_t padBack() const
|
||||||
|
{ return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; }
|
||||||
|
dim_t padT() const
|
||||||
|
{ return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; }
|
||||||
|
dim_t padB() const
|
||||||
|
{ return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; }
|
||||||
|
dim_t padL() const { return desc_.padding[0][ndims() - 3]; }
|
||||||
|
dim_t padR() const { return desc_.padding[1][ndims() - 3]; }
|
||||||
|
|
||||||
|
int ndims() const { return src_desc().ndims; }
|
||||||
|
bool is_3d() const { return ndims() == 5; }
|
||||||
|
|
||||||
|
bool has_zero_dim_memory() const
|
||||||
|
{ return memory_desc_wrapper(src_desc()).has_zero_dim(); }
|
||||||
|
|
||||||
|
bool is_fwd() const {
|
||||||
|
return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
|
||||||
|
prop_kind::forward_inference);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
pooling_desc_t desc_;
|
||||||
|
const pooling_fwd_pd_t *hint_fwd_pd_;
|
||||||
|
|
||||||
|
memory_desc_t ws_md_;
|
||||||
|
|
||||||
|
void init_default_ws() {
|
||||||
|
ws_md_ = is_fwd() ? *dst_md() : *diff_dst_md();
|
||||||
|
ws_md_.data_type = indices_data_type();
|
||||||
|
}
|
||||||
|
|
||||||
|
data_type_t indices_data_type() const {
|
||||||
|
/* the simplest way to express 256... */
|
||||||
|
const int u8_max = nstl::numeric_limits<
|
||||||
|
typename prec_traits<data_type::u8>::type>::max();
|
||||||
|
return utils::array_product(desc()->kernel, ndims()) <= u8_max
|
||||||
|
? data_type::u8 : data_type::s32;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const memory_desc_t &src_desc() const
|
||||||
|
{ return is_fwd() ? desc_.src_desc : desc_.diff_src_desc; }
|
||||||
|
const memory_desc_t &dst_desc() const
|
||||||
|
{ return is_fwd() ? desc_.dst_desc : desc_.diff_dst_desc; }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct pooling_fwd_pd_t: public pooling_pd_t {
|
||||||
|
typedef pooling_fwd_pd_t base_class;
|
||||||
|
typedef pooling_fwd_pd_t hint_class;
|
||||||
|
|
||||||
|
pooling_fwd_pd_t(engine_t *engine,
|
||||||
|
const pooling_desc_t *adesc,
|
||||||
|
const primitive_attr_t *attr,
|
||||||
|
const pooling_fwd_pd_t *hint_fwd_pd)
|
||||||
|
: pooling_pd_t(engine, adesc, attr, hint_fwd_pd)
|
||||||
|
, src_md_(desc_.src_desc)
|
||||||
|
, dst_md_(desc_.dst_desc)
|
||||||
|
{}
|
||||||
|
|
||||||
|
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
||||||
|
if (arg == MKLDNN_ARG_SRC)
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_DST)
|
||||||
|
return arg_usage_t::output;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr))
|
||||||
|
return arg_usage_t::output;
|
||||||
|
|
||||||
|
return primitive_desc_t::arg_usage(arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual const memory_desc_t *src_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &src_md_ : nullptr; }
|
||||||
|
virtual const memory_desc_t *dst_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &dst_md_ : nullptr; }
|
||||||
|
virtual const memory_desc_t *workspace_md(int index = 0) const override
|
||||||
|
{ return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; }
|
||||||
|
|
||||||
|
virtual int n_inputs() const override { return 1; }
|
||||||
|
virtual int n_outputs() const override
|
||||||
|
{ return 1 + (workspace_md() != nullptr); }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
memory_desc_t src_md_;
|
||||||
|
memory_desc_t dst_md_;
|
||||||
|
|
||||||
|
virtual status_t set_default_params() {
|
||||||
|
if (dst_md()->format_kind != format_kind::any)
|
||||||
|
return status::success;
|
||||||
|
|
||||||
|
if (src_md()->format_kind != format_kind::blocked)
|
||||||
|
return status::unimplemented;
|
||||||
|
|
||||||
|
return memory_desc_init_by_blocking_desc(dst_md_,
|
||||||
|
src_md_.format_desc.blocking);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct pooling_bwd_pd_t: public pooling_pd_t {
|
||||||
|
typedef pooling_bwd_pd_t base_class;
|
||||||
|
typedef pooling_fwd_pd_t hint_class;
|
||||||
|
|
||||||
|
pooling_bwd_pd_t(engine_t *engine,
|
||||||
|
const pooling_desc_t *adesc,
|
||||||
|
const primitive_attr_t *attr,
|
||||||
|
const pooling_fwd_pd_t *hint_fwd_pd)
|
||||||
|
: pooling_pd_t(engine, adesc, attr, hint_fwd_pd)
|
||||||
|
, diff_src_md_(desc_.diff_src_desc)
|
||||||
|
, diff_dst_md_(desc_.diff_dst_desc)
|
||||||
|
{}
|
||||||
|
|
||||||
|
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
||||||
|
if (arg == MKLDNN_ARG_DIFF_DST)
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_DIFF_SRC)
|
||||||
|
return arg_usage_t::output;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr))
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
return primitive_desc_t::arg_usage(arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual const memory_desc_t *diff_src_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &diff_src_md_ : nullptr; }
|
||||||
|
virtual const memory_desc_t *diff_dst_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &diff_dst_md_ : nullptr; }
|
||||||
|
virtual const memory_desc_t *workspace_md(int index = 0) const override
|
||||||
|
{ return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; }
|
||||||
|
|
||||||
|
virtual int n_inputs() const override
|
||||||
|
{ return 1 + (workspace_md() != nullptr); }
|
||||||
|
virtual int n_outputs() const override { return 1; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
memory_desc_t diff_src_md_;
|
||||||
|
memory_desc_t diff_dst_md_;
|
||||||
|
|
||||||
|
virtual status_t set_default_params() {
|
||||||
|
if (diff_src_md()->format_kind != format_kind::any)
|
||||||
|
return status::success;
|
||||||
|
|
||||||
|
if (diff_dst_md()->format_kind != format_kind::blocked)
|
||||||
|
return status::unimplemented;
|
||||||
|
|
||||||
|
return memory_desc_init_by_blocking_desc(diff_src_md_,
|
||||||
|
diff_dst_md_.format_desc.blocking);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
103
thirdparty/oidn/mkl-dnn/src/common/primitive.cpp
vendored
Normal file
103
thirdparty/oidn/mkl-dnn/src/common/primitive.cpp
vendored
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2016-2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "engine.hpp"
|
||||||
|
#include "primitive_desc.hpp"
|
||||||
|
#include "primitive.hpp"
|
||||||
|
#include "type_helpers.hpp"
|
||||||
|
#include "stream.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
using namespace mkldnn::impl;
|
||||||
|
using namespace mkldnn::impl::status;
|
||||||
|
using namespace mkldnn::impl::primitive_kind;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// XXX: this is a huge hammer. This disables all and any msan checks on
|
||||||
|
// primitives outputs.
|
||||||
|
//
|
||||||
|
// A proper approach would be an implementation-specific unpoisoning.
|
||||||
|
void unpoison_outputs(const exec_args_t &args) {
|
||||||
|
for(const auto &arg: args) {
|
||||||
|
if (arg.second.is_const) continue;
|
||||||
|
auto *mem = arg.second.mem;
|
||||||
|
void *p;
|
||||||
|
mem->get_data_handle(&p);
|
||||||
|
size_t s = memory_desc_wrapper(*mem->md()).size();
|
||||||
|
msan_unpoison(p, s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_primitive_desc_destroy(primitive_desc_t *primitive_desc) {
|
||||||
|
if (primitive_desc) delete primitive_desc;
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_primitive_create(primitive_t **primitive,
|
||||||
|
const primitive_desc_t *primitive_desc) {
|
||||||
|
if (utils::any_null(primitive, primitive_desc))
|
||||||
|
return invalid_arguments;
|
||||||
|
return primitive_desc->create_primitive(primitive);
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_primitive_execute(const primitive_t *primitive,
|
||||||
|
stream_t *stream, int nargs, const mkldnn_exec_arg_t *c_args) {
|
||||||
|
bool ok = true
|
||||||
|
&& !utils::any_null(primitive, stream)
|
||||||
|
&& primitive->engine() == stream->engine()
|
||||||
|
&& IMPLICATION(nargs > 0, c_args != nullptr);
|
||||||
|
if (!ok) return invalid_arguments;
|
||||||
|
|
||||||
|
exec_args_t args;
|
||||||
|
status_t status = cvt_primtive_args(primitive->pd(), nargs, c_args, args);
|
||||||
|
if (status != status::success) return status;
|
||||||
|
|
||||||
|
exec_ctx_t ctx(stream, std::move(args));
|
||||||
|
|
||||||
|
if (mkldnn_verbose()->level) {
|
||||||
|
double ms = get_msec();
|
||||||
|
status = primitive->execute(ctx);
|
||||||
|
ms = get_msec() - ms;
|
||||||
|
printf("mkldnn_verbose,exec,%s,%g\n", primitive->pd()->info(), ms);
|
||||||
|
fflush(0);
|
||||||
|
} else {
|
||||||
|
status = primitive->execute(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (msan_enabled) unpoison_outputs(ctx.args());
|
||||||
|
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_primitive_get_primitive_desc(const primitive_t *primitive,
|
||||||
|
const primitive_desc_t **primitive_desc) {
|
||||||
|
if (utils::any_null(primitive, primitive_desc))
|
||||||
|
return invalid_arguments;
|
||||||
|
return safe_ptr_assign<const primitive_desc_t>(*primitive_desc,
|
||||||
|
primitive->pd());
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_primitive_destroy(primitive_t *primitive) {
|
||||||
|
if (primitive != nullptr)
|
||||||
|
delete primitive;
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
76
thirdparty/oidn/mkl-dnn/src/common/primitive.hpp
vendored
Normal file
76
thirdparty/oidn/mkl-dnn/src/common/primitive.hpp
vendored
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2016-2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#ifndef PRIMITIVE_HPP
|
||||||
|
#define PRIMITIVE_HPP
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
#include "mkldnn.h"
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "nstl.hpp"
|
||||||
|
#include "primitive_desc.hpp"
|
||||||
|
#include "primitive_exec_types.hpp"
|
||||||
|
|
||||||
|
/** \brief A pure virtual primitive class
|
||||||
|
*
|
||||||
|
* Primitive contains links to its inputs & outputs, though it does not track
|
||||||
|
* their readiness on execution step.
|
||||||
|
*
|
||||||
|
* @remark @b Rational.
|
||||||
|
* Dependencies are essential through-out the whole MKL-DNN library, so it
|
||||||
|
* makes sense to include them on the very low level. On the other hand,
|
||||||
|
* tracking them should be a task for corresponding essence, like scheduler,
|
||||||
|
* stream or whatever. Primitive itself should know nothing about the
|
||||||
|
* environment it is running in.
|
||||||
|
*
|
||||||
|
* @note
|
||||||
|
* To make user experience better we should provide API which allows
|
||||||
|
* achieving the best (or good enough) performance when creating primitives
|
||||||
|
* in natural order: i.e. from bottom to top for forward pass and from top to
|
||||||
|
* bottom for backward pass. Please consider restriction [1] in Level 0.
|
||||||
|
*/
|
||||||
|
struct mkldnn_primitive: public mkldnn::impl::c_compatible {
|
||||||
|
mkldnn_primitive(const mkldnn::impl::primitive_desc_t *pd)
|
||||||
|
: pd_(pd->clone()) {}
|
||||||
|
virtual ~mkldnn_primitive() { delete pd_; }
|
||||||
|
|
||||||
|
/** returns primitive's engine */
|
||||||
|
mkldnn::impl::engine_t *engine() const { return pd_->engine(); }
|
||||||
|
/** returns primitive's inputs */
|
||||||
|
const mkldnn::impl::primitive_desc_t *pd() const { return pd_; }
|
||||||
|
/** returns primitive's kind */
|
||||||
|
mkldnn::impl::primitive_kind_t kind() const { return pd_->kind(); }
|
||||||
|
|
||||||
|
/** executes primitive with execution context @p ctx */
|
||||||
|
virtual mkldnn::impl::status_t execute(const mkldnn::impl::exec_ctx_t &ctx)
|
||||||
|
const = 0;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
const mkldnn::impl::primitive_desc_t *pd_;
|
||||||
|
|
||||||
|
private:
|
||||||
|
mkldnn_primitive() = delete;
|
||||||
|
mkldnn_primitive(const mkldnn_primitive &) = delete;
|
||||||
|
mkldnn_primitive(mkldnn_primitive &&) = delete;
|
||||||
|
mkldnn_primitive &operator=(const mkldnn_primitive &) = delete;
|
||||||
|
mkldnn_primitive &operator=(mkldnn_primitive &&) = delete;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
290
thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp
vendored
Normal file
290
thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp
vendored
Normal file
@ -0,0 +1,290 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2017-2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#include "mkldnn.h"
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "primitive_attr.hpp"
|
||||||
|
#include "type_helpers.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
using namespace mkldnn::impl;
|
||||||
|
using namespace mkldnn::impl::status;
|
||||||
|
using namespace mkldnn::impl::utils;
|
||||||
|
|
||||||
|
namespace mkldnn {
|
||||||
|
namespace impl {
|
||||||
|
|
||||||
|
status_t scales_t::set(dim_t count, int mask, const float *scales) {
|
||||||
|
cleanup();
|
||||||
|
|
||||||
|
count_ = count;
|
||||||
|
mask_ = mask;
|
||||||
|
|
||||||
|
if (count_ == 1) {
|
||||||
|
scales_ = scales_buf_;
|
||||||
|
utils::array_set(scales_, scales[0], scales_buf_size);
|
||||||
|
} else {
|
||||||
|
scales_ = (float *)impl::malloc(count_ * sizeof(*scales_), 64);
|
||||||
|
if (scales_ == nullptr)
|
||||||
|
return status::out_of_memory;
|
||||||
|
|
||||||
|
for (dim_t c = 0; c < count_; ++c)
|
||||||
|
scales_[c] = scales[c];
|
||||||
|
}
|
||||||
|
|
||||||
|
return status::success;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t post_ops_t::append_sum(float scale) {
|
||||||
|
if (len_ == capacity)
|
||||||
|
return out_of_memory;
|
||||||
|
|
||||||
|
entry_[len_].kind = primitive_kind::sum;
|
||||||
|
entry_[len_].sum.scale = scale;
|
||||||
|
|
||||||
|
len_++;
|
||||||
|
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t post_ops_t::append_eltwise(float scale, alg_kind_t alg, float alpha,
|
||||||
|
float beta) {
|
||||||
|
using namespace mkldnn::impl::alg_kind;
|
||||||
|
bool known_alg = one_of(alg, eltwise_relu, eltwise_tanh, eltwise_elu,
|
||||||
|
eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
|
||||||
|
eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic);
|
||||||
|
if (!known_alg)
|
||||||
|
return invalid_arguments;
|
||||||
|
|
||||||
|
if (len_ == capacity)
|
||||||
|
return out_of_memory;
|
||||||
|
|
||||||
|
entry_[len_].kind = primitive_kind::eltwise;
|
||||||
|
entry_[len_].eltwise.scale = scale;
|
||||||
|
entry_[len_].eltwise.alg = alg;
|
||||||
|
entry_[len_].eltwise.alpha = alpha;
|
||||||
|
entry_[len_].eltwise.beta = beta;
|
||||||
|
|
||||||
|
len_++;
|
||||||
|
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t primitive_attr_t::set_scratchpad_mode(
|
||||||
|
scratchpad_mode_t scratchpad_mode) {
|
||||||
|
using namespace mkldnn::impl::scratchpad_mode;
|
||||||
|
|
||||||
|
const bool ok = one_of(scratchpad_mode, library, user);
|
||||||
|
if (!ok)
|
||||||
|
return invalid_arguments;
|
||||||
|
|
||||||
|
scratchpad_mode_ = scratchpad_mode;
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t primitive_attr_t::set_post_ops(const post_ops_t &post_ops) {
|
||||||
|
this->post_ops_ = post_ops;
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Public C API */
|
||||||
|
|
||||||
|
status_t mkldnn_primitive_attr_create(primitive_attr_t **attr) {
|
||||||
|
if (attr == nullptr)
|
||||||
|
return invalid_arguments;
|
||||||
|
|
||||||
|
return safe_ptr_assign<mkldnn_primitive_attr>(*attr,
|
||||||
|
new mkldnn_primitive_attr);
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_primitive_attr_clone(primitive_attr_t **attr,
|
||||||
|
const primitive_attr_t *existing_attr) {
|
||||||
|
if (any_null(attr, existing_attr))
|
||||||
|
return invalid_arguments;
|
||||||
|
|
||||||
|
return safe_ptr_assign<mkldnn_primitive_attr>(*attr,
|
||||||
|
existing_attr->clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_primitive_attr_destroy(primitive_attr_t *attr) {
|
||||||
|
if (attr)
|
||||||
|
delete attr;
|
||||||
|
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_primitive_attr_get_scratchpad_mode(
|
||||||
|
const primitive_attr_t *attr, scratchpad_mode_t *scratchpad_mode) {
|
||||||
|
if (any_null(attr, scratchpad_mode))
|
||||||
|
return invalid_arguments;
|
||||||
|
|
||||||
|
*scratchpad_mode = attr->scratchpad_mode_;
|
||||||
|
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_primitive_attr_set_scratchpad_mode(
|
||||||
|
primitive_attr_t *attr, scratchpad_mode_t scratchpad_mode) {
|
||||||
|
if (any_null(attr))
|
||||||
|
return invalid_arguments;
|
||||||
|
|
||||||
|
return attr->set_scratchpad_mode(scratchpad_mode);
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_primitive_attr_get_output_scales(const primitive_attr_t *attr,
|
||||||
|
dim_t *count, int *mask, const float **scales) {
|
||||||
|
if (any_null(attr, count, mask, scales))
|
||||||
|
return invalid_arguments;
|
||||||
|
|
||||||
|
*count = attr->output_scales_.count_;
|
||||||
|
*mask = attr->output_scales_.mask_;
|
||||||
|
*scales = attr->output_scales_.scales_;
|
||||||
|
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_primitive_attr_set_output_scales(primitive_attr_t *attr,
|
||||||
|
dim_t count, int mask, const float *scales) {
|
||||||
|
bool ok = !any_null(attr, scales) && count > 0 && mask >= 0;
|
||||||
|
if (!ok)
|
||||||
|
return invalid_arguments;
|
||||||
|
|
||||||
|
return attr->output_scales_.set(count, mask, scales);
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_primitive_attr_get_post_ops(const primitive_attr_t *attr,
|
||||||
|
const post_ops_t **post_ops) {
|
||||||
|
if (any_null(attr, post_ops))
|
||||||
|
return invalid_arguments;
|
||||||
|
|
||||||
|
*post_ops = &attr->post_ops_;
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_primitive_attr_set_post_ops(primitive_attr_t *attr,
|
||||||
|
const post_ops_t *post_ops) {
|
||||||
|
if (any_null(attr, post_ops))
|
||||||
|
return invalid_arguments;
|
||||||
|
|
||||||
|
return attr->set_post_ops(*post_ops);
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_post_ops_create(post_ops_t **post_ops) {
|
||||||
|
if (post_ops == nullptr)
|
||||||
|
return invalid_arguments;
|
||||||
|
|
||||||
|
return safe_ptr_assign<mkldnn_post_ops>(*post_ops, new mkldnn_post_ops);
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_post_ops_destroy(post_ops_t *post_ops) {
|
||||||
|
if (post_ops)
|
||||||
|
delete post_ops;
|
||||||
|
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
int mkldnn_post_ops_len(const post_ops_t *post_ops) {
|
||||||
|
if (post_ops)
|
||||||
|
return post_ops->len_;
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
primitive_kind_t mkldnn_post_ops_get_kind(const post_ops_t *post_ops,
|
||||||
|
int index) {
|
||||||
|
bool ok = post_ops && 0 <= index && index < post_ops->len_;
|
||||||
|
if (!ok)
|
||||||
|
return primitive_kind::undefined;
|
||||||
|
|
||||||
|
return post_ops->entry_[index].kind;
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_post_ops_append_sum(post_ops_t *post_ops, float scale) {
|
||||||
|
if (post_ops == nullptr)
|
||||||
|
return invalid_arguments;
|
||||||
|
|
||||||
|
return post_ops->append_sum(scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
bool simple_get_params_check(const post_ops_t *post_ops, int index,
|
||||||
|
primitive_kind_t kind) {
|
||||||
|
bool ok = true
|
||||||
|
&& post_ops != nullptr
|
||||||
|
&& 0 <= index
|
||||||
|
&& index < post_ops->len_
|
||||||
|
&& post_ops->entry_[index].kind == kind;
|
||||||
|
return ok;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_post_ops_get_params_sum(const post_ops_t *post_ops, int index,
|
||||||
|
float *scale) {
|
||||||
|
bool ok = true
|
||||||
|
&& simple_get_params_check(post_ops, index, primitive_kind::sum)
|
||||||
|
&& !any_null(scale);
|
||||||
|
if (!ok)
|
||||||
|
return invalid_arguments;
|
||||||
|
|
||||||
|
*scale = post_ops->entry_[index].sum.scale;
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_post_ops_append_eltwise(post_ops_t *post_ops, float scale,
|
||||||
|
alg_kind_t kind, float alpha, float beta) {
|
||||||
|
if (post_ops == nullptr)
|
||||||
|
return invalid_arguments;
|
||||||
|
|
||||||
|
return post_ops->append_eltwise(scale, kind, alpha, beta);
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_post_ops_get_params_eltwise(const post_ops_t *post_ops,
|
||||||
|
int index, float *scale, alg_kind_t *alg, float *alpha, float *beta) {
|
||||||
|
bool ok = true
|
||||||
|
&& simple_get_params_check(post_ops, index, primitive_kind::eltwise)
|
||||||
|
&& !any_null(scale, alpha, beta);
|
||||||
|
if (!ok)
|
||||||
|
return invalid_arguments;
|
||||||
|
|
||||||
|
const auto &e = post_ops->entry_[index].eltwise;
|
||||||
|
*scale = e.scale;
|
||||||
|
*alg = e.alg;
|
||||||
|
*alpha = e.alpha;
|
||||||
|
*beta = e.beta;
|
||||||
|
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_primitive_attr_set_rnn_data_qparams(
|
||||||
|
primitive_attr_t *attr, const float scale, const float shift) {
|
||||||
|
if (attr == nullptr)
|
||||||
|
return invalid_arguments;
|
||||||
|
|
||||||
|
return attr->rnn_data_qparams_.set(scale, shift);
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_primitive_attr_set_rnn_weights_qparams(
|
||||||
|
primitive_attr_t *attr, dim_t count, int mask, const float *scales) {
|
||||||
|
bool ok = !any_null(attr, scales) && count > 0 && mask >= 0;
|
||||||
|
if (!ok)
|
||||||
|
return invalid_arguments;
|
||||||
|
|
||||||
|
return attr->rnn_weights_qparams_.set(count, mask, scales);
|
||||||
|
}
|
183
thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp
vendored
Normal file
183
thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp
vendored
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2017-2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#ifndef PRIMITIVE_ATTR_HPP
|
||||||
|
#define PRIMITIVE_ATTR_HPP
|
||||||
|
|
||||||
|
#include "mkldnn.h"
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "nstl.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
namespace mkldnn {
|
||||||
|
namespace impl {
|
||||||
|
|
||||||
|
struct rnn_data_qparams_t : public c_compatible {
|
||||||
|
rnn_data_qparams_t() : scale_(1.), shift_(0.) {}
|
||||||
|
bool has_default_values() const { return (scale_ == 1. && shift_ == 0.); }
|
||||||
|
|
||||||
|
status_t set(float scale, float shift) {
|
||||||
|
scale_ = scale;
|
||||||
|
shift_ = shift;
|
||||||
|
return status::success;
|
||||||
|
}
|
||||||
|
|
||||||
|
float scale_;
|
||||||
|
float shift_;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct scales_t: public c_compatible {
|
||||||
|
scales_t(): count_(1), mask_(0), scales_(scales_buf_)
|
||||||
|
{ set(1.); }
|
||||||
|
|
||||||
|
scales_t(const scales_t &rhs): scales_t()
|
||||||
|
{ set(rhs.count_, rhs.mask_, rhs.scales_); }
|
||||||
|
|
||||||
|
~scales_t() { cleanup(); }
|
||||||
|
|
||||||
|
scales_t &operator=(const scales_t &rhs) {
|
||||||
|
if (&rhs == this)
|
||||||
|
return *this;
|
||||||
|
status_t status = set(rhs.count_, rhs.mask_, rhs.scales_);
|
||||||
|
assert(status == status::success);
|
||||||
|
(void)status;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool has_default_values() const {
|
||||||
|
for (dim_t c = 0; c < count_; ++c) {
|
||||||
|
if(scales_[c] != 1.) return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t set(dim_t count, int mask, const float *scales);
|
||||||
|
status_t set(float single_scale) { return this->set(1, 0, &single_scale); }
|
||||||
|
|
||||||
|
dim_t count_;
|
||||||
|
int mask_;
|
||||||
|
float *scales_;
|
||||||
|
|
||||||
|
private:
|
||||||
|
enum { scales_buf_size = 16 };
|
||||||
|
float scales_buf_[scales_buf_size];
|
||||||
|
|
||||||
|
void cleanup() {
|
||||||
|
if (scales_ != scales_buf_ && scales_ != nullptr)
|
||||||
|
impl::free(scales_);
|
||||||
|
|
||||||
|
count_ = 1;
|
||||||
|
mask_ = 0;
|
||||||
|
scales_ = scales_buf_;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct mkldnn_post_ops: public mkldnn::impl::c_compatible {
|
||||||
|
struct entry_t {
|
||||||
|
struct eltwise_t {
|
||||||
|
mkldnn::impl::alg_kind_t alg;
|
||||||
|
float scale, alpha, beta;
|
||||||
|
};
|
||||||
|
|
||||||
|
mkldnn::impl::primitive_kind_t kind;
|
||||||
|
union {
|
||||||
|
struct { float scale; } sum;
|
||||||
|
eltwise_t eltwise;
|
||||||
|
};
|
||||||
|
|
||||||
|
bool is_eltwise(bool require_scale_one = true) const {
|
||||||
|
using namespace mkldnn::impl;
|
||||||
|
return kind == primitive_kind::eltwise
|
||||||
|
&& IMPLICATION(require_scale_one, eltwise.scale == 1.f);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_relu(bool require_scale_one = true,
|
||||||
|
bool require_nslope_zero = true) const {
|
||||||
|
using namespace mkldnn::impl;
|
||||||
|
return is_eltwise(require_scale_one)
|
||||||
|
&& eltwise.alg == alg_kind::eltwise_relu
|
||||||
|
&& IMPLICATION(require_nslope_zero, eltwise.alpha == 0.f);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_sum(bool require_scale_one = true) const {
|
||||||
|
using namespace mkldnn::impl;
|
||||||
|
return kind == primitive_kind::sum
|
||||||
|
&& IMPLICATION(require_scale_one, sum.scale == 1.f);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
mkldnn_post_ops(): len_(0) {}
|
||||||
|
|
||||||
|
mkldnn::impl::status_t append_sum(float scale);
|
||||||
|
mkldnn::impl::status_t append_eltwise(float scale,
|
||||||
|
mkldnn::impl::alg_kind_t alg, float alpha, float beta);
|
||||||
|
|
||||||
|
int find(mkldnn::impl::primitive_kind_t kind, int start = 0,
|
||||||
|
int stop = -1) const {
|
||||||
|
if (stop == -1) stop = len_;
|
||||||
|
stop = mkldnn::impl::nstl::min(stop, len_);
|
||||||
|
for (int idx = start; idx < stop; ++idx)
|
||||||
|
if (entry_[idx].kind == kind) return idx;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool has_default_values() const { return len_ == 0; }
|
||||||
|
|
||||||
|
bool contain(mkldnn::impl::primitive_kind_t kind, int index) const
|
||||||
|
{ return find(kind, index, index + 1) == index; }
|
||||||
|
|
||||||
|
enum { capacity = 4 };
|
||||||
|
|
||||||
|
int len_;
|
||||||
|
entry_t entry_[capacity];
|
||||||
|
};
|
||||||
|
|
||||||
|
struct mkldnn_primitive_attr: public mkldnn::impl::c_compatible {
|
||||||
|
mkldnn_primitive_attr()
|
||||||
|
: scratchpad_mode_(mkldnn::impl::scratchpad_mode::library)
|
||||||
|
{}
|
||||||
|
|
||||||
|
mkldnn_primitive_attr *clone() const
|
||||||
|
{ return new mkldnn_primitive_attr(*this); }
|
||||||
|
|
||||||
|
/** Returns true if the attributes have default values.
|
||||||
|
*
|
||||||
|
* @note The scratchpad_mode_ is not take into account */
|
||||||
|
bool has_default_values() const {
|
||||||
|
return true
|
||||||
|
&& output_scales_.has_default_values()
|
||||||
|
&& post_ops_.has_default_values()
|
||||||
|
&& rnn_data_qparams_.has_default_values()
|
||||||
|
&& rnn_weights_qparams_.has_default_values();
|
||||||
|
}
|
||||||
|
|
||||||
|
mkldnn::impl::status_t set_scratchpad_mode(
|
||||||
|
mkldnn::impl::scratchpad_mode_t scratchpad_mode);
|
||||||
|
mkldnn::impl::status_t set_post_ops(
|
||||||
|
const mkldnn::impl::post_ops_t &post_ops);
|
||||||
|
|
||||||
|
mkldnn::impl::scratchpad_mode_t scratchpad_mode_;
|
||||||
|
mkldnn::impl::scales_t output_scales_;
|
||||||
|
mkldnn::impl::post_ops_t post_ops_;
|
||||||
|
mkldnn::impl::rnn_data_qparams_t rnn_data_qparams_;
|
||||||
|
mkldnn::impl::scales_t rnn_weights_qparams_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif
|
78
thirdparty/oidn/mkl-dnn/src/common/primitive_desc.cpp
vendored
Normal file
78
thirdparty/oidn/mkl-dnn/src/common/primitive_desc.cpp
vendored
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2016-2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#include "mkldnn.h"
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "nstl.hpp"
|
||||||
|
#include "primitive_desc.hpp"
|
||||||
|
|
||||||
|
using namespace mkldnn::impl;
|
||||||
|
using namespace mkldnn::impl::status;
|
||||||
|
|
||||||
|
status_t primitive_desc_t::query(query_t what, int idx, void *result) const {
|
||||||
|
auto safe_ret_md = [&](const memory_desc_t *_) {
|
||||||
|
if (_ == nullptr) return not_required;
|
||||||
|
*(const memory_desc_t **)result = _;
|
||||||
|
return success;
|
||||||
|
};
|
||||||
|
|
||||||
|
switch (what) {
|
||||||
|
case query::engine: *(engine_t**)result = engine(); break;
|
||||||
|
case query::primitive_kind: *(primitive_kind_t*)result = kind(); break;
|
||||||
|
|
||||||
|
case query::scratchpad_engine:
|
||||||
|
*(engine_t**)result = scratchpad_engine(); break;
|
||||||
|
|
||||||
|
case query::memory_consumption_s64:
|
||||||
|
*(dim_t *)result = scratchpad_size(scratchpad_mode::library); break;
|
||||||
|
|
||||||
|
case query::op_d:
|
||||||
|
if (idx != 0 || op_desc() == nullptr) return invalid_arguments;
|
||||||
|
*(const_c_op_desc_t *)result
|
||||||
|
= static_cast<const_c_op_desc_t>(op_desc()); break;
|
||||||
|
|
||||||
|
case query::src_md: return safe_ret_md(src_md(idx));
|
||||||
|
case query::diff_src_md: return safe_ret_md(diff_src_md(idx));
|
||||||
|
case query::dst_md: return safe_ret_md(dst_md(idx));
|
||||||
|
case query::diff_dst_md: return safe_ret_md(diff_dst_md(idx));
|
||||||
|
case query::weights_md: return safe_ret_md(weights_md(idx));
|
||||||
|
case query::diff_weights_md: return safe_ret_md(diff_weights_md(idx));
|
||||||
|
case query::workspace_md:
|
||||||
|
if (idx != 0) return status::invalid_arguments;
|
||||||
|
return safe_ret_md(workspace_md(idx));
|
||||||
|
case query::scratchpad_md:
|
||||||
|
if (idx != 0) return status::invalid_arguments;
|
||||||
|
return safe_ret_md(scratchpad_md(idx));
|
||||||
|
|
||||||
|
case query::num_of_inputs_s32: *(int*)result = n_inputs(); break;
|
||||||
|
case query::num_of_outputs_s32: *(int*)result = n_outputs(); break;
|
||||||
|
|
||||||
|
case query::impl_info_str: *(const char **)result = name(); break;
|
||||||
|
|
||||||
|
default: return unimplemented;
|
||||||
|
}
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_primitive_desc_get_attr(const primitive_desc_t *primitive_desc,
|
||||||
|
const primitive_attr_t **attr) {
|
||||||
|
if (utils::any_null(primitive_desc, attr))
|
||||||
|
return invalid_arguments;
|
||||||
|
|
||||||
|
*attr = primitive_desc->attr();
|
||||||
|
return success;
|
||||||
|
}
|
174
thirdparty/oidn/mkl-dnn/src/common/primitive_desc.hpp
vendored
Normal file
174
thirdparty/oidn/mkl-dnn/src/common/primitive_desc.hpp
vendored
Normal file
@ -0,0 +1,174 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2016-2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#ifndef PRIMITIVE_DESC_HPP
|
||||||
|
#define PRIMITIVE_DESC_HPP
|
||||||
|
|
||||||
|
#include "mkldnn.h"
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "memory_tracking.hpp"
|
||||||
|
#include "nstl.hpp"
|
||||||
|
#include "type_helpers.hpp"
|
||||||
|
#include "primitive_attr.hpp"
|
||||||
|
#include "verbose.hpp"
|
||||||
|
|
||||||
|
struct mkldnn_primitive_desc: public mkldnn::impl::c_compatible {
|
||||||
|
using md_t = mkldnn::impl::memory_desc_t;
|
||||||
|
|
||||||
|
mkldnn_primitive_desc(mkldnn::impl::engine_t *engine,
|
||||||
|
const mkldnn::impl::primitive_attr_t *attr,
|
||||||
|
mkldnn::impl::primitive_kind_t kind)
|
||||||
|
: engine_(engine), attr_(*attr), kind_(kind) { info_[0] = '\0'; }
|
||||||
|
|
||||||
|
mkldnn_primitive_desc(mkldnn::impl::engine_t *engine,
|
||||||
|
mkldnn::impl::primitive_kind_t kind)
|
||||||
|
: engine_(engine), kind_(kind) { info_[0] = '\0'; }
|
||||||
|
|
||||||
|
virtual mkldnn_primitive_desc *clone() const = 0;
|
||||||
|
virtual ~mkldnn_primitive_desc() {}
|
||||||
|
|
||||||
|
const mkldnn::impl::primitive_attr_t *attr() const { return &attr_; }
|
||||||
|
mkldnn::impl::engine_t *engine() const { return engine_; }
|
||||||
|
mkldnn::impl::primitive_kind_t kind() const { return kind_; }
|
||||||
|
|
||||||
|
virtual void init_info() {}
|
||||||
|
const char *info() const { return info_; }
|
||||||
|
|
||||||
|
mkldnn::impl::memory_tracking::registry_t &scratchpad_registry()
|
||||||
|
{ return scratchpad_registry_; }
|
||||||
|
const mkldnn::impl::memory_tracking::registry_t &scratchpad_registry() const
|
||||||
|
{ return scratchpad_registry_; }
|
||||||
|
virtual mkldnn::impl::engine_t *scratchpad_engine() const
|
||||||
|
{ return engine_; }
|
||||||
|
|
||||||
|
virtual const mkldnn::impl::op_desc_t *op_desc() const { return nullptr; }
|
||||||
|
|
||||||
|
enum class arg_usage_t { unused, input, output };
|
||||||
|
virtual arg_usage_t arg_usage(
|
||||||
|
mkldnn::impl::primitive_arg_index_t arg) const {
|
||||||
|
using mkldnn::impl::types::is_zero_md;
|
||||||
|
if (arg == MKLDNN_ARG_SCRATCHPAD && !is_zero_md(scratchpad_md()))
|
||||||
|
return arg_usage_t::output;
|
||||||
|
return arg_usage_t::unused;
|
||||||
|
}
|
||||||
|
|
||||||
|
# define DECLARE_MD_STUB(stub) \
|
||||||
|
virtual const mkldnn::impl::memory_desc_t *stub(int idx = 0) const \
|
||||||
|
{ return nullptr; }
|
||||||
|
|
||||||
|
DECLARE_MD_STUB(input_md); DECLARE_MD_STUB(output_md);
|
||||||
|
DECLARE_MD_STUB(src_md); DECLARE_MD_STUB(diff_src_md);
|
||||||
|
DECLARE_MD_STUB(dst_md); DECLARE_MD_STUB(diff_dst_md);
|
||||||
|
DECLARE_MD_STUB(weights_md); DECLARE_MD_STUB(diff_weights_md);
|
||||||
|
DECLARE_MD_STUB(workspace_md);
|
||||||
|
# undef DECLARE_MD_STUB
|
||||||
|
|
||||||
|
const mkldnn::impl::memory_desc_t *scratchpad_md(int idx = 0) const {
|
||||||
|
return idx == 0 ? &scratchpad_md_ : nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual void init_scratchpad_md() {
|
||||||
|
auto size = scratchpad_size(mkldnn::impl::scratchpad_mode::user);
|
||||||
|
mkldnn::impl::dims_t dims = { size };
|
||||||
|
mkldnn_memory_desc_init_by_tag(&scratchpad_md_, size ? 1 : 0, dims,
|
||||||
|
mkldnn::impl::data_type::u8, mkldnn_x);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** returns the scratchpad size for the given scratchpad mode. */
|
||||||
|
mkldnn::impl::dim_t scratchpad_size(
|
||||||
|
mkldnn::impl::scratchpad_mode_t mode) const {
|
||||||
|
if (mode != attr_.scratchpad_mode_) return 0;
|
||||||
|
return scratchpad_registry().size();
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual int n_inputs() const { return 0; }
|
||||||
|
virtual int n_outputs() const { return 0; }
|
||||||
|
|
||||||
|
virtual mkldnn::impl::status_t query(mkldnn::impl::query_t what, int idx,
|
||||||
|
void *result) const;
|
||||||
|
|
||||||
|
virtual mkldnn::impl::status_t create_primitive(
|
||||||
|
mkldnn::impl::primitive_t **primitive) const = 0;
|
||||||
|
|
||||||
|
virtual const char *name() const { return "mkldnn_primitive_desc"; }
|
||||||
|
|
||||||
|
/* static magic */
|
||||||
|
|
||||||
|
template<typename pd_t>
|
||||||
|
static mkldnn::impl::status_t create(mkldnn::impl::primitive_desc_t **pd,
|
||||||
|
const mkldnn::impl::op_desc_t *adesc,
|
||||||
|
const mkldnn::impl::primitive_attr_t *attr,
|
||||||
|
mkldnn::impl::engine_t *engine,
|
||||||
|
const mkldnn::impl::primitive_desc_t *hint_fwd) {
|
||||||
|
using namespace mkldnn::impl;
|
||||||
|
using namespace mkldnn::impl::status;
|
||||||
|
using pd_op_desc_t = typename pkind_traits<pd_t::base_pkind>::desc_type;
|
||||||
|
if (adesc->kind != pd_t::base_pkind) return invalid_arguments;
|
||||||
|
assert(hint_fwd ? hint_fwd->kind() == pd_t::base_pkind : true);
|
||||||
|
auto hint =
|
||||||
|
reinterpret_cast<const typename pd_t::hint_class *>(hint_fwd);
|
||||||
|
auto _pd = new pd_t(engine, (const pd_op_desc_t *)adesc, attr, hint);
|
||||||
|
if (_pd == nullptr) return out_of_memory;
|
||||||
|
if (_pd->init() != success) { delete _pd; return unimplemented; }
|
||||||
|
_pd->init_info();
|
||||||
|
_pd->init_scratchpad_md();
|
||||||
|
*pd = _pd;
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
mkldnn::impl::engine_t *engine_;
|
||||||
|
mkldnn::impl::primitive_attr_t attr_;
|
||||||
|
mkldnn::impl::primitive_kind_t kind_;
|
||||||
|
|
||||||
|
mkldnn::impl::memory_desc_t scratchpad_md_;
|
||||||
|
|
||||||
|
char info_[MKLDNN_VERBOSE_BUF_LEN];
|
||||||
|
|
||||||
|
mkldnn::impl::memory_tracking::registry_t scratchpad_registry_;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
/** compares ws between fwd_pd and this (make sense to use for bwd_pd)
|
||||||
|
* Expectation: this already set workspace, and this workspace should
|
||||||
|
* exactly match the one from fwd_pd */
|
||||||
|
bool compare_ws(const mkldnn_primitive_desc *fwd_pd) const {
|
||||||
|
using namespace mkldnn::impl;
|
||||||
|
if (!workspace_md()) return true; // the impl lives fine w/o workspace
|
||||||
|
return fwd_pd && fwd_pd->workspace_md()
|
||||||
|
&& *fwd_pd->workspace_md() == *workspace_md();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#define DECLARE_COMMON_PD_t(impl_name, ...) \
|
||||||
|
virtual pd_t *clone() const override { return new pd_t(*this); } \
|
||||||
|
virtual status_t create_primitive(primitive_t **p) const override { \
|
||||||
|
double ms = get_msec(); \
|
||||||
|
auto ret = safe_ptr_assign<primitive_t>(*p, new (__VA_ARGS__)(this)); \
|
||||||
|
ms = get_msec() - ms; \
|
||||||
|
if (mkldnn_verbose()->level >= 2) { \
|
||||||
|
printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \
|
||||||
|
fflush(0); \
|
||||||
|
} \
|
||||||
|
return ret; \
|
||||||
|
} \
|
||||||
|
virtual const char *name() const override { return impl_name; }
|
||||||
|
#define DECLARE_COMMON_PD_T(impl_name, ...) \
|
||||||
|
DECLARE_COMMON_PD_t(impl_name, __VA_ARGS__)
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
90
thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.cpp
vendored
Normal file
90
thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.cpp
vendored
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#include "memory.hpp"
|
||||||
|
#include "primitive.hpp"
|
||||||
|
#include "primitive_exec_types.hpp"
|
||||||
|
|
||||||
|
namespace mkldnn {
|
||||||
|
namespace impl {
|
||||||
|
|
||||||
|
status_t cvt_primtive_args(const primitive_desc_t *pd, int nargs,
|
||||||
|
const mkldnn_exec_arg_t *c_args, exec_args_t &args) {
|
||||||
|
using namespace status;
|
||||||
|
|
||||||
|
if (!IMPLICATION(nargs > 0, c_args != nullptr)) return invalid_arguments;
|
||||||
|
|
||||||
|
int n_inputs = 0;
|
||||||
|
int n_outputs = 0;
|
||||||
|
|
||||||
|
for (int i = 0; i < nargs; ++i) {
|
||||||
|
primitive_arg_index_t arg = c_args[i].arg;
|
||||||
|
auto *mem = c_args[i].memory;
|
||||||
|
|
||||||
|
switch (pd->arg_usage(arg)) {
|
||||||
|
case primitive_desc_t::arg_usage_t::input:
|
||||||
|
if (args.count(arg) != 0) return invalid_arguments;
|
||||||
|
args[arg] = {mem, true};
|
||||||
|
n_inputs++;
|
||||||
|
break;
|
||||||
|
case primitive_desc_t::arg_usage_t::output:
|
||||||
|
if (args.count(arg) != 0) return invalid_arguments;
|
||||||
|
args[arg] = {mem, false};
|
||||||
|
n_outputs++;
|
||||||
|
break;
|
||||||
|
case primitive_desc_t::arg_usage_t::unused:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool scratchpad_required = !types::is_zero_md(pd->scratchpad_md());
|
||||||
|
|
||||||
|
if (n_inputs != pd->n_inputs()) return invalid_arguments;
|
||||||
|
if (n_outputs != pd->n_outputs() + (scratchpad_required ? 1 : 0))
|
||||||
|
return invalid_arguments;
|
||||||
|
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
const void *exec_ctx_t::input(primitive_arg_index_t arg) const {
|
||||||
|
if (args_.count(arg) != 1) return nullptr;
|
||||||
|
const auto ma = args_.at(arg);
|
||||||
|
assert(ma.is_const);
|
||||||
|
void *ptr;
|
||||||
|
status_t status = ma.mem->get_data_handle(&ptr);
|
||||||
|
assert(status == status::success); MAYBE_UNUSED(status);
|
||||||
|
return ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void *exec_ctx_t::output(primitive_arg_index_t arg) const {
|
||||||
|
if (args_.count(arg) != 1) return nullptr;
|
||||||
|
const auto ma = args_.at(arg);
|
||||||
|
assert(!ma.is_const);
|
||||||
|
void *ptr;
|
||||||
|
status_t status = ma.mem->get_data_handle(&ptr);
|
||||||
|
assert(status == status::success); MAYBE_UNUSED(status);
|
||||||
|
return ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
const memory_t *exec_ctx_t::memory(primitive_arg_index_t arg) const {
|
||||||
|
assert(args_.count(arg) == 1);
|
||||||
|
const auto ma = args_.at(arg);
|
||||||
|
assert(!ma.is_const);
|
||||||
|
return ma.mem;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
68
thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.hpp
vendored
Normal file
68
thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.hpp
vendored
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#ifndef PRIMITIVE_EXEC_TYPES_HPP
|
||||||
|
#define PRIMITIVE_EXEC_TYPES_HPP
|
||||||
|
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include "mkldnn_types.h"
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "memory.hpp"
|
||||||
|
#include "primitive_desc.hpp"
|
||||||
|
|
||||||
|
namespace mkldnn {
|
||||||
|
namespace impl {
|
||||||
|
|
||||||
|
struct memory_arg_t {
|
||||||
|
memory_t *mem;
|
||||||
|
bool is_const;
|
||||||
|
};
|
||||||
|
|
||||||
|
using exec_args_t = std::unordered_map<primitive_arg_index_t, memory_arg_t>;
|
||||||
|
|
||||||
|
status_t cvt_primtive_args(const primitive_desc_t *pd, int nargs,
|
||||||
|
const mkldnn_exec_arg_t *c_args, exec_args_t &args);
|
||||||
|
|
||||||
|
/** Primitive execution context (helps passing stream, memories, and events. */
|
||||||
|
struct exec_ctx_t {
|
||||||
|
exec_ctx_t(const exec_ctx_t &) = default;
|
||||||
|
exec_ctx_t(exec_ctx_t &&) = default;
|
||||||
|
|
||||||
|
exec_ctx_t(stream_t *stream): stream_(stream) {}
|
||||||
|
exec_ctx_t(stream_t *stream, exec_args_t &&args)
|
||||||
|
: stream_(stream)
|
||||||
|
, args_(std::move(args)) {}
|
||||||
|
|
||||||
|
stream_t *stream() const { return stream_; }
|
||||||
|
const exec_args_t &args() const { return args_; }
|
||||||
|
|
||||||
|
/* tentative solution... TODO: replace with functions return memory_t */
|
||||||
|
const void *input(primitive_arg_index_t arg) const;
|
||||||
|
void *output(primitive_arg_index_t arg) const;
|
||||||
|
|
||||||
|
const memory_t *memory(primitive_arg_index_t arg) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
stream_t *stream_;
|
||||||
|
exec_args_t args_;
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
89
thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.cpp
vendored
Normal file
89
thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.cpp
vendored
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2016-2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
#include "mkldnn.h"
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "engine.hpp"
|
||||||
|
#include "primitive_desc.hpp"
|
||||||
|
#include "type_helpers.hpp"
|
||||||
|
#include "primitive_iterator.hpp"
|
||||||
|
|
||||||
|
using namespace mkldnn::impl;
|
||||||
|
using namespace mkldnn::impl::status;
|
||||||
|
|
||||||
|
status_t mkldnn_primitive_desc_iterator_create(
|
||||||
|
primitive_desc_iterator_t **iterator, const_c_op_desc_t c_op_desc,
|
||||||
|
const primitive_attr_t *attr, engine_t *engine,
|
||||||
|
const primitive_desc_t *hint_fwd_pd) {
|
||||||
|
const op_desc_t *op_desc = (const op_desc_t *)c_op_desc;
|
||||||
|
|
||||||
|
auto it = new primitive_desc_iterator_t(engine, op_desc, attr, hint_fwd_pd);
|
||||||
|
if (it == nullptr) return out_of_memory;
|
||||||
|
|
||||||
|
++(*it);
|
||||||
|
if (*it == it->end()) {
|
||||||
|
delete it;
|
||||||
|
return unimplemented;
|
||||||
|
}
|
||||||
|
|
||||||
|
*iterator = it;
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_primitive_desc_iterator_next(
|
||||||
|
primitive_desc_iterator_t *iterator) {
|
||||||
|
if (iterator == nullptr) return invalid_arguments;
|
||||||
|
++(*iterator);
|
||||||
|
return *iterator == iterator->end() ? iterator_ends : success;
|
||||||
|
}
|
||||||
|
|
||||||
|
primitive_desc_t *mkldnn_primitive_desc_iterator_fetch(
|
||||||
|
const primitive_desc_iterator_t *iterator) {
|
||||||
|
if (iterator == nullptr) return nullptr;
|
||||||
|
return *(*iterator);
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_primitive_desc_clone(primitive_desc_t **primitive_desc,
|
||||||
|
const primitive_desc_t *existing_primitive_desc) {
|
||||||
|
if (utils::any_null(primitive_desc, existing_primitive_desc))
|
||||||
|
return invalid_arguments;
|
||||||
|
return safe_ptr_assign<primitive_desc_t>(*primitive_desc,
|
||||||
|
existing_primitive_desc->clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_primitive_desc_iterator_destroy(
|
||||||
|
primitive_desc_iterator_t *iterator) {
|
||||||
|
if (iterator != nullptr)
|
||||||
|
delete iterator;
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t mkldnn_primitive_desc_create(primitive_desc_t **primitive_desc,
|
||||||
|
const_c_op_desc_t c_op_desc, const primitive_attr_t *attr,
|
||||||
|
engine_t *engine, const primitive_desc_t *hint_fwd_pd) {
|
||||||
|
const op_desc_t *op_desc = (const op_desc_t *)c_op_desc;
|
||||||
|
|
||||||
|
mkldnn_primitive_desc_iterator it(engine, op_desc, attr, hint_fwd_pd);
|
||||||
|
++it;
|
||||||
|
if (it == it.end()) return unimplemented;
|
||||||
|
|
||||||
|
return safe_ptr_assign<primitive_desc_t>(*primitive_desc, *it);
|
||||||
|
}
|
||||||
|
|
||||||
|
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
79
thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.hpp
vendored
Normal file
79
thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.hpp
vendored
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
#ifndef PRIMITIVE_ITERATOR_HPP
|
||||||
|
#define PRIMITIVE_ITERATOR_HPP
|
||||||
|
|
||||||
|
#include "mkldnn.h"
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "engine.hpp"
|
||||||
|
#include "primitive_desc.hpp"
|
||||||
|
#include "type_helpers.hpp"
|
||||||
|
|
||||||
|
struct mkldnn_primitive_desc_iterator: public mkldnn::impl::c_compatible {
|
||||||
|
using pd_create_f = mkldnn::impl::engine_t::primitive_desc_create_f;
|
||||||
|
|
||||||
|
mkldnn_primitive_desc_iterator(mkldnn::impl::engine_t *engine, const mkldnn::impl::op_desc_t *op_desc,
|
||||||
|
const mkldnn::impl::primitive_attr_t *attr, const mkldnn::impl::primitive_desc_t *hint_fwd_pd)
|
||||||
|
: idx_(-1), engine_(engine), pd_(nullptr), op_desc_(op_desc)
|
||||||
|
, attr_(attr ? *attr : mkldnn::impl::primitive_attr_t()), hint_fwd_pd_(hint_fwd_pd)
|
||||||
|
, impl_list_(engine_->get_implementation_list()), last_idx_(0)
|
||||||
|
{
|
||||||
|
while (impl_list_[last_idx_] != nullptr) ++last_idx_;
|
||||||
|
}
|
||||||
|
~mkldnn_primitive_desc_iterator() { if (pd_) delete pd_; }
|
||||||
|
|
||||||
|
bool operator==(const mkldnn::impl::primitive_desc_iterator_t& rhs) const
|
||||||
|
{ return idx_ == rhs.idx_ && engine_ == rhs.engine_; }
|
||||||
|
bool operator!=(const mkldnn::impl::primitive_desc_iterator_t& rhs) const
|
||||||
|
{ return !operator==(rhs); }
|
||||||
|
|
||||||
|
mkldnn::impl::primitive_desc_iterator_t end() const
|
||||||
|
{ return mkldnn_primitive_desc_iterator(engine_, last_idx_); }
|
||||||
|
|
||||||
|
mkldnn::impl::primitive_desc_iterator_t &operator++() {
|
||||||
|
if (pd_) { delete pd_; pd_ = nullptr; }
|
||||||
|
while (++idx_ != last_idx_) {
|
||||||
|
auto s = impl_list_[idx_](&pd_, op_desc_, &attr_, engine_,
|
||||||
|
hint_fwd_pd_);
|
||||||
|
if (s == mkldnn::impl::status::success) break;
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
mkldnn::impl::primitive_desc_t *operator*() const {
|
||||||
|
if (*this == end() || pd_ == nullptr) return nullptr;
|
||||||
|
return pd_->clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
int idx_;
|
||||||
|
mkldnn::impl::engine_t *engine_;
|
||||||
|
mkldnn::impl::primitive_desc_t *pd_;
|
||||||
|
const mkldnn::impl::op_desc_t *op_desc_;
|
||||||
|
const mkldnn::impl::primitive_attr_t attr_;
|
||||||
|
const mkldnn::impl::primitive_desc_t *hint_fwd_pd_;
|
||||||
|
const pd_create_f *impl_list_;
|
||||||
|
int last_idx_;
|
||||||
|
|
||||||
|
private:
|
||||||
|
mkldnn_primitive_desc_iterator(mkldnn::impl::engine_t *engine, int last_idx)
|
||||||
|
: idx_(last_idx), engine_(engine), pd_(nullptr)
|
||||||
|
, op_desc_(nullptr), hint_fwd_pd_(nullptr)
|
||||||
|
, impl_list_(nullptr), last_idx_(last_idx) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif
|
59
thirdparty/oidn/mkl-dnn/src/common/query.cpp
vendored
Normal file
59
thirdparty/oidn/mkl-dnn/src/common/query.cpp
vendored
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2016-2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
#include "mkldnn.h"
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "engine.hpp"
|
||||||
|
#include "primitive_desc.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
using namespace mkldnn::impl;
|
||||||
|
using namespace mkldnn::impl::utils;
|
||||||
|
using namespace mkldnn::impl::status;
|
||||||
|
|
||||||
|
status_t mkldnn_primitive_desc_query(const primitive_desc_t *primitive_desc,
|
||||||
|
query_t what, int index, void *result) {
|
||||||
|
if (any_null(primitive_desc, result))
|
||||||
|
return invalid_arguments;
|
||||||
|
|
||||||
|
return primitive_desc->query(what, index, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
const memory_desc_t *mkldnn_primitive_desc_query_md(
|
||||||
|
const primitive_desc_t *primitive_desc, query_t what, int index) {
|
||||||
|
const memory_desc_t *res_md = nullptr;
|
||||||
|
bool args_ok = true
|
||||||
|
&& primitive_desc != nullptr
|
||||||
|
&& (what & query::some_md) == query::some_md
|
||||||
|
&& what != query::some_md
|
||||||
|
&& mkldnn_primitive_desc_query(primitive_desc,
|
||||||
|
what, index, &res_md) == success;
|
||||||
|
return args_ok ? res_md : nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
int mkldnn_primitive_desc_query_s32(const primitive_desc_t *primitive_desc,
|
||||||
|
query_t what, int index) {
|
||||||
|
int res_s32;
|
||||||
|
bool args_ok = primitive_desc != nullptr
|
||||||
|
&& one_of(what, query::num_of_inputs_s32, query::num_of_outputs_s32)
|
||||||
|
&& mkldnn_primitive_desc_query(primitive_desc, what, index, &res_s32)
|
||||||
|
== success;
|
||||||
|
return args_ok ? res_s32 : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
68
thirdparty/oidn/mkl-dnn/src/common/reorder.cpp
vendored
Normal file
68
thirdparty/oidn/mkl-dnn/src/common/reorder.cpp
vendored
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2016-2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
#include "mkldnn.h"
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "engine.hpp"
|
||||||
|
#include "type_helpers.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
#include "reorder_pd.hpp"
|
||||||
|
|
||||||
|
using namespace mkldnn::impl;
|
||||||
|
using namespace mkldnn::impl::utils;
|
||||||
|
using namespace mkldnn::impl::status;
|
||||||
|
|
||||||
|
status_t mkldnn_reorder_primitive_desc_create(
|
||||||
|
primitive_desc_t **reorder_pd,
|
||||||
|
engine_t *src_engine, const memory_desc_t *src_md,
|
||||||
|
engine_t *dst_engine, const memory_desc_t *dst_md,
|
||||||
|
const primitive_attr_t *attr) {
|
||||||
|
if (any_null(reorder_pd, src_engine, src_md, dst_engine, dst_md))
|
||||||
|
return invalid_arguments;
|
||||||
|
|
||||||
|
auto s_ek = src_engine->kind();
|
||||||
|
auto d_ek = dst_engine->kind();
|
||||||
|
if (!IMPLICATION(s_ek != d_ek, one_of(engine_kind::cpu, s_ek, d_ek)))
|
||||||
|
return invalid_arguments;
|
||||||
|
|
||||||
|
auto r_pd = reinterpret_cast<reorder_pd_t **>(reorder_pd);
|
||||||
|
auto s_mdw = memory_desc_wrapper(*src_md);
|
||||||
|
auto d_mdw = memory_desc_wrapper(*dst_md);
|
||||||
|
|
||||||
|
if (!s_mdw.consistent_with(d_mdw))
|
||||||
|
return invalid_arguments;
|
||||||
|
|
||||||
|
auto e = (s_ek != engine_kind::cpu) ? src_engine : dst_engine;
|
||||||
|
|
||||||
|
const primitive_attr_t dummy_attr;
|
||||||
|
if (attr == NULL)
|
||||||
|
attr = &dummy_attr;
|
||||||
|
|
||||||
|
for (auto r = e->get_reorder_implementation_list(); *r; ++r) {
|
||||||
|
if ((*r)(r_pd, e, attr, src_engine, src_md, dst_engine, dst_md)
|
||||||
|
== success) {
|
||||||
|
(*r_pd)->init_info();
|
||||||
|
(*r_pd)->init_scratchpad_md();
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return unimplemented;
|
||||||
|
}
|
||||||
|
|
||||||
|
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
85
thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp
vendored
Normal file
85
thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp
vendored
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2016-2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#ifndef REORDER_PD_HPP
|
||||||
|
#define REORDER_PD_HPP
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "primitive_attr.hpp"
|
||||||
|
#include "type_helpers.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
namespace mkldnn {
|
||||||
|
namespace impl {
|
||||||
|
|
||||||
|
struct reorder_pd_t: public primitive_desc_t {
|
||||||
|
reorder_pd_t(engine_t *engine, const primitive_attr_t *attr,
|
||||||
|
engine_t *src_engine, const memory_desc_t *src_md,
|
||||||
|
engine_t *dst_engine, const memory_desc_t *dst_md)
|
||||||
|
: primitive_desc_t(engine, attr, primitive_kind::reorder)
|
||||||
|
, src_engine_(src_engine)
|
||||||
|
, dst_engine_(dst_engine)
|
||||||
|
, scratchpad_engine_(nullptr)
|
||||||
|
, src_md_(*src_md)
|
||||||
|
, dst_md_(*dst_md)
|
||||||
|
{}
|
||||||
|
|
||||||
|
virtual const op_desc_t *op_desc() const override { return nullptr; }
|
||||||
|
virtual void init_info() override { impl::init_info(this, this->info_); }
|
||||||
|
|
||||||
|
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
||||||
|
if (arg == MKLDNN_ARG_FROM)
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_TO)
|
||||||
|
return arg_usage_t::output;
|
||||||
|
|
||||||
|
return primitive_desc_t::arg_usage(arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual const memory_desc_t *src_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &src_md_ : nullptr; }
|
||||||
|
virtual const memory_desc_t *dst_md(int index = 0) const override
|
||||||
|
{ return index == 0 ? &dst_md_ : nullptr; }
|
||||||
|
|
||||||
|
virtual int n_inputs() const override { return 1; }
|
||||||
|
virtual int n_outputs() const override { return 1; }
|
||||||
|
|
||||||
|
float alpha() const { return attr()->output_scales_.scales_[0]; }
|
||||||
|
float beta() const {
|
||||||
|
const int sum_idx = attr()->post_ops_.find(primitive_kind::sum);
|
||||||
|
return sum_idx == -1 ? 0 : attr()->post_ops_.entry_[sum_idx].sum.scale;
|
||||||
|
}
|
||||||
|
virtual mkldnn::impl::engine_t *scratchpad_engine() const override
|
||||||
|
{ return scratchpad_engine_; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
engine_t *src_engine_;
|
||||||
|
engine_t *dst_engine_;
|
||||||
|
engine_t *scratchpad_engine_;
|
||||||
|
|
||||||
|
memory_desc_t src_md_;
|
||||||
|
memory_desc_t dst_md_;
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
|
400
thirdparty/oidn/mkl-dnn/src/common/rnn.cpp
vendored
Normal file
400
thirdparty/oidn/mkl-dnn/src/common/rnn.cpp
vendored
Normal file
@ -0,0 +1,400 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#include "mkldnn.h"
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "type_helpers.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
#include "cpu/gemm/os_blas.hpp"
|
||||||
|
|
||||||
|
using namespace mkldnn::impl;
|
||||||
|
using namespace mkldnn::impl::status;
|
||||||
|
using namespace mkldnn::impl::types;
|
||||||
|
using namespace mkldnn::impl::utils;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
memory_desc_t copy_maybe_null(const memory_desc_t *md) {
|
||||||
|
return md ? *md : zero_md();
|
||||||
|
}
|
||||||
|
|
||||||
|
rnn_desc_t zero_rnn_desc() {
|
||||||
|
auto rd = rnn_desc_t();
|
||||||
|
rd.src_layer_desc = zero_md();
|
||||||
|
rd.src_iter_desc = zero_md();
|
||||||
|
rd.weights_layer_desc = zero_md();
|
||||||
|
rd.weights_iter_desc = zero_md();
|
||||||
|
rd.bias_desc = zero_md();
|
||||||
|
rd.dst_layer_desc = zero_md();
|
||||||
|
rd.dst_iter_desc = zero_md();
|
||||||
|
rd.diff_src_layer_desc = zero_md();
|
||||||
|
rd.diff_src_iter_desc = zero_md();
|
||||||
|
rd.diff_weights_layer_desc = zero_md();
|
||||||
|
rd.diff_weights_iter_desc = zero_md();
|
||||||
|
rd.diff_bias_desc = zero_md();
|
||||||
|
rd.diff_dst_layer_desc = zero_md();
|
||||||
|
rd.diff_dst_iter_desc = zero_md();
|
||||||
|
return rd;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Public C Api */
|
||||||
|
|
||||||
|
status_t mkldnn_rnn_cell_desc_init(rnn_cell_desc_t *rnn_cell_desc,
|
||||||
|
mkldnn_alg_kind_t cell_kind, mkldnn_alg_kind_t act_f,
|
||||||
|
unsigned int flags, float alpha, float clipping) {
|
||||||
|
using namespace mkldnn::impl::alg_kind;
|
||||||
|
|
||||||
|
bool args_ok = true
|
||||||
|
&& one_of(cell_kind, vanilla_rnn, vanilla_lstm, vanilla_gru,
|
||||||
|
gru_linear_before_reset)
|
||||||
|
&& IMPLICATION(cell_kind == vanilla_rnn,
|
||||||
|
one_of(act_f, eltwise_relu, eltwise_tanh, eltwise_logistic));
|
||||||
|
if (!args_ok)
|
||||||
|
return invalid_arguments;
|
||||||
|
|
||||||
|
auto rcd = mkldnn_rnn_cell_desc_t();
|
||||||
|
|
||||||
|
rcd.cell_kind = cell_kind;
|
||||||
|
rcd.activation_kind = act_f;
|
||||||
|
rcd.flags = flags;
|
||||||
|
rcd.alpha = rcd.flags & mkldnn_rnn_cell_with_relu ? alpha : 0;
|
||||||
|
rcd.clipping = rcd.flags & mkldnn_rnn_cell_with_clipping ? clipping : 0;
|
||||||
|
|
||||||
|
*rnn_cell_desc = rcd;
|
||||||
|
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
int mkldnn_rnn_cell_get_gates_count(const rnn_cell_desc_t *rnn_cell_desc) {
|
||||||
|
switch (rnn_cell_desc->cell_kind) {
|
||||||
|
case mkldnn::impl::alg_kind::vanilla_rnn: return 1;
|
||||||
|
case mkldnn::impl::alg_kind::vanilla_gru: return 3;
|
||||||
|
case mkldnn::impl::alg_kind::gru_linear_before_reset: return 3;
|
||||||
|
case mkldnn::impl::alg_kind::vanilla_lstm: return 4;
|
||||||
|
default: assert(!"unknown cell kind"); return 0;
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int mkldnn_rnn_cell_get_states_count(const rnn_cell_desc_t *rnn_cell_desc) {
|
||||||
|
switch (rnn_cell_desc->cell_kind) {
|
||||||
|
case mkldnn::impl::alg_kind::vanilla_rnn: return 1;
|
||||||
|
case mkldnn::impl::alg_kind::vanilla_gru: return 1;
|
||||||
|
case mkldnn::impl::alg_kind::gru_linear_before_reset: return 1;
|
||||||
|
case mkldnn::impl::alg_kind::vanilla_lstm: return 2;
|
||||||
|
default: assert(!"unknown cell kind"); return 0;
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t check_data_type_consistency_fwd(const rnn_cell_desc_t *rnn_cell_desc,
|
||||||
|
prop_kind_t prop_kind, const memory_desc_t *src_layer_desc,
|
||||||
|
const memory_desc_t *src_iter_desc,
|
||||||
|
const memory_desc_t *weights_layer_desc,
|
||||||
|
const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
|
||||||
|
const memory_desc_t *dst_layer_desc,
|
||||||
|
const memory_desc_t *dst_iter_desc) {
|
||||||
|
using namespace data_type;
|
||||||
|
data_type_t src_layer_dt = src_layer_desc->data_type;
|
||||||
|
data_type_t dst_layer_dt = dst_layer_desc->data_type;
|
||||||
|
data_type_t weights_iter_dt = weights_iter_desc->data_type;
|
||||||
|
data_type_t weights_layer_dt = weights_layer_desc->data_type;
|
||||||
|
|
||||||
|
bool is_f32 = everyone_is(f32, src_layer_dt, dst_layer_dt, weights_iter_dt,
|
||||||
|
weights_layer_dt)
|
||||||
|
&& IMPLICATION(!is_zero_md(src_iter_desc),
|
||||||
|
src_iter_desc->data_type == f32)
|
||||||
|
&& IMPLICATION(!is_zero_md(dst_iter_desc),
|
||||||
|
dst_iter_desc->data_type == f32)
|
||||||
|
&& IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32);
|
||||||
|
|
||||||
|
#if USE_MKL_PACKED_GEMM
|
||||||
|
bool is_u8u8u8 = src_layer_dt == u8
|
||||||
|
&& IMPLICATION(!is_zero_md(src_iter_desc),
|
||||||
|
src_iter_desc->data_type == u8)
|
||||||
|
&& IMPLICATION(!is_zero_md(dst_iter_desc),
|
||||||
|
dst_iter_desc->data_type == u8)
|
||||||
|
&& one_of(dst_layer_dt, u8, f32)
|
||||||
|
&& everyone_is(s8, weights_iter_dt, weights_layer_dt)
|
||||||
|
&& IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32);
|
||||||
|
|
||||||
|
bool is_f32u8f32 = src_layer_dt == u8
|
||||||
|
&& IMPLICATION(!is_zero_md(src_iter_desc),
|
||||||
|
src_iter_desc->data_type == f32)
|
||||||
|
&& IMPLICATION(!is_zero_md(dst_iter_desc),
|
||||||
|
dst_iter_desc->data_type == f32)
|
||||||
|
&& one_of(dst_layer_dt, u8, f32)
|
||||||
|
&& everyone_is(s8, weights_iter_dt, weights_layer_dt)
|
||||||
|
&& IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32);
|
||||||
|
|
||||||
|
bool is_inference = prop_kind == prop_kind::forward_inference;
|
||||||
|
bool is_lstm = rnn_cell_desc->cell_kind == mkldnn_vanilla_lstm;
|
||||||
|
|
||||||
|
return (is_f32 || ((is_u8u8u8 || is_f32u8f32) && is_lstm && is_inference))
|
||||||
|
? success
|
||||||
|
: unimplemented;
|
||||||
|
#else
|
||||||
|
return is_f32 ? success : unimplemented;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t check_dim_consistency(const rnn_cell_desc_t *rnn_cell_desc,
|
||||||
|
rnn_direction_t direction, int L, int D, int T, int N, int S, int G,
|
||||||
|
int SLC, int SIC, int DLC, int DIC, const memory_desc_t *src_layer_desc,
|
||||||
|
const memory_desc_t *src_iter_desc,
|
||||||
|
const memory_desc_t *weights_layer_desc,
|
||||||
|
const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
|
||||||
|
const memory_desc_t *dst_layer_desc,
|
||||||
|
const memory_desc_t *dst_iter_desc) {
|
||||||
|
bool args_ok;
|
||||||
|
|
||||||
|
// * algorithm specific
|
||||||
|
args_ok = true
|
||||||
|
&& IMPLICATION(rnn_cell_desc->cell_kind == alg_kind::vanilla_gru,
|
||||||
|
DIC == SIC);
|
||||||
|
if (!args_ok) return invalid_arguments;
|
||||||
|
int extra_bias =
|
||||||
|
rnn_cell_desc->cell_kind == alg_kind::gru_linear_before_reset;
|
||||||
|
|
||||||
|
// * on num layers
|
||||||
|
args_ok = true
|
||||||
|
&& L == weights_layer_desc->dims[0]
|
||||||
|
&& L == weights_iter_desc->dims[0]
|
||||||
|
&& IMPLICATION(!is_zero_md(bias_desc), L == bias_desc->dims[0])
|
||||||
|
&& IMPLICATION(!is_zero_md(src_iter_desc), L == src_iter_desc->dims[0])
|
||||||
|
&& IMPLICATION(!is_zero_md(dst_iter_desc), L == dst_iter_desc->dims[0]);
|
||||||
|
if (!args_ok) return invalid_arguments;
|
||||||
|
|
||||||
|
// * on num directions
|
||||||
|
args_ok = true
|
||||||
|
&& D == weights_layer_desc->dims[1]
|
||||||
|
&& D == weights_iter_desc->dims[1]
|
||||||
|
&& IMPLICATION(!is_zero_md(bias_desc), D == bias_desc->dims[1])
|
||||||
|
&& IMPLICATION(!is_zero_md(src_iter_desc), D == src_iter_desc->dims[1])
|
||||||
|
&& IMPLICATION(!is_zero_md(dst_iter_desc), D == dst_iter_desc->dims[1]);
|
||||||
|
if (!args_ok) return invalid_arguments;
|
||||||
|
|
||||||
|
// * on num iterations
|
||||||
|
args_ok = true
|
||||||
|
&& T == src_layer_desc->dims[0]
|
||||||
|
&& T == dst_layer_desc->dims[0];
|
||||||
|
if (!args_ok) return invalid_arguments;
|
||||||
|
|
||||||
|
// * on mb
|
||||||
|
args_ok = true
|
||||||
|
&& N == src_layer_desc->dims[1]
|
||||||
|
&& N == dst_layer_desc->dims[1]
|
||||||
|
&& IMPLICATION(!is_zero_md(src_iter_desc), N == src_iter_desc->dims[3])
|
||||||
|
&& IMPLICATION(!is_zero_md(dst_iter_desc), N == dst_iter_desc->dims[3]);
|
||||||
|
if (!args_ok) return invalid_arguments;
|
||||||
|
|
||||||
|
// * on num gates
|
||||||
|
args_ok = true
|
||||||
|
&& G == mkldnn_rnn_cell_get_gates_count(rnn_cell_desc)
|
||||||
|
&& G == weights_layer_desc->dims[3]
|
||||||
|
&& G == weights_iter_desc->dims[3]
|
||||||
|
&& IMPLICATION(!is_zero_md(bias_desc),
|
||||||
|
G + extra_bias == bias_desc->dims[2]);
|
||||||
|
if (!args_ok) return invalid_arguments;
|
||||||
|
|
||||||
|
// * on num states
|
||||||
|
args_ok = true
|
||||||
|
&& S == mkldnn_rnn_cell_get_states_count(rnn_cell_desc)
|
||||||
|
&& IMPLICATION(!is_zero_md(src_iter_desc), S == src_iter_desc->dims[2])
|
||||||
|
&& IMPLICATION(!is_zero_md(dst_iter_desc), S == dst_iter_desc->dims[2]);
|
||||||
|
if (!args_ok) return invalid_arguments;
|
||||||
|
|
||||||
|
// * on slc
|
||||||
|
args_ok = true
|
||||||
|
&& SLC == weights_layer_desc->dims[2]
|
||||||
|
&& SLC == src_layer_desc->dims[2];
|
||||||
|
if (!args_ok) return invalid_arguments;
|
||||||
|
|
||||||
|
// * on sic
|
||||||
|
args_ok = true
|
||||||
|
&& SIC == weights_iter_desc->dims[2]
|
||||||
|
&& IMPLICATION(!is_zero_md(src_iter_desc),
|
||||||
|
SIC == src_iter_desc->dims[4]);
|
||||||
|
if (!args_ok) return invalid_arguments;
|
||||||
|
|
||||||
|
// * on dlc
|
||||||
|
int dlc_multiplier = (direction == mkldnn_bidirectional_concat) ? 2 : 1;
|
||||||
|
args_ok = true
|
||||||
|
&& DLC == dlc_multiplier * DIC
|
||||||
|
&& DLC == dst_layer_desc->dims[2];
|
||||||
|
if (!args_ok) return invalid_arguments;
|
||||||
|
|
||||||
|
// * on dic
|
||||||
|
args_ok = true
|
||||||
|
&& DIC == weights_layer_desc->dims[4]
|
||||||
|
&& DIC == weights_iter_desc->dims[4]
|
||||||
|
&& IMPLICATION(!is_zero_md(bias_desc), DIC == bias_desc->dims[3])
|
||||||
|
&& IMPLICATION(!is_zero_md(dst_iter_desc),
|
||||||
|
DIC == dst_iter_desc->dims[4]);
|
||||||
|
if (!args_ok) return invalid_arguments;
|
||||||
|
|
||||||
|
// * unrolling/fusion conditions
|
||||||
|
args_ok = true
|
||||||
|
&& IMPLICATION(L > 1, (dlc_multiplier * SLC) == DLC)
|
||||||
|
&& IMPLICATION(T > 1, SIC == DIC);
|
||||||
|
if (!args_ok) return invalid_arguments;
|
||||||
|
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t MKLDNN_API mkldnn_rnn_forward_desc_init(mkldnn_rnn_desc_t *rnn_desc,
|
||||||
|
prop_kind_t prop_kind, const rnn_cell_desc_t *rnn_cell_desc,
|
||||||
|
const rnn_direction_t direction, const memory_desc_t *src_layer_desc,
|
||||||
|
const memory_desc_t *src_iter_desc,
|
||||||
|
const memory_desc_t *weights_layer_desc,
|
||||||
|
const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
|
||||||
|
const memory_desc_t *dst_layer_desc,
|
||||||
|
const memory_desc_t *dst_iter_desc) {
|
||||||
|
bool args_ok = true && rnn_cell_desc != nullptr
|
||||||
|
&& !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc,
|
||||||
|
dst_layer_desc);
|
||||||
|
if (!args_ok) return invalid_arguments;
|
||||||
|
|
||||||
|
//check dimensions consistency
|
||||||
|
int L = weights_layer_desc->dims[0];
|
||||||
|
int T = src_layer_desc->dims[0];
|
||||||
|
int N = src_layer_desc->dims[1];
|
||||||
|
const int D = one_of(direction, mkldnn_unidirectional_left2right,
|
||||||
|
mkldnn_unidirectional_right2left) ?
|
||||||
|
1 :
|
||||||
|
2;
|
||||||
|
int G = mkldnn_rnn_cell_get_gates_count(rnn_cell_desc);
|
||||||
|
int S = mkldnn_rnn_cell_get_states_count(rnn_cell_desc);
|
||||||
|
int SLC = src_layer_desc->dims[2];
|
||||||
|
int SIC = weights_iter_desc->dims[2];
|
||||||
|
int DLC = dst_layer_desc->dims[2];
|
||||||
|
int DIC = weights_layer_desc->dims[4];
|
||||||
|
|
||||||
|
CHECK(check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S,
|
||||||
|
G, SLC, SIC, DLC, DIC, src_layer_desc, src_iter_desc,
|
||||||
|
weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc,
|
||||||
|
dst_iter_desc));
|
||||||
|
|
||||||
|
CHECK(check_data_type_consistency_fwd(rnn_cell_desc, prop_kind,
|
||||||
|
src_layer_desc, src_iter_desc, weights_layer_desc,
|
||||||
|
weights_iter_desc, bias_desc, dst_layer_desc, dst_iter_desc));
|
||||||
|
|
||||||
|
// Create the descriptor
|
||||||
|
mkldnn_rnn_desc_t rd = zero_rnn_desc();
|
||||||
|
|
||||||
|
rd.primitive_kind = primitive_kind::rnn;
|
||||||
|
rd.prop_kind = prop_kind;
|
||||||
|
rd.cell_desc = *rnn_cell_desc;
|
||||||
|
rd.direction = direction;
|
||||||
|
rd.src_layer_desc = copy_maybe_null(src_layer_desc);
|
||||||
|
rd.src_iter_desc = copy_maybe_null(src_iter_desc);
|
||||||
|
rd.weights_layer_desc = copy_maybe_null(weights_layer_desc);
|
||||||
|
rd.weights_iter_desc = copy_maybe_null(weights_iter_desc);
|
||||||
|
rd.bias_desc = copy_maybe_null(bias_desc);
|
||||||
|
rd.dst_layer_desc = copy_maybe_null(dst_layer_desc);
|
||||||
|
rd.dst_iter_desc = copy_maybe_null(dst_iter_desc);
|
||||||
|
|
||||||
|
*rnn_desc = rd;
|
||||||
|
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
status_t MKLDNN_API mkldnn_rnn_backward_desc_init(mkldnn_rnn_desc_t *rnn_desc,
|
||||||
|
prop_kind_t prop_kind, const rnn_cell_desc_t *rnn_cell_desc,
|
||||||
|
const rnn_direction_t direction, const memory_desc_t *src_layer_desc,
|
||||||
|
const memory_desc_t *src_iter_desc,
|
||||||
|
const memory_desc_t *weights_layer_desc,
|
||||||
|
const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
|
||||||
|
const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc,
|
||||||
|
const memory_desc_t *diff_src_layer_desc,
|
||||||
|
const memory_desc_t *diff_src_iter_desc,
|
||||||
|
const memory_desc_t *diff_weights_layer_desc,
|
||||||
|
const memory_desc_t *diff_weights_iter_desc,
|
||||||
|
const memory_desc_t *diff_bias_desc,
|
||||||
|
const memory_desc_t *diff_dst_layer_desc,
|
||||||
|
const memory_desc_t *diff_dst_iter_desc) {
|
||||||
|
bool args_ok = true
|
||||||
|
&& !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc,
|
||||||
|
dst_layer_desc, diff_src_layer_desc,
|
||||||
|
diff_weights_layer_desc, diff_weights_iter_desc,
|
||||||
|
diff_dst_layer_desc);
|
||||||
|
if (!args_ok)
|
||||||
|
return invalid_arguments;
|
||||||
|
|
||||||
|
auto xnor_md = [=](const memory_desc_t *a_md, const memory_desc_t *b_md) {
|
||||||
|
return is_zero_md(a_md) == is_zero_md(b_md);
|
||||||
|
};
|
||||||
|
|
||||||
|
args_ok = args_ok && xnor_md(bias_desc, diff_bias_desc)
|
||||||
|
&& xnor_md(dst_iter_desc, diff_dst_iter_desc)
|
||||||
|
&& xnor_md(src_iter_desc, diff_src_iter_desc);
|
||||||
|
if (!args_ok)
|
||||||
|
return invalid_arguments;
|
||||||
|
|
||||||
|
//check dimensions consistency
|
||||||
|
int L = weights_layer_desc->dims[0];
|
||||||
|
int T = src_layer_desc->dims[0];
|
||||||
|
int N = src_layer_desc->dims[1];
|
||||||
|
const int D = one_of(direction, mkldnn_unidirectional_left2right,
|
||||||
|
mkldnn_unidirectional_right2left) ?
|
||||||
|
1 :
|
||||||
|
2;
|
||||||
|
int G = mkldnn_rnn_cell_get_gates_count(rnn_cell_desc);
|
||||||
|
int S = mkldnn_rnn_cell_get_states_count(rnn_cell_desc);
|
||||||
|
int SLC = src_layer_desc->dims[2];
|
||||||
|
int SIC = weights_iter_desc->dims[2];
|
||||||
|
int DLC = dst_layer_desc->dims[2];
|
||||||
|
int DIC = weights_layer_desc->dims[4];
|
||||||
|
|
||||||
|
status_t st = check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S,
|
||||||
|
G, SLC, SIC, DLC, DIC, src_layer_desc, src_iter_desc,
|
||||||
|
weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc,
|
||||||
|
dst_iter_desc);
|
||||||
|
if (st != success) return st;
|
||||||
|
|
||||||
|
st = check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S,
|
||||||
|
G, SLC, SIC, DLC, DIC, diff_src_layer_desc, diff_src_iter_desc,
|
||||||
|
diff_weights_layer_desc, diff_weights_iter_desc, diff_bias_desc,
|
||||||
|
diff_dst_layer_desc, diff_dst_iter_desc);
|
||||||
|
if (st != success) return st;
|
||||||
|
|
||||||
|
mkldnn_rnn_desc_t rd = zero_rnn_desc();
|
||||||
|
|
||||||
|
rd.primitive_kind = primitive_kind::rnn;
|
||||||
|
rd.prop_kind = prop_kind;
|
||||||
|
rd.cell_desc = *rnn_cell_desc;
|
||||||
|
rd.direction = direction;
|
||||||
|
|
||||||
|
rd.src_layer_desc = copy_maybe_null(src_layer_desc);
|
||||||
|
rd.src_iter_desc = copy_maybe_null(src_iter_desc);
|
||||||
|
rd.weights_layer_desc = copy_maybe_null(weights_layer_desc);
|
||||||
|
rd.weights_iter_desc = copy_maybe_null(weights_iter_desc);
|
||||||
|
rd.bias_desc = copy_maybe_null(bias_desc);
|
||||||
|
rd.dst_layer_desc = copy_maybe_null(dst_layer_desc);
|
||||||
|
rd.dst_iter_desc = copy_maybe_null(dst_iter_desc);
|
||||||
|
rd.diff_src_layer_desc = copy_maybe_null(diff_src_layer_desc);
|
||||||
|
rd.diff_src_iter_desc = copy_maybe_null(diff_src_iter_desc);
|
||||||
|
rd.diff_weights_layer_desc = copy_maybe_null(diff_weights_layer_desc);
|
||||||
|
rd.diff_weights_iter_desc = copy_maybe_null(diff_weights_iter_desc);
|
||||||
|
rd.diff_bias_desc = copy_maybe_null(diff_bias_desc);
|
||||||
|
rd.diff_dst_layer_desc = copy_maybe_null(diff_dst_layer_desc);
|
||||||
|
rd.diff_dst_iter_desc = copy_maybe_null(diff_dst_iter_desc);
|
||||||
|
|
||||||
|
*rnn_desc = rd;
|
||||||
|
|
||||||
|
return success;
|
||||||
|
}
|
280
thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp
vendored
Normal file
280
thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp
vendored
Normal file
@ -0,0 +1,280 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#ifndef RNN_PD_HPP
|
||||||
|
#define RNN_PD_HPP
|
||||||
|
|
||||||
|
#include "mkldnn.h"
|
||||||
|
|
||||||
|
#include "c_types_map.hpp"
|
||||||
|
#include "primitive_desc.hpp"
|
||||||
|
#include "type_helpers.hpp"
|
||||||
|
|
||||||
|
namespace mkldnn {
|
||||||
|
namespace impl {
|
||||||
|
|
||||||
|
struct rnn_fwd_pd_t;
|
||||||
|
|
||||||
|
struct rnn_pd_t : public primitive_desc_t {
|
||||||
|
static constexpr auto base_pkind = primitive_kind::rnn;
|
||||||
|
|
||||||
|
rnn_pd_t(engine_t *engine,
|
||||||
|
const rnn_desc_t *adesc,
|
||||||
|
const primitive_attr_t *attr,
|
||||||
|
const rnn_fwd_pd_t *hint_fwd_pd)
|
||||||
|
: primitive_desc_t(engine, attr, base_pkind)
|
||||||
|
, desc_(*adesc)
|
||||||
|
, hint_fwd_pd_(hint_fwd_pd)
|
||||||
|
, src_layer_md_(desc_.src_layer_desc)
|
||||||
|
, src_iter_md_(desc_.src_iter_desc)
|
||||||
|
, weights_layer_md_(desc_.weights_layer_desc)
|
||||||
|
, weights_iter_md_(desc_.weights_iter_desc)
|
||||||
|
, bias_md_(desc_.bias_desc)
|
||||||
|
, dst_layer_md_(desc_.dst_layer_desc)
|
||||||
|
, dst_iter_md_(desc_.dst_iter_desc)
|
||||||
|
, ws_md_()
|
||||||
|
{}
|
||||||
|
|
||||||
|
const rnn_desc_t *desc() const { return &desc_; }
|
||||||
|
virtual const op_desc_t *op_desc() const override
|
||||||
|
{ return reinterpret_cast<const op_desc_t *>(this->desc()); }
|
||||||
|
virtual void init_info() override { impl::init_info(this, this->info_); }
|
||||||
|
|
||||||
|
virtual status_t query(query_t what, int idx, void *result) const override {
|
||||||
|
switch (what) {
|
||||||
|
case query::rnn_d: *(const rnn_desc_t **)result = desc(); break;
|
||||||
|
default: return primitive_desc_t::query(what, idx, result);
|
||||||
|
}
|
||||||
|
return status::success;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual const memory_desc_t *src_md(int index = 0) const override {
|
||||||
|
if (index == 0) return &src_layer_md_;
|
||||||
|
if (index == 1 && with_src_iter()) return &src_iter_md_;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
virtual const memory_desc_t *weights_md(int index = 0) const override {
|
||||||
|
if (index == 0) return &weights_layer_md_;
|
||||||
|
if (index == 1) return &weights_iter_md_;
|
||||||
|
if (index == 2 && with_bias()) return &bias_md_;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
virtual const memory_desc_t *dst_md(int index = 0) const override {
|
||||||
|
if (index == 0) return &dst_layer_md_;
|
||||||
|
if (index == 1 && with_dst_iter()) return &dst_iter_md_;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
virtual const memory_desc_t *workspace_md(int index = 0) const override
|
||||||
|
{ return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; }
|
||||||
|
|
||||||
|
/* common pooling aux functions */
|
||||||
|
|
||||||
|
bool is_training() const {
|
||||||
|
return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
|
||||||
|
prop_kind::backward);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_fwd() const {
|
||||||
|
return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
|
||||||
|
prop_kind::forward_inference);
|
||||||
|
}
|
||||||
|
|
||||||
|
dim_t T() const { return desc_.src_layer_desc.dims[0]; }
|
||||||
|
dim_t MB() const { return desc_.src_layer_desc.dims[1]; }
|
||||||
|
|
||||||
|
dim_t L() const { return desc_.weights_layer_desc.dims[0]; }
|
||||||
|
dim_t D() const { return desc_.weights_layer_desc.dims[1]; }
|
||||||
|
|
||||||
|
dim_t SIC() const { return desc_.weights_iter_desc.dims[2]; }
|
||||||
|
|
||||||
|
dim_t SLC() const { return desc_.weights_layer_desc.dims[2]; }
|
||||||
|
dim_t G() const { return desc_.weights_layer_desc.dims[3]; }
|
||||||
|
dim_t DIC() const { return desc_.weights_layer_desc.dims[4]; }
|
||||||
|
|
||||||
|
dim_t DLC() const { return desc_.dst_layer_desc.dims[2]; }
|
||||||
|
|
||||||
|
bool with_bias() const
|
||||||
|
{ return !memory_desc_wrapper(desc_.bias_desc).is_zero(); }
|
||||||
|
|
||||||
|
bool with_src_iter() const
|
||||||
|
{ return !(memory_desc_wrapper(desc_.src_iter_desc).is_zero()); }
|
||||||
|
|
||||||
|
bool with_dst_iter() const
|
||||||
|
{ return !memory_desc_wrapper(desc_.dst_iter_desc).is_zero(); }
|
||||||
|
|
||||||
|
mkldnn::impl::alg_kind_t cell_kind() const
|
||||||
|
{ return desc_.cell_desc.cell_kind; }
|
||||||
|
mkldnn::impl::alg_kind_t activation_kind() const
|
||||||
|
{ return desc_.cell_desc.activation_kind; }
|
||||||
|
|
||||||
|
bool is_lbr() const
|
||||||
|
{ return cell_kind() == mkldnn_gru_linear_before_reset; }
|
||||||
|
|
||||||
|
mkldnn_rnn_direction_t direction() const { return desc_.direction; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
rnn_desc_t desc_;
|
||||||
|
const rnn_fwd_pd_t *hint_fwd_pd_;
|
||||||
|
|
||||||
|
memory_desc_t src_layer_md_;
|
||||||
|
memory_desc_t src_iter_md_;
|
||||||
|
memory_desc_t weights_layer_md_;
|
||||||
|
memory_desc_t weights_iter_md_;
|
||||||
|
memory_desc_t bias_md_;
|
||||||
|
memory_desc_t dst_layer_md_;
|
||||||
|
memory_desc_t dst_iter_md_;
|
||||||
|
|
||||||
|
memory_desc_t ws_md_;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct rnn_fwd_pd_t: public rnn_pd_t {
|
||||||
|
typedef rnn_fwd_pd_t base_class;
|
||||||
|
typedef rnn_fwd_pd_t hint_class;
|
||||||
|
|
||||||
|
rnn_fwd_pd_t(engine_t *engine,
|
||||||
|
const rnn_desc_t *adesc,
|
||||||
|
const primitive_attr_t *attr,
|
||||||
|
const rnn_fwd_pd_t *hint_fwd_pd)
|
||||||
|
: rnn_pd_t(engine, adesc, attr, hint_fwd_pd)
|
||||||
|
{}
|
||||||
|
|
||||||
|
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
||||||
|
if (arg == MKLDNN_ARG_SRC_LAYER)
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_SRC_ITER && with_src_iter())
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS_LAYER,
|
||||||
|
MKLDNN_ARG_WEIGHTS_ITER))
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_BIAS && with_bias())
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_DST_LAYER)
|
||||||
|
return arg_usage_t::output;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_DST_ITER && with_dst_iter())
|
||||||
|
return arg_usage_t::output;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_WORKSPACE && is_training())
|
||||||
|
return arg_usage_t::output;
|
||||||
|
|
||||||
|
return primitive_desc_t::arg_usage(arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual int n_inputs() const override
|
||||||
|
{ return 3 + with_bias() + with_src_iter(); }
|
||||||
|
virtual int n_outputs() const override
|
||||||
|
{ return 1 + with_dst_iter() + is_training(); }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct rnn_bwd_pd_t : public rnn_pd_t {
|
||||||
|
typedef rnn_bwd_pd_t base_class;
|
||||||
|
typedef rnn_fwd_pd_t hint_class;
|
||||||
|
|
||||||
|
rnn_bwd_pd_t(engine_t *engine,
|
||||||
|
const rnn_desc_t *adesc,
|
||||||
|
const primitive_attr_t *attr,
|
||||||
|
const rnn_fwd_pd_t *hint_fwd_pd)
|
||||||
|
: rnn_pd_t(engine, adesc, attr, hint_fwd_pd)
|
||||||
|
, diff_src_layer_md_(desc_.diff_src_layer_desc)
|
||||||
|
, diff_src_iter_md_(desc_.diff_src_iter_desc)
|
||||||
|
, diff_weights_layer_md_(desc_.diff_weights_layer_desc)
|
||||||
|
, diff_weights_iter_md_(desc_.diff_weights_iter_desc)
|
||||||
|
, diff_bias_md_(desc_.diff_bias_desc)
|
||||||
|
, diff_dst_layer_md_(desc_.diff_dst_layer_desc)
|
||||||
|
, diff_dst_iter_md_(desc_.diff_dst_iter_desc)
|
||||||
|
{}
|
||||||
|
|
||||||
|
virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
|
||||||
|
if (utils::one_of(arg, MKLDNN_ARG_SRC_LAYER, MKLDNN_ARG_DST_LAYER,
|
||||||
|
MKLDNN_ARG_DIFF_DST_LAYER))
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
if (with_src_iter()) {
|
||||||
|
if (arg == MKLDNN_ARG_SRC_ITER)
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_DIFF_SRC_ITER)
|
||||||
|
return arg_usage_t::output;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS_LAYER,
|
||||||
|
MKLDNN_ARG_WEIGHTS_ITER))
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
if (with_bias()) {
|
||||||
|
if (arg == MKLDNN_ARG_BIAS)
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_DIFF_BIAS)
|
||||||
|
return arg_usage_t::output;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (utils::one_of(arg, MKLDNN_ARG_DST_ITER, MKLDNN_ARG_DIFF_DST_ITER)
|
||||||
|
&& with_dst_iter())
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
if (arg == MKLDNN_ARG_WORKSPACE)
|
||||||
|
return arg_usage_t::input;
|
||||||
|
|
||||||
|
if (utils::one_of(arg, MKLDNN_ARG_DIFF_SRC_LAYER,
|
||||||
|
MKLDNN_ARG_DIFF_WEIGHTS_LAYER,
|
||||||
|
MKLDNN_ARG_DIFF_WEIGHTS_ITER))
|
||||||
|
return arg_usage_t::output;
|
||||||
|
|
||||||
|
return primitive_desc_t::arg_usage(arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual const memory_desc_t *diff_src_md(int index = 0) const override {
|
||||||
|
if (index == 0) return &diff_src_layer_md_;
|
||||||
|
if (index == 1 && with_src_iter()) return &diff_src_iter_md_;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
virtual const memory_desc_t *diff_weights_md(
|
||||||
|
int index = 0) const override {
|
||||||
|
if (index == 0) return &diff_weights_layer_md_;
|
||||||
|
if (index == 1) return &diff_weights_iter_md_;
|
||||||
|
if (index == 2 && with_bias()) return &diff_bias_md_;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
virtual const memory_desc_t *diff_dst_md(int index = 0) const override {
|
||||||
|
if (index == 0) return &diff_dst_layer_md_;
|
||||||
|
if (index == 1 && with_dst_iter()) return &diff_dst_iter_md_;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual int n_inputs() const override
|
||||||
|
{ return 6 + with_src_iter() + with_bias() + 2 * with_dst_iter(); }
|
||||||
|
virtual int n_outputs() const override
|
||||||
|
{ return 3 + with_src_iter() + with_bias(); }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
memory_desc_t diff_src_layer_md_;
|
||||||
|
memory_desc_t diff_src_iter_md_;
|
||||||
|
memory_desc_t diff_weights_layer_md_;
|
||||||
|
memory_desc_t diff_weights_iter_md_;
|
||||||
|
memory_desc_t diff_bias_md_;
|
||||||
|
memory_desc_t diff_dst_layer_md_;
|
||||||
|
memory_desc_t diff_dst_iter_md_;
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
112
thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp
vendored
Normal file
112
thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp
vendored
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright 2017-2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*******************************************************************************/
|
||||||
|
|
||||||
|
#include "mkldnn_thread.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
#include "scratchpad.hpp"
|
||||||
|
|
||||||
|
namespace mkldnn {
|
||||||
|
namespace impl {
|
||||||
|
|
||||||
|
/* Allocating memory buffers on a page boundary to reduce TLB/page misses */
|
||||||
|
const size_t page_size = 2097152;
|
||||||
|
|
||||||
|
/*
|
||||||
|
Implementation of the scratchpad_t interface that is compatible with
|
||||||
|
a concurrent execution
|
||||||
|
*/
|
||||||
|
struct concurent_scratchpad_t : public scratchpad_t {
|
||||||
|
concurent_scratchpad_t(size_t size) {
|
||||||
|
size_ = size;
|
||||||
|
scratchpad_ = (char *) malloc(size, page_size);
|
||||||
|
assert(scratchpad_ != nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
~concurent_scratchpad_t() {
|
||||||
|
free(scratchpad_);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual char *get() const {
|
||||||
|
return scratchpad_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
char *scratchpad_;
|
||||||
|
size_t size_;
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
Implementation of the scratchpad_t interface that uses a global
|
||||||
|
scratchpad
|
||||||
|
*/
|
||||||
|
|
||||||
|
struct global_scratchpad_t : public scratchpad_t {
|
||||||
|
global_scratchpad_t(size_t size) {
|
||||||
|
if (size > size_) {
|
||||||
|
if (scratchpad_ != nullptr) free(scratchpad_);
|
||||||
|
size_ = size;
|
||||||
|
scratchpad_ = (char *) malloc(size, page_size);
|
||||||
|
assert(scratchpad_ != nullptr);
|
||||||
|
}
|
||||||
|
reference_count_++;
|
||||||
|
}
|
||||||
|
|
||||||
|
~global_scratchpad_t() {
|
||||||
|
reference_count_--;
|
||||||
|
if (reference_count_ == 0) {
|
||||||
|
free(scratchpad_);
|
||||||
|
scratchpad_ = nullptr;
|
||||||
|
size_ = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual char *get() const {
|
||||||
|
return scratchpad_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/*
|
||||||
|
Using thread-local here is unnecessary and even buggy! All threads
|
||||||
|
actually share the same scratchpad, which is created and queried only
|
||||||
|
on the main thread. If the scratchpad is queried on some thread other
|
||||||
|
than the one it was created on (e.g. the application calls the API from
|
||||||
|
multiple threads), thread-local causes a segfault because the scratchpad
|
||||||
|
is uninitialized on the current thread.
|
||||||
|
*/
|
||||||
|
/*thread_local*/ static char *scratchpad_;
|
||||||
|
/*thread_local*/ static size_t size_;
|
||||||
|
/*thread_local*/ static unsigned int reference_count_;
|
||||||
|
};
|
||||||
|
|
||||||
|
/*thread_local*/ char *global_scratchpad_t::scratchpad_ = nullptr;
|
||||||
|
/*thread_local*/ size_t global_scratchpad_t::size_ = 0;
|
||||||
|
/*thread_local*/ unsigned int global_scratchpad_t::reference_count_ = 0;
|
||||||
|
|
||||||
|
|
||||||
|
/*
|
||||||
|
Scratchpad creation routine
|
||||||
|
*/
|
||||||
|
scratchpad_t *create_scratchpad(size_t size) {
|
||||||
|
#ifndef MKLDNN_ENABLE_CONCURRENT_EXEC
|
||||||
|
return new global_scratchpad_t(size);
|
||||||
|
#else
|
||||||
|
return new concurent_scratchpad_t(size);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user