Skip to content

Commit

Permalink
Merge pull request #2076 from linq2db/regression/cte_error
Browse files Browse the repository at this point in the history
Fixes NRE when building CTE. CTE refactoring.
  • Loading branch information
MaceWindu authored Feb 5, 2020
2 parents 1e66a48 + 882de84 commit bd1bd26
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 32 deletions.
41 changes: 30 additions & 11 deletions Source/LinqToDB/Linq/Builder/TableBuilder.CteTableContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,23 @@ static string GetColumnFriendlyAlias(SqlColumn column)
return alias;
}

void UpdateMissingFields()
{
// Collecting missed fields which has field in query. Should never happen.

if (_cteQueryContext != null)
{
for (int i = 0; i < _cte.Fields.Length; i++)
{
if (_cte.Fields[i] == null)
{
var column = _cte.Body.Select.Columns[i];
_cte.Fields[i] = new SqlField { Name = column.Alias, PhysicalName = column.Alias };
}
}
}
}

public override int ConvertToParentIndex(int index, IBuildContext context)
{
if (context == _cteQueryContext)
Expand All @@ -196,6 +213,8 @@ public override int ConvertToParentIndex(int index, IBuildContext context)
var field = RegisterCteField(null, queryColumn, index, alias);

index = SelectQuery.Select.Add(field);

UpdateMissingFields();
}

return base.ConvertToParentIndex(index, context);
Expand All @@ -205,20 +224,20 @@ SqlField RegisterCteField(ISqlExpression baseExpression, [NotNull] ISqlExpressio
{
if (expression == null) throw new ArgumentNullException(nameof(expression));

var cteField = _cte.RegisterFieldMapping(baseExpression, expression, index, () =>
{
var f = QueryHelper.GetUnderlyingField(baseExpression ?? expression);
var cteField = _cte.RegisterFieldMapping(index, () =>
{
var f = QueryHelper.GetUnderlyingField(baseExpression ?? expression);

var newField = f == null
? new SqlField { SystemType = expression.SystemType, CanBeNull = expression.CanBeNull, Name = alias }
: new SqlField(f);
var newField = f == null
? new SqlField { SystemType = expression.SystemType, CanBeNull = expression.CanBeNull, Name = alias }
: new SqlField(f);

if (alias != null)
newField.Name = alias;
if (alias != null)
newField.Name = alias;

newField.PhysicalName = newField.Name;
return newField;
});
newField.PhysicalName = newField.Name;
return newField;
});

if (!SqlTable.Fields.TryGetValue(cteField.Name, out var field))
{
Expand Down
22 changes: 1 addition & 21 deletions Source/LinqToDB/SqlQuery/CteClause.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@ namespace LinqToDB.SqlQuery
[DebuggerDisplay("CTE({CteID}, {Name})")]
public class CteClause : IQueryElement, ICloneableElement, ISqlExpressionWalkable
{
Dictionary<ISqlExpression, SqlField> FieldsByExpression { get; } = new Dictionary<ISqlExpression, SqlField>();
Dictionary<string, SqlField> FieldByName { get; } = new Dictionary<string, SqlField>();

SqlField[] _fields = new SqlField[0];

public static int CteIDCounter;
Expand Down Expand Up @@ -97,18 +94,11 @@ public ISqlExpression Walk(WalkOptions options, Func<ISqlExpression,ISqlExpressi
return null;
}

public SqlField RegisterFieldMapping(ISqlExpression baseExpression, ISqlExpression expression, int index, Func<SqlField> fieldFactory)
public SqlField RegisterFieldMapping(int index, Func<SqlField> fieldFactory)
{
if (Fields.Length > index && Fields[index] != null)
return Fields[index];

if (expression != null && FieldsByExpression.TryGetValue(expression, out var value))
return value;

var baseField = baseExpression as SqlField;
if (baseField != null && FieldByName.TryGetValue(baseField.Name, out value))
return value;

var newField = fieldFactory();

Utils.MakeUniqueNames(new[] { newField }, Fields.Where(f => f != null).Select(t => t.Name), f => f.Name, (f, n) =>
Expand All @@ -122,17 +112,7 @@ public SqlField RegisterFieldMapping(ISqlExpression baseExpression, ISqlExpressi

Fields[index] = newField;

if (expression != null && !FieldsByExpression.ContainsKey(expression))
FieldsByExpression.Add(expression, newField);
if (baseField != null)
FieldByName.Add(baseField.Name, newField);
else
{
if (expression is SqlField field && !FieldByName.ContainsKey(field.Name))
FieldByName.Add(field.Name, newField);
}
return newField;

}
}
}

0 comments on commit bd1bd26

Please sign in to comment.