Пример #1
0
 private void InitWeightsFactory()
 {
     m_weightFactory = new IWeightFactory[m_deviceIds.Length];
     for (int i = 0; i < m_deviceIds.Length; i++)
     {
         m_weightFactory[i] = new WeightTensorFactory();
     }
 }
Пример #2
0
 private void InitWeightsFactory()
 {
     m_weightFactory = new IWeightFactory[m_deviceIds.Length];
     if (m_archType == ArchTypeEnums.GPU_CUDA)
     {
         for (int i = 0; i < m_deviceIds.Length; i++)
         {
             m_weightFactory[i] = new WeightTensorFactory();
         }
     }
     else
     {
         for (int i = 0; i < m_deviceIds.Length; i++)
         {
             m_weightFactory[i] = new WeightMatrixFactory();
         }
     }
 }
Пример #3
0
        public ComputeGraphTensor(IWeightFactory weightFactory, int deviceId, bool needBack = true, ConcurrentList <Action> backprop = null, bool isSubGraph = false)
        {
            m_backprop            = backprop != null ? backprop : new ConcurrentList <Action>();
            m_weightTensorFactory = weightFactory as WeightTensorFactory;
            m_needsBackprop       = needBack;
            m_deviceId            = deviceId;
            //m_visNeuralNetwork = visNetwork;
            m_isSubGraph = isSubGraph;

            //m_name2SubGraph = new Dictionary<string, Subgraph>();
            //if (m_visNeuralNetwork)
            //{
            //    // Initialize parameters for neural network visualization
            //    m_opsViz = new Microsoft.Msagl.Drawing.Graph();
            //    m_setEdges = new HashSet<string>();
            //}

            m_tensorsBindToCurrentGraph = new List <IWeightTensor>();
        }