Example #1
0
        public void RenameOutputVariable(string oldName, string newName, bool cascadeChanges = false)
        {
            string key = null;

            foreach (var kvp in _outputMap)
            {
                if (kvp.Value == oldName)
                {
                    key = kvp.Key;
                    break;
                }
            }

            if (key != null)
            {
                _outputMap[key] = newName;
                if (cascadeChanges)
                {
                    _context.RenameContextVariable(oldName, newName);
                }
            }
        }
        private static Tuple <List <EntryPointNode>, Var <IPredictorModel> > ProcessClass(IHostEnvironment env, int k, string label, Arguments input, EntryPointNode node)
        {
            var macroNodes = new List <EntryPointNode>();

            // Convert label into T,F based on k.
            var remapper = new Legacy.Transforms.LabelIndicator
            {
                ClassIndex = k,
                Column     = new[]
                {
                    new Legacy.Transforms.LabelIndicatorTransformColumn
                    {
                        ClassIndex = k,
                        Name       = label,
                        Source     = label
                    }
                },
                Data = { VarName = node.GetInputVariable(nameof(input.TrainingData)).ToJson() }
            };
            var exp             = new Experiment(env);
            var remapperOutNode = exp.Add(remapper);
            var subNodes        = EntryPointNode.ValidateNodes(env, node.Context, exp.GetNodes());

            macroNodes.AddRange(subNodes);

            // Parse the nodes in input.Nodes into a temporary run context.
            var subGraphRunContext = new RunContext(env);
            var subGraphNodes      = EntryPointNode.ValidateNodes(env, subGraphRunContext, input.Nodes);

            // Rename all the variables such that they don't conflict with the ones in the outer run context.
            var  mapping     = new Dictionary <string, string>();
            bool foundOutput = false;
            Var <IPredictorModel> predModelVar = null;

            foreach (var entryPointNode in subGraphNodes)
            {
                // Rename variables in input/output maps, and in subgraph context.
                entryPointNode.RenameAllVariables(mapping);
                foreach (var kvp in mapping)
                {
                    subGraphRunContext.RenameContextVariable(kvp.Key, kvp.Value);
                }

                // Grab a hold of output model from this subgraph.
                if (entryPointNode.GetOutputVariableName("PredictorModel") is string mvn)
                {
                    predModelVar = new Var <IPredictorModel> {
                        VarName = mvn
                    };
                    foundOutput = true;
                }

                // Connect label remapper output to wherever training data was expected within the input graph.
                if (entryPointNode.GetInputVariable(nameof(input.TrainingData)) is VariableBinding vb)
                {
                    vb.Rename(remapperOutNode.OutputData.VarName);
                }

                // Change node to use the main context.
                entryPointNode.SetContext(node.Context);
            }

            // Move the variables from the subcontext to the main context.
            node.Context.AddContextVariables(subGraphRunContext);

            // Make sure we found the output variable for this model.
            if (!foundOutput)
            {
                throw new Exception("Invalid input graph. Does not output predictor model.");
            }

            // Add training subgraph to our context.
            macroNodes.AddRange(subGraphNodes);

            return(new Tuple <List <EntryPointNode>, Var <IPredictorModel> >(macroNodes, predModelVar));
        }
Example #3
0
        private static Var <PredictorModel> ProcessClass(IHostEnvironment env, List <EntryPointNode> macroNodes, int k, string label, Arguments input, EntryPointNode node)
        {
            Contracts.AssertValue(macroNodes);

            // Convert label into T,F based on k.
            var labelIndicatorArgs = new LabelIndicatorTransform.Arguments();

            labelIndicatorArgs.ClassIndex = k;
            labelIndicatorArgs.Column     = new[] { new LabelIndicatorTransform.Column()
                                                    {
                                                        Name = label, Source = label
                                                    } };

            var inputBindingMap = new Dictionary <string, List <ParameterBinding> >();
            var inputMap        = new Dictionary <ParameterBinding, VariableBinding>();
            var paramBinding    = new SimpleParameterBinding(nameof(labelIndicatorArgs.Data));

            inputBindingMap.Add(nameof(labelIndicatorArgs.Data), new List <ParameterBinding>()
            {
                paramBinding
            });
            inputMap.Add(paramBinding, node.GetInputVariable(nameof(input.TrainingData)));

            var outputMap        = new Dictionary <string, string>();
            var remappedLabelVar = new Var <IDataView>();

            outputMap.Add(nameof(CommonOutputs.TransformOutput.OutputData), remappedLabelVar.VarName);
            var labelIndicatorNode = EntryPointNode.Create(env, "Transforms.LabelIndicator", labelIndicatorArgs, node.Context,
                                                           inputBindingMap, inputMap, outputMap);

            macroNodes.Add(labelIndicatorNode);

            // Parse the nodes in input.Nodes into a temporary run context.
            var subGraphRunContext = new RunContext(env);
            var subGraphNodes      = EntryPointNode.ValidateNodes(env, subGraphRunContext, input.Nodes);

            // Rename all the variables such that they don't conflict with the ones in the outer run context.
            var  mapping     = new Dictionary <string, string>();
            bool foundOutput = false;
            Var <PredictorModel> predModelVar = null;

            foreach (var entryPointNode in subGraphNodes)
            {
                // Rename variables in input/output maps, and in subgraph context.
                entryPointNode.RenameAllVariables(mapping);
                foreach (var kvp in mapping)
                {
                    subGraphRunContext.RenameContextVariable(kvp.Key, kvp.Value);
                }

                // Grab a hold of output model from this subgraph.
                if (entryPointNode.GetOutputVariableName("PredictorModel") is string mvn)
                {
                    predModelVar = new Var <PredictorModel> {
                        VarName = mvn
                    };
                    foundOutput = true;
                }

                // Connect label remapper output to wherever training data was expected within the input graph.
                if (entryPointNode.GetInputVariable(nameof(input.TrainingData)) is VariableBinding vb)
                {
                    vb.Rename(remappedLabelVar.VarName);
                }

                // Change node to use the main context.
                entryPointNode.SetContext(node.Context);
            }

            // Move the variables from the subcontext to the main context.
            node.Context.AddContextVariables(subGraphRunContext);

            // Make sure we found the output variable for this model.
            if (!foundOutput)
            {
                throw new Exception("Invalid input graph. Does not output predictor model.");
            }

            // Add training subgraph to our context.
            macroNodes.AddRange(subGraphNodes);
            return(predModelVar);
        }