EntityFunctions.TruncateTime and unit tests

后端 未结 6 1595
悲&欢浪女
悲&欢浪女 2020-12-14 15:57

I\'m using System.Data.Objects.EntityFunctions.TruncateTime method to get date part of a datetime in my query:

if (searchOptions.Date.HasValue)
         


        
6条回答
  •  隐瞒了意图╮
    2020-12-14 16:36

    As outlined in my answer to How to Unit Test GetNewValues() which contains EntityFunctions.AddDays function, you can use a query expression visitor to replace calls to EntityFunctions functions with your own, LINQ To Objects compatible implementations.

    The implementation would look like:

    using System;
    using System.Data.Objects;
    using System.Linq;
    using System.Linq.Expressions;
    
    static class EntityFunctionsFake
    {
        public static DateTime? TruncateTime(DateTime? original)
        {
            if (!original.HasValue) return null;
            return original.Value.Date;
        }
    }
    public class EntityFunctionsFakerVisitor : ExpressionVisitor
    {
        protected override Expression VisitMethodCall(MethodCallExpression node)
        {
            if (node.Method.DeclaringType == typeof(EntityFunctions))
            {
                var visitedArguments = Visit(node.Arguments).ToArray();
                return Expression.Call(typeof(EntityFunctionsFake), node.Method.Name, node.Method.GetGenericArguments(), visitedArguments);
            }
    
            return base.VisitMethodCall(node);
        }
    }
    class VisitedQueryProvider : IQueryProvider
        where TVisitor : ExpressionVisitor, new()
    {
        private readonly IQueryProvider _underlyingQueryProvider;
        public VisitedQueryProvider(IQueryProvider underlyingQueryProvider)
        {
            if (underlyingQueryProvider == null) throw new ArgumentNullException();
            _underlyingQueryProvider = underlyingQueryProvider;
        }
    
        private static Expression Visit(Expression expression)
        {
            return new TVisitor().Visit(expression);
        }
    
        public IQueryable CreateQuery(Expression expression)
        {
            return new VisitedQueryable(_underlyingQueryProvider.CreateQuery(Visit(expression)));
        }
    
        public IQueryable CreateQuery(Expression expression)
        {
            var sourceQueryable = _underlyingQueryProvider.CreateQuery(Visit(expression));
            var visitedQueryableType = typeof(VisitedQueryable<,>).MakeGenericType(
                sourceQueryable.ElementType,
                typeof(TVisitor)
                );
    
            return (IQueryable)Activator.CreateInstance(visitedQueryableType, sourceQueryable);
        }
    
        public TResult Execute(Expression expression)
        {
            return _underlyingQueryProvider.Execute(Visit(expression));
        }
    
        public object Execute(Expression expression)
        {
            return _underlyingQueryProvider.Execute(Visit(expression));
        }
    }
    public class VisitedQueryable : IQueryable
        where TExpressionVisitor : ExpressionVisitor, new()
    {
        private readonly IQueryable _underlyingQuery;
        private readonly VisitedQueryProvider _queryProviderWrapper;
        public VisitedQueryable(IQueryable underlyingQuery)
        {
            _underlyingQuery = underlyingQuery;
            _queryProviderWrapper = new VisitedQueryProvider(underlyingQuery.Provider);
        }
    
        public IEnumerator GetEnumerator()
        {
            return _underlyingQuery.GetEnumerator();
        }
    
        IEnumerator IEnumerable.GetEnumerator()
        {
            return GetEnumerator();
        }
    
        public Expression Expression
        {
            get { return _underlyingQuery.Expression; }
        }
    
        public Type ElementType
        {
            get { return _underlyingQuery.ElementType; }
        }
    
        public IQueryProvider Provider
        {
            get { return _queryProviderWrapper; }
        }
    }
    

    And here is a usage sample with TruncateTime:

    var linq2ObjectsSource = new List() { null }.AsQueryable();
    var visitedSource = new VisitedQueryable(linq2ObjectsSource);
    // If you do not use a lambda expression on the following line,
    // The LINQ To Objects implementation is used. I have not found a way around it.
    var visitedQuery = visitedSource.Select(dt => EntityFunctions.TruncateTime(dt));
    var results = visitedQuery.ToList();
    Assert.AreEqual(1, results.Count);
    Assert.AreEqual(null, results[0]);
    

提交回复
热议问题