public CNTKFunction(CNTKBackend c, List <Variable> inputs, CNTK.Function[] outputs, List <List <Tensor> > updates, string name) { this.c = c; this.placeholders = inputs; this.trainer = null; this.unrelated_updates = null; this.updates = updates; if (updates.Count > 0) { if (len(outputs) <= 0) { throw new Exception(); } this.loss = outputs[0]; // need group update by gradient place holder var u_ops = new List <CNTK.Function>(); var unrelated_updates = new List <CNTK.Function>(); foreach (List <Tensor> update in updates) { CNTK.Function u; if (update.Count == 1) { u = c.In(update[0]); } else if (update.Count == 2) { u = C.Assign(c.In(update[0]), c.In(update[1])); } else { throw new NotImplementedException(); } if (u.Inputs.Count == 0) { u_ops.Add(u); } else { unrelated_updates.Add(u); } } var update_func = C.Combine(new VariableVector(u_ops.Select(u => u.Output).ToArray())); CNTK.Function[] grads = update_func.FindAllWithName("keras_grad_placeholder").ToArray(); var u_list = new List <CNTK.Function>(); var p_list = new List <CNTK.Parameter>(); foreach (CNTK.Function g in grads) { if (c.grad_parameter_dict.ContainsKey(g)) { p_list.Add(c.grad_parameter_dict[g]); u_list.Add(g); } else { throw new Exception($"CNTK backend: when constructing trainer, found gradient node {g} which is not related to any parameters in the model. Please double check how the gradient node is constructed."); } } if (len(u_list) > 0) { Learner learner = Learner.SGDLearner(p_list, new TrainingParameterScheduleDouble(0)); var criterion = (len(outputs) > 1) ? C.Combine(new VariableVector(new[] { outputs[0], outputs[1] })) : outputs[0]; this.trainer = Trainer.CreateTrainer(model: outputs[0], lossFunction: criterion, evaluationFunction: null, parameterLearners: new[] { learner }); this.trainer_output = new UnorderedMapVariableValuePtr(); foreach (CNTK.Function f in outputs) { this.trainer_output.Add(f, null); } } else if (len(u_ops) > 0) { unrelated_updates.AddRange(u_ops); } if (len(unrelated_updates) > 0) { this.unrelated_updates = C.Combine(new VariableVector(unrelated_updates.Select(_ => _.Output).ToArray())); } } if (this.trainer == null) { this.metrics_outputs = outputs.Select(f => f.Output).ToArray(); this.metrics_func = C.Combine(new VariableVector(this.metrics_outputs)); // cntk only could handle loss and 1 metric in trainer, for metrics more // than 2, need manual eval } else if (len(outputs) > 2) { this.metrics_outputs = Matrix.Get(outputs, 2, 0).Select(f => f.Output).ToArray(); this.metrics_func = C.Combine(new VariableVector(this.metrics_outputs)); } else { this.metrics_func = null; } }
public CNTKFunction(CNTKBackend c, Variable[] inputs, CNTK.Variable[] outputs, List <List <Tensor> > updates, string name) { // https://github.com/fchollet/keras/blob/f65a56fb65062c8d14d215c9f4b1015b97cc5bf3/keras/backend/cntk_backend.py#L1501 this.c = c; this.placeholders = inputs; this.trainer = null; this.unrelated_updates = null; this.updates = updates; if (updates.Count > 0) { if (len(outputs) <= 0) { throw new Exception(); } this.loss = outputs[0]; // need group update by gradient place holder var u_ops = new List <CNTK.Function>(); var unrelated_updates = new List <CNTK.Function>(); foreach (List <Tensor> update in updates) { CNTK.Function u; if (update.Count == 1) { u = c.In(update[0]); } else if (update.Count == 2) { u = C.Assign(c.In(update[0]), c.In(update[1])); } else { throw new NotImplementedException(); } if (u.Arguments.Count == 0) { u_ops.Add(u); } else { unrelated_updates.Add(u); } } var update_func = C.Combine(new VariableVector(u_ops.Select(u => u.Output).ToArray())); CNTK.Constant[] grads = update_func.Inputs.Where(x => x.Name == "keras_grad_placeholder").Select(x => new Constant(x)).ToArray(); var u_list = new List <CNTK.Constant>(); var p_list = new List <CNTK.Parameter>(); foreach (CNTK.Constant g in grads) { if (c.grad_parameter_dict.ContainsKey(g.Uid)) { p_list.Add(c.grad_parameter_dict[g.Uid]); u_list.Add(g); } else { throw new Exception($"CNTK backend: when constructing trainer, found gradient node {g} which is not related to any parameters in the model. Please double check how the gradient node is constructed."); } } if (len(u_list) > 0) { Learner learner = Learner.SGDLearner(p_list, new TrainingParameterScheduleDouble(1)); this.trainer = Trainer.CreateTrainer(model: outputs[0], lossFunction: outputs[0], evaluationFunction: outputs[1], parameterLearners: new[] { learner }); } else if (len(u_ops) > 0) { unrelated_updates.AddRange(u_ops); } if (len(unrelated_updates) > 0) { this.unrelated_updates = C.Combine(new VariableVector(unrelated_updates.Select(_ => _.Output).ToArray())); } } if (this.trainer == null) { this.metrics_outputs = outputs; this.metrics_func = C.Combine(new VariableVector(this.metrics_outputs)); // cntk only could handle loss and 1 metric in trainer, for metrics more // than 2, need manual eval } else if (len(outputs) > 2) { this.metrics_outputs = Matrix.Get(outputs, 2, 0); this.metrics_func = C.Combine(new VariableVector(this.metrics_outputs)); } else { this.metrics_func = null; } }