예제 #1
0
        internal static LambdaNode ParseAndBindLambda(IHostEnvironment env, string expression, int ivec, DataViewType[] inputTypes, out int[] perm)
        {
            perm = Utils.GetIdentityPermutation(inputTypes.Length);
            if (ivec >= 0)
            {
                if (ivec > 0)
                {
                    perm[0] = ivec;
                    for (int i = 1; i <= ivec; i++)
                    {
                        perm[i] = i - 1;
                    }
                }
            }
            CharCursor chars = new CharCursor(expression);

            var node = LambdaParser.Parse(out List <Error> errors, out List <int> lineMap, chars, perm, inputTypes);

            if (Utils.Size(errors) > 0)
            {
                throw env.ExceptParam(nameof(expression), $"parsing failed: {errors[0].GetMessage()}");
            }

            using (var ch = env.Start("LabmdaBinder.Run"))
                LambdaBinder.Run(env, ref errors, node, msg => ch.Error(msg));

            if (Utils.Size(errors) > 0)
            {
                throw env.ExceptParam(nameof(expression), $"binding failed: {errors[0].GetMessage()}");
            }
            return(node);
        }
        public static IDataView Create <TSrc>(IHostEnvironment env, string name, IDataView input,
                                              string src, DataViewType typeSrc, InPredicate <TSrc> predicate)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckNonEmpty(name, nameof(name));
            env.CheckValue(input, nameof(input));
            env.CheckNonEmpty(src, nameof(src));
            env.CheckValue(typeSrc, nameof(typeSrc));
            env.CheckValue(predicate, nameof(predicate));

            if (typeSrc.RawType != typeof(TSrc))
            {
                throw env.ExceptParam(nameof(predicate),
                                      "The source column type '{0}' doesn't match the input type of the predicate", typeSrc);
            }

            int  colSrc;
            bool tmp = input.Schema.TryGetColumnIndex(src, out colSrc);

            if (!tmp)
            {
                throw env.ExceptParam(nameof(src), "The input data doesn't have a column named '{0}'", src);
            }
            var typeOrig = input.Schema[colSrc].Type;

            // REVIEW: Ideally this should support vector-type conversion. It currently doesn't.
            bool     ident;
            Delegate conv;

            if (typeOrig.SameSizeAndItemType(typeSrc))
            {
                ident = true;
                conv  = null;
            }
            else if (!Conversions.DefaultInstance.TryGetStandardConversion(typeOrig, typeSrc, out conv, out ident))
            {
                throw env.ExceptParam(nameof(predicate),
                                      "The type of column '{0}', '{1}', cannot be converted to the input type of the predicate '{2}'",
                                      src, typeOrig, typeSrc);
            }

            IDataView impl;

            if (ident)
            {
                impl = new Impl <TSrc, TSrc>(env, name, input, colSrc, predicate);
            }
            else
            {
                Func <IHostEnvironment, string, IDataView, int,
                      InPredicate <int>, ValueMapper <int, int>, Impl <int, int> > del = CreateImpl <int, int>;
                var meth = del.GetMethodInfo().GetGenericMethodDefinition()
                           .MakeGenericMethod(typeOrig.RawType, typeof(TSrc));
                impl = (IDataView)meth.Invoke(null, new object[] { env, name, input, colSrc, predicate, conv });
            }

            return(new OpaqueDataView(impl));
        }
예제 #3
0
        private static IMamlEvaluator GetEvaluator(IHostEnvironment env, MacroUtils.TrainerKinds kind)
        {
            switch (kind)
            {
            case MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer:
                return(new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments()));

            case MacroUtils.TrainerKinds.SignatureMultiClassClassifierTrainer:
                return(new MultiClassMamlEvaluator(env, new MultiClassMamlEvaluator.Arguments()));

            case MacroUtils.TrainerKinds.SignatureRegressorTrainer:
                return(new RegressionMamlEvaluator(env, new RegressionMamlEvaluator.Arguments()));

            case MacroUtils.TrainerKinds.SignatureRankerTrainer:
                return(new RankerMamlEvaluator(env, new RankerMamlEvaluator.Arguments()));

            case MacroUtils.TrainerKinds.SignatureAnomalyDetectorTrainer:
                return(new AnomalyDetectionMamlEvaluator(env, new AnomalyDetectionMamlEvaluator.Arguments()));

            case MacroUtils.TrainerKinds.SignatureClusteringTrainer:
                return(new ClusteringMamlEvaluator(env, new ClusteringMamlEvaluator.Arguments()));

            case MacroUtils.TrainerKinds.SignatureMultiOutputRegressorTrainer:
                return(new MultiOutputRegressionMamlEvaluator(env, new MultiOutputRegressionMamlEvaluator.Arguments()));

            default:
                throw env.ExceptParam(nameof(kind), $"Trainer kind {kind} does not have an evaluator");
            }
        }
