问题
I'm working with the titanic data and I'm trying to use a combination of pyplot and seaborn to produce some subplots. I've written the following code to create 6 subplots in a 3x2 grid;
plt.rcParams['figure.figsize'] = [12, 8]
fig, axes = plt.subplots(nrows=3, ncols=2)
plt.tight_layout()
_ = sns.catplot(x='Pclass', y='Age', data=train_df, kind='box', height=8, palette=col_pal, ax=axes[0, 0])
_ = sns.catplot(x='Embarked', y='Age', data=train_df, kind='box', height=8, palette=col_pal, ax=axes[0, 1])
_ = sns.catplot(x='Sex', y='Age', data=train_df, kind='box', height=8, palette=col_pal, ax=axes[1, 0])
_ = sns.catplot(x='Sex', y='Age', hue='Pclass', data=train_df, kind='box', height=8, palette=col_pal, ax=axes[1, 1])
_ = sns.catplot(x='SibSp', y='Age', data=train_df, kind='box', height=8, palette=col_pal, ax=axes[2, 0])
_ = sns.catplot(x='Parch', y='Age', data=train_df, kind='box', height=8, palette=col_pal, ax=axes[2, 1])
plt.show()
When I run this in my notebook, it succesfully creates the desired plot, however, it also prints out 6 blank plots afterwards.
How can I suppress these empty plots from printing into my output?
回答1:
Unlike other sns plots catplot generates a fig not an axes. That's why to fix such weird behavior you need to use plt.close() after each catplot execution:
fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(8, 12))
fig.tight_layout()
sns.catplot(x='pclass', y='age', data=data, kind='box', ax=axes[0, 0])
plt.close()
sns.catplot(x='embarked', y='age', data=data, kind='box', ax=axes[0, 1])
plt.close()
sns.catplot(x='sex', y='age', data=data, kind='box', ax=axes[1, 0])
plt.close()
sns.catplot(x='sex', y='age', hue='pclass', data=data, kind='box', ax=axes[1, 1])
plt.close()
sns.catplot(x='sibsp', y='age', data=data, kind='box', ax=axes[2, 0])
plt.close()
sns.catplot(x='parch', y='age', data=data, kind='box', ax=axes[2, 1]);
plt.close()
plt.show()
Out:
回答2:
Assign each of your plots to a variable like g, and use plt.close(g.fig) to remove your unwanted subplots. Or iterate over all sns.axisgrid.FacetGrid type variables and close them like so:
for p in plots_names:
plt.close(vars()[p].fig)
The complete snippet below does just that. Note that I'm loading the titanic dataset using train_df = sns.load_dataset("titanic"). Here, all column names are lower case unlike in your example. I've also removed the palette=col_pal argument since col_pal is not defined in your snippet.
Plot:
Code:
import seaborn as sns
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [12, 8]
fig, axes = plt.subplots(nrows=3, ncols=2)
plt.tight_layout()
train_df = sns.load_dataset("titanic")
g = sns.catplot(x='pclass', y='age', data=train_df, kind='box', height=8, ax=axes[0, 0])
h = sns.catplot(x='embarked', y='age', data=train_df, kind='box', height=8, ax=axes[0, 1])
i = sns.catplot(x='sex', y='age', data=train_df, kind='box', height=8, ax=axes[1, 0])
j = sns.catplot(x='sex', y='age', hue='pclass', data=train_df, kind='box', height=8, ax=axes[1, 1])
k = sns.catplot(x='sibsp', y='age', data=train_df, kind='box', height=8, ax=axes[2, 0])
l = sns.catplot(x='parch', y='age', data=train_df, kind='box', height=8, ax=axes[2, 1])
# iterate over plots and run
# plt.close() to prevent duplicate
# subplot setup
var_dict = vars().copy()
var_keys = var_dict.keys()
plots_names = [x for x in var_keys if isinstance(var_dict[x], sns.axisgrid.FacetGrid)]
for p in plots_names:
plt.close(vars()[p].fig)
Please note that you will have to assign your plots to variable names for this to work. If you just add the snippet that closes the plots to the end of your original snippet, the duplicate subplot setup will remain untouched.
Code 2:
import seaborn as sns
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [12, 8]
fig, axes = plt.subplots(nrows=3, ncols=2)
plt.tight_layout()
train_df = sns.load_dataset("titanic")
_ = sns.catplot(x='pclass', y='age', data=train_df, kind='box', height=8, ax=axes[0, 0])
_ = sns.catplot(x='embarked', y='age', data=train_df, kind='box', height=8, ax=axes[0, 1])
_ = sns.catplot(x='sex', y='age', data=train_df, kind='box', height=8, ax=axes[1, 0])
_ = sns.catplot(x='sex', y='age', hue='pclass', data=train_df, kind='box', height=8, ax=axes[1, 1])
_ = sns.catplot(x='sibsp', y='age', data=train_df, kind='box', height=8, ax=axes[2, 0])
_ = sns.catplot(x='parch', y='age', data=train_df, kind='box', height=8, ax=axes[2, 1])
# iterate over plots and run
# plt.close() to prevent duplicate
# subplot setup
var_dict = vars().copy()
var_keys = var_dict.keys()
plots_names = [x for x in var_keys if isinstance(var_dict[x], sns.axisgrid.FacetGrid)]
for p in plots_names:
plt.close(vars()[p].fig)
Plot 2:
来源:https://stackoverflow.com/questions/60042218/how-to-stop-plots-printing-twice-in-jupyter-when-using-subplots