How to use datasets.fetch_mldata() in sklearn?

后端 未结 11 2122
半阙折子戏
半阙折子戏 2020-12-30 05:42

I am trying to run the following code for a brief machine learning algorithm:

import re
import argparse
import csv
from collections import Counter
from sklea         


        
11条回答
  •  甜味超标
    2020-12-30 06:06

    I was also getting a fetch_mldata() "IOError: could not read bytes" error. Here is the solution; the relevant lines of code are

    from sklearn.datasets.mldata import fetch_mldata
    mnist = fetch_mldata('mnist-original', data_home='/media/Vancouver/apps/mnist_dataset/')
    

    ... be sure to change 'data_home' for your preferred location (directory).

    Here is a script:

    #!/usr/bin/python
    # coding: utf-8
    
    # Source:
    # https://stackoverflow.com/questions/19530383/how-to-use-datasets-fetch-mldata-in-sklearn
    # ... modified, below, by Victoria
    
    """
    pers. comm. (Jan 27, 2016) from MLdata.org MNIST dataset contactee "Cheng Ong:"
    
        The MNIST data is called 'mnist-original'. The string you pass to sklearn
        has to match the name of the URL:
    
        from sklearn.datasets.mldata import fetch_mldata
        data = fetch_mldata('mnist-original')
    """
    
    def get_data():
    
        """
        Get MNIST data; returns a dict with keys 'train' and 'test'.
        Both have the keys 'X' (features) and 'y' (labels)
        """
    
        from sklearn.datasets.mldata import fetch_mldata
    
        mnist = fetch_mldata('mnist-original', data_home='/media/Vancouver/apps/mnist_dataset/')
    
        x = mnist.data
        y = mnist.target
    
        # Scale data to [-1, 1]
        x = x/255.0*2 - 1
    
        from sklearn.cross_validation import train_test_split
    
        x_train, x_test, y_train, y_test = train_test_split(x, y,
            test_size=0.33, random_state=42)
    
        data = {'train': {'X': x_train, 'y': y_train},
                'test': {'X': x_test, 'y': y_test}}
    
        return data
    
    data = get_data()
    print '\n', data, '\n'
    

提交回复
热议问题