public DnnTrainer(T net, Sgd sgd) { if (sgd == null) { throw new ArgumentNullException(nameof(sgd)); } sgd.ThrowIfDisposed(); this._Imp = CreateImp(net, sgd); this.NativePtr = this._Imp.NativePtr; }
private static TrainerImp <T> CreateImp(T net, Sgd sgd) { var t = typeof(T); if (t == typeof(LossMetric)) { return(new LossMetricTrainer(net.NativePtr, net.NetworkType, sgd) as TrainerImp <T>); } if (t == typeof(LossMmod)) { return(new LossMmodTrainer(net.NativePtr, net.NetworkType, sgd) as TrainerImp <T>); } if (t == typeof(LossMulticlassLog)) { return(new LossMulticlassLogTrainer(net.NativePtr, net.NetworkType, sgd) as TrainerImp <T>); } if (t == typeof(LossMulticlassLogPerPixel)) { return(new LossMulticlassLogPerPixelTrainer(net.NativePtr, net.NetworkType, sgd) as TrainerImp <T>); } throw new NotSupportedException(); }
public LossMulticlassLogPerPixelTrainer(IntPtr net, int type, Sgd sgd) { this.NetworkType = type; this.NativePtr = NativeMethods.LossMulticlassLogPerPixel_trainer_new2(type, net, sgd.NativePtr); }
public LossMmodTrainer(IntPtr net, int type, Sgd sgd) { this.NetworkType = type; this.NativePtr = NativeMethods.LossMmod_trainer_new2(type, net, sgd.NativePtr); }
public LossMulticlassLogPerPixelTrainer(IntPtr net, int type, Sgd sgd) { this.NetworkType = type; this.NativePtr = NativeMethods.dnn_trainer_loss_multiclass_log_per_pixel_new_sgd(net, type, sgd.NativePtr); }
public LossMmodTrainer(IntPtr net, int type, Sgd sgd) { this.NetworkType = type; this.NativePtr = NativeMethods.dnn_trainer_loss_mmod_new_sgd(net, type, sgd.NativePtr); }