public DataParallelExecutorGroup(Symbol sym, IList <string> argNames, IList <string> paramNames, IList <Context> ctx, IList <Tuple <int, int> > slices, IDataIProvide trainData, DataParallelExecutorGroup sharedGroup = null) { Model.CheckArguments(sym); if (sharedGroup == null) { this._sharedDataArrays = ctx.Select(s => new Dictionary <string, NdArray>()).ToList(); } else { this._sharedDataArrays = sharedGroup._sharedDataArrays; } this._dataNames = trainData.ProvideData.Select(s => s.Key).ToList(); this._labelNames = trainData.ProvideLabel.Select(s => s.Key).ToList(); this._auxNames = sym.ListAuxiliaryStates(); this._paramIdx = argNames.Select((x, i) => new { x = x, i = i }).Where(w => paramNames.Contains(w.x)).Select(s => s.i).ToList(); this._paramNames = _paramIdx.Select(s => argNames[s]).ToList(); this.TrainExecs = new List <Executor>(); for (int i = 0; i < ctx.Count; i++) { var concat = trainData.ProvideData.Concat(trainData.ProvideLabel); var dataShapes = concat.ToDictionary(kvK => kvK.Key, kvV => { List <uint> tuple = new List <uint>(); tuple.Add((uint)(slices[i].Item2 - slices[i].Item1)); tuple.AddRange(kvV.Value.Data().Skip(1)); return(tuple.ToArray()); }); var sharedExec = sharedGroup == null ? null : sharedGroup.TrainExecs[i]; var trainExec = BindExec(sym, ctx[i], dataShapes, this._paramNames, needGradInput: true, baseExec: sharedExec, sharedDataArrays: this._sharedDataArrays[i]); this.TrainExecs.Add(trainExec); } this._dataArrays = _dataNames.Select(name => TrainExecs.Select((e, i) => Tuple.Create(slices[i], e.ArgDict[name])).ToList()).ToList(); this._labelArrays = _labelNames.Select(name => TrainExecs.Select((e, i) => Tuple.Create(slices[i], e.ArgDict[name])).ToList()).ToList(); this.ParamArrays = _paramIdx.Select(i => TrainExecs.Select((e) => e.ArgArrays[i]).ToList()).ToList(); this.GradArrays = _paramIdx.Select(i => TrainExecs.Select((e) => e.GradArrays[i]).ToList()).ToList(); this.AuxArrays = Enumerable.Range(0, this._auxNames.Count).Select(i => TrainExecs.Select((e) => e.AuxArrays[i]).ToList()).ToList(); this._slices = slices; }
public DataParallelExecutorGroup(Symbol sym, List <string> arg_names, List <string> param_names, List <Context> ctx, List <Tuple <int, int> > slices, IDataIProvide train_data, DataParallelExecutorGroup shared_group = null) { Util._check_arguments(sym); if (shared_group == null) { this._shared_data_arrays = ctx.Select(s => new Dictionary <string, NDArray>()).ToList(); } else { this._shared_data_arrays = shared_group._shared_data_arrays; } this._data_names = train_data.provide_data.Select(s => s.Key).ToList(); this._label_names = train_data.provide_label.Select(s => s.Key).ToList(); this._aux_names = sym.ListAuxiliaryStates(); this._param_idx = arg_names.Select((x, i) => new { x = x, i = i }).Where(w => param_names.Contains(w.x)).Select(s => s.i).ToList(); this._param_names = _param_idx.Select(s => arg_names[s]).ToList(); this.train_execs = new List <Executor>(); for (int i = 0; i < ctx.Count; i++) { var concat = train_data.provide_data.Concat(train_data.provide_label); var data_shapes = concat.ToDictionary(kv_k => kv_k.Key, kv_v => { List <uint> tuple = new List <uint>(); tuple.Add((uint)(slices[i].Item2 - slices[i].Item1)); tuple.AddRange(kv_v.Value.Data().Skip(1)); return(tuple.ToArray()); }); var shared_exec = shared_group == null ? null : shared_group.train_execs[i]; var train_exec = _bind_exec(sym, ctx[i], data_shapes, this._param_names, need_grad_input: true, base_exec: shared_exec, shared_data_arrays: this._shared_data_arrays[i]); this.train_execs.Add(train_exec); } this._data_arrays = _data_names.Select(name => train_execs.Select((e, i) => Tuple.Create(slices[i], e.arg_dict[name])).ToList()).ToList(); this._label_arrays = _label_names.Select(name => train_execs.Select((e, i) => Tuple.Create(slices[i], e.arg_dict[name])).ToList()).ToList(); this.param_arrays = _param_idx.Select(i => train_execs.Select((e) => e.arg_arrays[i]).ToList()).ToList(); this.grad_arrays = _param_idx.Select(i => train_execs.Select((e) => e.grad_arrays[i]).ToList()).ToList(); this._aux_arrays = Enumerable.Range(0, this._aux_names.Count).Select(i => train_execs.Select((e) => e.aux_arrays[i]).ToList()).ToList(); this._slices = slices; }
public DataParallelExecutorManager(Symbol symbol, SymbolGenerate symGen, IList <Context> ctx, IDataIter trainData, IList <string> paramNames, IList <string> argNames, IList <string> auxNames, IList <int> workLoadList, ILog logger) { if (logger == null) { logger = LogManager.GetLogger(""); } this._logger = logger; var numDevice = ctx.Count; logger.Info("Start training with " + string.Join("", ctx)); if (workLoadList == null) { workLoadList = Enumerable.Repeat(1, numDevice).ToList(); } Util.Assert(workLoadList.Count == numDevice, "Invalid settings for work load. "); var slices = SplitInputSlice(trainData.BatchSize, workLoadList); this._slices = slices; this._argNames = argNames; this.ParamNames = paramNames; this._auxNames = auxNames; this._ctx = ctx; this._execgrp = new DataParallelExecutorGroup(symbol, this._argNames, this.ParamNames, this._ctx, this._slices, trainData); this._symbol = symbol; this._symGen = symGen; this._currExecgrp = null; // this is set when data is loaded if (this._symGen != null) { this._execgrpBucket = new Dictionary <string, DataParallelExecutorGroup>() { { trainData.DefaultBucketKey, this._execgrp } }; } }
public DataParallelExecutorManager(Symbol symbol, SymbolGenerate sym_gen, List <Context> ctx, IDataIter train_data, List <string> param_names, List <string> arg_names, List <string> aux_names, List <int> work_load_list, ILog logger) { if (logger == null) { logger = LogManager.GetLogger(""); } this._logger = logger; var num_device = ctx.Count; logger.Info("Start training with " + string.Join("", ctx)); if (work_load_list == null) { work_load_list = Enumerable.Repeat(1, num_device).ToList(); } Util.Assert(work_load_list.Count == num_device, "Invalid settings for work load. "); var slices = _split_input_slice(train_data.batch_size, work_load_list); this._slices = slices; this._arg_names = arg_names; this.param_names = param_names; this._aux_names = aux_names; this._ctx = ctx; this._execgrp = new DataParallelExecutorGroup(symbol, this._arg_names, this.param_names, this._ctx, this._slices, train_data); this._symbol = symbol; this._sym_gen = sym_gen; this._curr_execgrp = null; // this is set when data is loaded if (this._sym_gen != null) { this._execgrp_bucket = new Dictionary <string, DataParallelExecutorGroup>() { { train_data.default_bucket_key, this._execgrp } }; } }
public void LoadDataBatch(IDataBatch dataBatch) { if (this._symGen != null) { var key = dataBatch.BucketKey; if (!_execgrpBucket.ContainsKey(key)) { //create new bucket entry var symbol = _symGen(key); var execgrp = new DataParallelExecutorGroup(symbol, this._argNames, this.ParamNames, this._ctx, this._slices, dataBatch, sharedGroup: this._execgrp); this._execgrpBucket[key] = execgrp; } this._currExecgrp = this._execgrpBucket[key]; } else { this._currExecgrp = this._execgrp; } this._currExecgrp.LoadDataBatch(dataBatch); }
public void load_data_batch(IDataBatch data_batch) { if (this._sym_gen != null) { var key = data_batch.bucket_key; if (!_execgrp_bucket.ContainsKey(key)) { //create new bucket entry var symbol = _sym_gen(key); var execgrp = new DataParallelExecutorGroup(symbol, this._arg_names, this.param_names, this._ctx, this._slices, data_batch, shared_group: this._execgrp); this._execgrp_bucket[key] = execgrp; } this._curr_execgrp = this._execgrp_bucket[key]; } else { this._curr_execgrp = this._execgrp; } this._curr_execgrp.load_data_batch(data_batch); }