Exemplo n.º 1
0
        private static void CreateDepStat(string trainPath, string testPath, string ids, string target, string depstatPath, string measureField)
        {
            var fmngr = new FactorManager();
            fmngr.Load(depstatPath, target, measureField);

            var fdList = new List<double>();
            fdList.Add(0);
            fdList.Add(0.5);
            fdList.Add(1);
            fdList.Add(1.5);
            fdList = fdList.OrderByDescending(c => c).ToList();

            var tdList = new List<double>(fmngr.GetTargetValues());
            tdList = tdList.OrderByDescending(c => c).ToList();

            using (var sw = new StreamWriter(new FileStream("depstat.csv", FileMode.Create, FileAccess.Write)))
            {
                sw.WriteLine("td;fd;cnt;last_auc;best_auc;vars;measure");
                var countedDict = new Dictionary<string, ClassifierResult>();

                foreach (double td in tdList)
                {
                    foreach (double fd in fdList)
                    {
                        try
                        {
                            fmngr.TargDep = td;
                            fmngr.FactorDep = fd;
                            fmngr.SelectFactors();
                            var factors = fmngr.VisibleFactors;
                            Array.Sort(factors);
                            string vstr = string.Join("@", factors);

                            if (!countedDict.ContainsKey(vstr))
                            {
                                var cls = new DecisionForest();
                                var fdict = factors.ToDictionary(c => c);

                                foreach (string variable in fmngr.FactorDict.Keys)
                                {
                                    if (!fdict.ContainsKey(variable))
                                        cls.AddDropColumn(variable);
                                }

                                cls.LoadData();
                                var result = cls.Build();
                                countedDict.Add(vstr, result);
                            }
                            else
                            {
                                Logger.Log("skipping...");
                            }

                            sw.WriteLine(fmngr.TargDep.ToString("F06") + ";" + fmngr.FactorDep.ToString("F06") + ";" + factors.Length + ";" + countedDict[vstr].LastResult.AUC + ";" + countedDict[vstr].BestResult.AUC + ";" + vstr + ";" + measureField);
                            sw.Flush();
                            Logger.Log("td=" + td.ToString("F06") + "; fd=" + fd.ToString("F06") + "; cnt=" + factors.Length + ";" + countedDict[vstr].LastResult.AUC);
                        }
                        catch (Exception e)
                        {
                            Logger.Log(e);
                        }
                    }
                }
            }
        }
Exemplo n.º 2
0
        static void Main(string[] args)
        {
            if (args.Length < 4 || args.Length > 4)
            {
                Logger.Log("usage: program.exe <train.csv> <conf.csv> <id> <target_name>");
                return;
            }

            string dataPath = args[0];
            string confPath = args[1];
            string id = args[2];
            string target = args[3];

            Logger.Log("data: " + dataPath);
            Logger.Log("conf : " + confPath);
            Logger.Log("id : " + id);
            Logger.Log("target : " + target);

            try
            {
                var fmgr = new FactorManager();
                fmgr.Load(confPath, target);
                fmgr.TargDep = 10;
                fmgr.FactorDep = 100;
                fmgr.SelectFactors();
                var cols = fmgr.VisibleFactors.ToArray();

                //_loader.MaxRowsLoaded = 10000;
                _loader.AddTargetColumn(target);
                _loader.AddIdColumn(id);
                _loader.CollectDistrStat = true;
                _loader.Load(dataPath);

                var statDict = new Dictionary<TupleData, Dictionary<TupleData, StatItem>>();

                // collecting stats
                int idx = 0;
                int n = 4;
                var iter = new CombinationIterator(cols, n);
                while (iter.MoveNext())
                {
                    idx++;

                    var cval = iter.Current;
                    var ftuple = new TupleData(cval);

                    statDict.Add(ftuple, new Dictionary<TupleData, StatItem>());

                    foreach (var row in _loader.Rows)
                    {
                        var vtuple = CreateValueTuple(cval, row);
                        if (!statDict[ftuple].ContainsKey(vtuple))
                            statDict[ftuple].Add(vtuple, new StatItem());
                        if (row.Target<=1)
                        {
                            statDict[ftuple][vtuple].Count++;
                            statDict[ftuple][vtuple].Targets += (int)row.Target;
                        }
                    }

                    foreach (var t in statDict[ftuple].Keys)
                    {
                        statDict[ftuple][t].TargetProb = statDict[ftuple][t].Targets / (double)statDict[ftuple][t].Count;
                    }

                    Logger.Log(ftuple + " done;");
                }

                // creating modified file
                using (var sw = new StreamWriter(new FileStream(dataPath + "_cat.csv", FileMode.Create, FileAccess.Write)))
                {
                    idx = 0;
                    sw.WriteLine(CreateHeader(cols, n));
                    sw.Flush();
                    double defProb = (double)_loader.TargetStat[1] / (_loader.TargetStat[1] + _loader.TargetStat[0]);

                    foreach (var row in _loader.Rows)
                    {
                       idx++;

                        var sb = new StringBuilder();
                        iter = new CombinationIterator(cols, n);
                        sb.Append(row.Id);

                        while (iter.MoveNext())
                        {
                            var cval = iter.Current;
                            var ftuple = new TupleData(cval);
                            var t = CreateValueTuple(cval, row);

                            double prob = statDict[ftuple].ContainsKey(t) ? statDict[ftuple][t].TargetProb : defProb;

                            sb.Append(";" + prob.ToString("F05"));
                        }
                        sb.Append(";" + row.Target);
                        sw.WriteLine(sb);

                        if (idx%12345==0)
                        {
                            Logger.Log(idx + " lines writed;");
                            sw.Flush();
                        }
                    }
                    Logger.Log(idx + " lines writed; done;");
                }
            }
            catch (Exception e)
            {
                Logger.Log(e);
            }
        }