private void InitWeightsFactory() { m_weightFactory = new IWeightFactory[m_deviceIds.Length]; for (int i = 0; i < m_deviceIds.Length; i++) { m_weightFactory[i] = new WeightTensorFactory(); } }
private void InitWeightsFactory() { m_weightFactory = new IWeightFactory[m_deviceIds.Length]; if (m_archType == ArchTypeEnums.GPU_CUDA) { for (int i = 0; i < m_deviceIds.Length; i++) { m_weightFactory[i] = new WeightTensorFactory(); } } else { for (int i = 0; i < m_deviceIds.Length; i++) { m_weightFactory[i] = new WeightMatrixFactory(); } } }
public ComputeGraphTensor(IWeightFactory weightFactory, int deviceId, bool needBack = true, ConcurrentList <Action> backprop = null, bool isSubGraph = false) { m_backprop = backprop != null ? backprop : new ConcurrentList <Action>(); m_weightTensorFactory = weightFactory as WeightTensorFactory; m_needsBackprop = needBack; m_deviceId = deviceId; //m_visNeuralNetwork = visNetwork; m_isSubGraph = isSubGraph; //m_name2SubGraph = new Dictionary<string, Subgraph>(); //if (m_visNeuralNetwork) //{ // // Initialize parameters for neural network visualization // m_opsViz = new Microsoft.Msagl.Drawing.Graph(); // m_setEdges = new HashSet<string>(); //} m_tensorsBindToCurrentGraph = new List <IWeightTensor>(); }