public override bool Next()
        {
            var r = NativeMethods.MXDataIterNext(this._BlobPtr.Handle, out var @out);

            Logging.CHECK_EQ(r, 0);
            return(@out > 0);
        }
        public override int GetPadNum()
        {
            var r = NativeMethods.MXDataIterGetPadNum(this._BlobPtr.Handle, out var @out);

            Logging.CHECK_EQ(r, 0);
            return(@out);
        }
        public override NDArray GetLabel()
        {
            var r = NativeMethods.MXDataIterGetLabel(this._BlobPtr.Handle, out var handle);

            Logging.CHECK_EQ(r, 0);
            return(new NDArray(handle));
        }
Beispiel #4
0
        public override void Update(NDArray labels, NDArray preds)
        {
            if (labels == null)
            {
                throw new ArgumentNullException(nameof(labels));
            }
            if (preds == null)
            {
                throw new ArgumentNullException(nameof(preds));
            }

            Logging.CHECK_EQ(labels.GetShape().Count, 1);

            var len       = labels.GetShape()[0];
            var predData  = new mx_float[len];
            var labelData = new mx_float[len];

            preds.ArgmaxChannel().SyncCopyToCPU(predData);
            labels.SyncCopyToCPU(labelData);

            for (var i = 0; i < len; ++i)
            {
                this.SumMetric += Math.Abs(predData[i] - labelData[i]) < float.Epsilon ? 1 : 0;
                this.NumInst   += 1;
            }
        }
Beispiel #5
0
        public MXDataIterMap()
        {
            var r = NativeMethods.MXListDataIters(out var numDataIterCreators, out var dataIterCreators);

            Logging.CHECK_EQ(r, 0);


            this._DataIterCreators = new Dictionary <string, DataIterHandle>((int)numDataIterCreators);

            var array = InteropHelper.ToPointerArray(dataIterCreators, numDataIterCreators);

            for (var i = 0; i < numDataIterCreators; i++)
            {
                r = NativeMethods.MXDataIterGetIterInfo(array[i],
                                                        out var name,
                                                        out var description,
                                                        out var num_args,
                                                        out var arg_names2,
                                                        out var arg_type_infos2,
                                                        out var arg_descriptions2);

                Logging.CHECK_EQ(r, 0);

                var str = Marshal.PtrToStringAnsi(name);
                this._DataIterCreators.Add(str, array[i]);
            }
        }