예제 #4
0
                /// <summary>
                /// This is the constructor called for the initial wrapping.
                /// </summary>
                public Bound(IHostEnvironment env, ISchemaBoundRowMapper mapper, VectorType type, ValueGetter <VBuffer <T> > getter,
                             string metadataKind, Func <ISchemaBoundMapper, ColumnType, bool> canWrap)
                {
                    Contracts.CheckValue(env, nameof(env));
                    _host = env.Register(LoaderSignature);
                    _host.CheckValue(mapper, nameof(mapper));
                    _host.CheckValue(type, nameof(type));
                    _host.CheckValue(getter, nameof(getter));
                    _host.CheckNonEmpty(metadataKind, nameof(metadataKind));
                    _host.CheckValueOrNull(canWrap);

                    _mapper = mapper;

                    int  scoreIdx;
                    bool result = mapper.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out scoreIdx);

                    if (!result)
                    {
                        throw env.ExceptParam(nameof(mapper), "Mapper did not have a '{0}' column", MetadataUtils.Const.ScoreValueKind.Score);
                    }

                    _labelNameType   = type;
                    _labelNameGetter = getter;
                    _metadataKind    = metadataKind;

                    _outSchema = new SchemaImpl(mapper.Schema, scoreIdx, _labelNameType, _labelNameGetter, _metadataKind);
                    _canWrap   = canWrap;
                }
예제 #5
0
        public static double ExtractValueFromIdv(IHostEnvironment env, IDataView result, string columnName)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(result, nameof(result));
            env.CheckNonEmpty(columnName, nameof(columnName));

            double outputValue = 0;
            var    schema      = result.Schema;

            if (!schema.TryGetColumnIndex(columnName, out var metricCol))
            {
                throw env.ExceptParam(nameof(columnName), $"Schema does not contain column: {columnName}");
            }

            using (var cursor = result.GetRowCursor(col => col == metricCol))
            {
                var  getter = cursor.GetGetter <double>(metricCol);
                bool moved  = cursor.MoveNext();
                env.Check(moved, "Expected an IDataView with a single row. Results dataset has no rows to extract.");
                getter(ref outputValue);
                env.Check(!cursor.MoveNext(), "Expected an IDataView with a single row. Results dataset has too many rows.");
            }

            return(outputValue);
        }
예제 #6
0
            private static void CheckBinaryLabel(bool user, IHostEnvironment env, IPredictorModel[] predictors)
            {
                int classCount = CheckLabelColumn(env, predictors, true);

                if (classCount != 2)
                {
                    var error = string.Format("Expected label to have exactly 2 classes, instead has {0}", classCount);
                    throw user?env.ExceptParam(nameof(predictors), error) : env.ExceptDecode(error);
                }
            }
