Problem with plotting graphs in 1 row using plot method from pandas

删除回忆录丶 提交于 2019-12-20 03:52:27

问题


Suppose I want to plot 3 graphs in 1 row: dependencies cnt from other 3 features.

Code:

fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 10))
for idx, feature in enumerate(min_regressors):
    df_shuffled.plot(feature, "cnt", subplots=True, kind="scatter", ax= axes[0, idx])
plt.show()

Error message:

IndexErrorTraceback (most recent call last)
<ipython-input-697-e15bcbeccfad> in <module>()
      2 fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 10))
      3 for idx, feature in enumerate(min_regressors):
----> 4     df_shuffled.plot(feature, "cnt", subplots=True, kind="scatter", ax= axes[0, idx])
      5 plt.show()

IndexError: too many indices for array

But everything is ok when I'm plotting in (2,2) dimension:

Code:

fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(15, 10))
for idx, feature in enumerate(min_regressors):
    df_shuffled.plot(feature, "cnt", subplots=True, kind="scatter", ax= axes[idx / 2, idx % 2])
plt.show()

Output:

I'm using python 2.7


回答1:


The problem is not related to pandas. The index error you see comes from ax= axes[0, idx]. This is because you have a single row. [0, idx] would work when you have more than one row.

For just one row, you can skip the first index and use

fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 10))
for idx, feature in enumerate(min_regressors):
    df_shuffled.plot(feature, "cnt", subplots=True, kind="scatter", ax= axes[idx])
plt.show()

As a recap

Correct

fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(8, 3))
axes[0].plot([1,2], [1,2])

Incorrect

fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(8, 3))
axes[0, 0].plot([1,2], [1,2])

Correct

fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(8, 3))
axes[0,0].plot([1,2], [1,2])



回答2:


For you to learn and understand what is happening, I suggest you check the size of axes in both of these situations. You will see that when either nrows or ncols is 1, the axes variable will be 1-dimensional, and otherwise it'll be 2 dimensional.

You cannot index a 1-dimensional object the way you are doing (ax= axes[0, idx]).

What you can do is use numpy's atleast_2d to make the axes 2D.

Alternatively, a better solution would be to iterate over the features and axes directly:

fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 10))
for ax, feature in zip(axes, min_regressors):
    df_shuffled.plot(feature, "cnt", subplots=True, kind="scatter", ax=ax)
plt.show()


来源:https://stackoverflow.com/questions/54495582/problem-with-plotting-graphs-in-1-row-using-plot-method-from-pandas

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!