public override void Train(RoleMappedData trainData)
 {
     using (var ch = Host.Start("Training"))
     {
         ch.CheckValue(trainData, nameof(trainData));
         trainData.CheckRegressionLabel();
         trainData.CheckFeatureFloatVector();
         trainData.CheckOptFloatWeight();
         FeatureCount = trainData.Schema.Feature.Type.ValueCount;
         ConvertData(trainData);
         TrainCore(ch);
         ch.Done();
     }
 }
示例#2
0
        private void ConvertData(RoleMappedData trainData, RoleMappedData validationData)
        {
            trainData.CheckFeatureFloatVector();
            trainData.CheckOptFloatWeight();
            CheckLabel(trainData);

            var useTranspose      = UseTranspose(GamTrainerOptions.DiskTranspose, trainData);
            var instanceConverter = new ExamplesToFastTreeBins(Host, GamTrainerOptions.MaximumBinCountPerFeature, useTranspose, !GamTrainerOptions.FeatureFlocks, GamTrainerOptions.MinimumExampleCountPerLeaf, float.PositiveInfinity);

            ParallelTraining.InitEnvironment();
            TrainSet   = instanceConverter.FindBinsAndReturnDataset(trainData, PredictionKind, ParallelTraining, null, false);
            FeatureMap = instanceConverter.FeatureMap;
            if (validationData != null)
            {
                ValidSet = instanceConverter.GetCompatibleDataset(validationData, PredictionKind, null, false);
            }
            Host.Assert(FeatureMap == null || FeatureMap.Length == TrainSet.NumFeatures);
        }
        /// <summary>
        /// The basic training calls the optimizer
        /// </summary>
        public override void Train(RoleMappedData data)
        {
            Contracts.CheckValue(data, nameof(data));

            data.CheckFeatureFloatVector(out NumFeatures);
            CheckLabel(data);
            data.CheckOptFloatWeight();

            if (NumFeatures >= Utils.ArrayMaxSize / ClassCount)
            {
                throw Contracts.ExceptParam(nameof(data),
                                            String.Format("The number of model parameters which is equal to ('# of features' + 1) * '# of classes' should be less than or equal to {0}.", Utils.ArrayMaxSize));
            }

            using (var ch = Host.Start("Training"))
            {
                TrainCore(ch, data);
                ch.Done();
            }
        }