How to wrap Entity Framework to intercept the LINQ expression just before execution?

后端 未结 3 503
深忆病人
深忆病人 2020-12-02 07:03

I want to rewrite certain parts of the LINQ expression just before execution. And I\'m having problems injecting my rewriter in the correct place (at all actually).

3条回答
  •  不知归路
    2020-12-02 07:29

    I have exactly the sourcecode you'll need - but no idea how to attach a File.

    Here are some snippets (snippets! I had to adapt this code, so it may not compile):

    IQueryable:

    public class QueryTranslator : IOrderedQueryable
    {
        private Expression _expression = null;
        private QueryTranslatorProvider _provider = null;
    
        public QueryTranslator(IQueryable source)
        {
            _expression = Expression.Constant(this);
            _provider = new QueryTranslatorProvider(source);
        }
    
        public QueryTranslator(IQueryable source, Expression e)
        {
            if (e == null) throw new ArgumentNullException("e");
            _expression = e;
            _provider = new QueryTranslatorProvider(source);
        }
    
        public IEnumerator GetEnumerator()
        {
            return ((IEnumerable)_provider.ExecuteEnumerable(this._expression)).GetEnumerator();
        }
    
        IEnumerator System.Collections.IEnumerable.GetEnumerator()
        {
            return _provider.ExecuteEnumerable(this._expression).GetEnumerator();
        }
    
        public Type ElementType
        {
            get { return typeof(T); }
        }
    
        public Expression Expression
        {
            get { return _expression; }
        }
    
        public IQueryProvider Provider
        {
            get { return _provider; }
        }
    }
    

    IQueryProvider:

    public class QueryTranslatorProvider : ExpressionTreeTranslator, IQueryProvider
    {
        IQueryable _source;
    
        public QueryTranslatorProvider(IQueryable source)
        {
            if (source == null) throw new ArgumentNullException("source");
            _source = source;
        }
    
        #region IQueryProvider Members
    
        public IQueryable CreateQuery(Expression expression)
        {
            if (expression == null) throw new ArgumentNullException("expression");
            return new QueryTranslator(_source, expression) as IQueryable;
        }
    
        public IQueryable CreateQuery(Expression expression)
        {
            if (expression == null) throw new ArgumentNullException("expression");
    
            Type elementType = expression.Type.FindElementTypes().First();
            IQueryable result = (IQueryable)Activator.CreateInstance(typeof(QueryTranslator<>).MakeGenericType(elementType),
                new object[] { _source, expression });
            return result;
        }
    
        public TResult Execute(Expression expression)
        {
            if (expression == null) throw new ArgumentNullException("expression");
            object result = (this as IQueryProvider).Execute(expression);
            return (TResult)result;
        }
    
        public object Execute(Expression expression)
        {
            if (expression == null) throw new ArgumentNullException("expression");
    
            Expression translated = this.Visit(expression);
    
            return _source.Provider.Execute(translated);            
        }
    
        internal IEnumerable ExecuteEnumerable(Expression expression)
        {
            if (expression == null) throw new ArgumentNullException("expression");
    
            Expression translated = this.Visit(expression);
    
            return _source.Provider.CreateQuery(translated);
        }
    
        #endregion        
    
        #region Visits
        protected override MethodCallExpression VisitMethodCall(MethodCallExpression m)
        {
            return m;
        }
    
        protected override Expression VisitUnary(UnaryExpression u)
        {
             return Expression.MakeUnary(u.NodeType, base.Visit(u.Operand), u.Type.ToImplementationType(), u.Method);
        }
        #endregion
    }
    

    Usage (warning: adapted code! May not compile):

    private Dictionary _table = new Dictionary();
    public override IQueryable GetObjectQuery()
    {
        if (!_table.ContainsKey(type))
        {
            _table[type] = new QueryTranslator(
                _ctx.CreateQuery("[" + typeof(T).Name + "]"));
        }
    
        return (IQueryable)_table[type];
    }
    

    Expression Visitors/Translator:

    http://blogs.msdn.com/mattwar/archive/2007/07/31/linq-building-an-iqueryable-provider-part-ii.aspx

    http://msdn.microsoft.com/en-us/library/bb882521.aspx

    EDIT: Added FindElementTypes(). Hopefully all Methods are present now.

        /// 
        /// Finds all implemented IEnumerables of the given Type
        /// 
        public static IQueryable FindIEnumerables(this Type seqType)
        {
            if (seqType == null || seqType == typeof(object) || seqType == typeof(string))
                return new Type[] { }.AsQueryable();
    
            if (seqType.IsArray || seqType == typeof(IEnumerable))
                return new Type[] { typeof(IEnumerable) }.AsQueryable();
    
            if (seqType.IsGenericType && seqType.GetGenericArguments().Length == 1 && seqType.GetGenericTypeDefinition() == typeof(IEnumerable<>))
            {
                return new Type[] { seqType, typeof(IEnumerable) }.AsQueryable();
            }
    
            var result = new List();
    
            foreach (var iface in (seqType.GetInterfaces() ?? new Type[] { }))
            {
                result.AddRange(FindIEnumerables(iface));
            }
    
            return FindIEnumerables(seqType.BaseType).Union(result);
        }
    
        /// 
        /// Finds all element types provided by a specified sequence type.
        /// "Element types" are T for IEnumerable<T> and object for IEnumerable.
        /// 
        public static IQueryable FindElementTypes(this Type seqType)
        {
            return seqType.FindIEnumerables().Select(t => t.IsGenericType ? t.GetGenericArguments().Single() : typeof(object));
        }
    

提交回复
热议问题