コード例 #1
0
 public static AffineColumnFunction Create(ModelLoadContext ctx, IHost host, ColumnType typeSrc)
 {
     Contracts.CheckValue(host, nameof(host));
     if (typeSrc.IsNumber)
     {
         if (typeSrc == NumberType.R4)
         {
             return(Sng.ImplOne.Create(ctx, host, typeSrc));
         }
         if (typeSrc == NumberType.R8)
         {
             return(Dbl.ImplOne.Create(ctx, host, typeSrc));
         }
     }
     else if (typeSrc.ItemType.IsNumber)
     {
         if (typeSrc.ItemType == NumberType.R4)
         {
             return(Sng.ImplVec.Create(ctx, host, typeSrc));
         }
         if (typeSrc.ItemType == NumberType.R8)
         {
             return(Dbl.ImplVec.Create(ctx, host, typeSrc));
         }
     }
     throw host.ExceptUserArg(nameof(AffineArgumentsBase.Column), "Wrong column type. Expected: R4, R8, Vec<R4, n> or Vec<R8, n>. Got: {0}.", typeSrc.ToString());
 }
コード例 #2
0
            public static IColumnFunctionBuilder CreateBuilder(LogMeanVarArguments args, IHost host,
                                                               int icol, int srcIndex, ColumnType srcType, IRowCursor cursor)
            {
                Contracts.AssertValue(host);
                host.AssertValue(args);

                if (srcType.IsNumber)
                {
                    if (srcType == NumberType.R4)
                    {
                        return(Sng.MeanVarOneColumnFunctionBuilder.Create(args, host, icol, srcType, cursor.GetGetter <Single>(srcIndex)));
                    }
                    if (srcType == NumberType.R8)
                    {
                        return(Dbl.MeanVarOneColumnFunctionBuilder.Create(args, host, icol, srcType, cursor.GetGetter <Double>(srcIndex)));
                    }
                }
                if (srcType.IsVector && srcType.ItemType.IsNumber)
                {
                    if (srcType.ItemType == NumberType.R4)
                    {
                        return(Sng.MeanVarVecColumnFunctionBuilder.Create(args, host, icol, srcType, cursor.GetGetter <VBuffer <Single> >(srcIndex)));
                    }
                    if (srcType.ItemType == NumberType.R8)
                    {
                        return(Dbl.MeanVarVecColumnFunctionBuilder.Create(args, host, icol, srcType, cursor.GetGetter <VBuffer <Double> >(srcIndex)));
                    }
                }
                throw host.ExceptUserArg(nameof(args.Column), "Wrong column type for column {0}. Expected: R4, R8, Vec<R4, n> or Vec<R8, n>. Got: {1}.", args.Column[icol].Source, srcType.ToString());
            }
コード例 #3
0
            public static IColumnFunctionBuilder CreateBuilder(SupervisedBinArguments args, IHost host,
                                                               int icol, int srcIndex, ColumnType srcType, IRowCursor cursor)
            {
                Contracts.AssertValue(host);
                host.AssertValue(args);

                // checking for label column
                host.CheckUserArg(!string.IsNullOrWhiteSpace(args.LabelColumn), nameof(args.LabelColumn), "Must specify the label column name");
                int labelColumnId   = GetLabelColumnId(host, cursor.Schema, args.LabelColumn);
                var labelColumnType = cursor.Schema.GetColumnType(labelColumnId);

                if (labelColumnType.IsKey)
                {
                    host.CheckUserArg(labelColumnType.KeyCount > 0, nameof(args.LabelColumn), "Label column must have a known cardinality");
                }
                else
                {
                    host.CheckUserArg(labelColumnType.IsNumber, nameof(args.LabelColumn), "Label column must be a number or a key type");
                }

                if (srcType.IsNumber)
                {
                    if (srcType == NumberType.R4)
                    {
                        return(Sng.SupervisedBinOneColumnFunctionBuilder.Create(args, host, icol, srcIndex, labelColumnId, cursor));
                    }
                    if (srcType == NumberType.R8)
                    {
                        return(Dbl.SupervisedBinOneColumnFunctionBuilder.Create(args, host, icol, srcIndex, labelColumnId, cursor));
                    }
                }
                if (srcType.IsVector && srcType.ItemType.IsNumber)
                {
                    if (srcType.ItemType == NumberType.R4)
                    {
                        return(Sng.SupervisedBinVecColumnFunctionBuilder.Create(args, host, icol, srcIndex, labelColumnId, cursor));
                    }
                    if (srcType.ItemType == NumberType.R8)
                    {
                        return(Dbl.SupervisedBinVecColumnFunctionBuilder.Create(args, host, icol, srcIndex, labelColumnId, cursor));
                    }
                }

                throw host.ExceptUserArg(nameof(args.Column), "Wrong column type for column {0}. Expected: R4, R8, Vec<R4, n> or Vec<R8, n>. Got: {1}.", args.Column[icol].Source, srcType.ToString());
            }