Grouping by multiple dimensions

后端 未结 2 572
花落未央
花落未央 2021-01-12 20:01

Grouping by a single dimension works fine for xarray DataArrays:

d = xr.DataArray([1, 2, 3], coords={\'a\': [\'x\', \'x\', \'y\']}, dims=[\'a\'])
d.groupby(\         


        
2条回答
  •  半阙折子戏
    2021-01-12 20:26

    I don't know how it would compare speed-wise, nor do I have enough time to put together a full solution for this particular question instance, but I found this question and answer quite helpful when I was searching for ways to iterate over multiple dimensions in xarray and wanted to share the approach I ended up taking. I ultimately used dimension stacking based on this example code by @RyanAbernathy:

    import xarray as xr
    import numpy as np
    
    # create an example dataset
    da = xr.DataArray(np.random.rand(10,30,40), dims=['dtime', 'x', 'y'])
    
    # define a function to compute a linear trend of a timeseries
    def linear_trend(x):
        pf = np.polyfit(x.time, x, 1)
        # we need to return a dataarray or else xarray's groupby won't be happy
        return xr.DataArray(pf[0])
    
    # stack lat and lon into a single dimension called allpoints
    stacked = da.stack(allpoints=['x','y'])
    # apply the function over allpoints to calculate the trend at each point
    trend = stacked.groupby('allpoints').apply(linear_trend)
    # unstack back to lat lon coordinates
    trend_unstacked = trend.unstack('allpoints')
    

    in combination with some groupby wrappers to compute multiple groupbys:

    def _calc_allpoints(ds, function):
            """
            Helper function to do a pixel-wise calculation that requires using x and y dimension values
            as inputs. This version does the computation over all available timesteps as well.
    
            """
    
            # note: the below code will need to be generalized for other dimensions
    
            def _time_wrapper(gb):
                gb = gb.groupby('dtime', squeeze=False).apply(function)
                return gb
            
            # stack x and y into a single dimension called allpoints
            stacked = ds.stack(allpoints=['x','y'])
            # groupby time and apply the function over allpoints to calculate the trend at each point
            newelev = stacked.groupby('allpoints', squeeze=False).apply(_time_wrapper)
            # unstack back to x y coordinates
            ds = newelev.unstack('allpoints')
    
            return ds
    

    where function is whatever function you are using (e.g. linear_trend)

提交回复
热议问题