an expression tree lambda may not contain a null propagating operator

前端 未结 4 1945
眼角桃花
眼角桃花 2020-11-28 10:49

Question: The line price = co?.price ?? 0, in the following code gives me the above error. but if I remove ? from co.?

4条回答
  •  轻奢々
    轻奢々 (楼主)
    2020-11-28 11:33

    While expression tree does not support the C# 6.0 null propagating, what we can do is create a visitor that modify expression tree for safe null propagation, just like the operator does!

    Here is mine:

    public class NullPropagationVisitor : ExpressionVisitor
    {
        private readonly bool _recursive;
    
        public NullPropagationVisitor(bool recursive)
        {
            _recursive = recursive;
        }
    
        protected override Expression VisitUnary(UnaryExpression propertyAccess)
        {
            if (propertyAccess.Operand is MemberExpression mem)
                return VisitMember(mem);
    
            if (propertyAccess.Operand is MethodCallExpression met)
                return VisitMethodCall(met);
    
            if (propertyAccess.Operand is ConditionalExpression cond)
                return Expression.Condition(
                        test: cond.Test,
                        ifTrue: MakeNullable(Visit(cond.IfTrue)),
                        ifFalse: MakeNullable(Visit(cond.IfFalse)));
    
            return base.VisitUnary(propertyAccess);
        }
    
        protected override Expression VisitMember(MemberExpression propertyAccess)
        {
            return Common(propertyAccess.Expression, propertyAccess);
        }
    
        protected override Expression VisitMethodCall(MethodCallExpression propertyAccess)
        {
            if (propertyAccess.Object == null)
                return base.VisitMethodCall(propertyAccess);
    
            return Common(propertyAccess.Object, propertyAccess);
        }
    
        private BlockExpression Common(Expression instance, Expression propertyAccess)
        {
            var safe = _recursive ? base.Visit(instance) : instance;
            var caller = Expression.Variable(safe.Type, "caller");
            var assign = Expression.Assign(caller, safe);
            var acess = MakeNullable(new ExpressionReplacer(instance,
                IsNullableStruct(instance) ? caller : RemoveNullable(caller)).Visit(propertyAccess));
            var ternary = Expression.Condition(
                        test: Expression.Equal(caller, Expression.Constant(null)),
                        ifTrue: Expression.Constant(null, acess.Type),
                        ifFalse: acess);
    
            return Expression.Block(
                type: acess.Type,
                variables: new[]
                {
                    caller,
                },
                expressions: new Expression[]
                {
                    assign,
                    ternary,
                });
        }
    
        private static Expression MakeNullable(Expression ex)
        {
            if (IsNullable(ex))
                return ex;
    
            return Expression.Convert(ex, typeof(Nullable<>).MakeGenericType(ex.Type));
        }
    
        private static bool IsNullable(Expression ex)
        {
            return !ex.Type.IsValueType || (Nullable.GetUnderlyingType(ex.Type) != null);
        }
    
        private static bool IsNullableStruct(Expression ex)
        {
            return ex.Type.IsValueType && (Nullable.GetUnderlyingType(ex.Type) != null);
        }
    
        private static Expression RemoveNullable(Expression ex)
        {
            if (IsNullableStruct(ex))
                return Expression.Convert(ex, ex.Type.GenericTypeArguments[0]);
    
            return ex;
        }
    
        private class ExpressionReplacer : ExpressionVisitor
        {
            private readonly Expression _oldEx;
            private readonly Expression _newEx;
    
            internal ExpressionReplacer(Expression oldEx, Expression newEx)
            {
                _oldEx = oldEx;
                _newEx = newEx;
            }
    
            public override Expression Visit(Expression node)
            {
                if (node == _oldEx)
                    return _newEx;
    
                return base.Visit(node);
            }
        }
    }
    

    It passes on the following tests:

    private static string Foo(string s) => s;
    
    static void Main(string[] _)
    {
        var visitor = new NullPropagationVisitor(recursive: true);
    
        Test1();
        Test2();
        Test3();
    
        void Test1()
        {
            Expression> f = s => s == "foo" ? 'X' : Foo(s).Length.ToString()[0];
    
            var fBody = (Expression>)visitor.Visit(f);
    
            var fFunc = fBody.Compile();
    
            Debug.Assert(fFunc(null) == null);
            Debug.Assert(fFunc("bar") == '3');
            Debug.Assert(fFunc("foo") == 'X');
        }
    
        void Test2()
        {
            Expression> y = s => s.Length;
    
            var yBody = visitor.Visit(y.Body);
            var yFunc = Expression.Lambda>(
                                        body: yBody,
                                        parameters: y.Parameters)
                                .Compile();
    
            Debug.Assert(yFunc(null) == null);
            Debug.Assert(yFunc("bar") == 3);
        }
    
        void Test3()
        {
            Expression> y = s => s.Value.ToString()[0].ToString();
    
            var yBody = visitor.Visit(y.Body);
            var yFunc = Expression.Lambda>(
                                        body: yBody,
                                        parameters: y.Parameters)
                                .Compile();
    
            Debug.Assert(yFunc(null) == null);
            Debug.Assert(yFunc('A') == "A");
        }
    }
    

提交回复
热议问题