Пример #1
0
        static Loss TrainStep(ObjectDetectionDataset.EntryBatch batch, Model model, IOptimizer optimizer, int classCount, ReadOnlySpan <int> strides, bool bench = false)
        {
            if (bench)
            {
                return(ComputeLosses(model, batch, classCount, strides));
            }

            var    tape = new GradientTape();
            Loss   losses;
            Tensor totalLoss;

            using (tape.StartUsing()) {
                losses    = ComputeLosses(model, batch, classCount, strides);
                totalLoss = losses.GIUO + losses.Conf + losses.Prob;

                if (!tf.executing_eagerly() || !tf.logical_or(tf.is_inf(totalLoss), tf.is_nan(totalLoss)).numpy().any())
                {
                    PythonList <Tensor> gradients = tape.gradient(totalLoss, model.trainable_variables);
                    optimizer.apply_gradients(gradients.Zip(
                                                  (PythonList <Variable>)model.trainable_variables, (g, v) => (g, v)));
                }
                else
                {
                    Trace.TraceWarning("NaN/inf loss ignored");
                }
            }

            return(losses);
        }
Пример #2
0
        void _minimize(GradientTape tape, OptimizerV2 optimizer, Tensor loss, List <IVariableV1> trainable_variables)
        {
            var gradients = tape.gradient(loss, trainable_variables);

            gradients = optimizer._aggregate_gradients(zip(gradients, trainable_variables));
            gradients = optimizer._clip_gradients(gradients);

            optimizer.apply_gradients(zip(gradients, trainable_variables.Select(x => x as ResourceVariable)),
                                      experimental_aggregate_gradients: false);
        }
Пример #3
0
 private void InitGradientEnvironment()
 {
     _tapeSet = new GradientTape();
     ops.RegisterFromAssembly();
 }
Пример #4
0
 public AccumulatorCallState(GradientTape backward_tape, bool accumulating)
 {
     this.backward_tape = backward_tape;
     this.accumulating  = accumulating;
 }