问题
I've been struggling to save my graphs to the specific directory with some certaion look.
Here is the example data and what I've tried so far
import pandas as pd
import numpy as np
import itertools
import seaborn as sns
from matplotlib.colors import ListedColormap
print("seaborn version {}".format(sns.__version__))
# R expand.grid() function in Python
# https://stackoverflow.com/a/12131385/1135316
def expandgrid(*itrs):
product = list(itertools.product(*itrs))
return {'Var{}'.format(i+1):[x[i] for x in product] for i in range(len(itrs))}
ltt= ['lt1','lt2']
methods=['method 1', 'method 2', 'method 3', 'method 4']
labels = ['label1','label2']
times = range(0,100,10)
data = pd.DataFrame(expandgrid(ltt,methods,labels, times, times))
data.columns = ['ltt','method','labels','dtsi','rtsi']
#data['nw_score'] = np.random.sample(data.shape[0])
data['nw_score'] = np.random.choice([0,1],data.shape[0])
data
Out[25]:
ltt method labels dtsi rtsi nw_score
0 lt1 method 1 label1 0 0 0
1 lt1 method 1 label1 0 10 1
2 lt1 method 1 label1 0 20 1
3 lt1 method 1 label1 0 30 1
4 lt1 method 1 label1 0 40 1
... ... ... ... ... ...
1595 lt2 method 4 label2 90 50 0
1596 lt2 method 4 label2 90 60 0
1597 lt2 method 4 label2 90 70 0
1598 lt2 method 4 label2 90 80 0
1599 lt2 method 4 label2 90 90 0
labels_fill = {0:'red',1:'blue'}
def facet(data,color):
data = data.pivot(index="dtsi", columns='rtsi', values='nw_score')
g = sns.heatmap(data, cmap=ListedColormap(['red', 'blue']), cbar=False,annot=True)
for l in data.ltt.unique():
# print(l)
with sns.plotting_context(font_scale=5.5):
g = sns.FacetGrid(data,row="labels", col="method+l", size=2, aspect=1,margin_titles=False)
g = g.map_dataframe(facet)
g.add_legend()
# g.set(xlabel='common xlabel', ylabel='common ylabel')
#g.set_titles(col_template="{col_name}", fontweight='bold', fontsize=18)
g.set_titles(template="")
for ax,m in zip(g.axes[0,:],methods):
ax.set_title(m, fontweight='bold', fontsize=12)
for ax,l in zip(g.axes[:,0],labels):
ax.set_ylabel(l, fontweight='bold', fontsize=12, rotation=0, ha='right', va='center')
# g.fig.tight_layout()
save_results_to = 'D:/plots'
if not os.path.exists(save_results_to):
os.makedirs(save_results_to)
g.savefig(save_results_to + l+ '.png', dpi = 300)
When I ran the code above I'm getting an error which says
ValueError: Index contains duplicate entries, cannot reshape
the expected graph format
回答1:
The problems comes from the fact that you are trying to loop through the two ltt levels, but then you don't filter your database on those levels.
for l in data.ltt.unique():
g = sns.FacetGrid(data[data.ltt==l], ....)
Also, you have a conflict with the variable l that's used once for the ltt level and the second time in the loop for the row labels. Try using more descriptive variable names in your code.
Here is the full working code:
import pandas as pd
import numpy as np
import itertools
import seaborn as sns
from matplotlib.colors import ListedColormap
print("seaborn version {}".format(sns.__version__))
# R expand.grid() function in Python
# https://stackoverflow.com/a/12131385/1135316
def expandgrid(*itrs):
product = list(itertools.product(*itrs))
return {'Var{}'.format(i+1):[x[i] for x in product] for i in range(len(itrs))}
ltt= ['lt1','lt2']
methods=['method 1', 'method 2', 'method 3', 'method 4']
labels = ['label1','label2']
times = range(0,100,10)
data = pd.DataFrame(expandgrid(ltt,methods,labels, times, times))
data.columns = ['ltt','method','labels','dtsi','rtsi']
#data['nw_score'] = np.random.sample(data.shape[0])
data['nw_score'] = np.random.choice([0,1],data.shape[0])
labels_fill = {0:'red',1:'blue'}
def facet(data,color):
data = data.pivot(index="dtsi", columns='rtsi', values='nw_score')
g = sns.heatmap(data, cmap=ListedColormap(['red', 'blue']), cbar=False,annot=True)
for lt in data.ltt.unique():
with sns.plotting_context(font_scale=5.5):
g = sns.FacetGrid(data[data.ltt==lt],row="labels", col="method", size=2, aspect=1,margin_titles=False)
g = g.map_dataframe(facet)
g.add_legend()
g.set_titles(template="")
for ax,method in zip(g.axes[0,:],methods):
ax.set_title(method, fontweight='bold', fontsize=12)
for ax,label in zip(g.axes[:,0],labels):
ax.set_ylabel(label, fontweight='bold', fontsize=12, rotation=0, ha='right', va='center')
g.fig.suptitle(lt, fontweight='bold', fontsize=12)
g.fig.tight_layout()
g.fig.subplots_adjust(top=0.8) # make some room for the title
g.savefig(lt+'.png', dpi=300)
lt1.png
lt2.png
来源:https://stackoverflow.com/questions/59099912/trouble-with-saving-grouped-seaborn-facetgrid-heatmap-data-into-a-directory