/// <summary> /// Maintains moving averages of variables. /// </summary> /// <param name="var_list"></param> /// <returns></returns> public Operation apply(RefVariable[] var_list = null) { if (var_list == null) { var_list = variables.trainable_variables() as RefVariable[]; } foreach (var var in var_list) { if (!_averages.ContainsKey(var)) { ops.init_scope(); var slot_creator = new SlotCreator(); var value = var.initialized_value(); var avg = slot_creator.create_slot(var, value, name, colocate_with_primary: true); ops.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var); _averages[var] = avg; } else { // avg = slot_creator.create_zeros_slot( throw new NotImplementedException(""); } } return(tf_with(ops.name_scope(name), scope => { var decay = ops.convert_to_tensor(_decay, name: "decay"); if (_num_updates.HasValue) { throw new NotImplementedException("ExponentialMovingAverage.apply"); } var updates = new List <Tensor>(); foreach (var var in var_list) { var zero_debias = false;// _averages[var] in zero_debias_true var ama = moving_averages.assign_moving_average(_averages[var], var, decay, zero_debias: zero_debias); updates.Add(ama); } return control_flow_ops.group(updates.ToArray(), name: scope); })); }
/// <summary> /// Maintains moving averages of variables. /// </summary> /// <param name="var_list"></param> /// <returns></returns> public Operation apply(RefVariable[] var_list = null) { if (var_list == null) { var_list = variables.trainable_variables() as RefVariable[]; } foreach (var var in var_list) { if (!_averages.Contains(var)) { ops.init_scope(); var slot = new SlotCreator(); var.initialized_value(); // var avg = slot.create_zeros_slot } } throw new NotImplementedException(""); }