Skip to content

Commit

Permalink
Query: Fix dotnet#3668 - Subqueries with parameters skip parameteriza…
Browse files Browse the repository at this point in the history
…tion

- When queryables are captures in the closure, during parameterization we need to eval the member expression and then visit the sub-expression.
  • Loading branch information
anpete committed Dec 10, 2015
1 parent f272e36 commit 0561ceb
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,37 +15,55 @@ namespace Microsoft.Data.Entity.Query.ExpressionVisitors.Internal
{
public class ParameterExtractingExpressionVisitor : ExpressionVisitor
{
private static readonly TypeInfo _queryableTypeInfo = typeof(IQueryable).GetTypeInfo();

public static Expression ExtractParameters(
[NotNull] Expression expression,
[NotNull] QueryContext queryContext,
[NotNull] IEvaluatableExpressionFilter evaluatableExpressionFilter,
[NotNull] ISensitiveDataLogger logger)
{
var partialEvaluationInfo
= EvaluatableTreeFindingExpressionVisitor
.Analyze(expression, evaluatableExpressionFilter);
var visitor = new ParameterExtractingExpressionVisitor(evaluatableExpressionFilter, queryContext, logger);

var visitor = new ParameterExtractingExpressionVisitor(partialEvaluationInfo, queryContext, logger);

return visitor.Visit(expression);
return visitor.ExtractParameters(expression);
}

private readonly PartialEvaluationInfo _partialEvaluationInfo;
private readonly IEvaluatableExpressionFilter _evaluatableExpressionFilter;
private readonly QueryContext _queryContext;
private readonly ISensitiveDataLogger _logger;

private PartialEvaluationInfo _partialEvaluationInfo;

private bool _inLambda;

private ParameterExtractingExpressionVisitor(
PartialEvaluationInfo partialEvaluationInfo,
IEvaluatableExpressionFilter evaluatableExpressionFilter,
QueryContext queryContext,
ISensitiveDataLogger logger)
{
_partialEvaluationInfo = partialEvaluationInfo;
_evaluatableExpressionFilter = evaluatableExpressionFilter;
_queryContext = queryContext;
_logger = logger;
}

public Expression ExtractParameters([NotNull] Expression expression)
{
var oldPartialEvaluationInfo = _partialEvaluationInfo;

_partialEvaluationInfo
= EvaluatableTreeFindingExpressionVisitor
.Analyze(expression, _evaluatableExpressionFilter);

try
{
return Visit(expression);
}
finally
{
_partialEvaluationInfo = oldPartialEvaluationInfo;
}
}

protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
{
var methodInfo = methodCallExpression.Method;
Expand Down Expand Up @@ -86,7 +104,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
newArguments[j] = arguments[j];
}
}

if (parameterInfos[i].GetCustomAttribute<NotParameterizedAttribute>() != null)
{
var parameter = newArgument as ParameterExpression;
Expand Down Expand Up @@ -114,16 +132,27 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
}

protected override Expression VisitMember(MemberExpression memberExpression)
=> !_partialEvaluationInfo.IsEvaluatableExpression(memberExpression)
? base.VisitMember(memberExpression)
: !typeof(IQueryable).GetTypeInfo().IsAssignableFrom(memberExpression.Type.GetTypeInfo())
? TryExtractParameter(memberExpression)
: memberExpression;
{
if (!_partialEvaluationInfo.IsEvaluatableExpression(memberExpression))
{
return base.VisitMember(memberExpression);
}

if (!_queryableTypeInfo.IsAssignableFrom(memberExpression.Type.GetTypeInfo()))
{
return TryExtractParameter(memberExpression);
}

string _;
var queryable = (IQueryable)Evaluate(memberExpression, out _);

return ExtractParameters(queryable.Expression);
}

protected override Expression VisitConstant(ConstantExpression constantExpression)
=> !_inLambda
&& _partialEvaluationInfo.IsEvaluatableExpression(constantExpression)
&& !typeof(IQueryable).GetTypeInfo().IsAssignableFrom(constantExpression.Type.GetTypeInfo())
&& !_queryableTypeInfo.IsAssignableFrom(constantExpression.Type.GetTypeInfo())
? TryExtractParameter(constantExpression)
: constantExpression;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -687,29 +687,35 @@ join t in tags on g.FullName equals t.Gear?.FullName
}
}

//[Fact]
[Fact]
public virtual void Collection_with_inheritance_and_join_include_joined()
{
using (var context = CreateContext())
{
var query = (from t in context.Tags
join g in context.Gears.OfType<Officer>() on new { id1 = t.GearSquadId, id2 = t.GearNickName } equals new { id1 = (int?)g.SquadId, id2 = g.Nickname }
join g in context.Gears.OfType<Officer>() on new { id1 = t.GearSquadId, id2 = t.GearNickName }
equals new { id1 = (int?)g.SquadId, id2 = g.Nickname }
select g).Include(g => g.Tag);

var result = query.ToList();

Assert.NotNull(result);
}
}

