143 lines
4.5 KiB
C++
143 lines
4.5 KiB
C++
// ======================================================================== //
|
|
// Copyright 2009-2019 Intel Corporation //
|
|
// //
|
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
|
// you may not use this file except in compliance with the License. //
|
|
// You may obtain a copy of the License at //
|
|
// //
|
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
|
// //
|
|
// Unless required by applicable law or agreed to in writing, software //
|
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
|
// See the License for the specific language governing permissions and //
|
|
// limitations under the License. //
|
|
// ======================================================================== //
|
|
|
|
#pragma once
|
|
|
|
#include "common.h"
|
|
#include <vector>
|
|
|
|
namespace oidn {
|
|
|
|
class Node
|
|
{
|
|
public:
|
|
virtual ~Node() = default;
|
|
|
|
virtual void execute(stream& sm) = 0;
|
|
|
|
virtual std::shared_ptr<memory> getDst() const { return nullptr; }
|
|
|
|
virtual size_t getScratchpadSize() const { return 0; }
|
|
virtual void setScratchpad(const std::shared_ptr<memory>& mem) {}
|
|
|
|
virtual void setTile(int h1, int w1, int h2, int w2, int H, int W)
|
|
{
|
|
assert(0); // not supported
|
|
}
|
|
};
|
|
|
|
// Node wrapping an MKL-DNN primitive
|
|
class MklNode : public Node
|
|
{
|
|
private:
|
|
primitive prim;
|
|
std::unordered_map<int, memory> args;
|
|
std::shared_ptr<memory> scratchpad;
|
|
|
|
public:
|
|
MklNode(const primitive& prim, const std::unordered_map<int, memory>& args)
|
|
: prim(prim),
|
|
args(args)
|
|
{}
|
|
|
|
size_t getScratchpadSize() const override
|
|
{
|
|
const auto primDesc = prim.get_primitive_desc();
|
|
const mkldnn_memory_desc_t* scratchpadDesc = mkldnn_primitive_desc_query_md(primDesc, mkldnn_query_scratchpad_md, 0);
|
|
if (scratchpadDesc == nullptr)
|
|
return 0;
|
|
return mkldnn_memory_desc_get_size(scratchpadDesc);
|
|
}
|
|
|
|
void setScratchpad(const std::shared_ptr<memory>& mem) override
|
|
{
|
|
scratchpad = mem;
|
|
args.insert(std::make_pair(MKLDNN_ARG_SCRATCHPAD, *scratchpad));
|
|
}
|
|
|
|
void execute(stream& sm) override
|
|
{
|
|
prim.execute(sm, args);
|
|
}
|
|
};
|
|
|
|
// Convolution node
|
|
class ConvNode : public MklNode
|
|
{
|
|
private:
|
|
std::shared_ptr<memory> src;
|
|
std::shared_ptr<memory> weights;
|
|
std::shared_ptr<memory> bias;
|
|
std::shared_ptr<memory> dst;
|
|
|
|
public:
|
|
ConvNode(const convolution_forward::primitive_desc& desc,
|
|
const std::shared_ptr<memory>& src,
|
|
const std::shared_ptr<memory>& weights,
|
|
const std::shared_ptr<memory>& bias,
|
|
const std::shared_ptr<memory>& dst)
|
|
: MklNode(convolution_forward(desc),
|
|
{ { MKLDNN_ARG_SRC, *src },
|
|
{ MKLDNN_ARG_WEIGHTS, *weights },
|
|
{ MKLDNN_ARG_BIAS, *bias },
|
|
{ MKLDNN_ARG_DST, *dst } }),
|
|
src(src), weights(weights), bias(bias), dst(dst)
|
|
{}
|
|
|
|
std::shared_ptr<memory> getDst() const override { return dst; }
|
|
};
|
|
|
|
// Pooling node
|
|
class PoolNode : public MklNode
|
|
{
|
|
private:
|
|
std::shared_ptr<memory> src;
|
|
std::shared_ptr<memory> dst;
|
|
|
|
public:
|
|
PoolNode(const pooling_forward::primitive_desc& desc,
|
|
const std::shared_ptr<memory>& src,
|
|
const std::shared_ptr<memory>& dst)
|
|
: MklNode(pooling_forward(desc),
|
|
{ { MKLDNN_ARG_SRC, *src },
|
|
{ MKLDNN_ARG_DST, *dst } }),
|
|
src(src), dst(dst)
|
|
{}
|
|
|
|
std::shared_ptr<memory> getDst() const override { return dst; }
|
|
};
|
|
|
|
// Reorder node
|
|
class ReorderNode : public MklNode
|
|
{
|
|
private:
|
|
std::shared_ptr<memory> src;
|
|
std::shared_ptr<memory> dst;
|
|
|
|
public:
|
|
ReorderNode(const std::shared_ptr<memory>& src,
|
|
const std::shared_ptr<memory>& dst)
|
|
: MklNode(reorder(reorder::primitive_desc(*src, *dst)),
|
|
{ { MKLDNN_ARG_SRC, *src },
|
|
{ MKLDNN_ARG_DST, *dst } }),
|
|
src(src), dst(dst)
|
|
{}
|
|
|
|
std::shared_ptr<memory> getDst() const override { return dst; }
|
|
};
|
|
|
|
} // namespace oidn
|