// ======================================================================== // // Copyright 2009-2019 Intel Corporation // // // // Licensed under the Apache License, Version 2.0 (the "License"); // // you may not use this file except in compliance with the License. // // You may obtain a copy of the License at // // // // http://www.apache.org/licenses/LICENSE-2.0 // // // // Unless required by applicable law or agreed to in writing, software // // distributed under the License is distributed on an "AS IS" BASIS, // // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // // See the License for the specific language governing permissions and // // limitations under the License. // // ======================================================================== // #pragma once #include "node.h" #include "image.h" namespace oidn { // Input reorder node template<int K, class TransferFunction> class InputReorderNode : public Node { private: // Source Image color; Image albedo; Image normal; // Destination std::shared_ptr<memory> dst; float* dstPtr; int C2; int H2; int W2; // Tile int h1Begin; int w1Begin; int h2Begin; int w2Begin; int H; int W; std::shared_ptr<TransferFunction> transferFunc; public: InputReorderNode(const Image& color, const Image& albedo, const Image& normal, const std::shared_ptr<memory>& dst, const std::shared_ptr<TransferFunction>& transferFunc) : color(color), albedo(albedo), normal(normal), dst(dst), h1Begin(0), w1Begin(0), H(color.height), W(color.width), transferFunc(transferFunc) { const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data; assert(memory_desc_matches_tag(dstDesc, mkldnn_format_tag_t(BlockedFormat<K>::nChwKc))); assert(dstDesc.ndims == 4); assert(dstDesc.data_type == memory::data_type::f32); assert(dstDesc.dims[0] == 1); //assert(dstDesc.dims[1] >= getPadded<K>(C1)); dstPtr = (float*)dst->get_data_handle(); C2 = dstDesc.dims[1]; H2 = dstDesc.dims[2]; W2 = dstDesc.dims[3]; } void setTile(int h1, int w1, int h2, int w2, int H, int W) override { h1Begin = h1; w1Begin = w1; h2Begin = h2; w2Begin = w2; this->H = H; this->W = W; } void execute(stream& sm) override { assert(H + h1Begin <= color.height); assert(W + w1Begin <= color.width); assert(H + h2Begin <= H2); assert(W + w2Begin <= W2); parallel_nd(H2, [&](int h2) { const int h = h2 - h2Begin; if (h >= 0 && h < H) { const int h1 = h + h1Begin; // Zero pad for (int w2 = 0; w2 < w2Begin; ++w2) { int c = 0; while (c < C2) store(h2, w2, c, 0.f); } // Reorder for (int w = 0; w < W; ++w) { const int w1 = w + w1Begin; const int w2 = w + w2Begin; int c = 0; storeColor(h2, w2, c, (float*)color.get(h1, w1)); if (albedo) storeAlbedo(h2, w2, c, (float*)albedo.get(h1, w1)); if (normal) storeNormal(h2, w2, c, (float*)normal.get(h1, w1)); while (c < C2) store(h2, w2, c, 0.f); } // Zero pad for (int w2 = W + w2Begin; w2 < W2; ++w2) { int c = 0; while (c < C2) store(h2, w2, c, 0.f); } } else { // Zero pad for (int w2 = 0; w2 < W2; ++w2) { int c = 0; while (c < C2) store(h2, w2, c, 0.f); } } }); } std::shared_ptr<memory> getDst() const override { return dst; } private: // Stores a single value __forceinline void store(int h, int w, int& c, float value) { // Destination is in nChwKc format float* dst_c = dstPtr + (H2*W2*K*(c/K)) + h*W2*K + w*K + (c%K); *dst_c = value; c++; } // Stores a color __forceinline void storeColor(int h, int w, int& c, const float* values) { #pragma unroll for (int i = 0; i < 3; ++i) { // Load the value float x = values[i]; // Sanitize the value x = maxSafe(x, 0.f); // Apply the transfer function x = transferFunc->forward(x); // Store the value store(h, w, c, x); } } // Stores an albedo __forceinline void storeAlbedo(int h, int w, int& c, const float* values) { #pragma unroll for (int i = 0; i < 3; ++i) { // Load the value float x = values[i]; // Sanitize the value x = clampSafe(x, 0.f, 1.f); // Store the value store(h, w, c, x); } } // Stores a normal __forceinline void storeNormal(int h, int w, int& c, const float* values) { // Load the normal float x = values[0]; float y = values[1]; float z = values[2]; // Compute the length of the normal const float lengthSqr = sqr(x) + sqr(y) + sqr(z); // Normalize the normal and transform it to [0..1] if (isfinite(lengthSqr)) { const float invLength = (lengthSqr > minVectorLengthSqr) ? rsqrt(lengthSqr) : 1.f; const float scale = invLength * 0.5f; const float offset = 0.5f; x = x * scale + offset; y = y * scale + offset; z = z * scale + offset; } else { x = 0.f; y = 0.f; z = 0.f; } // Store the normal store(h, w, c, x); store(h, w, c, y); store(h, w, c, z); } }; } // namespace oidn