Reshaping tensors in C++

一世执手 提交于 2019-12-05 19:41:15

Solution with checking whether reshaped tensor has the same number of elements of the source tensor:

// Extracted image features from MobileNet_224
tensorflow::Tensor image_features(tensorflow::DT_FLOAT,
                                  tensorflow::TensorShape({1, 14, 14, 512}));

tensorflow::Tensor image_features_reshaped(tensorflow::DT_FLOAT,
                                           tensorflow::TensorShape({1, 196, 512}));

// Reshape tensor from [1, 14, 14, 512] to [1, 196, 512]
if(!image_features_reshaped.CopyFrom(image_features, tensorflow::TensorShape({1, 196, 512})))
{
  LOG(ERROR) << "Unsuccessfully reshaped image features tensor [" << image_features.DebugString() << "] to [1, 196, 512]";
  return false;
}

LOG(INFO) << "Reshaped features tensor: " << image_features_reshaped.DebugString();

This should work:

Tensor my_tensor; // [A, B, C, D]
Tensor reshaped_tensor = my_tensor.shaped<float, 3>({A*B, C, D});  //[A*B, C, D]
标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!