Exemplo n.º 1
0
        public static void LogisticIrt(int numParams, PriorType priorType, AlgorithmType algType, string conditionPrefix = "")
        {
            // timing on Intel Core 2 Duo P9500 with 4GB RAM running Windows Vista
            // 10_250 trial 1:
            // Bayesian/Hierarchical 2000 = 5.4s inference only
            // Variational/Hierarchical 50 iter = 4.2s inference only
            // Variational/Hierarchical 10 iter = 0.85s inference only
            // Variational_JJ/Hierarchical 50 iter = 0.1s inference only
            // Variational_JJ/Hierarchical 10 iter = 0.04s inference only
            // time on desktop:
            // Variational/Hierarchical 10 iter = 0.75s inference only (including test)
            // Variational_JJ/Hierarchical 10 iter = 0.07s inference only (including test)
            LogisticIrtModel train = new LogisticIrtModel(numParams, priorType);
            //train.engine.NumberOfIterations = 100;
            //train.engine.ShowTimings = true;
            string logistic_type = "";

            //logistic_type = "JJ";
            if (logistic_type == "JJ")
            {
                train.engine.Compiler.GivePriorityTo(typeof(LogisticOp_JJ96));
            }
            bool specialInitialization = false;

            if (specialInitialization)
            {
                // change initialization
                train.abilityMean.InitialiseTo(new Gaussian(5, 10));
                train.abilityPrecision.InitialiseTo(new Gamma(1, 10));
                train.difficultyMean.InitialiseTo(new Gaussian(5, 10));
                train.difficultyPrecision.InitialiseTo(new Gamma(1, 10));
            }
            LogisticIrtTestModel test = new LogisticIrtTestModel(numParams);

            train.engine.ShowProgress = false;
            test.engine.ShowProgress  = false;
            if (algType == AlgorithmType.Variational)
            {
                train.engine.Algorithm = new VariationalMessagePassing();
                test.engine.Algorithm  = new VariationalMessagePassing();
            }
            bool showTiming = false;

            string baseFolder = @"..\..\";
            string modelName  = numParams + "-PL";
            //modelName = "Mild_skew";
            //modelName = "Extreme_skew";
            //modelName = "Lsat";
            //modelName = "Wide_b";
            string        modelFolder = baseFolder + @"Data_mat\" + modelName;
            DirectoryInfo modelDir    = new DirectoryInfo(modelFolder);

            foreach (DirectoryInfo conditionDir in modelDir.GetDirectories())
            {
                string condition = conditionDir.Name;
                if (!condition.StartsWith(conditionPrefix))
                {
                    continue;
                }
                int    trimStart   = condition.Length - 1;
                string inputFolder = baseFolder + @"Data_mat\" + modelName + @"\" + condition;
                string alg;
                if (algType == AlgorithmType.Variational)
                {
                    alg = "Variational" + logistic_type + @"\" + priorType;
                }
                else
                {
                    alg = algType + @"\" + priorType;
                }
                string outputFolder = baseFolder + @"Estimates_mat\" + modelName + @"\" + alg + @"\" + condition;
                Console.WriteLine(outputFolder);
                DirectoryInfo outputDir = Directory.CreateDirectory(outputFolder);
                DirectoryInfo inputDir  = new DirectoryInfo(inputFolder);
                foreach (FileInfo file in inputDir.GetFiles("*.mat"))
                {
                    string name           = file.Name;
                    string number         = name; //.Substring(trimStart);
                    string outputFileName = outputFolder + @"\" + number;
                    if (File.Exists(outputFileName))
                    {
                        continue;
                    }
                    Console.WriteLine(file.FullName);
                    Dictionary <string, object> dict = MatlabReader.Read(file.FullName);
                    Matrix     m = (Matrix)dict["Y"];
                    Gaussian[] abilityPost, difficultyPost;
                    Gamma[]    discriminationPost = null;
                    Beta[]     guessProbPost      = null;
                    Matrix     responseProbMean;
                    if (algType != AlgorithmType.MCMC)
                    {
                        // VMP
                        Stopwatch watch = new Stopwatch();
                        watch.Start();
                        train.ObserveResponses(m);
                        train.RunToConvergence();
                        abilityPost    = train.engine.Infer <Gaussian[]>(train.ability);
                        difficultyPost = train.engine.Infer <Gaussian[]>(train.difficulty);
                        if (numParams >= 2)
                        {
                            discriminationPost = train.engine.Infer <Gamma[]>(train.discrimination);
                        }
                        if (numParams >= 3)
                        {
                            guessProbPost = train.engine.Infer <Beta[]>(train.guessProb);
                        }
                        responseProbMean = test.GetResponseProbs(abilityPost, difficultyPost, discriminationPost, guessProbPost);
                        watch.Stop();
                        if (showTiming)
                        {
                            Console.WriteLine(algType + " elapsed time = {0}ms", watch.ElapsedMilliseconds);
                        }
                    }
                    else
                    {
                        // sampler
                        LogisticIrtSampler sampler = new LogisticIrtSampler();
                        sampler.abilityMeanPrior        = Gaussian.FromMeanAndVariance(0, 1e6);
                        sampler.abilityPrecPrior        = Gamma.FromShapeAndRate(1, 1);
                        sampler.difficultyMeanPrior     = Gaussian.FromMeanAndVariance(0, 1e6);
                        sampler.difficultyPrecPrior     = Gamma.FromShapeAndRate(1, 1);
                        sampler.discriminationMeanPrior = Gaussian.FromMeanAndVariance(0, 1e6);
                        sampler.discriminationPrecPrior = Gamma.FromShapeAndRate(1, 1);
                        // for debugging
                        //sampler.abilityObserved = ((Matrix)dict["ability"]).ToArray<double>();
                        //sampler.difficultyObserved = ((Matrix)dict["difficulty"]).ToArray<double>();
                        //sampler.discriminationObserved = ((Matrix)dict["discrimination"]).ToArray<double>();
                        if (train.abilityMean.IsObserved)
                        {
                            sampler.abilityMeanPrior = Gaussian.PointMass(train.abilityMean.ObservedValue);
                        }
                        if (train.abilityPrecision.IsObserved)
                        {
                            sampler.abilityPrecPrior = Gamma.PointMass(train.abilityPrecision.ObservedValue);
                        }
                        if (train.difficultyMean.IsObserved)
                        {
                            sampler.difficultyMeanPrior = Gaussian.PointMass(train.difficultyMean.ObservedValue);
                        }
                        if (train.difficultyPrecision.IsObserved)
                        {
                            sampler.difficultyPrecPrior = Gamma.PointMass(train.difficultyPrecision.ObservedValue);
                        }
                        if (train.discriminationMean.IsObserved)
                        {
                            sampler.discriminationMeanPrior = Gaussian.PointMass(train.discriminationMean.ObservedValue);
                        }
                        if (train.discriminationPrecision.IsObserved)
                        {
                            sampler.discriminationPrecPrior = Gamma.PointMass(train.discriminationPrecision.ObservedValue);
                        }
                        Stopwatch watch = new Stopwatch();
                        watch.Start();
                        sampler.Sample(new Options(), m);
                        abilityPost        = sampler.abilityPost;
                        difficultyPost     = sampler.difficultyPost;
                        responseProbMean   = sampler.responseProbMean;
                        discriminationPost = sampler.discriminationPost;
                        watch.Stop();
                        if (showTiming)
                        {
                            Console.WriteLine("MCMC elapsed time = {0}ms", watch.ElapsedMilliseconds);
                        }
                    }
                    bool showEstimates = false;
                    if (showEstimates)
                    {
                        Console.WriteLine("abilityMean = {0}", train.engine.Infer(train.abilityMean));
                        Console.WriteLine("abilityPrecision = {0}", train.engine.Infer(train.abilityPrecision));
                        //Console.WriteLine("abilityMean2 = {0}", train.engine.Infer(train.abilityMean2));
                        //Console.WriteLine("abilityPrecision2 = {0}", train.engine.Infer(train.abilityPrecision2));
                        Console.WriteLine("difficultyMean = {0}", train.engine.Infer(train.difficultyMean));
                        Console.WriteLine("difficultyPrecision = {0}", train.engine.Infer(train.difficultyPrecision));
                    }
                    if (showEstimates)
                    {
                        for (int i = 0; i < 10; i++)
                        {
                            Console.WriteLine(responseProbMean[i]);
                        }
                        //Console.WriteLine(ToMeanMatrix(difficultyPost));
                    }
                    using (MatlabWriter writer = new MatlabWriter(outputFileName))
                    {
                        writer.Write("ability", ToMeanMatrix(abilityPost));
                        writer.Write("ability_se", ToStddevMatrix(abilityPost));
                        writer.Write("difficulty", ToMeanMatrix(difficultyPost));
                        writer.Write("difficulty_se", ToStddevMatrix(difficultyPost));
                        if (discriminationPost != null)
                        {
                            writer.Write("discrimination", ToMeanMatrix(discriminationPost));
                            writer.Write("discrimination_se", ToStddevMatrix(discriminationPost));
                        }
                        if (guessProbPost != null)
                        {
                            writer.Write("guessing", ToMeanAndStddevMatrix(guessProbPost));
                        }
                        writer.Write("p", responseProbMean);
                    }
                    //break;
                }
                //break;
            }
        }
