How to get data from 2D tensor?

情到浓时终转凉″ 提交于 2019-12-01 21:22:31

问题


I would like to get the data from a 2D tensor with tensorflow.js. I tried to use the data() method like this:

const X = tf.tensor2d([[1, 2, 3, 4], [2, 2, 5, 3]]);
X.data().then(X => console.log(X)};

But the result is a flatten 1D array:

Float32Array(8) [1, 2, 3, 4, 2, 2, 5, 3]

Is there a way to keep the shape of the array?


回答1:


Data in the Tensor is always stored flattened as types 1 dimensional array, for speed.

The example you gave will not work, because 2nd parameter to tensor2d is shape. To make it work you either need to wrap it another array:

const x = tf.tensor2d([[1, 2, 3, 4], [2, 2, 5, 3]]); //shape inferred as [2, 4]

or you could explicitly provide shape:

const x = tf.tensor2d([1, 2, 3, 4, 2, 2, 5, 3], [2, 4]); // shape explicitly passed

as you suggested though, when you inspect data you will always get 1 dimensional array, regardless of original shape

await x.data() // Float32Array(8) [1, 2, 3, 4, 2, 2, 5, 3]
x.shape // [2, 4]

if however you print() your tensor, shape is taken into account and it will appear as

Tensor
    [[1, 2, 3, 4],
     [2, 2, 5, 3]]



回答2:


I use a function to show the 2D tensor in a webpage

async function myTensorTable(myDiv, myOutTensor, myCols, myTitle){   

 document.getElementById(myDiv).innerHTML += myTitle + '<br>'
 const myOutput = await myOutTensor.data()
 myTemp = '<table border=3><tr>'
   for (myCount = 0;    myCount <= myOutTensor.size - 1;   myCount++){   
     myTemp += '<td>'+ myOutput[myCount] + '</td>'
     if (myCount % myCols == myCols-1){
         myTemp += '</tr><tr>'
     }
   }   
   myTemp += '</tr></table>'
   document.getElementById(myDiv).innerHTML += myTemp + '<br>'
}

examples of usage at

https://hpssjellis.github.io/beginner-tensorflowjs-examples-in-javascript/beginner-examples/tfjs02-basics.html




回答3:


You can use arraySync method on the tensor object. It returns the array as the same shape it was before synchronously.

const X = tf.tensor2d([[1, 2, 3, 4], [2, 2, 5, 3]]); 
console.log(X.arraySync())
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.2.7/dist/tf.min.js"></script>


来源:https://stackoverflow.com/questions/49646291/how-to-get-data-from-2d-tensor

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!