コード例 #1
0
        public override int Run(string[] remainingArguments)
        {
            var trainable = YOLO.CreateV4Trainable(inputSize: this.InputSize,
                                                   classCount: this.ClassCount,
                                                   strides: this.Strides);

            trainable.load_weights(this.WeigthsPath);
            var    output  = YOLOv4.Output.Get(trainable);
            Tensor input   = trainable.input_dyn;
            var    savable = YOLO.CreateSaveable(inputSize: this.InputSize, input, output,
                                                 classCount: this.ClassCount,
                                                 strides: this.Strides,
                                                 anchors: tf.constant(this.Anchors),
                                                 xyScale: YOLOv4.XYScale,
                                                 scoreThreshold: this.ScoreThreshold);

            savable.summary();
            savable.save(this.OutputPath, save_format: "tf", include_optimizer: false);
            return(0);
        }
コード例 #2
0
ファイル: TrainV4.cs プロジェクト: molekm/YOLOv4
        public override int Run(string[] remainingArguments)
        {
            Trace.Listeners.Add(new ConsoleTraceListener(useErrorStream: true));

            tf.debugging.set_log_device_placement(this.LogDevicePlacement);

            if (this.GpuAllowGrowth)
            {
                dynamic config = config_pb2.ConfigProto.CreateInstance();
                config.gpu_options.allow_growth = true;
                tf.keras.backend.set_session(Session.NewDyn(config: config));
            }

            if (this.TestRun)
            {
                this.Annotations = this.Annotations.Take(this.BatchSize * 3).ToArray();
            }

            var dataset = new ObjectDetectionDataset(this.Annotations,
                                                     classNames: this.ClassNames,
                                                     strides: this.Strides,
                                                     inputSize: this.InputSize,
                                                     anchors: this.Anchors,
                                                     anchorsPerScale: this.AnchorsPerScale,
                                                     maxBBoxPerScale: this.MaxBBoxPerScale);
            var model = YOLO.CreateV4Trainable(dataset.InputSize, dataset.ClassNames.Length, dataset.Strides);

            var learningRateSchedule = new YOLO.LearningRateSchedule(
                totalSteps: (long)(this.FirstStageEpochs + this.SecondStageEpochs) * dataset.BatchCount(this.BatchSize),
                warmupSteps: this.WarmupEpochs * dataset.BatchCount(this.BatchSize));
            // https://github.com/AlexeyAB/darknet/issues/1845
            var optimizer = new Adam(learning_rate: learningRateSchedule, epsilon: 0.000001);

            if (this.ModelSummary)
            {
                model.summary();
            }
            if (this.WeightsPath != null)
            {
                model.load_weights(this.WeightsPath);
            }

            var callbacks = new List <ICallback> {
                new LearningRateLogger(),
                new TensorBoard(log_dir: this.LogDir, batch_size: this.BatchSize, profile_batch: 4),
            };

            if (!this.Benchmark && !this.TestRun)
            {
                callbacks.Add(new ModelCheckpoint("yoloV4.weights.{epoch:02d}", save_weights_only: true));
            }

            YOLO.TrainGenerator(model, optimizer, dataset, batchSize: this.BatchSize,
                                firstStageEpochs: this.FirstStageEpochs,
                                secondStageEpochs: this.SecondStageEpochs,
                                callbacks: callbacks);

            if (!this.Benchmark && !this.TestRun)
            {
                model.save_weights("yoloV4.weights-trained");
            }

            // the following does not work due to the need to name layers properly
            // https://stackoverflow.com/questions/61402903/unable-to-create-group-name-already-exists
            // model.save("yoloV4-trained");
            return(0);
        }