Ejemplo n.º 1
0
        public void Fit(Tensor x, Tensor y, uint batchSize = 32, uint epochs = 10, uint verbose = 1)
        {
            KerasProto kerasProto = new KerasProto();

            kerasProto.Graph = _graph.ToString(Formatting.None);

            kerasProto.BatchSize = batchSize;
            kerasProto.Epochs    = epochs;

            kerasProto.Verbose = verbose;

            kerasProto.Inputs.Add(x.GetProto());
            kerasProto.Inputs.Add(y.GetProto());

            kerasProto.Command = KerasCommand.Fit;

            _state    = new ProgressCallbackState(new ProgressWriter(epochs, (uint)x.Sizes[0]));
            _callback = new ProgressCallback(_state.Callback);

            kerasProto.ProgressCallback = (ulong)Marshal.GetFunctionPointerForDelegate(_callback);

            using (var stream = new MemoryStream())
            {
                kerasProto.WriteTo(stream);
                var bytes = stream.ToArray();

                IntPtr outData = IntPtr.Zero;
                uint   outLen  = 0;
                ulong  outPtr  = 0;

                IntPtr exceptionData = IntPtr.Zero;
                uint   exceptionLen  = 0;
                ulong  exceptionPtr  = 0;

                KerasFitModel(bytes, (uint)bytes.Length, ref outData, ref outLen, ref outPtr, ref exceptionData, ref exceptionLen, ref exceptionPtr);

                if (exceptionLen == 0)
                {
                    var resultBytes = new byte[outLen];
                    Marshal.Copy(outData, resultBytes, 0, (int)outLen);
                    KerasDeletePointer(outPtr);

                    var resultProto = KerasProto.Parser.ParseFrom(resultBytes);
                    _model = resultProto.Model.ToArray();
                }
                else
                {
                    var outBytes = new byte[exceptionLen];
                    Marshal.Copy(exceptionData, outBytes, 0, (int)exceptionLen);

                    var exception = new KerasException(Encoding.ASCII.GetString(outBytes));
                    KerasDeletePointer(exceptionPtr);

                    throw exception;
                }
            }
        }
Ejemplo n.º 2
0
        public void Fit(string path, uint nsamples, uint nfeatures, uint nlabels, uint batchSize, uint epochs, uint verbose = 1)
        {
            KerasProto kerasProto = new KerasProto();

            kerasProto.Graph     = _graph.ToString(Formatting.None);
            kerasProto.Path      = path;
            kerasProto.Nsamples  = nsamples;
            kerasProto.Nfeatures = nfeatures;
            kerasProto.Nlabels   = nlabels;

            kerasProto.BatchSize = batchSize;
            kerasProto.Epochs    = epochs;

            kerasProto.Verbose = verbose;

            using (var stream = new MemoryStream())
            {
                kerasProto.WriteTo(stream);
                var bytes = stream.ToArray();

                IntPtr outData = IntPtr.Zero;
                uint   outLen  = 0;
                ulong  outPtr  = 0;

                IntPtr exceptionData = IntPtr.Zero;
                uint   exceptionLen  = 0;
                ulong  exceptionPtr  = 0;

                KerasFitModel(bytes, (uint)bytes.Length, ref outData, ref outLen, ref outPtr, ref exceptionData, ref exceptionLen, ref exceptionPtr);
                // KerasCntkDll.KerasFitModel(bytes, (uint)bytes.Length, ref outData, ref outLen, ref outPtr, ref exceptionData, ref exceptionLen, ref exceptionPtr);

                if (exceptionLen == 0)
                {
                    _model = new byte[outLen];
                    Marshal.Copy(outData, _model, 0, (int)outLen);

                    KerasDeletePointer(outPtr);
                }
                else
                {
                    var outBytes = new byte[exceptionLen];
                    Marshal.Copy(exceptionData, outBytes, 0, (int)exceptionLen);

                    var exception = new KerasException(Encoding.ASCII.GetString(outBytes));
                    KerasDeletePointer(exceptionPtr);

                    throw exception;
                }
            }
        }
Ejemplo n.º 3
0
        public Tensor Predict(Tensor x, uint batchSize = 32, uint verbose = 1, bool cache = true)
        {
            KerasProto kerasProto = new KerasProto();

            kerasProto.BatchSize = batchSize;
            kerasProto.Verbose   = verbose;

            var jobj = new JObject()
            {
                ["cache"] = cache
            };

            kerasProto.PredictParams = jobj.ToString(Formatting.None);

            // TODO Consider not copying the model if we have a uuid
            if (_model != null)
            {
                kerasProto.Model = ByteString.CopyFrom(_model);
            }
            kerasProto.ModelUuid = _uuid;
            kerasProto.ModelPath = _path;
            kerasProto.Inputs.Add(x.GetProto());

            kerasProto.Command = KerasCommand.Predict;

            using (var stream = new MemoryStream())
            {
                kerasProto.WriteTo(stream);
                var bytes = stream.ToArray();

                IntPtr outData = IntPtr.Zero;
                uint   outLen  = 0;
                ulong  outPtr  = 0;

                IntPtr exceptionData = IntPtr.Zero;
                uint   exceptionLen  = 0;
                ulong  exceptionPtr  = 0;

                KerasFitModel(bytes, (uint)bytes.Length, ref outData, ref outLen, ref outPtr, ref exceptionData, ref exceptionLen, ref exceptionPtr);

                if (exceptionLen == 0)
                {
                    var resultBytes = new byte[outLen];
                    Marshal.Copy(outData, resultBytes, 0, (int)outLen);

                    KerasDeletePointer(outPtr);

                    var resultProto = KerasProto.Parser.ParseFrom(resultBytes);
                    _uuid = resultProto.ModelUuid;
                    return(TensorUtils.Deserialize(resultProto.Outputs[0]));
                }
                else
                {
                    var outBytes = new byte[exceptionLen];
                    Marshal.Copy(exceptionData, outBytes, 0, (int)exceptionLen);

                    var exception = new KerasException(Encoding.ASCII.GetString(outBytes));
                    KerasDeletePointer(exceptionPtr);

                    throw exception;
                }
            }
        }