public CNTKFunction(CNTKBackend c, List <Variable> inputs, CNTK.Function[] outputs, List <List <Tensor> > updates, string name)
        {
            this.c                 = c;
            this.placeholders      = inputs;
            this.trainer           = null;
            this.unrelated_updates = null;
            this.updates           = updates;
            if (updates.Count > 0)
            {
                if (len(outputs) <= 0)
                {
                    throw new Exception();
                }

                this.loss = outputs[0];
                // need group update by gradient place holder
                var u_ops             = new List <CNTK.Function>();
                var unrelated_updates = new List <CNTK.Function>();
                foreach (List <Tensor> update in updates)
                {
                    CNTK.Function u;

                    if (update.Count == 1)
                    {
                        u = c.In(update[0]);
                    }
                    else if (update.Count == 2)
                    {
                        u = C.Assign(c.In(update[0]), c.In(update[1]));
                    }
                    else
                    {
                        throw new NotImplementedException();
                    }

                    if (u.Inputs.Count == 0)
                    {
                        u_ops.Add(u);
                    }
                    else
                    {
                        unrelated_updates.Add(u);
                    }
                }

                var update_func = C.Combine(new VariableVector(u_ops.Select(u => u.Output).ToArray()));

                CNTK.Function[] grads = update_func.FindAllWithName("keras_grad_placeholder").ToArray();

                var u_list = new List <CNTK.Function>();
                var p_list = new List <CNTK.Parameter>();
                foreach (CNTK.Function g in grads)
                {
                    if (c.grad_parameter_dict.ContainsKey(g))
                    {
                        p_list.Add(c.grad_parameter_dict[g]);
                        u_list.Add(g);
                    }
                    else
                    {
                        throw new Exception($"CNTK backend: when constructing trainer, found gradient node {g} which is not related to any parameters in the model. Please double check how the gradient node is constructed.");
                    }
                }

                if (len(u_list) > 0)
                {
                    Learner learner = Learner.SGDLearner(p_list, new TrainingParameterScheduleDouble(0));

                    var criterion = (len(outputs) > 1) ?
                                    C.Combine(new VariableVector(new[] { outputs[0], outputs[1] })) :
                                    outputs[0];

                    this.trainer = Trainer.CreateTrainer(model: outputs[0], lossFunction: criterion, evaluationFunction: null, parameterLearners: new[] { learner });

                    this.trainer_output = new UnorderedMapVariableValuePtr();
                    foreach (CNTK.Function f in outputs)
                    {
                        this.trainer_output.Add(f, null);
                    }
                }
                else if (len(u_ops) > 0)
                {
                    unrelated_updates.AddRange(u_ops);
                }

                if (len(unrelated_updates) > 0)
                {
                    this.unrelated_updates = C.Combine(new VariableVector(unrelated_updates.Select(_ => _.Output).ToArray()));
                }
            }

            if (this.trainer == null)
            {
                this.metrics_outputs = outputs.Select(f => f.Output).ToArray();

                this.metrics_func = C.Combine(new VariableVector(this.metrics_outputs));
                // cntk only could handle loss and 1 metric in trainer, for metrics more
                // than 2, need manual eval
            }
            else if (len(outputs) > 2)
            {
                this.metrics_outputs = Matrix.Get(outputs, 2, 0).Select(f => f.Output).ToArray();

                this.metrics_func = C.Combine(new VariableVector(this.metrics_outputs));
            }
            else
            {
                this.metrics_func = null;
            }
        }
Beispiel #2
0
        public CNTKFunction(CNTKBackend c, Variable[] inputs, CNTK.Variable[] outputs, List <List <Tensor> > updates, string name)
        {
            // https://github.com/fchollet/keras/blob/f65a56fb65062c8d14d215c9f4b1015b97cc5bf3/keras/backend/cntk_backend.py#L1501
            this.c                 = c;
            this.placeholders      = inputs;
            this.trainer           = null;
            this.unrelated_updates = null;
            this.updates           = updates;
            if (updates.Count > 0)
            {
                if (len(outputs) <= 0)
                {
                    throw new Exception();
                }

                this.loss = outputs[0];
                // need group update by gradient place holder
                var u_ops             = new List <CNTK.Function>();
                var unrelated_updates = new List <CNTK.Function>();
                foreach (List <Tensor> update in updates)
                {
                    CNTK.Function u;

                    if (update.Count == 1)
                    {
                        u = c.In(update[0]);
                    }
                    else if (update.Count == 2)
                    {
                        u = C.Assign(c.In(update[0]), c.In(update[1]));
                    }
                    else
                    {
                        throw new NotImplementedException();
                    }

                    if (u.Arguments.Count == 0)
                    {
                        u_ops.Add(u);
                    }
                    else
                    {
                        unrelated_updates.Add(u);
                    }
                }

                var update_func = C.Combine(new VariableVector(u_ops.Select(u => u.Output).ToArray()));

                CNTK.Constant[] grads = update_func.Inputs.Where(x => x.Name == "keras_grad_placeholder").Select(x => new Constant(x)).ToArray();

                var u_list = new List <CNTK.Constant>();
                var p_list = new List <CNTK.Parameter>();
                foreach (CNTK.Constant g in grads)
                {
                    if (c.grad_parameter_dict.ContainsKey(g.Uid))
                    {
                        p_list.Add(c.grad_parameter_dict[g.Uid]);
                        u_list.Add(g);
                    }
                    else
                    {
                        throw new Exception($"CNTK backend: when constructing trainer, found gradient node {g} which is not related to any parameters in the model. Please double check how the gradient node is constructed.");
                    }
                }

                if (len(u_list) > 0)
                {
                    Learner learner = Learner.SGDLearner(p_list, new TrainingParameterScheduleDouble(1));

                    this.trainer = Trainer.CreateTrainer(model: outputs[0],
                                                         lossFunction: outputs[0],
                                                         evaluationFunction: outputs[1],
                                                         parameterLearners: new[] { learner });
                }
                else if (len(u_ops) > 0)
                {
                    unrelated_updates.AddRange(u_ops);
                }

                if (len(unrelated_updates) > 0)
                {
                    this.unrelated_updates = C.Combine(new VariableVector(unrelated_updates.Select(_ => _.Output).ToArray()));
                }
            }

            if (this.trainer == null)
            {
                this.metrics_outputs = outputs;

                this.metrics_func = C.Combine(new VariableVector(this.metrics_outputs));
                // cntk only could handle loss and 1 metric in trainer, for metrics more
                // than 2, need manual eval
            }
            else if (len(outputs) > 2)
            {
                this.metrics_outputs = Matrix.Get(outputs, 2, 0);

                this.metrics_func = C.Combine(new VariableVector(this.metrics_outputs));
            }
            else
            {
                this.metrics_func = null;
            }
        }