84 lines
3.0 KiB
C++
84 lines
3.0 KiB
C++
|
// ======================================================================== //
|
||
|
// Copyright 2009-2019 Intel Corporation //
|
||
|
// //
|
||
|
// Licensed under the Apache License, Version 2.0 (the "License"); //
|
||
|
// you may not use this file except in compliance with the License. //
|
||
|
// You may obtain a copy of the License at //
|
||
|
// //
|
||
|
// http://www.apache.org/licenses/LICENSE-2.0 //
|
||
|
// //
|
||
|
// Unless required by applicable law or agreed to in writing, software //
|
||
|
// distributed under the License is distributed on an "AS IS" BASIS, //
|
||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
|
||
|
// See the License for the specific language governing permissions and //
|
||
|
// limitations under the License. //
|
||
|
// ======================================================================== //
|
||
|
|
||
|
#include "exception.h"
|
||
|
#include "tensor.h"
|
||
|
|
||
|
namespace oidn {
|
||
|
|
||
|
std::map<std::string, Tensor> parseTensors(void* buffer)
|
||
|
{
|
||
|
char* input = (char*)buffer;
|
||
|
|
||
|
// Parse the magic value
|
||
|
const int magic = *(unsigned short*)input;
|
||
|
if (magic != 0x41D7)
|
||
|
throw Exception(Error::InvalidOperation, "invalid tensor archive");
|
||
|
input += sizeof(unsigned short);
|
||
|
|
||
|
// Parse the version
|
||
|
const int majorVersion = *(unsigned char*)input++;
|
||
|
const int minorVersion = *(unsigned char*)input++;
|
||
|
UNUSED(minorVersion);
|
||
|
if (majorVersion > 1)
|
||
|
throw Exception(Error::InvalidOperation, "unsupported tensor archive version");
|
||
|
|
||
|
// Parse the number of tensors
|
||
|
const int numTensors = *(int*)input;
|
||
|
input += sizeof(int);
|
||
|
|
||
|
// Parse the tensors
|
||
|
std::map<std::string, Tensor> tensorMap;
|
||
|
for (int i = 0; i < numTensors; ++i)
|
||
|
{
|
||
|
Tensor tensor;
|
||
|
|
||
|
// Parse the name
|
||
|
const int nameLen = *(unsigned char*)input++;
|
||
|
std::string name(input, nameLen);
|
||
|
input += nameLen;
|
||
|
|
||
|
// Parse the number of dimensions
|
||
|
const int ndims = *(unsigned char*)input++;
|
||
|
|
||
|
// Parse the shape of the tensor
|
||
|
tensor.dims.resize(ndims);
|
||
|
for (int i = 0; i < ndims; ++i)
|
||
|
tensor.dims[i] = ((int*)input)[i];
|
||
|
input += ndims * sizeof(int);
|
||
|
|
||
|
// Parse the format of the tensor
|
||
|
tensor.format = std::string(input, input + ndims);
|
||
|
input += ndims;
|
||
|
|
||
|
// Parse the data type of the tensor
|
||
|
const char type = *(unsigned char*)input++;
|
||
|
if (type != 'f') // only float32 is supported
|
||
|
throw Exception(Error::InvalidOperation, "unsupported tensor data type");
|
||
|
|
||
|
// Skip the data
|
||
|
tensor.data = (float*)input;
|
||
|
input += tensor.size() * sizeof(float);
|
||
|
|
||
|
// Add the tensor to the map
|
||
|
tensorMap.emplace(name, std::move(tensor));
|
||
|
}
|
||
|
|
||
|
return tensorMap;
|
||
|
}
|
||
|
|
||
|
} // namespace oidn
|