Example #1
0
        public NDArrayList Call(NDArrayList args)
        {
            NativeMethods.MXInvokeCachedOpEx(handle, args.Length, MxUtil.GetNDArrayHandles(args), out var num_outputs,
                                             out var outputs, out var out_stypes);
            var result = new NDArrayList();

            for (var i = 0; i < num_outputs; i++)
            {
                result.Add(new NDArray(outputs[i]).ToSType((StorageStype)out_stypes[i]));
            }

            return(result.ToArray());
        }
Example #2
0
        public static NDArrayList Grad(NDArrayList heads, NDArrayList variables, NDArrayList head_grads = null,
                                       bool retain_graph = false, bool create_graph = true, bool train_mode = true)
        {
            var(head_handles, head_grads_handles) = ParseHead(heads, head_grads);

            //var grad_handles = new IntPtr[head_handles.Length];
            //var grad_stypes = new int[head_handles.Length];

            NativeMethods.MXAutogradBackwardEx(head_handles.Length, head_handles, head_grads_handles, variables.Length,
                                               MxUtil.GetNDArrayHandles(variables), Convert.ToInt32(retain_graph),
                                               Convert.ToInt32(create_graph), Convert.ToInt32(train_mode), out var grad_handles, out var grad_stypes);

            var result = new NDArrayList();

            foreach (var item in grad_handles)
            {
                result.Add(new NDArray(item));
            }

            return(result.ToArray());
        }
        public void UpdateMetric(EvalMetric metric, NDArrayList labels, bool pre_sliced = false)
        {
            var labels_slice = new NDArrayList();
            var i            = 0;

            train_execs.Zip(slices, (e, s) =>
            {
                if (!pre_sliced)
                {
                    foreach (var label in labels)
                    {
                        labels_slice.Add(label.Slice(s.Begin, s.End.Value));
                    }
                }
                else
                {
                    labels_slice.Add(labels[i]);
                }

                metric.Update(labels_slice.ToArray(), e.Outputs.ToArray());
                i++;
                return(true);
            });
        }
 public void LoadDataBatch(DataBatch data_batch)
 {
     ExecuterManager.LoadData(data_batch, data_arrays.ToArray());
     ExecuterManager.LoadData(data_batch, label_arrays.ToArray());
 }