示例#1
0
        public DataParallelExecutorGroup(Symbol sym,
                                         IList <string> argNames,
                                         IList <string> paramNames,
                                         IList <Context> ctx,
                                         IList <Tuple <int, int> > slices,
                                         IDataIProvide trainData,
                                         DataParallelExecutorGroup sharedGroup = null)
        {
            Model.CheckArguments(sym);

            if (sharedGroup == null)
            {
                this._sharedDataArrays = ctx.Select(s => new Dictionary <string, NdArray>()).ToList();
            }
            else
            {
                this._sharedDataArrays = sharedGroup._sharedDataArrays;
            }
            this._dataNames  = trainData.ProvideData.Select(s => s.Key).ToList();
            this._labelNames = trainData.ProvideLabel.Select(s => s.Key).ToList();
            this._auxNames   = sym.ListAuxiliaryStates();

            this._paramIdx   = argNames.Select((x, i) => new { x = x, i = i }).Where(w => paramNames.Contains(w.x)).Select(s => s.i).ToList();
            this._paramNames = _paramIdx.Select(s => argNames[s]).ToList();

            this.TrainExecs = new List <Executor>();
            for (int i = 0; i < ctx.Count; i++)
            {
                var concat = trainData.ProvideData.Concat(trainData.ProvideLabel);

                var dataShapes = concat.ToDictionary(kvK => kvK.Key,
                                                     kvV =>
                {
                    List <uint> tuple = new List <uint>();
                    tuple.Add((uint)(slices[i].Item2 - slices[i].Item1));
                    tuple.AddRange(kvV.Value.Data().Skip(1));
                    return(tuple.ToArray());
                });

                var sharedExec = sharedGroup == null ? null : sharedGroup.TrainExecs[i];

                var trainExec = BindExec(sym, ctx[i], dataShapes, this._paramNames,
                                         needGradInput: true, baseExec: sharedExec,
                                         sharedDataArrays: this._sharedDataArrays[i]);

                this.TrainExecs.Add(trainExec);
            }


            this._dataArrays  = _dataNames.Select(name => TrainExecs.Select((e, i) => Tuple.Create(slices[i], e.ArgDict[name])).ToList()).ToList();
            this._labelArrays = _labelNames.Select(name => TrainExecs.Select((e, i) => Tuple.Create(slices[i], e.ArgDict[name])).ToList()).ToList();

            this.ParamArrays = _paramIdx.Select(i => TrainExecs.Select((e) => e.ArgArrays[i]).ToList()).ToList();

            this.GradArrays = _paramIdx.Select(i => TrainExecs.Select((e) => e.GradArrays[i]).ToList()).ToList();

            this.AuxArrays = Enumerable.Range(0, this._auxNames.Count).Select(i => TrainExecs.Select((e) => e.AuxArrays[i]).ToList()).ToList();

            this._slices = slices;
        }
示例#2
0
        public DataParallelExecutorGroup(Symbol sym,
                                         List <string> arg_names,
                                         List <string> param_names,
                                         List <Context> ctx,
                                         List <Tuple <int, int> > slices,
                                         IDataIProvide train_data,
                                         DataParallelExecutorGroup shared_group = null)
        {
            Util._check_arguments(sym);

            if (shared_group == null)
            {
                this._shared_data_arrays = ctx.Select(s => new Dictionary <string, NDArray>()).ToList();
            }
            else
            {
                this._shared_data_arrays = shared_group._shared_data_arrays;
            }
            this._data_names  = train_data.provide_data.Select(s => s.Key).ToList();
            this._label_names = train_data.provide_label.Select(s => s.Key).ToList();
            this._aux_names   = sym.ListAuxiliaryStates();

            this._param_idx   = arg_names.Select((x, i) => new { x = x, i = i }).Where(w => param_names.Contains(w.x)).Select(s => s.i).ToList();
            this._param_names = _param_idx.Select(s => arg_names[s]).ToList();

            this.train_execs = new List <Executor>();
            for (int i = 0; i < ctx.Count; i++)
            {
                var concat = train_data.provide_data.Concat(train_data.provide_label);

                var data_shapes = concat.ToDictionary(kv_k => kv_k.Key,
                                                      kv_v =>
                {
                    List <uint> tuple = new List <uint>();
                    tuple.Add((uint)(slices[i].Item2 - slices[i].Item1));
                    tuple.AddRange(kv_v.Value.Data().Skip(1));
                    return(tuple.ToArray());
                });

                var shared_exec = shared_group == null ? null : shared_group.train_execs[i];

                var train_exec = _bind_exec(sym, ctx[i], data_shapes, this._param_names,
                                            need_grad_input: true, base_exec: shared_exec,
                                            shared_data_arrays: this._shared_data_arrays[i]);

                this.train_execs.Add(train_exec);
            }


            this._data_arrays  = _data_names.Select(name => train_execs.Select((e, i) => Tuple.Create(slices[i], e.arg_dict[name])).ToList()).ToList();
            this._label_arrays = _label_names.Select(name => train_execs.Select((e, i) => Tuple.Create(slices[i], e.arg_dict[name])).ToList()).ToList();

            this.param_arrays = _param_idx.Select(i => train_execs.Select((e) => e.arg_arrays[i]).ToList()).ToList();

            this.grad_arrays = _param_idx.Select(i => train_execs.Select((e) => e.grad_arrays[i]).ToList()).ToList();

            this._aux_arrays = Enumerable.Range(0, this._aux_names.Count).Select(i => train_execs.Select((e) => e.aux_arrays[i]).ToList()).ToList();

            this._slices = slices;
        }