Grouping by a single dimension works fine for xarray DataArrays:
d = xr.DataArray([1, 2, 3], coords={\'a\': [\'x\', \'x\', \'y\']}, dims=[\'a\'])
d.groupby(\
I don't know how it would compare speed-wise, nor do I have enough time to put together a full solution for this particular question instance, but I found this question and answer quite helpful when I was searching for ways to iterate over multiple dimensions in xarray and wanted to share the approach I ended up taking. I ultimately used dimension stacking based on this example code by @RyanAbernathy:
import xarray as xr
import numpy as np
# create an example dataset
da = xr.DataArray(np.random.rand(10,30,40), dims=['dtime', 'x', 'y'])
# define a function to compute a linear trend of a timeseries
def linear_trend(x):
pf = np.polyfit(x.time, x, 1)
# we need to return a dataarray or else xarray's groupby won't be happy
return xr.DataArray(pf[0])
# stack lat and lon into a single dimension called allpoints
stacked = da.stack(allpoints=['x','y'])
# apply the function over allpoints to calculate the trend at each point
trend = stacked.groupby('allpoints').apply(linear_trend)
# unstack back to lat lon coordinates
trend_unstacked = trend.unstack('allpoints')
in combination with some groupby wrappers to compute multiple groupbys:
def _calc_allpoints(ds, function):
"""
Helper function to do a pixel-wise calculation that requires using x and y dimension values
as inputs. This version does the computation over all available timesteps as well.
"""
# note: the below code will need to be generalized for other dimensions
def _time_wrapper(gb):
gb = gb.groupby('dtime', squeeze=False).apply(function)
return gb
# stack x and y into a single dimension called allpoints
stacked = ds.stack(allpoints=['x','y'])
# groupby time and apply the function over allpoints to calculate the trend at each point
newelev = stacked.groupby('allpoints', squeeze=False).apply(_time_wrapper)
# unstack back to x y coordinates
ds = newelev.unstack('allpoints')
return ds
where function is whatever function you are using (e.g. linear_trend)