134 lines
3.9 KiB
C++
134 lines
3.9 KiB
C++
|
// ======================================================================== //
|
||
|
// Copyright 2009-2019 Intel Corporation //
|
||
|
// //
|
||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||
|
// you may not use this file except in compliance with the License. //
|
||
|
// You may obtain a copy of the License at //
|
||
|
// //
|
||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||
|
// //
|
||
|
// Unless required by applicable law or agreed to in writing, software //
|
||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||
|
// See the License for the specific language governing permissions and //
|
||
|
// limitations under the License. //
|
||
|
// ======================================================================== //
|
||
|
|
||
|
#pragma once
|
||
|
|
||
|
#include "common/platform.h"
|
||
|
|
||
|
#include "mkl-dnn/include/mkldnn.hpp"
|
||
|
#include "mkl-dnn/include/mkldnn_debug.h"
|
||
|
#include "mkl-dnn/src/common/mkldnn_thread.hpp"
|
||
|
#include "mkl-dnn/src/common/type_helpers.hpp"
|
||
|
#include "mkl-dnn/src/cpu/jit_generator.hpp"
|
||
|
|
||
|
#include "common/ref.h"
|
||
|
#include "common/exception.h"
|
||
|
#include "common/thread.h"
|
||
|
#include "math.h"
|
||
|
|
||
|
namespace oidn {
|
||
|
|
||
|
using namespace mkldnn;
|
||
|
using namespace mkldnn::impl::cpu;
|
||
|
using mkldnn::impl::parallel_nd;
|
||
|
using mkldnn::impl::memory_desc_matches_tag;
|
||
|
|
||
|
|
||
|
inline size_t getFormatBytes(Format format)
|
||
|
{
|
||
|
switch (format)
|
||
|
{
|
||
|
case Format::Undefined: return 1;
|
||
|
case Format::Float: return sizeof(float);
|
||
|
case Format::Float2: return sizeof(float)*2;
|
||
|
case Format::Float3: return sizeof(float)*3;
|
||
|
case Format::Float4: return sizeof(float)*4;
|
||
|
}
|
||
|
assert(0);
|
||
|
return 0;
|
||
|
}
|
||
|
|
||
|
|
||
|
inline memory::dims getTensorDims(const std::shared_ptr<memory>& mem)
|
||
|
{
|
||
|
const mkldnn_memory_desc_t& desc = mem->get_desc().data;
|
||
|
return memory::dims(&desc.dims[0], &desc.dims[desc.ndims]);
|
||
|
}
|
||
|
|
||
|
inline memory::data_type getTensorType(const std::shared_ptr<memory>& mem)
|
||
|
{
|
||
|
const mkldnn_memory_desc_t& desc = mem->get_desc().data;
|
||
|
return memory::data_type(desc.data_type);
|
||
|
}
|
||
|
|
||
|
// Returns the number of values in a tensor
|
||
|
inline size_t getTensorSize(const memory::dims& dims)
|
||
|
{
|
||
|
size_t res = 1;
|
||
|
for (int i = 0; i < (int)dims.size(); ++i)
|
||
|
res *= dims[i];
|
||
|
return res;
|
||
|
}
|
||
|
|
||
|
inline memory::dims getMaxTensorDims(const std::vector<memory::dims>& dims)
|
||
|
{
|
||
|
memory::dims result;
|
||
|
size_t maxSize = 0;
|
||
|
|
||
|
for (const auto& d : dims)
|
||
|
{
|
||
|
const size_t size = getTensorSize(d);
|
||
|
if (size > maxSize)
|
||
|
{
|
||
|
result = d;
|
||
|
maxSize = size;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return result;
|
||
|
}
|
||
|
|
||
|
inline size_t getTensorSize(const std::shared_ptr<memory>& mem)
|
||
|
{
|
||
|
return getTensorSize(getTensorDims(mem));
|
||
|
}
|
||
|
|
||
|
|
||
|
template<int K>
|
||
|
inline int getPadded(int dim)
|
||
|
{
|
||
|
return (dim + (K-1)) & ~(K-1);
|
||
|
}
|
||
|
|
||
|
template<int K>
|
||
|
inline memory::dims getPadded_nchw(const memory::dims& dims)
|
||
|
{
|
||
|
assert(dims.size() == 4);
|
||
|
memory::dims padDims = dims;
|
||
|
padDims[1] = getPadded<K>(dims[1]); // pad C
|
||
|
return padDims;
|
||
|
}
|
||
|
|
||
|
|
||
|
template<int K>
|
||
|
struct BlockedFormat;
|
||
|
|
||
|
template<>
|
||
|
struct BlockedFormat<8>
|
||
|
{
|
||
|
static constexpr memory::format_tag nChwKc = memory::format_tag::nChw8c;
|
||
|
static constexpr memory::format_tag OIhwKiKo = memory::format_tag::OIhw8i8o;
|
||
|
};
|
||
|
|
||
|
template<>
|
||
|
struct BlockedFormat<16>
|
||
|
{
|
||
|
static constexpr memory::format_tag nChwKc = memory::format_tag::nChw16c;
|
||
|
static constexpr memory::format_tag OIhwKiKo = memory::format_tag::OIhw16i16o;
|
||
|
};
|
||
|
|
||
|
} // namespace oidn
|