Exemplo n.º 2
0
        public static void OneShot(PriorType priorType, AlgorithmType algType, Matrix responses, string outputFolder, Options options)
        {
            Directory.CreateDirectory(outputFolder);
            LogisticIrtModel     train = new LogisticIrtModel(options.numParams, priorType);
            LogisticIrtTestModel test  = new LogisticIrtTestModel(options.numParams);

            train.engine.Compiler.WriteSourceFiles = false;
            test.engine.Compiler.WriteSourceFiles  = false;
            train.engine.ShowProgress = false;
            test.engine.ShowProgress  = false;
            if (algType == AlgorithmType.Variational)
            {
                train.engine.Algorithm = new VariationalMessagePassing();
                test.engine.Algorithm  = new VariationalMessagePassing();
            }
            Gaussian[] abilityPost, difficultyPost;
            Gamma[]    discriminationPost = null;
            Beta[]     guessProbPost      = null;
            Matrix     responseProbMean;
            Matrix     abilityCred, difficultyCred;

            if (algType != AlgorithmType.MCMC)
            {
                train.ObserveResponses(responses);
                train.RunToConvergence();
                abilityPost    = train.engine.Infer <Gaussian[]>(train.ability);
                difficultyPost = train.engine.Infer <Gaussian[]>(train.difficulty);
                if (options.numParams >= 2)
                {
                    discriminationPost = train.engine.Infer <Gamma[]>(train.discrimination);
                }
                if (options.numParams >= 3)
                {
                    guessProbPost = train.engine.Infer <Beta[]>(train.guessProb);
                }
                responseProbMean = test.GetResponseProbs(abilityPost, difficultyPost, discriminationPost, guessProbPost);
            }
            else
            { // MCMC
                LogisticIrtSampler sampler = new LogisticIrtSampler();
                sampler.abilityMeanPrior        = Gaussian.FromMeanAndVariance(0, 1e6);
                sampler.abilityPrecPrior        = Gamma.FromShapeAndRate(1, 1);
                sampler.difficultyMeanPrior     = Gaussian.FromMeanAndVariance(0, 1e6);
                sampler.difficultyPrecPrior     = Gamma.FromShapeAndRate(1, 1);
                sampler.discriminationMeanPrior = Gaussian.FromMeanAndVariance(0, 1e6);
                sampler.discriminationPrecPrior = Gamma.FromShapeAndRate(1, 1);
                if (train.abilityMean.IsObserved)
                {
                    sampler.abilityMeanPrior = Gaussian.PointMass(train.abilityMean.ObservedValue);
                }
                if (train.abilityPrecision.IsObserved)
                {
                    sampler.abilityPrecPrior = Gamma.PointMass(train.abilityPrecision.ObservedValue);
                }
                if (train.difficultyMean.IsObserved)
                {
                    sampler.difficultyMeanPrior = Gaussian.PointMass(train.difficultyMean.ObservedValue);
                }
                if (train.difficultyPrecision.IsObserved)
                {
                    sampler.difficultyPrecPrior = Gamma.PointMass(train.difficultyPrecision.ObservedValue);
                }
                if (train.discriminationMean.IsObserved)
                {
                    sampler.discriminationMeanPrior = Gaussian.PointMass(train.discriminationMean.ObservedValue);
                }
                if (train.discriminationPrecision.IsObserved)
                {
                    sampler.discriminationPrecPrior = Gamma.PointMass(train.discriminationPrecision.ObservedValue);
                }
                sampler.Sample(options, responses);
                abilityPost        = sampler.abilityPost;
                difficultyPost     = sampler.difficultyPost;
                responseProbMean   = sampler.responseProbMean;
                discriminationPost = sampler.discriminationPost;
                abilityCred        = sampler.abilityCred;
                difficultyCred     = sampler.difficultyCred;
                WriteMatrix(abilityCred, outputFolder + @"\ability_ci.txt");
                WriteMatrix(difficultyCred, outputFolder + @"\difficulty_ci.txt");
            }
            bool showEstimates = false;

            if (showEstimates)
            {
                Console.WriteLine("abilityMean = {0}", train.engine.Infer(train.abilityMean));
                Console.WriteLine("abilityPrecision = {0}", train.engine.Infer(train.abilityPrecision));
                //Console.WriteLine("abilityMean2 = {0}", train.engine.Infer(train.abilityMean2));
                //Console.WriteLine("abilityPrecision2 = {0}", train.engine.Infer(train.abilityPrecision2));
                Console.WriteLine("difficultyMean = {0}", train.engine.Infer(train.difficultyMean));
                Console.WriteLine("difficultyPrecision = {0}", train.engine.Infer(train.difficultyPrecision));
            }
            WriteMatrix(ToMeanMatrix(abilityPost), outputFolder + @"\ability.txt");
            WriteMatrix(ToStddevMatrix(abilityPost), outputFolder + @"\ability_se.txt");
            WriteMatrix(ToMeanMatrix(difficultyPost), outputFolder + @"\difficulty.txt");
            WriteMatrix(ToStddevMatrix(difficultyPost), outputFolder + @"\difficulty_se.txt");
            WriteMatrix(responseProbMean, outputFolder + @"\p.txt");
            if (discriminationPost != null)
            {
                WriteMatrix(ToMeanMatrix(discriminationPost), outputFolder + @"\discrimination.txt");
                WriteMatrix(ToStddevMatrix(discriminationPost), outputFolder + @"\discrimination_se.txt");
            }
            if (guessProbPost != null)
            {
                WriteMatrix(ToMeanMatrix(guessProbPost), outputFolder + @"\guess.txt");
                WriteMatrix(ToStddevMatrix(guessProbPost), outputFolder + @"\guess_se.txt");
            }
        }