/// <summary> /// This method puts all of the buffered instances in array of pointers to pass it to SymSGDNative. /// </summary> /// <param name="inputDataManager">The buffered data</param> /// <param name="tuneLR">Specifies if SymSGD should tune alpha automatically</param> /// <param name="lr">Initial learning rate</param> /// <param name="l2Const"></param> /// <param name="piw"></param> /// <param name="weightVector">The storage for the weight vector</param> /// <param name="bias">bias</param> /// <param name="numFeatres">Number of features</param> /// <param name="numPasses">Number of passes</param> /// <param name="numThreads">Number of threads</param> /// <param name="tuneNumLocIter">Specifies if SymSGD should tune numLocIter automatically</param> /// <param name="numLocIter">Number of thread local iterations of SGD before combining with the global model</param> /// <param name="tolerance">Tolerance for the amount of decrease in the total loss in consecutive passes</param> /// <param name="needShuffle">Specifies if data needs to be shuffled</param> /// <param name="shouldInitialize">Specifies if this is the first time to run SymSGD</param> /// <param name="stateGCHandle"></param> /// <param name="info"></param> public static void LearnAll(InputDataManager inputDataManager, bool tuneLR, ref float lr, float l2Const, float piw, Span <float> weightVector, ref float bias, int numFeatres, int numPasses, int numThreads, bool tuneNumLocIter, ref int numLocIter, float tolerance, bool needShuffle, bool shouldInitialize, GCHandle stateGCHandle, ChannelCallBack info) { inputDataManager.PrepareCursoring(); int totalNumInstances = inputDataManager.Count; // Each instance has a pointer to indices array and a pointer to values array int *[] arrayIndicesPointers = new int *[totalNumInstances]; float *[] arrayValuesPointers = new float *[totalNumInstances]; // Labels of the instances float[] instLabels = new float[totalNumInstances]; // Sizes of each inst int[] instSizes = new int[totalNumInstances]; int instanceIndex = 0; // Going through the buffer to set the properties and the pointers while (inputDataManager.GiveNextInstance(out InstanceProperties? prop, out GCHandle? indicesGcHandle, out int indicesStartIndex, out GCHandle? valuesGcHandle, out int valuesStartIndex)) { if (prop.Value.IsDense) { arrayIndicesPointers[instanceIndex] = null; } else { int *pIndicesArray = (int *)indicesGcHandle.Value.AddrOfPinnedObject(); arrayIndicesPointers[instanceIndex] = &pIndicesArray[indicesStartIndex]; } float *pValuesArray = (float *)valuesGcHandle.Value.AddrOfPinnedObject(); arrayValuesPointers[instanceIndex] = &pValuesArray[valuesStartIndex]; instLabels[instanceIndex] = prop.Value.Label; instSizes[instanceIndex] = prop.Value.FeatureCount; instanceIndex++; } fixed(float *pweightVector = &weightVector[0]) fixed(int **pIndicesPointer = &arrayIndicesPointers[0]) fixed(float **pValuesPointer = &arrayValuesPointers[0]) fixed(int *pInstSizes = &instSizes[0]) fixed(float *pInstLabels = &instLabels[0]) { LearnAll(totalNumInstances, pInstSizes, pIndicesPointer, pValuesPointer, pInstLabels, tuneLR, ref lr, l2Const, piw, pweightVector, ref bias, numFeatres, numPasses, numThreads, tuneNumLocIter, ref numLocIter, tolerance, needShuffle, shouldInitialize, (State *)stateGCHandle.AddrOfPinnedObject(), info); } }
private static extern void LearnAll(int totalNumInstances, int *instSizes, int **instIndices, float **instValues, float *labels, bool tuneLR, ref float lr, float l2Const, float piw, float *weightVector, ref float bias, int numFeatres, int numPasses, int numThreads, bool tuneNumLocIter, ref int numLocIter, float tolerance, bool needShuffle, bool shouldInitialize, State *state, ChannelCallBack info);