Beispiel #6
0
        public override void Update(int index, NDArray weight, NDArray grad)
        {
            if (weight == null)
            {
                throw new ArgumentNullException(nameof(weight));
            }
            if (grad == null)
            {
                throw new ArgumentNullException(nameof(grad));
            }

            if (!this._States.ContainsKey(index))
            {
                this.CreateState(index, weight);
            }

            this.Params["lr"] = this.GetLearningRate(index).ToString(CultureInfo.InvariantCulture);
            this.Params["wd"] = this.GetWeightDecay(index).ToString(CultureInfo.InvariantCulture);
            this.UpdateCount(index);
            var keys   = this.GetParamKeys_();
            var values = this.GetParamValues_();

            Logging.CHECK_EQ(keys.Length, values.Length);

            var inputs = new NDArrayHandle[3];

            inputs[0] = weight.GetHandle();
            inputs[1] = grad.GetHandle();

            var numOutputs = 1;
            var output     = weight.GetHandle();
            var outputs    = new[] { output };

            if (this._States[index] == null)
            {
                NativeMethods.MXImperativeInvoke(this._UpdateHandle,
                                                 2,
                                                 inputs,
                                                 ref numOutputs,
                                                 ref outputs,
                                                 keys.Length,
                                                 keys,
                                                 values);
            }
            else
            {
                inputs[2] = this._States[index].GetHandle();
                NativeMethods.MXImperativeInvoke(this._MomUpdateHandle,
                                                 3,
                                                 inputs,
                                                 ref numOutputs,
                                                 ref outputs,
                                                 keys.Length,
                                                 keys,
                                                 values);
            }
        }
 protected static void CheckLabelShapes(NDArray labels, NDArray preds, bool strict = false)
 {
     if (strict)
     {
         Logging.CHECK_EQ(new Shape(labels.GetShape()), new Shape(preds.GetShape()));
     }
     else
     {
         Logging.CHECK_EQ(labels.Size, preds.Size);
     }
 }
        public override void Update(int index, NDArray weight, NDArray grad)
        {
            if (weight == null)
            {
                throw new ArgumentNullException(nameof(weight));
            }
            if (grad == null)
            {
                throw new ArgumentNullException(nameof(grad));
            }

            if (!this._Mean.ContainsKey(index))
            {
                this.CreateState(index, weight);
            }

            this.Params["lr"] = this.GetLearningRate(index).ToString(CultureInfo.InvariantCulture);
            this.Params["wd"] = this.GetWeightDecay(index).ToString(CultureInfo.InvariantCulture);
            this.UpdateCount(index);
            var keys   = this.GetParamKeys_();
            var values = this.GetParamValues_();

            Logging.CHECK_EQ(keys.Length, values.Length);

            //var lr = double.Parse(params_["lr"]);
            //var wd = float.Parse(params_["wd"]);
            //var b1 = float.Parse(params_["beta1"]);
            //var b2 = float.Parse(params_["beta2"]);
            //var t = count_[index];
            //var coef1 = 1.0d - Math.Pow(b1, t);
            //var coef2 = 1.0d - Math.Pow(b2, t);
            //lr *= Math.Sqrt(coef2) / coef1;

            var inputs = new NDArrayHandle[4];

            inputs[0] = weight.GetHandle();
            inputs[1] = grad.GetHandle();
            inputs[2] = this._Mean[index].GetHandle();
            inputs[3] = this._Var[index].GetHandle();

            var numOutputs = 1;
            var output     = weight.GetHandle();
            var outputs    = new[] { output };

            NativeMethods.MXImperativeInvoke(this._UpdateHandle,
                                             4,
                                             inputs,
                                             ref numOutputs,
                                             ref outputs,
                                             keys.Length,
                                             keys,
                                             values);
        }
        public void Invoke(List <NDArray> outputs)
        {
            if (outputs == null)
            {
                throw new ArgumentNullException(nameof(outputs));
            }

            if (this._InputKeys.Count > 0)
            {
                Logging.CHECK_EQ(this._InputKeys.Count, this._InputSymbols.Count);
            }

            var keys        = this._Params.Keys.ToArray();
            var paramKeys   = new string[keys.Length];
            var paramValues = new string[keys.Length];

            for (var index = 0; index < keys.Length; index++)
            {
                var key = keys[index];
                paramKeys[index]   = key;
                paramValues[index] = this._Params[key];
            }

            var num_inputs  = this._InputNdarrays.Count;
            var num_outputs = outputs.Count;

            var output_handles = outputs.Select(array => array.NativePtr).ToArray();

            NDArrayHandle[] outputsReceiver = null;
            if (num_outputs > 0)
            {
                outputsReceiver = output_handles;
            }

            Logging.CHECK_EQ(NativeMethods.MXImperativeInvoke(this._Handle,
                                                              num_inputs,
                                                              this._InputNdarrays.ToArray(),
                                                              ref num_outputs,
                                                              ref outputsReceiver,
                                                              paramKeys.Length,
                                                              paramKeys,
                                                              paramValues), NativeMethods.OK);

            if (outputs.Count > 0)
            {
                return;
            }

            outputs.AddRange(outputsReceiver.Select(ptr => new NDArray(ptr)));
        }
Beispiel #10
0
        public static void Save(string fileName, IDictionary <string, NDArray> arrayMap)
        {
            var tmp = arrayMap.Keys.ToArray();

            var args = new NDArrayHandle[tmp.Length];
            var keys = new string[tmp.Length];

            for (var i = 0; i < tmp.Length; i++)
            {
                var kv = arrayMap[keys[i]];
                args[i] = kv.GetHandle();
                keys[i] = keys[i];
            }

            Logging.CHECK_EQ(NativeMethods.MXNDArraySave(fileName, (uint)args.Length, args, keys), NativeMethods.OK);
        }
        public override int[] GetIndex()
        {
            var r = NativeMethods.MXDataIterGetIndex(this._BlobPtr.Handle, out var outIndex, out var outSize);

            Logging.CHECK_EQ(r, 0);

            var outIndexArray = InteropHelper.ToUInt64Array(outIndex, (uint)outSize);
            var ret           = new int[outSize];

            for (var i = 0ul; i < outSize; ++i)
            {
                ret[i] = (int)outIndexArray[i];
            }

            return(ret);
        }
Beispiel #12
0
        public OpMap()
        {
            var r = NativeMethods.MXSymbolListAtomicSymbolCreators(out var numSymbolCreators, out var symbolCreators);

            Logging.CHECK_EQ(r, NativeMethods.OK);

            this._SymbolCreators = new Dictionary <string, AtomicSymbolCreator>((int)numSymbolCreators);

            var symbolCreatorsArray = InteropHelper.ToPointerArray(symbolCreators, numSymbolCreators);

            for (var i = 0; i < numSymbolCreators; i++)
            {
                var return_type = System.IntPtr.Zero;
                r = NativeMethods.MXSymbolGetAtomicSymbolInfo(symbolCreatorsArray[i],
                                                              out var name,
                                                              out var description,
                                                              out var numArgs,
                                                              out var argNames,
                                                              out var argTypeInfos,
                                                              out var argDescriptions,
                                                              out var nameBuilder,
                                                              ref return_type);
                Logging.CHECK_EQ(r, NativeMethods.OK);
                var str = Marshal.PtrToStringAnsi(name);
                this._SymbolCreators.Add(str, symbolCreatorsArray[i]);
            }

            r = NativeMethods.NNListAllOpNames(out var numOps, out var opNames);
            Logging.CHECK_EQ(r, NativeMethods.OK);

            this._OpHandles = new Dictionary <string, AtomicSymbolCreator>((int)numOps);

            var opNamesArray = InteropHelper.ToPointerArray(opNames, numOps);

            for (var i = 0; i < numOps; i++)
            {
                r = NativeMethods.NNGetOpHandle(opNamesArray[i], out var handle);
                Logging.CHECK_EQ(r, NativeMethods.OK);
                var str = Marshal.PtrToStringAnsi(opNamesArray[i]);
                this._OpHandles.Add(str, handle);
            }
        }
