Exemplo n.º 1
0
        public override void Training()
        {
            Init(DNN_Query, DNN_Doc);
            DNN dnn_query_backup = null, dnn_doc_backup = null;

            Program.Print("Starting DNN Learning!");

            float trainingLoss = 0;

            float previous_devEval = 0;
            float VALIDATION_Eval  = 0;
            //// determin the last stopped iteration
            int lastRunStopIter = -1;

            for (int iter = 0; iter <= ParameterSetting.MAX_ITER; ++iter)
            {
                if (!File.Exists(ParameterSetting.MODEL_PATH + "_QUERY_ITER" + iter.ToString()))
                {
                    break;
                }
                lastRunStopIter = iter;
            }

            if (lastRunStopIter == -1)
            {
                Program.Print("Initialization (Iter 0)");
                Program.Print("Saving models ...");
                DNN_Query.CopyOutFromCuda();
                Tuple <string, string> dssmModelPaths = ComposeDSSMModelPaths(0);
                DNN_Query.Model_Save(dssmModelPaths.Item1);
                if (!ParameterSetting.IS_SHAREMODEL)
                {
                    DNN_Doc.CopyOutFromCuda();
                    DNN_Doc.Model_Save(dssmModelPaths.Item2);
                }
                if (ParameterSetting.ISVALIDATE)
                {
                    Program.Print("Start validation process ...");
                    if (!ParameterSetting.VALIDATE_MODEL_ONLY)
                    {
                        VALIDATION_Eval = Evaluate();
                    }
                    else
                    {
                        VALIDATION_Eval = EvaluateModelOnly(dssmModelPaths.Item1, dssmModelPaths.Item2);
                    }
                    Program.Print("Dataset VALIDATION :\n/*******************************/ \n" + VALIDATION_Eval.ToString() + " \n/*******************************/ \n");
                }
                File.WriteAllText(ParameterSetting.MODEL_PATH + "_LEARNING_RATE_ITER" + 0.ToString(), LearningParameters.lr_mid.ToString());
                lastRunStopIter = 0;
            }
            else
            {
                if (ParameterSetting.ISVALIDATE)
                {
                    //// go through all previous saved runs and print validation
                    for (int iter = 0; iter <= lastRunStopIter; ++iter)
                    {
                        Program.Print("Loading from previously trained Iter " + iter.ToString());
                        Tuple <string, string> dssmModelPaths = ComposeDSSMModelPaths(iter);
                        LoadModel(dssmModelPaths.Item1,
                                  ref DNN_Query,
                                  dssmModelPaths.Item2,
                                  ref DNN_Doc,
                                  false);
                        Program.Print("Start validation process ...");
                        if (!ParameterSetting.VALIDATE_MODEL_ONLY)
                        {
                            VALIDATION_Eval = Evaluate();
                        }
                        else
                        {
                            VALIDATION_Eval = EvaluateModelOnly(dssmModelPaths.Item1, dssmModelPaths.Item2);
                        }
                        Program.Print("Dataset VALIDATION :\n/*******************************/ \n" + VALIDATION_Eval.ToString() + " \n/*******************************/ \n");
                        if (File.Exists(ParameterSetting.MODEL_PATH + "_LEARNING_RATE" + iter.ToString()))
                        {
                            LearningParameters.lr_mid = float.Parse(File.ReadAllText(ParameterSetting.MODEL_PATH + "_LEARNING_RATE" + iter.ToString()));
                        }
                    }
                }
                else
                {
                    //// just load the last iteration
                    int iter = lastRunStopIter;
                    Program.Print("Loading from previously trained Iter " + iter.ToString());
                    LoadModel(ParameterSetting.MODEL_PATH + "_QUERY_ITER" + iter.ToString(),
                              ref DNN_Query,
                              ParameterSetting.MODEL_PATH + "_DOC_ITER" + iter.ToString(),
                              ref DNN_Doc,
                              false);
                    if (File.Exists(ParameterSetting.MODEL_PATH + "_LEARNING_RATE" + iter.ToString()))
                    {
                        LearningParameters.lr_mid = float.Parse(File.ReadAllText(ParameterSetting.MODEL_PATH + "_LEARNING_RATE" + iter.ToString()));
                    }
                }
            }

            //// Clone to backup models
            if (ParameterSetting.ISVALIDATE)
            {
                dnn_query_backup = (DNN)DNN_Query.CreateBackupClone();
                if (!ParameterSetting.IS_SHAREMODEL)
                {
                    dnn_doc_backup = (DNN)DNN_Doc.CreateBackupClone();
                }
            }

            if (ParameterSetting.NOTrain)
            {
                return;
            }
            Program.Print("total query sample number : " + PairStream.qstream.total_Batch_Size.ToString());
            Program.Print("total doc sample number : " + PairStream.dstream.total_Batch_Size.ToString());
            Program.Print("Training batches: " + PairStream.qstream.BATCH_NUM.ToString());
            //Program.Print("Learning Objective : " + ParameterSetting.OBJECTIVE.ToString());
            LearningParameters.total_doc_num = PairStream.dstream.total_Batch_Size;

            previous_devEval = VALIDATION_Eval;

            Program.Print("Start Training");
            Program.Print("-----------------------------------------------------------");
            int mmindex = 0;

            for (int iter = lastRunStopIter + 1; iter <= ParameterSetting.MAX_ITER; iter++)
            {
                Program.Print("ITER : " + iter.ToString());
                LearningParameters.learning_rate = LearningParameters.lr_mid;
                LearningParameters.momentum      = 0.0f;

                Program.timer.Reset();
                Program.timer.Start();

                //// load the training file and all associated streams, the "open action" is cheap
                if (iter != lastRunStopIter + 1)
                {
                    //// we don't need to load if "iter == lastRunStopIter + 1", because it has been already opened.
                    //// we only open a new pair from the second iteration

                    LoadPairDataAtIdx();
                }

                /// adjust learning rate here.
                PairStream.Init_Batch();
                trainingLoss = 0;
                LearningParameters.neg_static_sample = false;
                mmindex = 0;

                while (PairStream.Next_Batch(SrcNorm, TgtNorm))
                {
                    trainingLoss += feedstream_batch(PairStream.GPU_qbatch, PairStream.GPU_dbatch, true, PairStream.srNCEProbDist);
                    mmindex      += 1;
                    if (mmindex % 50 == 0)
                    {
                        Console.Write("Training :{0}\r", mmindex.ToString());
                    }
                }

                Program.Print("Training Loss : " + trainingLoss.ToString());
                Program.Print("Learning Rate : " + (LearningParameters.learning_rate.ToString()));
                Tuple <string, string> dssmModelPaths = ComposeDSSMModelPaths(iter);
                Program.Print("Saving models ...");
                DNN_Query.CopyOutFromCuda();
                DNN_Query.Model_Save(dssmModelPaths.Item1);
                if (!ParameterSetting.IS_SHAREMODEL)
                {
                    DNN_Doc.CopyOutFromCuda();
                    DNN_Doc.Model_Save(dssmModelPaths.Item2);
                }

                if (ParameterSetting.ISVALIDATE)
                {
                    Program.Print("Start validation process ...");
                    if (!ParameterSetting.VALIDATE_MODEL_ONLY)
                    {
                        VALIDATION_Eval = Evaluate();
                    }
                    else
                    {
                        VALIDATION_Eval = EvaluateModelOnly(dssmModelPaths.Item1, dssmModelPaths.Item2);
                    }
                    Program.Print("Dataset VALIDATION :\n/*******************************/ \n" + VALIDATION_Eval.ToString() + " \n/*******************************/ \n");

                    if (VALIDATION_Eval >= previous_devEval - LearningParameters.accept_range)
                    {
                        Console.WriteLine("Accepted it");
                        previous_devEval = VALIDATION_Eval;
                        if (LearningParameters.IsrateDown)
                        {
                            LearningParameters.lr_mid = LearningParameters.lr_mid * LearningParameters.down_rate;
                        }
                        //// save model to backups
                        dnn_query_backup.Init(DNN_Query);
                        if (!ParameterSetting.IS_SHAREMODEL)
                        {
                            dnn_doc_backup.Init(DNN_Doc);
                        }
                    }
                    else
                    {
                        Console.WriteLine("Reject it");

                        LearningParameters.IsrateDown = true;
                        LearningParameters.lr_mid     = LearningParameters.lr_mid * LearningParameters.reject_rate;

                        //// recover model from the last saved backup
                        DNN_Query.Init(dnn_query_backup);
                        if (!ParameterSetting.IS_SHAREMODEL)
                        {
                            DNN_Doc.Init(dnn_doc_backup);
                        }
                    }
                }

                //// write the learning rate after this iter
                File.WriteAllText(ParameterSetting.MODEL_PATH + "_LEARNING_RATE_ITER" + iter.ToString(), LearningParameters.lr_mid.ToString());

                Program.timer.Stop();
                Program.Print("Training Runing Time : " + Program.timer.Elapsed.ToString());
                Program.Print("-----------------------------------------------------------");
            }

            //// Final save
            DNN_Query.CopyOutFromCuda();
            DNN_Query.Model_Save(ParameterSetting.MODEL_PATH + "_QUERY_DONE");
            if (!ParameterSetting.IS_SHAREMODEL)
            {
                DNN_Doc.CopyOutFromCuda();
                DNN_Doc.Model_Save(ParameterSetting.MODEL_PATH + "_DOC_DONE");
            }

            //pstream.General_Train_Test(ParameterSetting.TRAIN_TEST_RATE);
            //dnn_train
        }