Example #1
0
        private unsafe TransformerEstimatorSafeHandle CreateTransformerFromEstimator(IDataView input)
        {
            IntPtr estimator;
            IntPtr errorHandle;
            bool   success;

            var allColumns = input.Schema.Where(x => _allColumnNames.Contains(x.Name)).Select(x => TypedColumn.CreateTypedColumn(x, _dataColumns)).ToDictionary(x => x.Column.Name);

            // Create TypeId[] for types of grain and data columns;
            var dataColumnTypes  = new TypeId[_dataColumns.Length];
            var grainColumnTypes = new TypeId[_grainColumns.Length];

            foreach (var column in _grainColumns.Select((value, index) => new { index, value }))
            {
                grainColumnTypes[column.index] = allColumns[column.value].GetTypeId();
            }

            foreach (var column in _dataColumns.Select((value, index) => new { index, value }))
            {
                dataColumnTypes[column.index] = allColumns[column.value].GetTypeId();

                fixed(bool *suppressErrors = &_suppressTypeErrors)
                fixed(TypeId * rawDataColumnTypes  = dataColumnTypes)
                fixed(TypeId * rawGrainColumnTypes = grainColumnTypes)
                {
                    success = CreateEstimatorNative(rawGrainColumnTypes, new IntPtr(grainColumnTypes.Length), rawDataColumnTypes, new IntPtr(dataColumnTypes.Length), _imputeMode, suppressErrors, out estimator, out errorHandle);
                }
                if (!success)
                {
                    throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
                }

                using (var estimatorHandle = new TransformerEstimatorSafeHandle(estimator, DestroyEstimatorNative))
                {
                    TrainingState trainingState;
                    FitResult     fitResult;

                    // Create buffer to hold binary data
                    var memoryStream = new MemoryStream(4096);
                    var binaryWriter = new BinaryWriter(memoryStream, Encoding.UTF8);

                    // Can't use a using with this because it potentially needs to be reset. Manually disposing as needed.
                    var cursor = input.GetRowCursorForAllColumns();
                    // Initialize getters
                    foreach (var column in allColumns.Values)
                    {
                        column.InitializeGetter(cursor);
                    }

                    // Start the loop with the cursor in a valid state already.
                    var valid = cursor.MoveNext();

                    // Make sure its not an empty data frame
                    Debug.Assert(valid);
                    while (true)
                    {
                        // Get the state of the native estimator.
                        success = GetStateNative(estimatorHandle, out trainingState, out errorHandle);
                        if (!success)
                        {
                            throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
                        }

                        // If we are no longer training then exit loop.
                        if (trainingState != TrainingState.Training)
                        {
                            break;
                        }

                        // Build byte array to send column data to native featurizer
                        BuildColumnByteArray(allColumns, ref binaryWriter);

                        // Fit the estimator
                        fixed(byte *bufferPointer = memoryStream.GetBuffer())
                        {
                            var binaryArchiveData = new NativeBinaryArchiveData()
                            {
                                Data = bufferPointer, DataSize = new IntPtr(memoryStream.Position)
                            };

                            success = FitNative(estimatorHandle, binaryArchiveData, out fitResult, out errorHandle);
                        }

                        // Reset memory stream to 0
                        memoryStream.Position = 0;

                        if (!success)
                        {
                            throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
                        }

                        // If we need to reset the data to the beginning.
                        if (fitResult == FitResult.ResetAndContinue)
                        {
                            ResetCursor(input, ref cursor, allColumns);
                        }

                        // If we are at the end of the data.
                        if (!cursor.MoveNext())
                        {
                            // If we get here fitResult should never be ResetAndContinue
                            Debug.Assert(fitResult != FitResult.ResetAndContinue);

                            OnDataCompletedNative(estimatorHandle, out errorHandle);
                            if (!success)
                            {
                                throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
                            }

                            ResetCursor(input, ref cursor, allColumns);
                        }
                    }

                    // When done training complete the estimator.
                    success = CompleteTrainingNative(estimatorHandle, out errorHandle);
                    if (!success)
                    {
                        throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
                    }

                    // Create the native transformer from the estimator;
                    success = CreateTransformerFromEstimatorNative(estimatorHandle, out IntPtr transformer, out errorHandle);
                    if (!success)
                    {
                        throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
                    }

                    // Manually dispose of the IEnumerator since we don't have a using statement;
                    cursor.Dispose();

                    return(new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative));
                }
        }
Example #2
0
        private unsafe TransformerEstimatorSafeHandle CreateTransformerFromEstimator(IDataView input)
        {
            IntPtr estimator;
            IntPtr errorHandle;
            bool   success;

            var allColumns = input.Schema.Where(x => _allColumnNames.Contains(x.Name)).Select(x => TypedColumn.CreateTypedColumn(x, _dataColumns)).ToDictionary(x => x.Column.Name);

            // Create buffer to hold binary data
            var columnBuffer = new byte[4096];

            // Create TypeId[] for types of grain and data columns;
            var dataColumnTypes  = new TypeId[_dataColumns.Length];
            var grainColumnTypes = new TypeId[_grainColumns.Length];

            foreach (var column in _grainColumns.Select((value, index) => new { index, value }))
            {
                grainColumnTypes[column.index] = allColumns[column.value].GetTypeId();
            }

            foreach (var column in _dataColumns.Select((value, index) => new { index, value }))
            {
                dataColumnTypes[column.index] = allColumns[column.value].GetTypeId();

                fixed(bool *suppressErrors = &_suppressTypeErrors)
                fixed(TypeId * rawDataColumnTypes  = dataColumnTypes)
                fixed(TypeId * rawGrainColumnTypes = grainColumnTypes)
                {
                    success = CreateEstimatorNative(rawGrainColumnTypes, new IntPtr(grainColumnTypes.Length), rawDataColumnTypes, new IntPtr(dataColumnTypes.Length), _imputeMode, suppressErrors, out estimator, out errorHandle);
                }
                if (!success)
                {
                    throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
                }

                using (var estimatorHandler = new TransformerEstimatorSafeHandle(estimator, DestroyEstimatorNative))
                {
                    var fitResult = FitResult.Continue;
                    while (fitResult != FitResult.Complete)
                    {
                        using (var cursor = input.GetRowCursorForAllColumns())
                        {
                            // Initialize getters for start of loop
                            foreach (var column in allColumns.Values)
                            {
                                column.InitializeGetter(cursor);
                            }

                            while ((fitResult == FitResult.Continue || fitResult == FitResult.ResetAndContinue) && cursor.MoveNext())
                            {
                                BuildColumnByteArray(allColumns, ref columnBuffer, out int serializedDataLength);

                                fixed(byte *bufferPointer = columnBuffer)
                                {
                                    var binaryArchiveData = new NativeBinaryArchiveData()
                                    {
                                        Data = bufferPointer, DataSize = new IntPtr(serializedDataLength)
                                    };

                                    success = FitNative(estimatorHandler, binaryArchiveData, out fitResult, out errorHandle);
                                }

                                if (!success)
                                {
                                    throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
                                }
                            }

                            success = CompleteTrainingNative(estimatorHandler, out fitResult, out errorHandle);
                            if (!success)
                            {
                                throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
                            }
                        }
                    }

                    success = CreateTransformerFromEstimatorNative(estimatorHandler, out IntPtr transformer, out errorHandle);
                    if (!success)
                    {
                        throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
                    }

                    return(new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative));
                }
        }