This code works fine to produce a single row of 3 subplots
g = np.random.choice([1,2,3], 10)
s = np.random.normal(size=10)
s2 =np.random.normal(size=10)
df = pd.DataFrame([g, s, s2]).T
df.columns = ['key', 's1', 's2']
gb = df.groupby('key')
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(12, 4))
i = 0
for key, df2 in gb:
df2.plot(ax=axes[i], x='s1', y='s2', title=key)
i = i + 1
But if I try to add a second row (nrows=2), it blows up with an Attribute Error
AttributeError: 'numpy.ndarray' object has no attribute 'get_figure'
fig, axes = plt.subplots(nrows=2, ncols=3)
fig.tight_layout() # Or equivalently, "plt.tight_layout()"
i = 1
for key, df2 in gb:
df2.plot(ax=axes[i])
i = i + 1
That's because axes
is now a 2-d array of matplotlib axes.
In [35]: fig, axes = plt.subplots(nrows=2, ncols=3)
In [36]: axes
Out[36]:
array([[<matplotlib.axes._subplots.AxesSubplot object at 0x10c094ac8>,
<matplotlib.axes._subplots.AxesSubplot object at 0x10ee40b70>,
<matplotlib.axes._subplots.AxesSubplot object at 0x10c0ac240>],
[<matplotlib.axes._subplots.AxesSubplot object at 0x10edf8a90>,
<matplotlib.axes._subplots.AxesSubplot object at 0x10ec27630>,
<matplotlib.axes._subplots.AxesSubplot object at 0x10edc9128>]], dtype=object)
In [37]: axes.shape
Out[37]: (2, 3)
Try something like
In [38]: for i, (key, df2) in enumerate(gb):
df2.plot(ax=axes[0][i])
If you want to wrap around to the second row you'll have to do something like axes[i // 3][i % 3]
(I expanded your example to have 6 groups)
In [67]: df
Out[67]:
key s1 s2
0 3 -1.452043 -0.119374
1 1 0.603860 -1.635034
2 3 0.964165 -0.043124
3 2 0.459628 -0.538155
4 3 0.398761 -0.195261
5 1 0.085750 -0.116766
6 2 -0.397419 -0.140660
7 3 -0.053209 1.547755
8 1 -0.634555 -0.509077
9 3 0.138808 0.608165
10 6 -1.452043 -0.119374
11 4 0.603860 -1.635034
12 6 0.964165 -0.043124
13 5 0.459628 -0.538155
14 6 0.398761 -0.195261
15 4 0.085750 -0.116766
16 5 -0.397419 -0.140660
17 6 -0.053209 1.547755
18 4 -0.634555 -0.509077
19 6 0.138808 0.608165
In [63]: for i, (key, df2) in enumerate(gb):
df2.plot(ax=axes[i // 3][i % 3])
Thank you so much, Tom. This now works perfectly. Sincerely appreciate your helpful reply.
Most helpful comment
If you want to wrap around to the second row you'll have to do something like
axes[i // 3][i % 3]
(I expanded your example to have 6 groups)