예제 #7
0
        public TypeName(IHostEnvironment env, float p, int foo)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckParam(0 <= p && p <= 1, nameof(p), "Should be in range [0,1]");
            env.CheckParam(0 <= p && p <= 1, "p");                   // Should fail.
            env.CheckParam(0 <= p && p <= 1, nameof(p) + nameof(p)); // Should fail.
            env.CheckValue(paramName: nameof(p), val: "p");          // Should succeed despite confusing order.
            env.CheckValue(paramName: "p", val: nameof(p));          // Should fail despite confusing order.
            env.CheckValue("p", nameof(p));
            env.CheckUserArg(foo > 5, "foo", "Nice");
            env.CheckUserArg(foo > 5, nameof(foo), "Nice");
            env.Except();                                           // Not throwing or doing anything with the exception, so should fail.
            Contracts.ExceptParam(nameof(env), "What a silly env"); // Should also fail.
            if (false)
            {
                throw env.Except(); // Should not fail.
            }
            if (false)
            {
                throw env.ExceptParam(nameof(env), "What a silly env"); // Should not fail.
            }
            if (false)
            {
                throw env.ExceptParam("env", "What a silly env"); // Should fail due to name error.
            }
            var e = env.Except();

            env.Check(true, $"Hello {foo} is cool");
            env.Check(true, "Hello it is cool");
            string coolMessage = "Hello it is cool";

            env.Check(true, coolMessage);
            env.Check(true, string.Format("Hello {0} is cool", foo));
            env.Check(true, Messages.CoolMessage);
            env.CheckDecode(true, "Not suspicious, no ModelLoadContext");
            Contracts.Check(true, "Fine: " + nameof(env));
            Contracts.Check(true, "Less fine: " + env.GetType().Name);
            Contracts.CheckUserArg(0 <= p && p <= 1,
                                   "p", "On a new line");
        }
예제 #8
0
        // Extracts the indices and types of the input columns to the whitening transform.
        private static void GetColTypesAndIndex(IHostEnvironment env, IDataView inputData, ColumnInfo[] columns, out ColumnType[] srcTypes, out int[] cols)
        {
            cols     = new int[columns.Length];
            srcTypes = new ColumnType[columns.Length];
            var inputSchema = inputData.Schema;

            for (int i = 0; i < columns.Length; i++)
            {
                if (!inputSchema.TryGetColumnIndex(columns[i].Input, out cols[i]))
                {
                    throw env.ExceptSchemaMismatch(nameof(inputSchema), "input", columns[i].Input);
                }
                srcTypes[i] = inputSchema.GetColumnType(cols[i]);
                var reason = TestColumn(srcTypes[i]);
                if (reason != null)
                {
                    throw env.ExceptParam(nameof(inputData.Schema), reason);
                }
            }
        }
예제 #9
0
        // Extracts the indices and types of the input columns to the whitening transform.
        private static void GetColTypesAndIndex(IHostEnvironment env, IDataView inputData, VectorWhiteningEstimator.ColumnOptions[] columns, out DataViewType[] srcTypes, out int[] cols)
        {
            cols     = new int[columns.Length];
            srcTypes = new DataViewType[columns.Length];
            var inputSchema = inputData.Schema;

            for (int i = 0; i < columns.Length; i++)
            {
                var col = inputSchema.GetColumnOrNull(columns[i].InputColumnName);
                if (!col.HasValue)
                {
                    throw env.ExceptSchemaMismatch(nameof(inputSchema), "input", columns[i].InputColumnName);
                }

                cols[i]     = col.Value.Index;
                srcTypes[i] = col.Value.Type;
                var reason = TestColumn(srcTypes[i]);
                if (reason != null)
                {
                    throw env.ExceptParam(nameof(inputData.Schema), reason);
                }
            }
        }
