public DataParallelExecutorManager(Symbol symbol, Context[] ctx, DataIter train_data, string[] arg_names, string[] param_names, string[] aux_names, int[] work_load_list = null, Logger logger = null, Func <int, Symbol> sym_gen = null) { num_device = ctx.Length; Logger.Info(string.Format("Start training with {0}", num_device)); if (work_load_list == null) { work_load_list = new int[num_device]; for (var i = 0; i < num_device; i++) { work_load_list[i] = 1; } } else if (work_load_list.Length != num_device) { throw new MXNetException("Invalid setting for work load"); } slices = ExecuterManager.SplitInputSlice(train_data.BatchSize, work_load_list); this.arg_names = arg_names; this.param_names = param_names; this.aux_names = aux_names; contexts = ctx; execgrp = new DataParallelExecutorGroup(symbol, arg_names, param_names, ctx, slices, train_data); this.symbol = symbol; this.sym_gen = sym_gen; if (sym_gen != null) { execgrp_bucket.Add(train_data.DefaultBucketKey, execgrp); } }
public DataParallelExecutorGroup(Symbol sym, string[] arg_names, string[] param_names, Context[] ctxlist, Slice[] slices, DataIter train_data, DataParallelExecutorGroup shared_group = null) { ExecuterManager.CheckArguments(sym); if (shared_group == null) { foreach (var item in ctxlist) { shared_data_arrays.Add(new NDArrayDict()); } } else { shared_data_arrays = shared_group.shared_data_arrays; } foreach (var item in train_data.ProvideData) { data_names.Add(item.Name); } foreach (var item in train_data.ProvideLabel) { label_names.Add(item.Name); } aux_names = sym.ListAuxiliaryStates().ToList(); for (var i = 0; i < arg_names.Length; i++) { if (param_names.Contains(arg_names[i])) { param_idx.Add(i); this.param_names.Add(arg_names[i]); } } for (var i = 0; i < ctxlist.Length; i++) { var data_shapes = new Dictionary <string, Shape>(); var data_types = new Dictionary <string, DType>(); var shapeData = new List <int>(); foreach (var item in train_data.ProvideData) { shapeData = item.Shape.Data.ToList(); shapeData.RemoveAt(0); shapeData.Insert(0, slices[i].End.Value - slices[i].Begin); data_shapes[item.Name] = new Shape(shapeData); data_types[item.Name] = item.DataType; } foreach (var item in train_data.ProvideLabel) { shapeData = item.Shape.Data.ToList(); shapeData.RemoveAt(0); shapeData.Insert(0, slices[i].End.Value - slices[i].Begin); data_shapes[item.Name] = new Shape(shapeData); data_types[item.Name] = item.DataType; } var shared_exec = shared_group == null ? null : shared_group.train_execs[i]; var train_exec = ExecuterManager.BindExec(sym, ctxlist[i], data_shapes, param_names, true, shared_exec, shared_data_arrays[i], data_types); train_execs.Add(train_exec); } foreach (var name in data_names) { for (var i = 0; i < train_execs.Count; i++) { data_arrays.Add(train_execs[i].ArgmentDictionary()[name]); } } foreach (var name in label_names) { for (var i = 0; i < train_execs.Count; i++) { label_arrays.Add(train_execs[i].ArgmentDictionary()[name]); } } foreach (var idx in param_idx) { for (var i = 0; i < train_execs.Count; i++) { param_arrays.Add(train_execs[i].ArgmentArrays[idx]); } } foreach (var idx in param_idx) { for (var i = 0; i < train_execs.Count; i++) { grad_arrays.Add(train_execs[i].GradientArrays[idx]); } } for (var idx = 0; idx < aux_names.Count; idx++) { for (var i = 0; i < train_execs.Count; i++) { aux_arrays.Add(train_execs[i].AuxiliaryArrays[i]); } } this.slices = slices; }
public void LoadDataBatch(DataBatch data_batch) { ExecuterManager.LoadData(data_batch, data_arrays.ToArray()); ExecuterManager.LoadData(data_batch, label_arrays.ToArray()); }