ND4J arrays & their shapes: getting data into a list

风格不统一 提交于 2019-12-24 19:03:46

问题


Consider the following code, which uses the ND4J library to create a simpler version of the "moons" test data set:

val n = 100
val n1: Int = n/2
val n2: Int = n-n1
val outerX = Nd4j.getExecutioner.execAndReturn(new Cos(Nd4j.linspace(0, Math.PI, n1)))
val outerY = Nd4j.getExecutioner.execAndReturn(new Sin(Nd4j.linspace(0, Math.PI, n1)))
val innerX = Nd4j.getExecutioner.execAndReturn(new Cos(Nd4j.linspace(0, Math.PI, n2))).mul(-1).add(1)
val innerY = Nd4j.getExecutioner.execAndReturn(new Sin(Nd4j.linspace(0, Math.PI, n2))).mul(-1).add(1)
val X: INDArray = Nd4j.vstack(
  Nd4j.concat(1, outerX, innerX), // 1 x n
  Nd4j.concat(1, outerY, innerY)  // 1 x n
) // 2 x n
val y: INDArray = Nd4j.hstack(
  Nd4j.zeros(n1), // 1 x n1
  Nd4j.ones(n2)   // 1 x n2
) // 1 x n
println(s"# y shape: ${y.shape().toList}")                        // 1x100
println(s"# y data length: ${y.data().length()}")                 // 100
println(s"# X shape: ${X.shape().toList}")                        // 2x100
println(s"# X row 0 shape: ${X.getRow(0).shape().toList}")        // 1x100
println(s"# X row 1 shape: ${X.getRow(1).shape().toList}")        // 1x100
println(s"# X row 0 data length: ${X.getRow(0).data().length()}") // 200    <- !
println(s"# X row 1 data length: ${X.getRow(1).data().length()}") // 100

On the second to last line, X.getRow(0).data().length() is, surprisingly, 200 not 100. On inspection this is because the structure returned by data() contains the entire matrix, i.e. both rows, concatenated.

How do I get just the actual first row of the X matrix into a Java (or Scala) List? I could take just the first 100 items of the 200-element "first row", but that doesn't seem very elegant.


回答1:


.data() gives you a straight row. See: http://nd4j.org/tensor

The shape of an array is just a view of the underlying databuffer. I typically don't recommend doing what you're trying to do without good reason. All of the data is stored off heap. That copy is expensive.

On heap is bad for doing any kind of math. The only use case here is integrations. I would suggest operating on the arrays directly as much as possible. Everything from serialization to indexing is handled for you.

If you really need it for an integration of some kind, use guava and you can do it in one line: Doubles.asList(arr.data().dup().asDouble());

where arr is your ndarray to operate on.




回答2:


Yeah, it turns out .data() with ND4J is not something you should really use for anything very serious. This is a bit of a shame for what I was trying to do: writing unit tests that don't really depend on ND4J and how it internally handles data.

As another example of the issue here, consider this code:

import org.nd4j.linalg.factory.Nd4j

object foo extends App {

  val x = Nd4j.create(Array[Double](1,2, 3,4, 5,6), Array(3,2))
  // 1,2
  // 3,4
  // 5,6
  println(x)
  val xArr = x.data().asDouble().toList
  // 1,2,  3,4,  5,6 - row-wise
  println(xArr)

  val w = Nd4j.create(Array[Double](10,20,30, 40,50,60), Array(2,3))
  // 10,20,30
  // 40,50,60
  println(w)
  val wArr = w.data().asDouble().toList
  // 10,20,30,  40,50,60 - row-wise
  println(wArr)

  val wx = w.mmul(x)
  /*
   *  10,20,30   1,2     10*1+20*3+30*5  10*2+20*4+30*6      220  280
   *  40,50,60   3,4  =  40*1+50*3+60*5  40*2+50*4+60*6  =   490  640
   *             5,6
   */
  println(wx)
  val wxArr = wx.data().asDouble().toList
  // 220, 490,  280, 640 - column-wise
  println(wxArr)
  val wxTArr = wx.transpose().data().asDouble().toList
  // 220, 490,  280, 640 - still column-wise
  println(wxTArr)
  val wxTIArr = wx.transposei().data().asDouble().toList
  // 220, 490,  280, 640 - still column-wise
  println(wxTIArr)

}

As you can see ND4J basically does what it wants internally, and when you use .data() it will simply give you its internal representation; this representation isn't altered by any transposes or whatever else you ask it to do, since those don't actually move the underlying data around.

This is all fine, but what I wanted to do was basically: make a Scala list of ordinary doubles; give that to my custom library thing; ask library to do its thing; take its output and convert that to another Scala list of doubles; verify that these doubles are what I expected it to calculate. Instead, what I'm having to do is put the expected stuff in an ND4J array so I can properly compare that to the actual output, so my tests now depend on ND4J, which is an internal technical choice of my library.

Anyway, this is a relatively minor complaint and the lesson is, avoid .data() and instead if you're using ND4J, use it throughout (even if you think that a bit inelegant).



来源:https://stackoverflow.com/questions/47839617/nd4j-arrays-their-shapes-getting-data-into-a-list

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!