예제 #10
0
        internal static DatabaseLoader CreateDatabaseLoader <TInput>(IHostEnvironment host)
        {
            var userType = typeof(TInput);

            var fieldInfos = userType.GetFields(BindingFlags.Public | BindingFlags.Instance);

            var propertyInfos =
                userType
                .GetProperties(BindingFlags.Public | BindingFlags.Instance)
                .Where(x => x.CanRead && x.GetGetMethod() != null && x.GetIndexParameters().Length == 0);

            var memberInfos = (fieldInfos as IEnumerable <MemberInfo>).Concat(propertyInfos).ToArray();

            if (memberInfos.Length == 0)
            {
                throw host.ExceptParam(nameof(TInput), $"Should define at least one public, readable field or property in {nameof(TInput)}.");
            }

            var columns = new List <Column>();

            for (int index = 0; index < memberInfos.Length; index++)
            {
                var memberInfo      = memberInfos[index];
                var mappingAttrName = memberInfo.GetCustomAttribute <ColumnNameAttribute>();

                var column = new Column();
                column.Name = mappingAttrName?.Name ?? memberInfo.Name;

                var indexMappingAttr = memberInfo.GetCustomAttribute <LoadColumnAttribute>();
                var nameMappingAttr  = memberInfo.GetCustomAttribute <LoadColumnNameAttribute>();

                if (indexMappingAttr is object)
                {
                    if (nameMappingAttr is object)
                    {
                        throw Contracts.Except($"Cannot specify both {nameof(LoadColumnAttribute)} and {nameof(LoadColumnNameAttribute)}");
                    }

                    column.Source = indexMappingAttr.Sources.Select((source) => Range.FromTextLoaderRange(source)).ToArray();
                }
                else if (nameMappingAttr is object)
                {
                    column.Source = nameMappingAttr.Sources.Select((source) => new Range(source)).ToArray();
                }

                InternalDataKind dk;
                switch (memberInfo)
                {
                case FieldInfo field:
                    if (!InternalDataKindExtensions.TryGetDataKind(field.FieldType.IsArray ? field.FieldType.GetElementType() : field.FieldType, out dk))
                    {
                        throw Contracts.Except($"Field {memberInfo.Name} is of unsupported type.");
                    }

                    break;

                case PropertyInfo property:
                    if (!InternalDataKindExtensions.TryGetDataKind(property.PropertyType.IsArray ? property.PropertyType.GetElementType() : property.PropertyType, out dk))
                    {
                        throw Contracts.Except($"Property {memberInfo.Name} is of unsupported type.");
                    }
                    break;

                default:
                    Contracts.Assert(false);
                    throw Contracts.ExceptNotSupp("Expected a FieldInfo or a PropertyInfo");
                }

                column.Type = dk.ToDbType();

                columns.Add(column);
            }

            var options = new Options
            {
                Columns = columns.ToArray()
            };

            return(new DatabaseLoader(host, options));
        }
예제 #11
0
        public static PipelineResultRow[] ExtractResults(IHostEnvironment env, IDataView data,
                                                         string graphColName, string metricColName, string idColName, string trainingMetricColName,
                                                         string firstInputColName, string predictorModelColName)
        {
            var results = new List <PipelineResultRow>();
            var schema  = data.Schema;

            if (!schema.TryGetColumnIndex(graphColName, out var graphCol))
            {
                throw env.ExceptParam(nameof(graphColName), $"Column name {graphColName} not found");
            }
            if (!schema.TryGetColumnIndex(metricColName, out var metricCol))
            {
                throw env.ExceptParam(nameof(metricColName), $"Column name {metricColName} not found");
            }
            if (!schema.TryGetColumnIndex(trainingMetricColName, out var trainingMetricCol))
            {
                throw env.ExceptParam(nameof(trainingMetricColName), $"Column name {trainingMetricColName} not found");
            }
            if (!schema.TryGetColumnIndex(idColName, out var pipelineIdCol))
            {
                throw env.ExceptParam(nameof(idColName), $"Column name {idColName} not found");
            }
            if (!schema.TryGetColumnIndex(firstInputColName, out var firstInputCol))
            {
                throw env.ExceptParam(nameof(firstInputColName), $"Column name {firstInputColName} not found");
            }
            if (!schema.TryGetColumnIndex(predictorModelColName, out var predictorModelCol))
            {
                throw env.ExceptParam(nameof(predictorModelColName), $"Column name {predictorModelColName} not found");
            }

            using (var cursor = data.GetRowCursor(col => true))
            {
                var    getter1                       = cursor.GetGetter <double>(metricCol);
                var    getter2                       = cursor.GetGetter <ReadOnlyMemory <char> >(graphCol);
                var    getter3                       = cursor.GetGetter <ReadOnlyMemory <char> >(pipelineIdCol);
                var    getter4                       = cursor.GetGetter <double>(trainingMetricCol);
                var    getter5                       = cursor.GetGetter <ReadOnlyMemory <char> >(firstInputCol);
                var    getter6                       = cursor.GetGetter <ReadOnlyMemory <char> >(predictorModelCol);
                double metricValue                   = 0;
                double trainingMetricValue           = 0;
                ReadOnlyMemory <char> graphJson      = default;
                ReadOnlyMemory <char> pipelineId     = default;
                ReadOnlyMemory <char> firstInput     = default;
                ReadOnlyMemory <char> predictorModel = default;

                while (cursor.MoveNext())
                {
                    getter1(ref metricValue);
                    getter2(ref graphJson);
                    getter3(ref pipelineId);
                    getter4(ref trainingMetricValue);
                    getter5(ref firstInput);
                    getter6(ref predictorModel);

                    results.Add(new PipelineResultRow(graphJson.ToString(),
                                                      metricValue, pipelineId.ToString(), trainingMetricValue,
                                                      firstInput.ToString(), predictorModel.ToString()));
                }
            }

            return(results.ToArray());
        }
