Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes NRE when building CTE. CTE refactoring. #2076

Merged
merged 1 commit into from
Feb 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 30 additions & 11 deletions Source/LinqToDB/Linq/Builder/TableBuilder.CteTableContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,23 @@ static string GetColumnFriendlyAlias(SqlColumn column)
return alias;
}

void UpdateMissingFields()
{
// Collecting missed fields which has field in query. Should never happen.

if (_cteQueryContext != null)
{
for (int i = 0; i < _cte.Fields.Length; i++)
{
if (_cte.Fields[i] == null)
{
var column = _cte.Body.Select.Columns[i];
_cte.Fields[i] = new SqlField { Name = column.Alias, PhysicalName = column.Alias };
}
}
}
}

public override int ConvertToParentIndex(int index, IBuildContext context)
{
if (context == _cteQueryContext)
Expand All @@ -196,6 +213,8 @@ public override int ConvertToParentIndex(int index, IBuildContext context)
var field = RegisterCteField(null, queryColumn, index, alias);

index = SelectQuery.Select.Add(field);

UpdateMissingFields();
}

return base.ConvertToParentIndex(index, context);
Expand All @@ -205,20 +224,20 @@ SqlField RegisterCteField(ISqlExpression baseExpression, [NotNull] ISqlExpressio
{
if (expression == null) throw new ArgumentNullException(nameof(expression));

var cteField = _cte.RegisterFieldMapping(baseExpression, expression, index, () =>
{
var f = QueryHelper.GetUnderlyingField(baseExpression ?? expression);
var cteField = _cte.RegisterFieldMapping(index, () =>
{
var f = QueryHelper.GetUnderlyingField(baseExpression ?? expression);

var newField = f == null
? new SqlField { SystemType = expression.SystemType, CanBeNull = expression.CanBeNull, Name = alias }
: new SqlField(f);
var newField = f == null
? new SqlField { SystemType = expression.SystemType, CanBeNull = expression.CanBeNull, Name = alias }
: new SqlField(f);

if (alias != null)
newField.Name = alias;
if (alias != null)
newField.Name = alias;

newField.PhysicalName = newField.Name;
return newField;
});
newField.PhysicalName = newField.Name;
return newField;
});

if (!SqlTable.Fields.TryGetValue(cteField.Name, out var field))
{
Expand Down
22 changes: 1 addition & 21 deletions Source/LinqToDB/SqlQuery/CteClause.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@ namespace LinqToDB.SqlQuery
[DebuggerDisplay("CTE({CteID}, {Name})")]
public class CteClause : IQueryElement, ICloneableElement, ISqlExpressionWalkable
{
Dictionary<ISqlExpression, SqlField> FieldsByExpression { get; } = new Dictionary<ISqlExpression, SqlField>();
Dictionary<string, SqlField> FieldByName { get; } = new Dictionary<string, SqlField>();

SqlField[] _fields = new SqlField[0];

public static int CteIDCounter;
Expand Down Expand Up @@ -97,18 +94,11 @@ public ISqlExpression Walk(WalkOptions options, Func<ISqlExpression,ISqlExpressi
return null;
}

public SqlField RegisterFieldMapping(ISqlExpression baseExpression, ISqlExpression expression, int index, Func<SqlField> fieldFactory)
public SqlField RegisterFieldMapping(int index, Func<SqlField> fieldFactory)
{
if (Fields.Length > index && Fields[index] != null)
return Fields[index];

if (expression != null && FieldsByExpression.TryGetValue(expression, out var value))
return value;

var baseField = baseExpression as SqlField;
if (baseField != null && FieldByName.TryGetValue(baseField.Name, out value))
return value;

var newField = fieldFactory();

Utils.MakeUniqueNames(new[] { newField }, Fields.Where(f => f != null).Select(t => t.Name), f => f.Name, (f, n) =>
Expand All @@ -122,17 +112,7 @@ public SqlField RegisterFieldMapping(ISqlExpression baseExpression, ISqlExpressi

Fields[index] = newField;

if (expression != null && !FieldsByExpression.ContainsKey(expression))
FieldsByExpression.Add(expression, newField);
if (baseField != null)
FieldByName.Add(baseField.Name, newField);
else
{
if (expression is SqlField field && !FieldByName.ContainsKey(field.Name))
FieldByName.Add(field.Name, newField);
}
return newField;

}
}
}