How to stop plots printing twice in jupyter when using subplots?

佐手、 提交于 2020-02-28 07:27:09

问题


I'm working with the titanic data and I'm trying to use a combination of pyplot and seaborn to produce some subplots. I've written the following code to create 6 subplots in a 3x2 grid;

plt.rcParams['figure.figsize'] = [12, 8]
fig, axes = plt.subplots(nrows=3, ncols=2)
plt.tight_layout()

_ = sns.catplot(x='Pclass', y='Age', data=train_df, kind='box', height=8, palette=col_pal, ax=axes[0, 0])
_ = sns.catplot(x='Embarked', y='Age', data=train_df, kind='box', height=8, palette=col_pal, ax=axes[0, 1])
_ = sns.catplot(x='Sex', y='Age', data=train_df, kind='box', height=8, palette=col_pal, ax=axes[1, 0])
_ = sns.catplot(x='Sex', y='Age', hue='Pclass', data=train_df, kind='box', height=8, palette=col_pal, ax=axes[1, 1])
_ = sns.catplot(x='SibSp', y='Age', data=train_df, kind='box', height=8, palette=col_pal, ax=axes[2, 0])
_ = sns.catplot(x='Parch', y='Age', data=train_df, kind='box', height=8, palette=col_pal, ax=axes[2, 1])
plt.show()

When I run this in my notebook, it succesfully creates the desired plot, however, it also prints out 6 blank plots afterwards.

How can I suppress these empty plots from printing into my output?


回答1:


Unlike other sns plots catplot generates a fig not an axes. That's why to fix such weird behavior you need to use plt.close() after each catplot execution:

fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(8, 12))
fig.tight_layout()

sns.catplot(x='pclass', y='age', data=data, kind='box', ax=axes[0, 0])
plt.close()
sns.catplot(x='embarked', y='age', data=data, kind='box', ax=axes[0, 1])
plt.close()
sns.catplot(x='sex', y='age', data=data, kind='box', ax=axes[1, 0])
plt.close()
sns.catplot(x='sex', y='age', hue='pclass', data=data, kind='box', ax=axes[1, 1])
plt.close()
sns.catplot(x='sibsp', y='age', data=data, kind='box', ax=axes[2, 0])
plt.close()
sns.catplot(x='parch', y='age', data=data, kind='box', ax=axes[2, 1]);
plt.close()

plt.show()

Out:




回答2:


Assign each of your plots to a variable like g, and use plt.close(g.fig) to remove your unwanted subplots. Or iterate over all sns.axisgrid.FacetGrid type variables and close them like so:

for p in plots_names:
    plt.close(vars()[p].fig)

The complete snippet below does just that. Note that I'm loading the titanic dataset using train_df = sns.load_dataset("titanic"). Here, all column names are lower case unlike in your example. I've also removed the palette=col_pal argument since col_pal is not defined in your snippet.

Plot:

Code:

import seaborn as sns
import matplotlib.pyplot as plt

plt.rcParams['figure.figsize'] = [12, 8]
fig, axes = plt.subplots(nrows=3, ncols=2)
plt.tight_layout()

train_df = sns.load_dataset("titanic")

g = sns.catplot(x='pclass', y='age', data=train_df, kind='box', height=8, ax=axes[0, 0])
h = sns.catplot(x='embarked', y='age', data=train_df, kind='box', height=8, ax=axes[0, 1])
i = sns.catplot(x='sex', y='age', data=train_df, kind='box', height=8, ax=axes[1, 0])
j = sns.catplot(x='sex', y='age', hue='pclass', data=train_df, kind='box', height=8, ax=axes[1, 1])
k = sns.catplot(x='sibsp', y='age', data=train_df, kind='box', height=8, ax=axes[2, 0])
l = sns.catplot(x='parch', y='age', data=train_df, kind='box', height=8, ax=axes[2, 1])

# iterate over plots and run
# plt.close() to prevent duplicate
# subplot setup
var_dict = vars().copy()
var_keys = var_dict.keys()
plots_names = [x for x in var_keys if isinstance(var_dict[x], sns.axisgrid.FacetGrid)]
for p in plots_names:
    plt.close(vars()[p].fig)

Please note that you will have to assign your plots to variable names for this to work. If you just add the snippet that closes the plots to the end of your original snippet, the duplicate subplot setup will remain untouched.

Code 2:

import seaborn as sns
import matplotlib.pyplot as plt

plt.rcParams['figure.figsize'] = [12, 8]
fig, axes = plt.subplots(nrows=3, ncols=2)
plt.tight_layout()

train_df = sns.load_dataset("titanic")

_ = sns.catplot(x='pclass', y='age', data=train_df, kind='box', height=8, ax=axes[0, 0])
_ = sns.catplot(x='embarked', y='age', data=train_df, kind='box', height=8, ax=axes[0, 1])
_ = sns.catplot(x='sex', y='age', data=train_df, kind='box', height=8, ax=axes[1, 0])
_ = sns.catplot(x='sex', y='age', hue='pclass', data=train_df, kind='box', height=8, ax=axes[1, 1])
_ = sns.catplot(x='sibsp', y='age', data=train_df, kind='box', height=8, ax=axes[2, 0])
_ = sns.catplot(x='parch', y='age', data=train_df, kind='box', height=8, ax=axes[2, 1])

# iterate over plots and run
# plt.close() to prevent duplicate
# subplot setup
var_dict = vars().copy()
var_keys = var_dict.keys()
plots_names = [x for x in var_keys if isinstance(var_dict[x], sns.axisgrid.FacetGrid)]
for p in plots_names:
    plt.close(vars()[p].fig)

Plot 2:



来源:https://stackoverflow.com/questions/60042218/how-to-stop-plots-printing-twice-in-jupyter-when-using-subplots

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!