Scikit Learn OneHotEncoder fit and transform Error: ValueError: X has different shape than during fitting

余生长醉 提交于 2019-11-27 05:33:23

Instead of using pd.get_dummies() you need LabelEncoder + OneHotEncoder which can store the original values and then use them on the new data.

Changing your code like below will give you required results.

import pandas as pd
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
input_df = pd.DataFrame(dict(fruit=['Apple', 'Orange', 'Pine'], 
                             color=['Red', 'Orange','Green'],
                             is_sweet = [0,0,1],
                             country=['USA','India','Asia']))

filtered_df = input_df.apply(pd.to_numeric, errors='ignore')

# This is what you need
le_dict = {}
for col in filtered_df.columns:
    le_dict[col] = LabelEncoder().fit(filtered_df[col])
    filtered_df[col] = le_dict[col].transform(filtered_df[col])

enc = OneHotEncoder()
enc.fit(filtered_df)
refreshed_df = enc.transform(filtered_df).toarray()

new_df = pd.DataFrame(dict(fruit=['Apple'], 
                             color=['Red'],
                             is_sweet = [0],
                             country=['USA']))
for col in new_df.columns:
    new_df[col] = le_dict[col].transform(new_df[col])

new_refreshed_df = enc.transform(new_df).toarray()

print(filtered_df)
      color  country  fruit  is_sweet
0      2        2      0         0
1      1        1      1         0
2      0        0      2         1

print(refreshed_df)
[[ 0.  0.  1.  0.  0.  1.  1.  0.  0.  1.  0.]
 [ 0.  1.  0.  0.  1.  0.  0.  1.  0.  1.  0.]
 [ 1.  0.  0.  1.  0.  0.  0.  0.  1.  0.  1.]]

print(new_df)
      color  country  fruit  is_sweet
0      2        2      0         0

print(new_refreshed_df)
[[ 0.  0.  1.  0.  0.  1.  1.  0.  0.  1.  0.]]

You encoder is fitted on refreshed_df which contains 10 columns and your refreshed_df1 contains only 4, literally what it is reported in the error. You have either to delete the columns not appearing on your refreshed_df1 or just fit your encoder to a new version of refreshed_df that only contains the 4 columns appearing in refreshed_df1 .

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!