Inheritance: SqlExpressionVisitor
		public static List<Expression> FindAll(Expression expression, Predicate<Expression> isMatch)
		{
			var finder = new SqlExpressionFinder(isMatch, false);

			finder.Visit(expression);

			return finder.results;
		}
		public static Expression FindFirst(Expression expression, Predicate<Expression> isMatch)
		{
			var finder = new SqlExpressionFinder(isMatch, true);

			finder.Visit(expression);

			return finder.results.FirstOrDefault();
		}
		public static bool FindExists(Expression expression, Predicate<Expression> isMatch)
		{
			var finder = new SqlExpressionFinder(isMatch, true);

			finder.Visit(expression);

			return finder.results.Count > 0;
		}
        protected override Expression VisitSelect(SqlSelectExpression select)
        {
            var saveIsOuterMostSelect = this.isOuterMostSelect;

            try
            {
                this.isOuterMostSelect = false;

                select = (SqlSelectExpression)base.VisitSelect(select);

                var hasOrderBy = select.OrderBy != null && select.OrderBy.Count > 0;

                if (hasOrderBy)
                {
                    this.PrependOrderings(select.OrderBy.Select(c => (SqlOrderByExpression)c));
                }

                var canHaveOrderBy     = saveIsOuterMostSelect && !SqlExpressionFinder.FindExists(select, c => c.NodeType == (ExpressionType)SqlExpressionType.Aggregate || c.NodeType == (ExpressionType)SqlExpressionType.AggregateSubquery);
                var canPassOnOrderings = !saveIsOuterMostSelect;

                var columns = select.Columns;
                IEnumerable <Expression> orderings = (canHaveOrderBy) ? this.gatheredOrderings : null;

                if (this.gatheredOrderings != null)
                {
                    if (canPassOnOrderings)
                    {
                        var producedAliases = AliasesProduced.Gather(select.From);
                        var project         = this.RebindOrderings(this.gatheredOrderings, select.Alias, producedAliases, select.Columns);

                        this.gatheredOrderings = project.Orderings;

                        columns = project.Columns;
                    }
                    else
                    {
                        this.gatheredOrderings = null;
                    }
                }

                if (orderings != select.OrderBy || columns != select.Columns)
                {
                    select = new SqlSelectExpression(select.Type, select.Alias, columns, select.From, select.Where, orderings, select.GroupBy, select.Distinct, select.Skip, select.Take, select.ForUpdate);
                }

                return(select);
            }
            finally
            {
                this.isOuterMostSelect = saveIsOuterMostSelect;
            }
        }
