private TransformInference.SuggestedTransform[] InferAndFilter(IDataView data, TransformInference.Arguments args,
                                                                           TransformInference.SuggestedTransform[] existingTransforms = null)
            {
                // Infer transforms using experts
                var levelTransforms = TransformInference.InferTransforms(_env, data, args);

                // Retain only those transforms inferred which were also passed in.
                if (existingTransforms != null)
                {
                    return(levelTransforms.Where(t => existingTransforms.Any(t2 => t2.Equals(t))).ToArray());
                }
                return(levelTransforms);
            }
Exemplo n.º 2
0
        /// <summary>
        /// Gets a final transform to concatenate all numeric columns into a "Features" vector column.
        /// Note: May return empty set if Features column already present and is only relevant numeric column.
        /// (In other words, if there would be nothing for that concatenate transform to do.)
        /// </summary>
        private static TransformInference.SuggestedTransform[] GetFinalFeatureConcat(IHostEnvironment env,
                                                                                     IDataView dataSample, int[] excludedColumnIndices, int level, int atomicIdOffset, RoleMappedData dataRoles)
        {
            var finalArgs = new TransformInference.Arguments
            {
                EstimatedSampleFraction         = 1.0,
                ExcludeFeaturesConcatTransforms = false,
                ExcludedColumnIndices           = excludedColumnIndices
            };

            var featuresConcatTransforms = TransformInference.InferConcatNumericFeatures(env, dataSample, finalArgs, dataRoles);

            for (int i = 0; i < featuresConcatTransforms.Length; i++)
            {
                featuresConcatTransforms[i].RoutingStructure.Level = level;
                featuresConcatTransforms[i].AtomicGroupId         += atomicIdOffset;
            }

            return(featuresConcatTransforms.ToArray());
        }
        public static SuggestedRecipe[] InferRecipesFromData(IHostEnvironment env, string dataFile, string schemaDefinitionFile,
                                                             out Type predictorType, out string settingsString, out TransformInference.InferenceResult inferenceResult,
                                                             bool excludeFeaturesConcatTransforms = false)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register("InferRecipesFromData", seed: 0, verbose: false);

            using (var ch = h.Start("InferRecipesFromData"))
            {
                // Validate the schema file has content if provided.
                // Warn the user early if that is provided but beign skipped.
                string schemaJson = null;
                if (!string.IsNullOrEmpty(schemaDefinitionFile))
                {
                    try
                    {
                        schemaJson = File.ReadAllText(schemaDefinitionFile);
                    }
                    catch (Exception ex)
                    {
                        ch.Warning($"Unable to read the schema file. Proceeding to infer the schema :{ex.Message}");
                    }
                }

                ch.Info("Loading file sample into memory.");
                var sample = TextFileSample.CreateFromFullFile(h, dataFile);

                ch.Info("Detecting separator and columns");
                var splitResult = TextFileContents.TrySplitColumns(h, sample, TextFileContents.DefaultSeparators);

                // initialize to clustering if we're not successful?
                predictorType  = typeof(SignatureClusteringTrainer);
                settingsString = "";
                if (!splitResult.IsSuccess)
                {
                    throw ch.ExceptDecode("Couldn't detect separator.");
                }

                ch.Info($"Separator detected as '{splitResult.Separator}', there's {splitResult.ColumnCount} columns.");

                ColumnGroupingInference.GroupingColumn[] columns;
                bool hasHeader = false;
                if (string.IsNullOrEmpty(schemaJson))
                {
                    ch.Warning("Empty schema file. Proceeding to infer the schema.");
                    columns = InferenceUtils.InferColumnPurposes(ch, h, sample, splitResult, out hasHeader);
                }
                else
                {
                    try
                    {
                        columns = JsonConvert.DeserializeObject <ColumnGroupingInference.GroupingColumn[]>(schemaJson);
                        ch.Info("Using the provided schema file.");
                    }
                    catch
                    {
                        ch.Warning("Invalid json in the schema file. Proceeding to infer the schema.");
                        columns = InferenceUtils.InferColumnPurposes(ch, h, sample, splitResult, out hasHeader);
                    }
                }

                var finalLoaderArgs = new TextLoader.Arguments
                {
                    Column       = ColumnGroupingInference.GenerateLoaderColumns(columns),
                    HasHeader    = hasHeader,
                    Separator    = splitResult.Separator,
                    AllowSparse  = splitResult.AllowSparse,
                    AllowQuoting = splitResult.AllowQuote
                };

                settingsString = CommandLine.CmdParser.GetSettings(ch, finalLoaderArgs, new TextLoader.Arguments());
                ch.Info($"Loader options: {settingsString}");

                ch.Info("Inferring recipes");
                var finalData = TextLoader.ReadFile(h, finalLoaderArgs, sample);
                var cached    = new CacheDataView(h, finalData,
                                                  Enumerable.Range(0, finalLoaderArgs.Column.Length).ToArray());

                var purposeColumns = columns.Select((x, i) => new PurposeInference.Column(i, x.Purpose, x.ItemKind)).ToArray();

                var fraction = sample.FullFileSize == null ? 1.0 : (double)sample.SampleSize / sample.FullFileSize.Value;
                var transformInferenceResult = TransformInference.InferTransforms(h, cached, purposeColumns,
                                                                                  new TransformInference.Arguments
                {
                    EstimatedSampleFraction         = fraction,
                    ExcludeFeaturesConcatTransforms = excludeFeaturesConcatTransforms
                }
                                                                                  );
                predictorType = InferenceUtils.InferPredictorCategoryType(cached, purposeColumns);
                var recipeInferenceResult = InferRecipes(h, transformInferenceResult, predictorType);

                ch.Done();

                inferenceResult = transformInferenceResult;
                return(recipeInferenceResult.SuggestedRecipes);
            }
        }