Ejemplo n.º 1
0
        static Loss TrainStep(ObjectDetectionDataset.EntryBatch batch, Model model, IOptimizer optimizer, int classCount, ReadOnlySpan <int> strides, bool bench = false)
        {
            if (bench)
            {
                return(ComputeLosses(model, batch, classCount, strides));
            }

            var    tape = new GradientTape();
            Loss   losses;
            Tensor totalLoss;

            using (tape.StartUsing()) {
                losses    = ComputeLosses(model, batch, classCount, strides);
                totalLoss = losses.GIUO + losses.Conf + losses.Prob;

                if (!tf.executing_eagerly() || !tf.logical_or(tf.is_inf(totalLoss), tf.is_nan(totalLoss)).numpy().any())
                {
                    PythonList <Tensor> gradients = tape.gradient(totalLoss, model.trainable_variables);
                    optimizer.apply_gradients(gradients.Zip(
                                                  (PythonList <Variable>)model.trainable_variables, (g, v) => (g, v)));
                }
                else
                {
                    Trace.TraceWarning("NaN/inf loss ignored");
                }
            }

            return(losses);
        }
Ejemplo n.º 2
0
        public static Loss ComputeLosses(Model model,
                                         ObjectDetectionDataset.EntryBatch batch,
                                         int classCount, ReadOnlySpan <int> strides)
        {
            if (model is null)
            {
                throw new ArgumentNullException(nameof(model));
            }
            if (classCount <= 0)
            {
                throw new ArgumentOutOfRangeException(nameof(classCount));
            }

            IList <Tensor> output = model.__call__(batch.Images, training: true);
            var            loss   = Loss.Zero;

            for (int scaleIndex = 0; scaleIndex < YOLOv4.XYScale.Length; scaleIndex++)
            {
                Tensor conv = output[scaleIndex * 2];
                Tensor pred = output[scaleIndex * 2 + 1];

                loss += ComputeLoss(pred, conv,
                                    targetLabels: batch.BBoxLabels[scaleIndex],
                                    targetBBoxes: batch.BBoxes[scaleIndex],
                                    strideSize: strides[scaleIndex],
                                    classCount: classCount,
                                    intersectionOverUnionLossThreshold: DefaultIntersectionOverUnionLossThreshold);
            }

            return(loss);
        }
Ejemplo n.º 3
0
 static Loss TestStep(ObjectDetectionDataset.EntryBatch batch, Model model, int classCount, ReadOnlySpan <int> strides)
 {
     return(ComputeLosses(model, batch, classCount, strides));
 }