Exemplo n.º 1
0
        public DnnTrainer(T net, Sgd sgd)
        {
            if (sgd == null)
            {
                throw new ArgumentNullException(nameof(sgd));
            }

            sgd.ThrowIfDisposed();

            this._Imp = CreateImp(net, sgd);

            this.NativePtr = this._Imp.NativePtr;
        }
Exemplo n.º 2
0
        private static TrainerImp <T> CreateImp(T net, Sgd sgd)
        {
            var t = typeof(T);

            if (t == typeof(LossMetric))
            {
                return(new LossMetricTrainer(net.NativePtr, net.NetworkType, sgd) as TrainerImp <T>);
            }
            if (t == typeof(LossMmod))
            {
                return(new LossMmodTrainer(net.NativePtr, net.NetworkType, sgd) as TrainerImp <T>);
            }
            if (t == typeof(LossMulticlassLog))
            {
                return(new LossMulticlassLogTrainer(net.NativePtr, net.NetworkType, sgd) as TrainerImp <T>);
            }
            if (t == typeof(LossMulticlassLogPerPixel))
            {
                return(new LossMulticlassLogPerPixelTrainer(net.NativePtr, net.NetworkType, sgd) as TrainerImp <T>);
            }

            throw new NotSupportedException();
        }
Exemplo n.º 3
0
 public LossMulticlassLogPerPixelTrainer(IntPtr net, int type, Sgd sgd)
 {
     this.NetworkType = type;
     this.NativePtr   = NativeMethods.LossMulticlassLogPerPixel_trainer_new2(type, net, sgd.NativePtr);
 }
Exemplo n.º 4
0
 public LossMmodTrainer(IntPtr net, int type, Sgd sgd)
 {
     this.NetworkType = type;
     this.NativePtr   = NativeMethods.LossMmod_trainer_new2(type, net, sgd.NativePtr);
 }
Exemplo n.º 5
0
 public LossMulticlassLogPerPixelTrainer(IntPtr net, int type, Sgd sgd)
 {
     this.NetworkType = type;
     this.NativePtr   = NativeMethods.dnn_trainer_loss_multiclass_log_per_pixel_new_sgd(net, type, sgd.NativePtr);
 }
Exemplo n.º 6
0
 public LossMmodTrainer(IntPtr net, int type, Sgd sgd)
 {
     this.NetworkType = type;
     this.NativePtr   = NativeMethods.dnn_trainer_loss_mmod_new_sgd(net, type, sgd.NativePtr);
 }