plot dataset and labels over multiple rows (jupyter notebook)

与世无争的帅哥 提交于 2021-02-20 05:02:14

问题


I have the following code below which simply plots datasets (consisting of dogs and cats images) and their labels. I'm using jupyter notebook:

train_path = 'dataset/train'
valid_path = 'dataset/valid'
test_path = 'dataset/test'

train_batches = ImageDataGenerator().flow_from_directory(train_path, target_size=(224,224), classes=['dog', 'cat'], batch_size=10)
valid_batches = ImageDataGenerator().flow_from_directory(valid_path, target_size=(224,224), classes=['dog', 'cat'], batch_size=4)
test_batches = ImageDataGenerator().flow_from_directory(test_path, target_size=(224,224), classes=['dog', 'cat'], batch_size=10)

# plot function, used to plot images with labels within jupyter notebook
def plots(ims, figsize=(12,6), rows=1, interp=False, titles=None):
    if type(ims[0]) is np.ndarray:
        ims= np.array(ims).astype(np.uint8)
        if (ims.shape[-1] != 3):
            ims = ims.transpose((0,2,3,1))
    f = plt.figure(figsize=figsize)
    cols = len(ims)//rows if len(ims) % 2 == 0 else len(ims)//rows + 1
    for i in range(len(ims)):
        sp = f.add_subplot(rows, cols, i+1)
        sp.axis('off')
        if titles is not None:
            sp.set_title(titles[i], fontsize=16)
        plt.imshow(ims[i], interpolation=None if interp else 'none')

imgs, labels = next(train_batches)

# we plot these samples of images and their labels 1 batch at a time.
plots(imgs, titles=labels)

If I use only 10 samples per batch as per the code above, this fits adequately along the notebook page width:

But if I want to change the batch size to more than that, say 100 samples (or any size) in a batch (i.e. in the code train_batches = ImageDataGenerator() change batch_size=100), and plot this, it will just try to squeeze it all inline on 1 row, as per the screenshot below:

How do I change this in the code so it plots all 100 data samples to fit adequately across multiple rows, rather than squeeze them all (to a tiny size) in a single row.

Many thanks in advance.

来源:https://stackoverflow.com/questions/57548144/plot-dataset-and-labels-over-multiple-rows-jupyter-notebook

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!