From 43c64d5c11cc7d68fcf9dcab59f517adad0fe658 Mon Sep 17 00:00:00 2001 From: Andrew Peters Date: Mon, 28 Nov 2016 15:38:59 -0800 Subject: [PATCH] Fix #7115 - DB query is executed with the parameters of the previous query - Fixes an issue in ExpressionEqualityComparer where we would incorrectly determine equality for constant EnumerableQuery nodes. --- .../QueryTestBase.cs | 18 ++++++++++++++++++ .../ParameterExtractingExpressionVisitor.cs | 14 +++++++------- .../Internal/ExpressionEqualityComparer.cs | 6 ++++++ 3 files changed, 31 insertions(+), 7 deletions(-) diff --git a/src/Microsoft.EntityFrameworkCore.Specification.Tests/QueryTestBase.cs b/src/Microsoft.EntityFrameworkCore.Specification.Tests/QueryTestBase.cs index 0c01009a17f..6fb85d6ccd1 100644 --- a/src/Microsoft.EntityFrameworkCore.Specification.Tests/QueryTestBase.cs +++ b/src/Microsoft.EntityFrameworkCore.Specification.Tests/QueryTestBase.cs @@ -42,6 +42,24 @@ public virtual void Local_array() cs.Single(c => c.CustomerID == (string)context.Arguments["customerId"])); } + [ConditionalFact] + public virtual void Method_with_constant_queryable_arg() + { + using (var context = CreateContext()) + { + var count = QueryableArgQuery(context, new [] { "ALFKI" }.AsQueryable()).Count(); + + Assert.Equal(1, count); + + count = QueryableArgQuery(context, new [] { "FOO" }.AsQueryable()).Count(); + + Assert.Equal(0, count); + } + } + + private static IQueryable QueryableArgQuery(NorthwindContext context, IQueryable ids) + => context.Customers.Where(c => ids.Contains(c.CustomerID)); + [ConditionalFact] public void Query_composition_against_ienumerable_set() { diff --git a/src/Microsoft.EntityFrameworkCore/Query/ExpressionVisitors/Internal/ParameterExtractingExpressionVisitor.cs b/src/Microsoft.EntityFrameworkCore/Query/ExpressionVisitors/Internal/ParameterExtractingExpressionVisitor.cs index e34104e5c17..beb4f307ed5 100644 --- a/src/Microsoft.EntityFrameworkCore/Query/ExpressionVisitors/Internal/ParameterExtractingExpressionVisitor.cs +++ b/src/Microsoft.EntityFrameworkCore/Query/ExpressionVisitors/Internal/ParameterExtractingExpressionVisitor.cs @@ -91,9 +91,9 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp } if (declaringType == typeof(Queryable) - || (declaringType == typeof(EntityFrameworkQueryableExtensions) - && (!methodInfo.IsGenericMethod - || methodInfo.GetGenericMethodDefinition() != EntityFrameworkQueryableExtensions.StringIncludeMethodInfo))) + || declaringType == typeof(EntityFrameworkQueryableExtensions) + && (!methodInfo.IsGenericMethod + || methodInfo.GetGenericMethodDefinition() != EntityFrameworkQueryableExtensions.StringIncludeMethodInfo)) { return base.VisitMethodCall(methodCallExpression); } @@ -261,8 +261,8 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) if (leftConstantExpression != null) { var constantValue = (bool)leftConstantExpression.Value; - if ((constantValue && binaryExpression.NodeType == ExpressionType.OrElse) - || (!constantValue && binaryExpression.NodeType == ExpressionType.AndAlso)) + if (constantValue && binaryExpression.NodeType == ExpressionType.OrElse + || !constantValue && binaryExpression.NodeType == ExpressionType.AndAlso) { return newLeftExpression; } @@ -274,8 +274,8 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) if (rightConstantExpression != null) { var constantValue = (bool)rightConstantExpression.Value; - if ((constantValue && binaryExpression.NodeType == ExpressionType.OrElse) - || (!constantValue && binaryExpression.NodeType == ExpressionType.AndAlso)) + if (constantValue && binaryExpression.NodeType == ExpressionType.OrElse + || !constantValue && binaryExpression.NodeType == ExpressionType.AndAlso) { return newRightExpression; } diff --git a/src/Microsoft.EntityFrameworkCore/Query/Internal/ExpressionEqualityComparer.cs b/src/Microsoft.EntityFrameworkCore/Query/Internal/ExpressionEqualityComparer.cs index 1a86695dab2..65919184140 100644 --- a/src/Microsoft.EntityFrameworkCore/Query/Internal/ExpressionEqualityComparer.cs +++ b/src/Microsoft.EntityFrameworkCore/Query/Internal/ExpressionEqualityComparer.cs @@ -420,6 +420,12 @@ private static bool CompareConstant(ConstantExpression a, ConstantExpression b) return false; } + if (a.Value is EnumerableQuery + && b.Value is EnumerableQuery) + { + return false; // EnumerableQueries are opaque + } + if (a.Value is IQueryable && b.Value is IQueryable && a.Value.GetType() == b.Value.GetType())