public ComputeGraphTensor(IWeightFactory weightFactory, int deviceId, bool needBack = true, bool visNetwork = false, ConcurrentList <Action> backprop = null, bool isSubGraph = false) { m_backprop = backprop != null ? backprop : new ConcurrentList <Action>(); m_weightTensorFactory = weightFactory as WeightTensorFactory; m_needsBackprop = needBack; m_deviceId = deviceId; m_visNeuralNetwork = visNetwork; m_isSubGraph = isSubGraph; m_name2SubGraph = new Dictionary <string, Subgraph>(); if (m_visNeuralNetwork) { // Initialize parameters for neural network visualization m_opsViz = new Microsoft.Msagl.Drawing.Graph(); m_setEdges = new HashSet <string>(); } }
public void RunTopBackward() { backprop[backprop.Count - 1](); backprop.RemoveLastItem(); }