private void InitParallelTraining() { Options = Args.ToDictionary(Host); ParallelTraining = Args.ParallelTrainer != null?Args.ParallelTrainer.CreateComponent(Host) : new SingleTrainer(); if (ParallelTraining.ParallelType() != "serial" && ParallelTraining.NumMachines() > 1) { Options["tree_learner"] = ParallelTraining.ParallelType(); var otherParams = ParallelTraining.AdditionalParams(); if (otherParams != null) { foreach (var pair in otherParams) { Options[pair.Key] = pair.Value; } } Contracts.CheckValue(ParallelTraining.GetReduceScatterFunction(), nameof(ParallelTraining.GetReduceScatterFunction)); Contracts.CheckValue(ParallelTraining.GetAllgatherFunction(), nameof(ParallelTraining.GetAllgatherFunction)); LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.NetworkInitWithFunctions( ParallelTraining.NumMachines(), ParallelTraining.Rank(), ParallelTraining.GetReduceScatterFunction(), ParallelTraining.GetAllgatherFunction() )); } }
private protected LightGbmTrainerBase(IHostEnvironment env, LightGbmArguments args, string name) : base(env, name) { Host.CheckValue(args, nameof(args)); Args = args; Options = Args.ToDictionary(Host); ParallelTraining = Args.ParallelTrainer != null?Args.ParallelTrainer.CreateComponent(env) : new SingleTrainer(); InitParallelTraining(); }
protected LightGbmTrainerBase(IHostEnvironment env, LightGbmArguments args, PredictionKind predictionKind, string name) { Contracts.CheckValue(env, nameof(env)); env.CheckNonWhiteSpace(name, nameof(name)); Host = env.Register(name); Host.CheckValue(args, nameof(args)); Args = args; Options = Args.ToDictionary(Host); _predictionKind = predictionKind; _env = env; ParallelTraining = Args.ParallelTrainer != null?Args.ParallelTrainer.CreateComponent(env) : new SingleTrainer(); InitParallelTraining(); }