//[Fact]
[Fact]
public virtual void Collection_with_inheritance_and_join_include_source()
{
using (var context = CreateContext())
{
var query = (from g in context.Gears.OfType<Officer>()
join t in context.Tags on new { id1 = (int?)g.SquadId, id2 = g.Nickname } equals new { id1 = t.GearSquadId, id2 = t.GearNickName }
join t in context.Tags on new { id1 = (int?)g.SquadId, id2 = g.Nickname }
equals new { id1 = t.GearSquadId, id2 = t.GearNickName }
select g).Include(g => g.Tag);

var result = query.ToList();

Assert.NotNull(result);
}
}

Expand Down
23 changes: 23 additions & 0 deletions test/EntityFramework.Core.FunctionalTests/QueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,29 @@ public virtual void Where_simple_closure_via_query_cache_nullable_type_reverse()
entryCount: 5);
}

[ConditionalFact]
public virtual void Where_subquery_closure_via_query_cache()
{
using (var context = CreateContext())
{
string customerID = null;

var orders = context.Orders.Where(o => o.CustomerID == customerID);

customerID = "ALFKI";

var customers = context.Customers.Where(c => orders.Any(o => o.CustomerID == c.CustomerID)).ToList();

Assert.Equal(1, customers.Count);

customerID = "ANATR";

customers = context.Customers.Where(c => orders.Any(o => o.CustomerID == c.CustomerID)).ToList();

Assert.Equal("ANATR", customers.Single().CustomerID);
}
}

[ConditionalFact]
public virtual void Where_simple_shadow()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,14 @@ public override void Collection_with_inheritance_and_join_include_joined()
base.Collection_with_inheritance_and_join_include_joined();

Assert.Equal(
@"",
@"SELECT [t0].[Nickname], [t0].[SquadId], [t0].[AssignedCityName], [t0].[CityOrBirthName], [t0].[Discriminator], [t0].[FullName], [t0].[LeaderNickname], [t0].[LeaderSquadId], [t0].[Rank], [c].[Id], [c].[GearNickName], [c].[GearSquadId], [c].[Note]
FROM [CogTag] AS [t]
INNER JOIN (
SELECT [g].[Nickname], [g].[SquadId], [g].[AssignedCityName], [g].[CityOrBirthName], [g].[Discriminator], [g].[FullName], [g].[LeaderNickname], [g].[LeaderSquadId], [g].[Rank]
FROM [Gear] AS [g]
WHERE [g].[Discriminator] = 'Officer'
) AS [t0] ON ([t].[GearSquadId] = [t0].[SquadId]) AND ([t].[GearNickName] = [t0].[Nickname])
LEFT JOIN [CogTag] AS [c] ON ([c].[GearNickName] = [t0].[Nickname]) AND ([c].[GearSquadId] = [t0].[SquadId])",
Sql);
}

Expand All @@ -692,7 +699,14 @@ public override void Collection_with_inheritance_and_join_include_source()
base.Collection_with_inheritance_and_join_include_source();

Assert.Equal(
@"",
@"SELECT [t0].[Nickname], [t0].[SquadId], [t0].[AssignedCityName], [t0].[CityOrBirthName], [t0].[Discriminator], [t0].[FullName], [t0].[LeaderNickname], [t0].[LeaderSquadId], [t0].[Rank], [c].[Id], [c].[GearNickName], [c].[GearSquadId], [c].[Note]
FROM (
SELECT [g].[Nickname], [g].[SquadId], [g].[AssignedCityName], [g].[CityOrBirthName], [g].[Discriminator], [g].[FullName], [g].[LeaderNickname], [g].[LeaderSquadId], [g].[Rank]
FROM [Gear] AS [g]
WHERE [g].[Discriminator] = 'Officer'
) AS [t0]
INNER JOIN [CogTag] AS [t] ON ([t0].[SquadId] = [t].[GearSquadId]) AND ([t0].[Nickname] = [t].[GearNickName])
LEFT JOIN [CogTag] AS [c] ON ([c].[GearNickName] = [t0].[Nickname]) AND ([c].[GearSquadId] = [t0].[SquadId])",
Sql);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,41 @@ FROM [Customers] AS [c]
Sql);
}

public override void Where_subquery_closure_via_query_cache()
{
base.Where_subquery_closure_via_query_cache();

Assert.Equal(
@"@__customerID_0: ALFKI
SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
WHERE (
SELECT CASE
WHEN EXISTS (
SELECT 1
FROM [Orders] AS [o]
WHERE ([o].[CustomerID] = @__customerID_0) AND ([o].[CustomerID] = [c].[CustomerID]))
THEN CAST(1 AS BIT) ELSE CAST(0 AS BIT)
END
) = 1
@__customerID_0: ANATR
SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
WHERE (
SELECT CASE
WHEN EXISTS (
SELECT 1
FROM [Orders] AS [o]
WHERE ([o].[CustomerID] = @__customerID_0) AND ([o].[CustomerID] = [c].[CustomerID]))
THEN CAST(1 AS BIT) ELSE CAST(0 AS BIT)
END
) = 1",
Sql);
}

public override void Count_with_predicate()
{
base.Count_with_predicate();
Expand Down

0 comments on commit 0561ceb

Please sign in to comment.