public DataParallelExecutorManager(Symbol symbol, Context[] ctx, DataIter train_data, string[] arg_names,
                                           string[] param_names,
                                           string[] aux_names, int[] work_load_list = null, Logger logger = null, Func <int, Symbol> sym_gen = null)
        {
            num_device = ctx.Length;
            Logger.Info(string.Format("Start training with {0}", num_device));

            if (work_load_list == null)
            {
                work_load_list = new int[num_device];
                for (var i = 0; i < num_device; i++)
                {
                    work_load_list[i] = 1;
                }
            }
            else if (work_load_list.Length != num_device)
            {
                throw new MXNetException("Invalid setting for work load");
            }

            slices = ExecuterManager.SplitInputSlice(train_data.BatchSize, work_load_list);

            this.arg_names   = arg_names;
            this.param_names = param_names;
            this.aux_names   = aux_names;
            contexts         = ctx;
            execgrp          = new DataParallelExecutorGroup(symbol, arg_names, param_names, ctx, slices, train_data);
            this.symbol      = symbol;
            this.sym_gen     = sym_gen;
            if (sym_gen != null)
            {
                execgrp_bucket.Add(train_data.DefaultBucketKey, execgrp);
            }
        }
        public void LoadDataBatch(DataBatch data_batch)
        {
            if (sym_gen != null)
            {
                var key = data_batch.BucketKey.Value;
                if (execgrp_bucket.ContainsKey(key))
                {
                    symbol  = sym_gen(key);
                    execgrp = new DataParallelExecutorGroup(symbol, arg_names, param_names, contexts, slices,
                                                            NDArrayIter.FromBatch(data_batch), execgrp);
                    execgrp_bucket[key] = execgrp;
                }

                curr_execgrp = execgrp_bucket[key];
            }
            else
            {
                curr_execgrp = execgrp;
            }

            curr_execgrp.LoadDataBatch(data_batch);
        }
        public DataParallelExecutorGroup(Symbol sym, string[] arg_names, string[] param_names, Context[] ctxlist,
                                         Slice[] slices, DataIter train_data, DataParallelExecutorGroup shared_group = null)
        {
            ExecuterManager.CheckArguments(sym);

            if (shared_group == null)
            {
                foreach (var item in ctxlist)
                {
                    shared_data_arrays.Add(new NDArrayDict());
                }
            }
            else
            {
                shared_data_arrays = shared_group.shared_data_arrays;
            }

            foreach (var item in train_data.ProvideData)
            {
                data_names.Add(item.Name);
            }

            foreach (var item in train_data.ProvideLabel)
            {
                label_names.Add(item.Name);
            }

            aux_names = sym.ListAuxiliaryStates().ToList();
            for (var i = 0; i < arg_names.Length; i++)
            {
                if (param_names.Contains(arg_names[i]))
                {
                    param_idx.Add(i);
                    this.param_names.Add(arg_names[i]);
                }
            }

            for (var i = 0; i < ctxlist.Length; i++)
            {
                var data_shapes = new Dictionary <string, Shape>();
                var data_types  = new Dictionary <string, DType>();
                var shapeData   = new List <int>();
                foreach (var item in train_data.ProvideData)
                {
                    shapeData = item.Shape.Data.ToList();
                    shapeData.RemoveAt(0);
                    shapeData.Insert(0, slices[i].End.Value - slices[i].Begin);
                    data_shapes[item.Name] = new Shape(shapeData);
                    data_types[item.Name]  = item.DataType;
                }

                foreach (var item in train_data.ProvideLabel)
                {
                    shapeData = item.Shape.Data.ToList();
                    shapeData.RemoveAt(0);
                    shapeData.Insert(0, slices[i].End.Value - slices[i].Begin);
                    data_shapes[item.Name] = new Shape(shapeData);
                    data_types[item.Name]  = item.DataType;
                }

                var shared_exec = shared_group == null ? null : shared_group.train_execs[i];
                var train_exec  = ExecuterManager.BindExec(sym, ctxlist[i], data_shapes, param_names, true,
                                                           shared_exec, shared_data_arrays[i], data_types);

                train_execs.Add(train_exec);
            }

            foreach (var name in data_names)
            {
                for (var i = 0; i < train_execs.Count; i++)
                {
                    data_arrays.Add(train_execs[i].ArgmentDictionary()[name]);
                }
            }

            foreach (var name in label_names)
            {
                for (var i = 0; i < train_execs.Count; i++)
                {
                    label_arrays.Add(train_execs[i].ArgmentDictionary()[name]);
                }
            }

            foreach (var idx in param_idx)
            {
                for (var i = 0; i < train_execs.Count; i++)
                {
                    param_arrays.Add(train_execs[i].ArgmentArrays[idx]);
                }
            }

            foreach (var idx in param_idx)
            {
                for (var i = 0; i < train_execs.Count; i++)
                {
                    grad_arrays.Add(train_execs[i].GradientArrays[idx]);
                }
            }

            for (var idx = 0; idx < aux_names.Count; idx++)
            {
                for (var i = 0; i < train_execs.Count; i++)
                {
                    aux_arrays.Add(train_execs[i].AuxiliaryArrays[i]);
                }
            }

            this.slices = slices;
        }