예제 #1
0
        public unsafe void Train(IChannel ch, int rowCount, int colCount,
                                 RowCursor cursor, ValueGetter <float> labGetter,
                                 ValueGetter <uint> rowGetter, ValueGetter <uint> colGetter)
        {
            if (_pMFModel != null)
            {
                MFDestroyModel(ref _pMFModel);
                _host.Assert(_pMFModel == null);
            }

            MFProblem prob = new MFProblem();

            MFNode[] nodes = ConstructLabeledNodesFrom(ch, cursor, labGetter, rowGetter, colGetter, rowCount, colCount);

            fixed(MFNode *nodesPtr = &nodes[0])
            {
                prob.R   = nodesPtr;
                prob.M   = rowCount;
                prob.N   = colCount;
                prob.Nnz = nodes.Length;

                ch.Info("Training {0} by {1} problem on {2} examples",
                        prob.M, prob.N, prob.Nnz);

                fixed(MFParameter *pParam = &_mfParam)
                {
                    _pMFModel = MFTrain(&prob, pParam);
                }
            }
        }
        public unsafe void TrainWithValidation(IChannel ch, int rowCount, int colCount,
                                               DataViewRowCursor cursor, ValueGetter <float> labGetter,
                                               ValueGetter <uint> rowGetter, ValueGetter <uint> colGetter,
                                               DataViewRowCursor validCursor, ValueGetter <float> validLabGetter,
                                               ValueGetter <uint> validRowGetter, ValueGetter <uint> validColGetter)
        {
            if (_pMFModel != null)
            {
                MFDestroyModel(ref _pMFModel);
                _host.Assert(_pMFModel == null);
            }

            MFNode[]  nodes      = ConstructLabeledNodesFrom(ch, cursor, labGetter, rowGetter, colGetter, rowCount, colCount);
            MFNode[]  validNodes = ConstructLabeledNodesFrom(ch, validCursor, validLabGetter, validRowGetter, validColGetter, rowCount, colCount);
            MFProblem prob       = new MFProblem();
            MFProblem validProb  = new MFProblem();

            fixed(MFNode *nodesPtr = &nodes[0])
            fixed(MFNode * validNodesPtrs = &validNodes[0])
            {
                prob.R   = nodesPtr;
                prob.M   = rowCount;
                prob.N   = colCount;
                prob.Nnz = nodes.Length;

                validProb.R   = validNodesPtrs;
                validProb.M   = rowCount;
                validProb.N   = colCount;
                validProb.Nnz = nodes.Length;

                ch.Info("Training {0} by {1} problem on {2} examples with a {3} by {4} validation set including {5} examples",
                        prob.M, prob.N, prob.Nnz, validProb.M, validProb.N, validProb.Nnz);

                fixed(MFParameter *pParam = &_mfParam)
                {
                    _pMFModel = MFTrainWithValidation(&prob, &validProb, pParam);
                }
            }
        }