Applying a function along an axis of a dask array
I'm analyzing ocean temperature data from a climate model simulation where the 4D data arrays (time, depth, latitude, longitude; denoted dask_array below) typically have a shape of (6000, 31, 189, 192) and a size of ~25GB (hence my desire to use dask; I've been getting memory errors trying to process these arrays using numpy). I need to fit a cubic polynomial along the time axis at each level / latitude / longitude point and store the resulting 4 coefficients. I've therefore set chunksize=(6000, 1, 1, 1) so I have a separate chunk for each grid point. This is my function for getting the