예제 #12
0
        // REVIEW: It would be nice to support propagation of select metadata.
        public static IDataView Create <TSrc, TDst>(IHostEnvironment env, string name, IDataView input,
                                                    string src, string dst, ColumnType typeSrc, ColumnType typeDst, ValueMapper <TSrc, TDst> mapper,
                                                    ValueGetter <VBuffer <ReadOnlyMemory <char> > > keyValueGetter = null, ValueGetter <VBuffer <ReadOnlyMemory <char> > > slotNamesGetter = null)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckNonEmpty(name, nameof(name));
            env.CheckValue(input, nameof(input));
            env.CheckNonEmpty(src, nameof(src));
            env.CheckNonEmpty(dst, nameof(dst));
            env.CheckValue(typeSrc, nameof(typeSrc));
            env.CheckValue(typeDst, nameof(typeDst));
            env.CheckValue(mapper, nameof(mapper));
            env.Check(keyValueGetter == null || typeDst.GetItemType() is KeyType);
            env.Check(slotNamesGetter == null || typeDst.IsKnownSizeVector());

            if (typeSrc.RawType != typeof(TSrc))
            {
                throw env.ExceptParam(nameof(mapper),
                                      "The source column type '{0}' doesn't match the input type of the mapper", typeSrc);
            }
            if (typeDst.RawType != typeof(TDst))
            {
                throw env.ExceptParam(nameof(mapper),
                                      "The destination column type '{0}' doesn't match the output type of the mapper", typeDst);
            }

            bool tmp = input.Schema.TryGetColumnIndex(src, out int colSrc);

            if (!tmp)
            {
                throw env.ExceptParam(nameof(src), "The input data doesn't have a column named '{0}'", src);
            }
            var typeOrig = input.Schema[colSrc].Type;

            // REVIEW: Ideally this should support vector-type conversion. It currently doesn't.
            bool     ident;
            Delegate conv;

            if (typeOrig.SameSizeAndItemType(typeSrc))
            {
                ident = true;
                conv  = null;
            }
            else if (!Conversions.Instance.TryGetStandardConversion(typeOrig, typeSrc, out conv, out ident))
            {
                throw env.ExceptParam(nameof(mapper),
                                      "The type of column '{0}', '{1}', cannot be converted to the input type of the mapper '{2}'",
                                      src, typeOrig, typeSrc);
            }

            var       col = new Column(src, dst);
            IDataView impl;

            if (ident)
            {
                impl = new Impl <TSrc, TDst, TDst>(env, name, input, col, typeDst, mapper, keyValueGetter: keyValueGetter, slotNamesGetter: slotNamesGetter);
            }
            else
            {
                Func <IHostEnvironment, string, IDataView, Column, ColumnType, ValueMapper <int, int>,
                      ValueMapper <int, int>, ValueGetter <VBuffer <ReadOnlyMemory <char> > >, ValueGetter <VBuffer <ReadOnlyMemory <char> > >,
                      Impl <int, int, int> > del = CreateImpl <int, int, int>;
                var meth = del.GetMethodInfo().GetGenericMethodDefinition()
                           .MakeGenericMethod(typeOrig.RawType, typeof(TSrc), typeof(TDst));
                impl = (IDataView)meth.Invoke(null, new object[] { env, name, input, col, typeDst, conv, mapper, keyValueGetter, slotNamesGetter });
            }

            return(new OpaqueDataView(impl));
        }