public void TestEstimatorSymSgdInitPredictor()
        {
            (var pipe, var dataView) = GetBinaryClassificationPipeline();
            var transformedData = pipe.Fit(dataView).Transform(dataView);

            var initPredictor = ML.BinaryClassification.Trainers.StochasticDualCoordinateAscent().Fit(transformedData);
            var data          = initPredictor.Transform(transformedData);

            var withInitPredictor = new SymSgdClassificationTrainer(Env, new SymSgdClassificationTrainer.Options()).Fit(transformedData,
                                                                                                                        modelParameters: initPredictor.Model.SubModel);
            var outInitData = withInitPredictor.Transform(transformedData);

            var notInitPredictor = new SymSgdClassificationTrainer(Env, new SymSgdClassificationTrainer.Options()).Fit(transformedData);
            var outNoInitData    = notInitPredictor.Transform(transformedData);

            int numExamples = 10;
            var col1        = data.GetColumn <float>(data.Schema["Score"]).Take(numExamples).ToArray();
            var col2        = outInitData.GetColumn <float>(outInitData.Schema["Score"]).Take(numExamples).ToArray();
            var col3        = outNoInitData.GetColumn <float>(outNoInitData.Schema["Score"]).Take(numExamples).ToArray();

            bool col12Diff = default;
            bool col23Diff = default;
            bool col13Diff = default;

            for (int i = 0; i < numExamples; i++)
            {
                col12Diff = col12Diff || (col1[i] != col2[i]);
                col23Diff = col23Diff || (col2[i] != col3[i]);
                col13Diff = col13Diff || (col1[i] != col3[i]);
            }
            Contracts.Assert(col12Diff && col23Diff && col13Diff);
            Done();
        }
        public void TestEstimatorSymSgdInitPredictor()
        {
            (var pipe, var dataView) = GetBinaryClassificationPipeline();
            var transformedData = pipe.Fit(dataView).Transform(dataView);

            var initPredictor = new LinearClassificationTrainer(Env, "Features", "Label").Fit(transformedData);
            var data          = initPredictor.Transform(transformedData);

            var withInitPredictor = new SymSgdClassificationTrainer(Env, "Features", "Label").Train(transformedData, initialPredictor: initPredictor.Model);
            var outInitData       = withInitPredictor.Transform(transformedData);

            var notInitPredictor = new SymSgdClassificationTrainer(Env, "Features", "Label").Train(transformedData);
            var outNoInitData    = notInitPredictor.Transform(transformedData);

            int numExamples = 10;
            var col1        = data.GetColumn <float>(Env, "Score").Take(numExamples).ToArray();
            var col2        = outInitData.GetColumn <float>(Env, "Score").Take(numExamples).ToArray();
            var col3        = outNoInitData.GetColumn <float>(Env, "Score").Take(numExamples).ToArray();

            bool col12Diff = default;
            bool col23Diff = default;
            bool col13Diff = default;

            for (int i = 0; i < numExamples; i++)
            {
                col12Diff = col12Diff || (col1[i] != col2[i]);
                col23Diff = col23Diff || (col2[i] != col3[i]);
                col13Diff = col13Diff || (col1[i] != col3[i]);
            }
            Contracts.Assert(col12Diff && col23Diff && col13Diff);
            Done();
        }
        public void TestEstimatorSymSgdClassificationTrainer()
        {
            (var pipe, var dataView) = GetBinaryClassificationPipeline();
            var trainer         = new SymSgdClassificationTrainer(Env, new SymSgdClassificationTrainer.Options());
            var pipeWithTrainer = pipe.Append(trainer);

            TestEstimatorCore(pipeWithTrainer, dataView);

            var transformedDataView = pipe.Fit(dataView).Transform(dataView);
            var model = trainer.Fit(transformedDataView);

            trainer.Fit(transformedDataView, model.Model.SubModel);
            Done();
        }