public ComputeGraphTensor(IWeightFactory weightFactory, int deviceId, bool needBack = true, bool visNetwork = false, ConcurrentList <Action> backprop = null, bool isSubGraph = false)
        {
            m_backprop            = backprop != null ? backprop : new ConcurrentList <Action>();
            m_weightTensorFactory = weightFactory as WeightTensorFactory;
            m_needsBackprop       = needBack;
            m_deviceId            = deviceId;
            m_visNeuralNetwork    = visNetwork;
            m_isSubGraph          = isSubGraph;

            m_name2SubGraph = new Dictionary <string, Subgraph>();
            if (m_visNeuralNetwork)
            {
                // Initialize parameters for neural network visualization
                m_opsViz   = new Microsoft.Msagl.Drawing.Graph();
                m_setEdges = new HashSet <string>();
            }
        }
示例#2
0
        public void RunTopBackward()
        {
            backprop[backprop.Count - 1]();

            backprop.RemoveLastItem();
        }