public void Fit(Tensor x, Tensor y, uint batchSize = 32, uint epochs = 10, uint verbose = 1) { KerasProto kerasProto = new KerasProto(); kerasProto.Graph = _graph.ToString(Formatting.None); kerasProto.BatchSize = batchSize; kerasProto.Epochs = epochs; kerasProto.Verbose = verbose; kerasProto.Inputs.Add(x.GetProto()); kerasProto.Inputs.Add(y.GetProto()); kerasProto.Command = KerasCommand.Fit; _state = new ProgressCallbackState(new ProgressWriter(epochs, (uint)x.Sizes[0])); _callback = new ProgressCallback(_state.Callback); kerasProto.ProgressCallback = (ulong)Marshal.GetFunctionPointerForDelegate(_callback); using (var stream = new MemoryStream()) { kerasProto.WriteTo(stream); var bytes = stream.ToArray(); IntPtr outData = IntPtr.Zero; uint outLen = 0; ulong outPtr = 0; IntPtr exceptionData = IntPtr.Zero; uint exceptionLen = 0; ulong exceptionPtr = 0; KerasFitModel(bytes, (uint)bytes.Length, ref outData, ref outLen, ref outPtr, ref exceptionData, ref exceptionLen, ref exceptionPtr); if (exceptionLen == 0) { var resultBytes = new byte[outLen]; Marshal.Copy(outData, resultBytes, 0, (int)outLen); KerasDeletePointer(outPtr); var resultProto = KerasProto.Parser.ParseFrom(resultBytes); _model = resultProto.Model.ToArray(); } else { var outBytes = new byte[exceptionLen]; Marshal.Copy(exceptionData, outBytes, 0, (int)exceptionLen); var exception = new KerasException(Encoding.ASCII.GetString(outBytes)); KerasDeletePointer(exceptionPtr); throw exception; } } }
public void Fit(string path, uint nsamples, uint nfeatures, uint nlabels, uint batchSize, uint epochs, uint verbose = 1) { KerasProto kerasProto = new KerasProto(); kerasProto.Graph = _graph.ToString(Formatting.None); kerasProto.Path = path; kerasProto.Nsamples = nsamples; kerasProto.Nfeatures = nfeatures; kerasProto.Nlabels = nlabels; kerasProto.BatchSize = batchSize; kerasProto.Epochs = epochs; kerasProto.Verbose = verbose; using (var stream = new MemoryStream()) { kerasProto.WriteTo(stream); var bytes = stream.ToArray(); IntPtr outData = IntPtr.Zero; uint outLen = 0; ulong outPtr = 0; IntPtr exceptionData = IntPtr.Zero; uint exceptionLen = 0; ulong exceptionPtr = 0; KerasFitModel(bytes, (uint)bytes.Length, ref outData, ref outLen, ref outPtr, ref exceptionData, ref exceptionLen, ref exceptionPtr); // KerasCntkDll.KerasFitModel(bytes, (uint)bytes.Length, ref outData, ref outLen, ref outPtr, ref exceptionData, ref exceptionLen, ref exceptionPtr); if (exceptionLen == 0) { _model = new byte[outLen]; Marshal.Copy(outData, _model, 0, (int)outLen); KerasDeletePointer(outPtr); } else { var outBytes = new byte[exceptionLen]; Marshal.Copy(exceptionData, outBytes, 0, (int)exceptionLen); var exception = new KerasException(Encoding.ASCII.GetString(outBytes)); KerasDeletePointer(exceptionPtr); throw exception; } } }
public Tensor Predict(Tensor x, uint batchSize = 32, uint verbose = 1, bool cache = true) { KerasProto kerasProto = new KerasProto(); kerasProto.BatchSize = batchSize; kerasProto.Verbose = verbose; var jobj = new JObject() { ["cache"] = cache }; kerasProto.PredictParams = jobj.ToString(Formatting.None); // TODO Consider not copying the model if we have a uuid if (_model != null) { kerasProto.Model = ByteString.CopyFrom(_model); } kerasProto.ModelUuid = _uuid; kerasProto.ModelPath = _path; kerasProto.Inputs.Add(x.GetProto()); kerasProto.Command = KerasCommand.Predict; using (var stream = new MemoryStream()) { kerasProto.WriteTo(stream); var bytes = stream.ToArray(); IntPtr outData = IntPtr.Zero; uint outLen = 0; ulong outPtr = 0; IntPtr exceptionData = IntPtr.Zero; uint exceptionLen = 0; ulong exceptionPtr = 0; KerasFitModel(bytes, (uint)bytes.Length, ref outData, ref outLen, ref outPtr, ref exceptionData, ref exceptionLen, ref exceptionPtr); if (exceptionLen == 0) { var resultBytes = new byte[outLen]; Marshal.Copy(outData, resultBytes, 0, (int)outLen); KerasDeletePointer(outPtr); var resultProto = KerasProto.Parser.ParseFrom(resultBytes); _uuid = resultProto.ModelUuid; return(TensorUtils.Deserialize(resultProto.Outputs[0])); } else { var outBytes = new byte[exceptionLen]; Marshal.Copy(exceptionData, outBytes, 0, (int)exceptionLen); var exception = new KerasException(Encoding.ASCII.GetString(outBytes)); KerasDeletePointer(exceptionPtr); throw exception; } } }