Exemple #1
0
            private SqlExpression ExpandTogether(List <SqlExpression> exprs)
            {
                switch (exprs[0].NodeType)
                {
                case SqlNodeType.MethodCall: {
                    SqlMethodCall[] mcs = new SqlMethodCall[exprs.Count];
                    for (int i = 0; i < mcs.Length; ++i)
                    {
                        mcs[i] = (SqlMethodCall)exprs[i];
                    }

                    List <SqlExpression> expandedArgs = new List <SqlExpression>();

                    for (int i = 0; i < mcs[0].Arguments.Count; ++i)
                    {
                        List <SqlExpression> args = new List <SqlExpression>();
                        for (int j = 0; j < mcs.Length; ++j)
                        {
                            args.Add(mcs[j].Arguments[i]);
                        }
                        SqlExpression expanded = this.ExpandTogether(args);
                        expandedArgs.Add(expanded);
                    }
                    return(factory.MethodCall(mcs[0].Method, mcs[0].Object, expandedArgs.ToArray(), mcs[0].SourceExpression));
                }

                case SqlNodeType.ClientCase: {
                    // Are they all the same?
                    SqlClientCase[] scs = new SqlClientCase[exprs.Count];
                    scs[0] = (SqlClientCase)exprs[0];
                    for (int i = 1; i < scs.Length; ++i)
                    {
                        scs[i] = (SqlClientCase)exprs[i];
                    }

                    // Expand expressions together.
                    List <SqlExpression> expressions = new List <SqlExpression>();
                    for (int i = 0; i < scs.Length; ++i)
                    {
                        expressions.Add(scs[i].Expression);
                    }
                    SqlExpression expression = this.ExpandTogether(expressions);

                    // Expand individual expressions together.
                    List <SqlClientWhen> whens = new List <SqlClientWhen>();
                    for (int i = 0; i < scs[0].Whens.Count; ++i)
                    {
                        List <SqlExpression> scos = new List <SqlExpression>();
                        for (int j = 0; j < scs.Length; ++j)
                        {
                            SqlClientWhen when = scs[j].Whens[i];
                            scos.Add(when.Value);
                        }
                        whens.Add(new SqlClientWhen(scs[0].Whens[i].Match, this.ExpandTogether(scos)));
                    }

                    return(new SqlClientCase(scs[0].ClrType, expression, whens, scs[0].SourceExpression));
                }

                case SqlNodeType.TypeCase: {
                    // Are they all the same?
                    SqlTypeCase[] tcs = new SqlTypeCase[exprs.Count];
                    tcs[0] = (SqlTypeCase)exprs[0];
                    for (int i = 1; i < tcs.Length; ++i)
                    {
                        tcs[i] = (SqlTypeCase)exprs[i];
                    }

                    // Expand discriminators together.
                    List <SqlExpression> discriminators = new List <SqlExpression>();
                    for (int i = 0; i < tcs.Length; ++i)
                    {
                        discriminators.Add(tcs[i].Discriminator);
                    }
                    SqlExpression discriminator = this.ExpandTogether(discriminators);
                    // Write expanded discriminators back in.
                    for (int i = 0; i < tcs.Length; ++i)
                    {
                        tcs[i].Discriminator = discriminators[i];
                    }
                    // Expand individual type bindings together.
                    List <SqlTypeCaseWhen> whens = new List <SqlTypeCaseWhen>();
                    for (int i = 0; i < tcs[0].Whens.Count; ++i)
                    {
                        List <SqlExpression> scos = new List <SqlExpression>();
                        for (int j = 0; j < tcs.Length; ++j)
                        {
                            SqlTypeCaseWhen when = tcs[j].Whens[i];
                            scos.Add(when.TypeBinding);
                        }
                        SqlExpression expanded = this.ExpandTogether(scos);
                        whens.Add(new SqlTypeCaseWhen(tcs[0].Whens[i].Match, expanded));
                    }

                    return(factory.TypeCase(tcs[0].ClrType, tcs[0].RowType, discriminator, whens, tcs[0].SourceExpression));
                }

                case SqlNodeType.New: {
                    // first verify all are similar client objects...
                    SqlNew[] cobs = new SqlNew[exprs.Count];
                    cobs[0] = (SqlNew)exprs[0];
                    for (int i = 1, n = exprs.Count; i < n; i++)
                    {
                        if (exprs[i] == null || exprs[i].NodeType != SqlNodeType.New)
                        {
                            throw Error.UnionIncompatibleConstruction();
                        }
                        cobs[i] = (SqlNew)exprs[1];
                        if (cobs[i].Members.Count != cobs[0].Members.Count)
                        {
                            throw Error.UnionDifferentMembers();
                        }
                        for (int m = 0, mn = cobs[0].Members.Count; m < mn; m++)
                        {
                            if (cobs[i].Members[m].Member != cobs[0].Members[m].Member)
                            {
                                throw Error.UnionDifferentMemberOrder();
                            }
                        }
                    }
                    SqlMemberAssign[] bindings = new SqlMemberAssign[cobs[0].Members.Count];
                    for (int m = 0, mn = bindings.Length; m < mn; m++)
                    {
                        List <SqlExpression> mexprs = new List <SqlExpression>();
                        for (int i = 0, n = exprs.Count; i < n; i++)
                        {
                            mexprs.Add(cobs[i].Members[m].Expression);
                        }
                        bindings[m] = new SqlMemberAssign(cobs[0].Members[m].Member, this.ExpandTogether(mexprs));
                        for (int i = 0, n = exprs.Count; i < n; i++)
                        {
                            cobs[i].Members[m].Expression = mexprs[i];
                        }
                    }
                    SqlExpression[] arguments = new SqlExpression[cobs[0].Args.Count];
                    for (int m = 0, mn = arguments.Length; m < mn; ++m)
                    {
                        List <SqlExpression> mexprs = new List <SqlExpression>();
                        for (int i = 0, n = exprs.Count; i < n; i++)
                        {
                            mexprs.Add(cobs[i].Args[m]);
                        }
                        arguments[m] = ExpandTogether(mexprs);
                    }
                    return(factory.New(cobs[0].MetaType, cobs[0].Constructor, arguments, cobs[0].ArgMembers, bindings, exprs[0].SourceExpression));
                }

                case SqlNodeType.Link: {
                    SqlLink[] links = new SqlLink[exprs.Count];
                    links[0] = (SqlLink)exprs[0];
                    for (int i = 1, n = exprs.Count; i < n; i++)
                    {
                        if (exprs[i] == null || exprs[i].NodeType != SqlNodeType.Link)
                        {
                            throw Error.UnionIncompatibleConstruction();
                        }
                        links[i] = (SqlLink)exprs[i];
                        if (links[i].KeyExpressions.Count != links[0].KeyExpressions.Count ||
                            links[i].Member != links[0].Member ||
                            (links[i].Expansion != null) != (links[0].Expansion != null))
                        {
                            throw Error.UnionIncompatibleConstruction();
                        }
                    }
                    SqlExpression[]      kexprs = new SqlExpression[links[0].KeyExpressions.Count];
                    List <SqlExpression> lexprs = new List <SqlExpression>();
                    for (int k = 0, nk = links[0].KeyExpressions.Count; k < nk; k++)
                    {
                        lexprs.Clear();
                        for (int i = 0, n = exprs.Count; i < n; i++)
                        {
                            lexprs.Add(links[i].KeyExpressions[k]);
                        }
                        kexprs[k] = this.ExpandTogether(lexprs);
                        for (int i = 0, n = exprs.Count; i < n; i++)
                        {
                            links[i].KeyExpressions[k] = lexprs[i];
                        }
                    }
                    SqlExpression expansion = null;
                    if (links[0].Expansion != null)
                    {
                        lexprs.Clear();
                        for (int i = 0, n = exprs.Count; i < n; i++)
                        {
                            lexprs.Add(links[i].Expansion);
                        }
                        expansion = this.ExpandTogether(lexprs);
                        for (int i = 0, n = exprs.Count; i < n; i++)
                        {
                            links[i].Expansion = lexprs[i];
                        }
                    }
                    return(new SqlLink(links[0].Id, links[0].RowType, links[0].ClrType, links[0].SqlType, links[0].Expression, links[0].Member, kexprs, expansion, links[0].SourceExpression));
                }

                case SqlNodeType.Value: {
                    /*
                     * ExprSet of all literals of the same value reduce to just a single literal.
                     */
                    SqlValue val0 = (SqlValue)exprs[0];
                    for (int i = 1; i < exprs.Count; ++i)
                    {
                        SqlValue val = (SqlValue)exprs[i];
                        if (!object.Equals(val.Value, val0.Value))
                        {
                            return(this.ExpandIntoExprSet(exprs));
                        }
                    }
                    return(val0);
                }

                case SqlNodeType.OptionalValue: {
                    if (exprs[0].SqlType.CanBeColumn)
                    {
                        goto default;
                    }
                    List <SqlExpression> hvals = new List <SqlExpression>(exprs.Count);
                    List <SqlExpression> vals  = new List <SqlExpression>(exprs.Count);
                    for (int i = 0, n = exprs.Count; i < n; i++)
                    {
                        if (exprs[i] == null || exprs[i].NodeType != SqlNodeType.OptionalValue)
                        {
                            throw Error.UnionIncompatibleConstruction();
                        }
                        SqlOptionalValue sov = (SqlOptionalValue)exprs[i];
                        hvals.Add(sov.HasValue);
                        vals.Add(sov.Value);
                    }
                    return(new SqlOptionalValue(this.ExpandTogether(hvals), this.ExpandTogether(vals)));
                }

                case SqlNodeType.OuterJoinedValue: {
                    if (exprs[0].SqlType.CanBeColumn)
                    {
                        goto default;
                    }
                    List <SqlExpression> values = new List <SqlExpression>(exprs.Count);
                    for (int i = 0, n = exprs.Count; i < n; i++)
                    {
                        if (exprs[i] == null || exprs[i].NodeType != SqlNodeType.OuterJoinedValue)
                        {
                            throw Error.UnionIncompatibleConstruction();
                        }
                        SqlUnary su = (SqlUnary)exprs[i];
                        values.Add(su.Operand);
                    }
                    return(factory.Unary(SqlNodeType.OuterJoinedValue, this.ExpandTogether(values)));
                }

                case SqlNodeType.DiscriminatedType: {
                    SqlDiscriminatedType sdt0 = (SqlDiscriminatedType)exprs[0];
                    List <SqlExpression> foos = new List <SqlExpression>(exprs.Count);
                    foos.Add(sdt0.Discriminator);
                    for (int i = 1, n = exprs.Count; i < n; i++)
                    {
                        SqlDiscriminatedType sdtN = (SqlDiscriminatedType)exprs[i];
                        if (sdtN.TargetType != sdt0.TargetType)
                        {
                            throw Error.UnionIncompatibleConstruction();
                        }
                        foos.Add(sdtN.Discriminator);
                    }
                    return(factory.DiscriminatedType(this.ExpandTogether(foos), ((SqlDiscriminatedType)exprs[0]).TargetType));
                }

                case SqlNodeType.ClientQuery:
                case SqlNodeType.Multiset:
                case SqlNodeType.Element:
                case SqlNodeType.Grouping:
                    throw Error.UnionWithHierarchy();

                default:
                    return(this.ExpandIntoExprSet(exprs));
                }
            }
            internal override SqlExpression VisitMethodCall(SqlMethodCall mc)
            {
                mc.Object = this.VisitExpression(mc.Object);
                for (int i = 0, n = mc.Arguments.Count; i < n; i++)
                {
                    mc.Arguments[i] = this.VisitExpression(mc.Arguments[i]);
                }
                if (mc.Method.IsStatic)
                {
                    if (mc.Method.Name == "Equals" && mc.Arguments.Count == 2)
                    {
                        return(sql.Binary(SqlNodeType.EQ2V, mc.Arguments[0], mc.Arguments[1], mc.Method));
                    }
                    else if (mc.Method.DeclaringType == typeof(string) && mc.Method.Name == "Concat")
                    {
                        SqlClientArray       arr   = mc.Arguments[0] as SqlClientArray;
                        List <SqlExpression> exprs = null;
                        if (arr != null)
                        {
                            exprs = arr.Expressions;
                        }
                        else
                        {
                            exprs = mc.Arguments;
                        }
                        if (exprs.Count == 0)
                        {
                            return(sql.ValueFromObject("", false, mc.SourceExpression));
                        }
                        else
                        {
                            SqlExpression sum;
                            if (exprs[0].SqlType.IsString || exprs[0].SqlType.IsChar)
                            {
                                sum = exprs[0];
                            }
                            else
                            {
                                sum = sql.ConvertTo(typeof(string), exprs[0]);
                            }
                            for (int i = 1; i < exprs.Count; i++)
                            {
                                if (exprs[i].SqlType.IsString || exprs[i].SqlType.IsChar)
                                {
                                    sum = sql.Concat(sum, exprs[i]);
                                }
                                else
                                {
                                    sum = sql.Concat(sum, sql.ConvertTo(typeof(string), exprs[i]));
                                }
                            }
                            return(sum);
                        }
                    }
                    else if (IsVbIIF(mc))
                    {
                        return(TranslateVbIIF(mc));
                    }
                    else
                    {
                        switch (mc.Method.Name)
                        {
                        case "op_Equality":
                            return(sql.Binary(SqlNodeType.EQ, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType));

                        case "op_Inequality":
                            return(sql.Binary(SqlNodeType.NE, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType));

                        case "op_LessThan":
                            return(sql.Binary(SqlNodeType.LT, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType));

                        case "op_LessThanOrEqual":
                            return(sql.Binary(SqlNodeType.LE, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType));

                        case "op_GreaterThan":
                            return(sql.Binary(SqlNodeType.GT, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType));

                        case "op_GreaterThanOrEqual":
                            return(sql.Binary(SqlNodeType.GE, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType));

                        case "op_Multiply":
                            return(sql.Binary(SqlNodeType.Mul, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType));

                        case "op_Division":
                            return(sql.Binary(SqlNodeType.Div, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType));

                        case "op_Subtraction":
                            return(sql.Binary(SqlNodeType.Sub, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType));

                        case "op_Addition":
                            return(sql.Binary(SqlNodeType.Add, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType));

                        case "op_Modulus":
                            return(sql.Binary(SqlNodeType.Mod, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType));

                        case "op_BitwiseAnd":
                            return(sql.Binary(SqlNodeType.BitAnd, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType));

                        case "op_BitwiseOr":
                            return(sql.Binary(SqlNodeType.BitOr, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType));

                        case "op_ExclusiveOr":
                            return(sql.Binary(SqlNodeType.BitXor, mc.Arguments[0], mc.Arguments[1], mc.Method, mc.ClrType));

                        case "op_UnaryNegation":
                            return(sql.Unary(SqlNodeType.Negate, mc.Arguments[0], mc.Method, mc.SourceExpression));

                        case "op_OnesComplement":
                            return(sql.Unary(SqlNodeType.BitNot, mc.Arguments[0], mc.Method, mc.SourceExpression));

                        case "op_False":
                            return(sql.Unary(SqlNodeType.Not, mc.Arguments[0], mc.Method, mc.SourceExpression));
                        }
                    }
                }
                else
                {
                    if (mc.Method.Name == "Equals" && mc.Arguments.Count == 1)
                    {
                        return(sql.Binary(SqlNodeType.EQ, mc.Object, mc.Arguments[0]));
                    }
                    else if (mc.Method.Name == "GetType" && mc.Arguments.Count == 0)
                    {
                        MetaType mt = TypeSource.GetSourceMetaType(mc.Object, this.model);
                        if (mt.HasInheritance)
                        {
                            Type discriminatorType             = mt.Discriminator.Type;
                            SqlDiscriminatorOf discriminatorOf = new SqlDiscriminatorOf(mc.Object, discriminatorType, this.sql.TypeProvider.From(discriminatorType), mc.SourceExpression);
                            return(this.VisitExpression(sql.DiscriminatedType(discriminatorOf, mt)));
                        }
                        return(this.VisitExpression(sql.StaticType(mt, mc.SourceExpression)));
                    }
                }
                return(mc);
            }