public DataParallelExecutorGroup(Symbol sym, IList <string> argNames, IList <string> paramNames, IList <Context> ctx, IList <Tuple <int, int> > slices, IDataIProvide trainData, DataParallelExecutorGroup sharedGroup = null) { Model.CheckArguments(sym); if (sharedGroup == null) { this._sharedDataArrays = ctx.Select(s => new Dictionary <string, NdArray>()).ToList(); } else { this._sharedDataArrays = sharedGroup._sharedDataArrays; } this._dataNames = trainData.ProvideData.Select(s => s.Key).ToList(); this._labelNames = trainData.ProvideLabel.Select(s => s.Key).ToList(); this._auxNames = sym.ListAuxiliaryStates(); this._paramIdx = argNames.Select((x, i) => new { x = x, i = i }).Where(w => paramNames.Contains(w.x)).Select(s => s.i).ToList(); this._paramNames = _paramIdx.Select(s => argNames[s]).ToList(); this.TrainExecs = new List <Executor>(); for (int i = 0; i < ctx.Count; i++) { var concat = trainData.ProvideData.Concat(trainData.ProvideLabel); var dataShapes = concat.ToDictionary(kvK => kvK.Key, kvV => { List <uint> tuple = new List <uint>(); tuple.Add((uint)(slices[i].Item2 - slices[i].Item1)); tuple.AddRange(kvV.Value.Data().Skip(1)); return(tuple.ToArray()); }); var sharedExec = sharedGroup == null ? null : sharedGroup.TrainExecs[i]; var trainExec = BindExec(sym, ctx[i], dataShapes, this._paramNames, needGradInput: true, baseExec: sharedExec, sharedDataArrays: this._sharedDataArrays[i]); this.TrainExecs.Add(trainExec); } this._dataArrays = _dataNames.Select(name => TrainExecs.Select((e, i) => Tuple.Create(slices[i], e.ArgDict[name])).ToList()).ToList(); this._labelArrays = _labelNames.Select(name => TrainExecs.Select((e, i) => Tuple.Create(slices[i], e.ArgDict[name])).ToList()).ToList(); this.ParamArrays = _paramIdx.Select(i => TrainExecs.Select((e) => e.ArgArrays[i]).ToList()).ToList(); this.GradArrays = _paramIdx.Select(i => TrainExecs.Select((e) => e.GradArrays[i]).ToList()).ToList(); this.AuxArrays = Enumerable.Range(0, this._auxNames.Count).Select(i => TrainExecs.Select((e) => e.AuxArrays[i]).ToList()).ToList(); this._slices = slices; }
public DataParallelExecutorGroup(Symbol sym, List <string> arg_names, List <string> param_names, List <Context> ctx, List <Tuple <int, int> > slices, IDataIProvide train_data, DataParallelExecutorGroup shared_group = null) { Util._check_arguments(sym); if (shared_group == null) { this._shared_data_arrays = ctx.Select(s => new Dictionary <string, NDArray>()).ToList(); } else { this._shared_data_arrays = shared_group._shared_data_arrays; } this._data_names = train_data.provide_data.Select(s => s.Key).ToList(); this._label_names = train_data.provide_label.Select(s => s.Key).ToList(); this._aux_names = sym.ListAuxiliaryStates(); this._param_idx = arg_names.Select((x, i) => new { x = x, i = i }).Where(w => param_names.Contains(w.x)).Select(s => s.i).ToList(); this._param_names = _param_idx.Select(s => arg_names[s]).ToList(); this.train_execs = new List <Executor>(); for (int i = 0; i < ctx.Count; i++) { var concat = train_data.provide_data.Concat(train_data.provide_label); var data_shapes = concat.ToDictionary(kv_k => kv_k.Key, kv_v => { List <uint> tuple = new List <uint>(); tuple.Add((uint)(slices[i].Item2 - slices[i].Item1)); tuple.AddRange(kv_v.Value.Data().Skip(1)); return(tuple.ToArray()); }); var shared_exec = shared_group == null ? null : shared_group.train_execs[i]; var train_exec = _bind_exec(sym, ctx[i], data_shapes, this._param_names, need_grad_input: true, base_exec: shared_exec, shared_data_arrays: this._shared_data_arrays[i]); this.train_execs.Add(train_exec); } this._data_arrays = _data_names.Select(name => train_execs.Select((e, i) => Tuple.Create(slices[i], e.arg_dict[name])).ToList()).ToList(); this._label_arrays = _label_names.Select(name => train_execs.Select((e, i) => Tuple.Create(slices[i], e.arg_dict[name])).ToList()).ToList(); this.param_arrays = _param_idx.Select(i => train_execs.Select((e) => e.arg_arrays[i]).ToList()).ToList(); this.grad_arrays = _param_idx.Select(i => train_execs.Select((e) => e.grad_arrays[i]).ToList()).ToList(); this._aux_arrays = Enumerable.Range(0, this._aux_names.Count).Select(i => train_execs.Select((e) => e.aux_arrays[i]).ToList()).ToList(); this._slices = slices; }