How to standard scale a 3D matrix?

随声附和 提交于 2019-12-19 18:55:30

问题


I am working on a signal classification problem and would like to scale the dataset matrix first, but my data is in a 3D format (batch, length, channels).
I tried to use Scikit-learn Standard Scaler:

from sklearn.preprocessing import StandardScaler
sc = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.transform(X_test)

But I've got this error message:

Found array with dim 3. StandardScaler expected <= 2

I think one solution would be to split the matrix by each channel in multiples 2D matrices, scale them separately and then put back in 3D format, but I wonder if there is a better solution.
Thank you very much.


回答1:


You'll have to fit and store a scaler for each channel

from sklearn.preprocessing import StandardScaler

scalers = {}
for i in range(X_train.shape[1]):
    scalers[i] = StandardScaler()
    X_train[:, i, :] = scalers[i].fit_transform(X_train[:, i, :]) 

for i in range(X_test.shape[1]):
    X_test[:, i, :] = scalers[i].transform(X_test[:, i, :]) 



回答2:


If you want to scale each feature differently, like StandardScaler does, you can use this:

import numpy as np
from sklearn.base import TransformerMixin
from sklearn.preprocessing import StandardScaler


class NDStandardScaler(TransformerMixin):
    def __init__(self, **kwargs):
        self._scaler = StandardScaler(copy=True, **kwargs)
        self._orig_shape = None

    def fit(self, X, **kwargs):
        X = np.array(X)
        # Save the original shape to reshape the flattened X later
        # back to its original shape
        if len(X.shape) > 1:
            self._orig_shape = X.shape[1:]
        X = self._flatten(X)
        self._scaler.fit(X, **kwargs)
        return self

    def transform(self, X, **kwargs):
        X = np.array(X)
        X = self._flatten(X)
        X = self._scaler.transform(X, **kwargs)
        X = self._reshape(X)
        return X

    def _flatten(self, X):
        # Reshape X to <= 2 dimensions
        if len(X.shape) > 2:
            n_dims = np.prod(self._orig_shape)
            X = X.reshape(-1, n_dims)
        return X

    def _reshape(self, X):
        # Reshape X back to it's original shape
        if len(X.shape) >= 2:
            X = X.reshape(-1, *self._orig_shape)
        return X

It simply flattens the features of the input before giving it to sklearn's StandardScaler. Then, it reshapes them back. The usage is the same as for the StandardScaler:

data = [[[0, 1], [2, 3]], [[1, 5], [2, 9]]]
scaler = NDStandardScaler()
print(scaler.fit_transform(data))

prints

[[[-1. -1.]
  [ 0. -1.]]

 [[ 1.  1.]
  [ 0.  1.]]]

The arguments with_mean and with_std are directly passed to StandardScaler and thus work as expected. copy=False won't work, since the reshaping does not happen inplace. For 2-D inputs, the NDStandardScaler works like the StandardScaler:

data = [[0, 0], [0, 0], [1, 1], [1, 1]]
scaler = NDStandardScaler()
scaler.fit(data)
print(scaler.transform(data))
print(scaler.transform([[2, 2]]))

prints

[[-1. -1.]
 [-1. -1.]
 [ 1.  1.]
 [ 1.  1.]]
[[3. 3.]]

just like in the sklearn example for StandardScaler.




回答3:


s0, s1, s2 = y_train.shape[0], y_train.shape[1], y_train.shape[2]
y_train = y_train.reshape(s0 * s1, s2)
y_train = minMaxScaler.fit_transform(y_train)
y_train = y_train.reshape(s0, s1, s2)

s0, s1, s2 = y_test.shape[0], y_test.shape[1], y_test.shape[2]
y_test = y_test.reshape(s0 * s1, s2)
y_test = minMaxScaler.transform(y_test)
y_test = y_test.reshape(s0, s1, s2)

Just reshaped the data like so. For zero padded use similar:

s0, s1, s2 = x_train.shape[0], x_train.shape[1], x_train.shape[2]
x_train = x_train.reshape(s0 * s1, s2)
minMaxScaler.fit(x_train[0::s1])
x_train = minMaxScaler.transform(x_train)
x_train = x_train.reshape(s0, s1, s2)

s0, s1, s2 = x_test.shape[0], x_test.shape[1], x_test.shape[2]
x_test = x_test.reshape(s0 * s1, s2)
x_test = minMaxScaler.transform(x_test)
x_test = x_test.reshape(s0, s1, s2)


来源:https://stackoverflow.com/questions/50125844/how-to-standard-scale-a-3d-matrix

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!