Beispiel #13
0
        private static IDictionary <string, NDArray> GetDictionary(IList <string> names, IList <NDArray> arrays)
        {
            var ret = new Dictionary <string, NDArray>();

            var set = new HashSet <string>();

            foreach (var s in names)
            {
                Logging.CHECK(set.Contains(s), $"Duplicate names detected, {s}");
                set.Add(s);
            }

            Logging.CHECK_EQ(set.Count, arrays.Count, "names size not equal to arrays size");
            for (var i = 0; i < names.Count; ++i)
            {
                ret[names[i]] = arrays[i];
            }

            return(ret);
        }
Beispiel #14
0
        public static IDictionary <string, NDArray> LoadToMap(string fileName)
        {
            var arrayMap = new SortedDictionary <string, NDArray>();

            Logging.CHECK_EQ(NativeMethods.MXNDArrayLoad(fileName,
                                                         out var outSize,
                                                         out var outArr,
                                                         out var outNameSize,
                                                         out var outNames), NativeMethods.OK);
            if (outNameSize > 0)
            {
                var array     = InteropHelper.ToPointerArray(outArr, outSize);
                var namearray = InteropHelper.ToPointerArray(outNames, outNameSize);

                Logging.CHECK_EQ(outNameSize, outSize);
                for (mx_uint i = 0; i < outSize; ++i)
                {
                    var name = Marshal.PtrToStringAnsi(namearray[i]);
                    arrayMap[name] = new NDArray(array[i]);
                }
            }

            return(arrayMap);
        }
Beispiel #15
0
        public Symbol CreateSymbol(string name = "")
        {
            if (this._InputKeys.Count > 0)
            {
                Logging.CHECK_EQ(this._InputKeys.Count, this._InputSymbols.Count);
            }

            var pname = name == "" ? null : name;

            var keys        = this._Params.Keys.ToArray();
            var paramKeys   = new string[keys.Length];
            var paramValues = new string[keys.Length];

            for (var index = 0; index < keys.Length; index++)
            {
                var key = keys[index];
                paramKeys[index]   = key;
                paramValues[index] = this._Params[key];
            }

            var inputKeys = this._InputKeys.Count != 0 ? this._InputKeys.ToArray() : null;

            Logging.CHECK_EQ(NativeMethods.MXSymbolCreateAtomicSymbol(this._Handle,
                                                                      (uint)paramKeys.Length,
                                                                      paramKeys,
                                                                      paramValues,
                                                                      out var symbolHandle), NativeMethods.OK);

            Logging.CHECK_EQ(NativeMethods.MXSymbolCompose(symbolHandle,
                                                           pname,
                                                           (uint)this._InputSymbols.Count,
                                                           inputKeys,
                                                           this._InputSymbols.ToArray()), NativeMethods.OK);

            return(new Symbol(symbolHandle));
        }
Beispiel #16
0
        public static void Save(string fileName, IList <NDArray> arrayList)
        {
            var args = arrayList.Select(array => array.GetHandle()).ToArray();

            Logging.CHECK_EQ(NativeMethods.MXNDArraySave(fileName, (uint)args.Length, args, null), NativeMethods.OK);
        }
 public static int Register(string name, Func <Optimizer> creator)
 {
     Logging.CHECK_EQ(cmap.ContainsKey(name), false, " already registered");
     cmap.Add(name, creator);
     return(0);
 }
Beispiel #18
0
 public void Save(string fileName)
 {
     Logging.CHECK_EQ(NativeMethods.MXSymbolSaveToFile(this.GetHandle(), fileName), NativeMethods.OK);
 }
        public override void BeforeFirst()
        {
            var r = NativeMethods.MXDataIterBeforeFirst(this._BlobPtr.Handle);

            Logging.CHECK_EQ(r, 0);
        }
Beispiel #20
0
 public static void WaitAll()
 {
     Logging.CHECK_EQ(NativeMethods.MXNDArrayWaitAll(), NativeMethods.OK);
 }
Beispiel #21
0
 public void WaitToWrite()
 {
     Logging.CHECK_EQ(NativeMethods.MXNDArrayWaitToWrite(this._Blob.Handle), NativeMethods.OK);
 }
Beispiel #22
0
 public static void MXNotifyShutdown()
 {
     Logging.CHECK_EQ(NativeMethods.MXNotifyShutdown(), NativeMethods.OK);
 }