Beispiel #5
0
        protected override Expression VisitSelect(SqlSelectExpression selectExpression)
        {
            if (this.currentProjection == null || this.currentProjection.Select != selectExpression)
            {
                return(base.VisitSelect(selectExpression));
            }

            var objectType = selectExpression.Type.GetSequenceElementType();

            if (!objectType.IsDataAccessObjectType())
            {
                return(base.VisitSelect(selectExpression));
            }

            var aliasesAndTypes = SqlAliasTypeCollector.Collect(selectExpression)
                                  .ToDictionary(c => c.Item1, c => typeDescriptorProvider.GetTypeDescriptor(c.Item2.GetSequenceElementType() ?? c.Item2));

            List <SqlOrderByExpression> orderBys = null;
            var includeJoins = selectExpression.From.GetIncludeJoins().ToList();
            List <SqlColumnExpression> leftMostColumns = null;

            foreach (var includeJoin in includeJoins)
            {
                var equalsExpression = (BinaryExpression)SqlExpressionFinder.FindFirst(includeJoin, c => c.NodeType == ExpressionType.Equal);

                var left  = (SqlColumnExpression)equalsExpression.Left;
                var right = (SqlColumnExpression)equalsExpression.Right;

                var leftType  = aliasesAndTypes[left.SelectAlias];
                var rightType = aliasesAndTypes[right.SelectAlias];

                if (leftMostColumns == null)
                {
                    var typeDescriptor    = this.typeDescriptorProvider.GetTypeDescriptor(objectType);
                    var primaryKeyColumns = new HashSet <string>(QueryBinder.GetPrimaryKeyColumnInfos(this.typeDescriptorProvider, typeDescriptor).Select(c => c.ColumnName));

                    leftMostColumns = primaryKeyColumns
                                      .Select(c => new SqlColumnExpression(objectType, left.SelectAlias, c))
                                      .ToList();
                }

                var rightProperty = rightType.GetPropertyDescriptorByColumnName(right.Name);
                var leftProperty  = rightProperty.RelationshipInfo?.TargetProperty ?? leftType.GetPropertyDescriptorByColumnName(left.Name);

                if (leftProperty.PropertyType.GetGenericTypeDefinitionOrNull() == typeof(RelatedDataAccessObjects <>))
                {
                    var rightColumns = SqlExpressionFinder.FindAll(includeJoin, c => c.NodeType == (ExpressionType)SqlExpressionType.Column && ((SqlColumnExpression)c).SelectAlias == right.SelectAlias);
                    var leftColumns  = SqlExpressionFinder.FindAll(includeJoin, c => c.NodeType == (ExpressionType)SqlExpressionType.Column && ((SqlColumnExpression)c).SelectAlias == left.SelectAlias);

                    if (orderBys == null)
                    {
                        orderBys = new List <SqlOrderByExpression>();

                        if (selectExpression.OrderBy?.Count > 0)
                        {
                            orderBys.AddRange(selectExpression.OrderBy);
                            orderBys.AddRange(leftMostColumns.Select(c => new SqlOrderByExpression(OrderType.Ascending, c)));
                        }
                    }

                    orderBys.AddRange(rightColumns.Select(c => new SqlOrderByExpression(OrderType.Ascending, c)));
                    orderBys.AddRange(leftColumns.Select(c => new SqlOrderByExpression(OrderType.Ascending, c)));
                }
            }

            return(selectExpression.ChangeOrderBy(orderBys?.Distinct(SqlExpressionEqualityComparer <SqlOrderByExpression> .Default) ?? selectExpression.OrderBy));
        }
        protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
        {
            if (!(methodCallExpression.Method.DeclaringType == typeof(Queryable) ||
                  methodCallExpression.Method.DeclaringType == typeof(Enumerable) ||
                  methodCallExpression.Method.DeclaringType == typeof(QueryableExtensions)))
            {
                return(base.VisitMethodCall(methodCallExpression));
            }

            var saveInsideSkipTake = this.insideSkipTake;

            if (methodCallExpression.Method.Name == "Skip" || methodCallExpression.Method.Name == "Take")
            {
                this.insideSkipTake = true;

                try
                {
                    var retval = base.VisitMethodCall(methodCallExpression);

                    if (!saveInsideSkipTake)
                    {
                        foreach (var includeSelector in this.includeSelectors)
                        {
                            retval = Expression.Call(MethodInfoFastRef.QueryableExtensionsIncludeMethod.MakeGenericMethod(retval.Type.GetGenericArguments()[0], includeSelector.Body.Type), retval, includeSelector);
                        }
                    }

                    return(retval);
                }
                finally
                {
                    this.insideSkipTake = saveInsideSkipTake;
                }
            }
            else if (methodCallExpression.Method.Name == "Select" && this.insideSkipTake)
            {
                var lambda = methodCallExpression.Arguments[1].StripQuotes();

                var saveSelector = this.selector;

                if (this.selector != null)
                {
                    var body = SqlExpressionReplacer.Replace(this.selector.Body, this.selector.Parameters[0], lambda.Body);

                    this.selector = Expression.Lambda(body, lambda.Parameters[0]);
                }
                else
                {
                    this.selector = lambda;
                }

                try
                {
                    return(base.VisitMethodCall(methodCallExpression));
                }
                finally
                {
                    this.selector = saveSelector;
                }
            }
            else if (methodCallExpression.Method.Name == "Include" && this.insideSkipTake)
            {
                var source = Visit(methodCallExpression.Arguments[0]);

                List <MemberInfo> path;

                if (this.selector == null)
                {
                    path = new List <MemberInfo>();
                }
                else
                {
                    path = ParameterPathFinder.Find(this.selector);
                }

                if (path == null)
                {
                    ParameterPathFinder.Find(this.selector);

                    return(source);
                }

                bool IsRelatedObjectsMemberExpression(Expression expression)
                {
                    return(expression is MemberExpression memberExpression &&
                           memberExpression.Member.GetMemberReturnType().IsGenericType &&
                           memberExpression.Member.GetMemberReturnType().GetGenericTypeDefinition() == typeof(RelatedDataAccessObjects <>));
                }

                var includeSelector = methodCallExpression.Arguments[1].StripQuotes();
                var includesRelatedDataAccessObjects = SqlExpressionFinder.FindExists(includeSelector.Body, IsRelatedObjectsMemberExpression);

                if (includesRelatedDataAccessObjects)
                {
                    // Create a new include selector adjusting for any additional member accesses necessary because of select projections

                    var oldParam = includeSelector.Parameters[0];
                    var oldBody  = includeSelector.Body;

                    var newParam = path.Count > 0 ? Expression.Parameter(path.First().DeclaringType) : oldParam;

                    var replacement = path.Aggregate((Expression)newParam, Expression.MakeMemberAccess, c => c);

                    var newBody = SqlExpressionReplacer.Replace(oldBody, oldParam, replacement);

                    this.includeSelectors.Add(Expression.Lambda(newBody, newParam));
                }

                return(source);
            }

            return(base.VisitMethodCall(methodCallExpression));
        }