private unsafe TransformerEstimatorSafeHandle CreateTransformerFromEstimator(IDataView input) { IntPtr estimator; IntPtr errorHandle; bool success; var allColumns = input.Schema.Where(x => _allColumnNames.Contains(x.Name)).Select(x => TypedColumn.CreateTypedColumn(x, _dataColumns)).ToDictionary(x => x.Column.Name); // Create TypeId[] for types of grain and data columns; var dataColumnTypes = new TypeId[_dataColumns.Length]; var grainColumnTypes = new TypeId[_grainColumns.Length]; foreach (var column in _grainColumns.Select((value, index) => new { index, value })) { grainColumnTypes[column.index] = allColumns[column.value].GetTypeId(); } foreach (var column in _dataColumns.Select((value, index) => new { index, value })) { dataColumnTypes[column.index] = allColumns[column.value].GetTypeId(); fixed(bool *suppressErrors = &_suppressTypeErrors) fixed(TypeId * rawDataColumnTypes = dataColumnTypes) fixed(TypeId * rawGrainColumnTypes = grainColumnTypes) { success = CreateEstimatorNative(rawGrainColumnTypes, new IntPtr(grainColumnTypes.Length), rawDataColumnTypes, new IntPtr(dataColumnTypes.Length), _imputeMode, suppressErrors, out estimator, out errorHandle); } if (!success) { throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); } using (var estimatorHandle = new TransformerEstimatorSafeHandle(estimator, DestroyEstimatorNative)) { TrainingState trainingState; FitResult fitResult; // Create buffer to hold binary data var memoryStream = new MemoryStream(4096); var binaryWriter = new BinaryWriter(memoryStream, Encoding.UTF8); // Can't use a using with this because it potentially needs to be reset. Manually disposing as needed. var cursor = input.GetRowCursorForAllColumns(); // Initialize getters foreach (var column in allColumns.Values) { column.InitializeGetter(cursor); } // Start the loop with the cursor in a valid state already. var valid = cursor.MoveNext(); // Make sure its not an empty data frame Debug.Assert(valid); while (true) { // Get the state of the native estimator. success = GetStateNative(estimatorHandle, out trainingState, out errorHandle); if (!success) { throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); } // If we are no longer training then exit loop. if (trainingState != TrainingState.Training) { break; } // Build byte array to send column data to native featurizer BuildColumnByteArray(allColumns, ref binaryWriter); // Fit the estimator fixed(byte *bufferPointer = memoryStream.GetBuffer()) { var binaryArchiveData = new NativeBinaryArchiveData() { Data = bufferPointer, DataSize = new IntPtr(memoryStream.Position) }; success = FitNative(estimatorHandle, binaryArchiveData, out fitResult, out errorHandle); } // Reset memory stream to 0 memoryStream.Position = 0; if (!success) { throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); } // If we need to reset the data to the beginning. if (fitResult == FitResult.ResetAndContinue) { ResetCursor(input, ref cursor, allColumns); } // If we are at the end of the data. if (!cursor.MoveNext()) { // If we get here fitResult should never be ResetAndContinue Debug.Assert(fitResult != FitResult.ResetAndContinue); OnDataCompletedNative(estimatorHandle, out errorHandle); if (!success) { throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); } ResetCursor(input, ref cursor, allColumns); } } // When done training complete the estimator. success = CompleteTrainingNative(estimatorHandle, out errorHandle); if (!success) { throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); } // Create the native transformer from the estimator; success = CreateTransformerFromEstimatorNative(estimatorHandle, out IntPtr transformer, out errorHandle); if (!success) { throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); } // Manually dispose of the IEnumerator since we don't have a using statement; cursor.Dispose(); return(new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative)); } }
private unsafe TransformerEstimatorSafeHandle CreateTransformerFromEstimator(IDataView input) { IntPtr estimator; IntPtr errorHandle; bool success; var allColumns = input.Schema.Where(x => _allColumnNames.Contains(x.Name)).Select(x => TypedColumn.CreateTypedColumn(x, _dataColumns)).ToDictionary(x => x.Column.Name); // Create buffer to hold binary data var columnBuffer = new byte[4096]; // Create TypeId[] for types of grain and data columns; var dataColumnTypes = new TypeId[_dataColumns.Length]; var grainColumnTypes = new TypeId[_grainColumns.Length]; foreach (var column in _grainColumns.Select((value, index) => new { index, value })) { grainColumnTypes[column.index] = allColumns[column.value].GetTypeId(); } foreach (var column in _dataColumns.Select((value, index) => new { index, value })) { dataColumnTypes[column.index] = allColumns[column.value].GetTypeId(); fixed(bool *suppressErrors = &_suppressTypeErrors) fixed(TypeId * rawDataColumnTypes = dataColumnTypes) fixed(TypeId * rawGrainColumnTypes = grainColumnTypes) { success = CreateEstimatorNative(rawGrainColumnTypes, new IntPtr(grainColumnTypes.Length), rawDataColumnTypes, new IntPtr(dataColumnTypes.Length), _imputeMode, suppressErrors, out estimator, out errorHandle); } if (!success) { throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); } using (var estimatorHandler = new TransformerEstimatorSafeHandle(estimator, DestroyEstimatorNative)) { var fitResult = FitResult.Continue; while (fitResult != FitResult.Complete) { using (var cursor = input.GetRowCursorForAllColumns()) { // Initialize getters for start of loop foreach (var column in allColumns.Values) { column.InitializeGetter(cursor); } while ((fitResult == FitResult.Continue || fitResult == FitResult.ResetAndContinue) && cursor.MoveNext()) { BuildColumnByteArray(allColumns, ref columnBuffer, out int serializedDataLength); fixed(byte *bufferPointer = columnBuffer) { var binaryArchiveData = new NativeBinaryArchiveData() { Data = bufferPointer, DataSize = new IntPtr(serializedDataLength) }; success = FitNative(estimatorHandler, binaryArchiveData, out fitResult, out errorHandle); } if (!success) { throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); } } success = CompleteTrainingNative(estimatorHandler, out fitResult, out errorHandle); if (!success) { throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); } } } success = CreateTransformerFromEstimatorNative(estimatorHandler, out IntPtr transformer, out errorHandle); if (!success) { throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); } return(new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative)); } }