sklearn.LabelEncoder with never seen before values

后端 未结 12 984
执笔经年
执笔经年 2020-11-27 10:37

If a sklearn.LabelEncoder has been fitted on a training set, it might break if it encounters new values when used on a test set.

The only solution I c

12条回答
  •  离开以前
    2020-11-27 11:34

    I have created a class to support this. If you have a new label comes, this will assign it as unknown class.

    from sklearn.preprocessing import LabelEncoder
    import numpy as np
    
    
    class LabelEncoderExt(object):
        def __init__(self):
            """
            It differs from LabelEncoder by handling new classes and providing a value for it [Unknown]
            Unknown will be added in fit and transform will take care of new item. It gives unknown class id
            """
            self.label_encoder = LabelEncoder()
            # self.classes_ = self.label_encoder.classes_
    
        def fit(self, data_list):
            """
            This will fit the encoder for all the unique values and introduce unknown value
            :param data_list: A list of string
            :return: self
            """
            self.label_encoder = self.label_encoder.fit(list(data_list) + ['Unknown'])
            self.classes_ = self.label_encoder.classes_
    
            return self
    
        def transform(self, data_list):
            """
            This will transform the data_list to id list where the new values get assigned to Unknown class
            :param data_list:
            :return:
            """
            new_data_list = list(data_list)
            for unique_item in np.unique(data_list):
                if unique_item not in self.label_encoder.classes_:
                    new_data_list = ['Unknown' if x==unique_item else x for x in new_data_list]
    
            return self.label_encoder.transform(new_data_list)
    

    The sample usage:

    country_list = ['Argentina', 'Australia', 'Canada', 'France', 'Italy', 'Spain', 'US', 'Canada', 'Argentina, ''US']
    
    label_encoder = LabelEncoderExt()
    
    label_encoder.fit(country_list)
    print(label_encoder.classes_) # you can see new class called Unknown
    print(label_encoder.transform(country_list))
    
    
    new_country_list = ['Canada', 'France', 'Italy', 'Spain', 'US', 'India', 'Pakistan', 'South Africa']
    print(label_encoder.transform(new_country_list))
    

提交回复
热议问题