How to return empty IQueryable in an async repository method

梦想的初衷 提交于 2019-11-30 20:51:54

If you don't want to hit the DB, you'll most likely have to provide your own implementation of empty IQuerable that implements IDbAsyncEnumerable. But I don't think it is too hard. In all the enumerators just return null for Current and false for MoveNext. In Dispose just do nothing. Try it. Enumerable.Empty<MyObject>().AsQueryable() has nothing to do with database, it definitely does not implement IDbAsyncEnumerable. You need an implementation that does, according to this.

I ended up implementing an extension method that returns wrapper which implements IDbAsyncEnumerable. It is based on this boilerplate implementation for mocking async code.

With this extension method I can use

return Enumerable.Empty<MyObject>().AsAsyncQueryable();

which works great.

Implementation:

using System;
using System.Collections;
using System.Collections.Generic;
using System.Data.Entity.Infrastructure;
using System.Linq;
using System.Linq.Expressions;
using System.Threading;
using System.Threading.Tasks;

namespace MyProject.MyDatabase.Extensions
{
    public static class EnumerableExtensions
    {
        public static IQueryable<T> AsAsyncQueryable<T>(this IEnumerable<T> source)
        {
            return new AsyncQueryableWrapper<T>(source);
        }

        public static IQueryable<T> AsAsyncQueryable<T>(this IQueryable<T> source)
        {
            return new AsyncQueryableWrapper<T>(source);
        }
    }

    internal class AsyncQueryableWrapper<T>: IDbAsyncEnumerable<T>, IQueryable<T>
    {
        private readonly IQueryable<T> _source;

        public AsyncQueryableWrapper(IQueryable<T> source)
        {
            _source = source;
        }

        public AsyncQueryableWrapper(IEnumerable<T> source)
        {
            _source = source.AsQueryable();
        }

        public IDbAsyncEnumerator<T> GetAsyncEnumerator()
        {
            return new AsyncEnumerator<T>(this.AsEnumerable().GetEnumerator());
        }

        IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator()
        {
            return GetAsyncEnumerator();
        }

        public IEnumerator<T> GetEnumerator()
        {
            return _source.GetEnumerator();
        }

        IEnumerator IEnumerable.GetEnumerator()
        {
            return GetEnumerator();
        }

        public Expression Expression => _source.Expression;
        public Type ElementType => _source.ElementType;
        public IQueryProvider Provider => new AsyncQueryProvider<T>(_source.Provider);
    }

    internal class AsyncEnumerable<T> : EnumerableQuery<T>, IDbAsyncEnumerable<T>, IQueryable<T>
    {
        public AsyncEnumerable(IEnumerable<T> enumerable)
            : base(enumerable)
        { }

        public AsyncEnumerable(Expression expression)
            : base(expression)
        { }

        public IDbAsyncEnumerator<T> GetAsyncEnumerator()
        {
            return new AsyncEnumerator<T>(this.AsEnumerable().GetEnumerator());
        }

        IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator()
        {
            return GetAsyncEnumerator();
        }

        IQueryProvider IQueryable.Provider => new AsyncQueryProvider<T>(this);
    }

    internal class AsyncQueryProvider<TEntity> : IDbAsyncQueryProvider
    {
        private readonly IQueryProvider _inner;

        internal AsyncQueryProvider(IQueryProvider inner)
        {
            _inner = inner;
        }

        public IQueryable CreateQuery(Expression expression)
        {
            var t = expression.Type;
            if (!t.IsGenericType)
            {
                return new AsyncEnumerable<TEntity>(expression);
            }

            var genericParams = t.GetGenericArguments();
            var genericParam = genericParams[0];
            var enumerableType = typeof(AsyncEnumerable<>).MakeGenericType(genericParam);

            return (IQueryable)Activator.CreateInstance(enumerableType, expression);
        }

        public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
        {
            return new AsyncEnumerable<TElement>(expression);
        }

        public object Execute(Expression expression)
        {
            return _inner.Execute(expression);
        }

        public TResult Execute<TResult>(Expression expression)
        {
            return _inner.Execute<TResult>(expression);
        }

        public Task<object> ExecuteAsync(Expression expression, CancellationToken cancellationToken)
        {
            return Task.FromResult(Execute(expression));
        }

        public Task<TResult> ExecuteAsync<TResult>(Expression expression, CancellationToken cancellationToken)
        {
            return Task.FromResult(Execute<TResult>(expression));
        }
    }

    internal class AsyncEnumerator<T> : IDbAsyncEnumerator<T>
    {
        private readonly IEnumerator<T> _inner;

        public AsyncEnumerator(IEnumerator<T> inner)
        {
            _inner = inner;
        }

        public void Dispose()
        {
            _inner.Dispose();
        }

        public Task<bool> MoveNextAsync(CancellationToken cancellationToken)
        {
            return Task.FromResult(_inner.MoveNext());
        }

        public T Current => _inner.Current;

        object IDbAsyncEnumerator.Current => Current;
    }
}
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!