Numpy thermometer encoding

后端 未结 4 1632
佛祖请我去吃肉
佛祖请我去吃肉 2021-01-06 00:28

I am trying to use numpy optimized in-built functions to generate thermometer encoding. Thermometer encoding is basically generating n amount if 1\'s in a given len

4条回答
  •  攒了一身酷
    2021-01-06 00:47

    I'd never heard of "thermometer encoding" before, but when you realise how it's so similar to one-hot encoding, it becomes clear you can get there using bit shift ops:

    >>> a = np.array([2, 3, 4, 1], dtype=np.uint8)
    >>> print(np.fliplr(np.unpackbits((1 << a) - 1).reshape(-1,8)))
    [[1 1 0 0 0 0 0 0]
     [1 1 1 0 0 0 0 0]
     [1 1 1 1 0 0 0 0]
     [1 0 0 0 0 0 0 0]]
    

    Edit: You can generalise the idea to arbitrary size integers by working in 8 column chunks:

    a = np.array([2, 13, 4, 0, 1, 17], dtype=np.uint8)
    out = np.empty((len(a), 0), dtype=np.uint8)
    while a.any():
        block = np.fliplr(np.unpackbits((1 << a) - 1).reshape(-1,8))
        out = np.concatenate([out, block], axis=1)
        a = np.where(a<8, 0, a-8)
    
    print(out)
    [[1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0]
     [1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
     [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0]]
    

提交回复
热议问题