Exemple #1
0
        public virtual AbstractModel trainModel(int iterations, DataIndexer di, int cutoff, bool useAverage)
        {
            display("Incorporating indexed data for training...  \n");
            contexts           = di.Contexts;
            values             = di.Values;
            numTimesEventsSeen = di.NumTimesEventsSeen;
            numEvents          = di.NumEvents;
            numUniqueEvents    = contexts.Length;

            outcomeLabels = di.OutcomeLabels;
            outcomeList   = di.OutcomeList;

            predLabels  = di.PredLabels;
            numPreds    = predLabels.Length;
            numOutcomes = outcomeLabels.Length;

            display("done.\n");

            display("\tNumber of Event Tokens: " + numUniqueEvents + "\n");
            display("\t    Number of Outcomes: " + numOutcomes + "\n");
            display("\t  Number of Predicates: " + numPreds + "\n");

            display("Computing model parameters...\n");

            MutableContext[] finalParameters = findParameters(iterations, useAverage);

            display("...done.\n");

            /// <summary>
            ///************* Create and return the model ***************** </summary>
            return(new PerceptronModel(finalParameters, predLabels, outcomeLabels));
        }
Exemple #2
0
        public LogLikelihoodFunction(DataIndexer indexer)
        {
            // get data from indexer.
            if (indexer is OnePassRealValueDataIndexer)
            {
                this.values = indexer.Values;
            }
            else
            {
                this.values = null;
            }

            this.contexts           = indexer.Contexts;
            this.outcomeList        = indexer.OutcomeList;
            this.numTimesEventsSeen = indexer.NumTimesEventsSeen;

            this.outcomeLabels = indexer.OutcomeLabels;
            this.predLabels    = indexer.PredLabels;

            this.numOutcomes     = indexer.OutcomeLabels.Length;
            this.numFeatures     = indexer.PredLabels.Length;
            this.numContexts     = this.contexts.Length;
            this.domainDimension = numOutcomes * numFeatures;
            this.probModel       = RectangularArrays.ReturnRectangularDoubleArray(numContexts, numOutcomes);
            this.gradient        = null;
        }
Exemple #3
0
        public void Load()
        {
            // create file save if not exist
#if (IN_UNITY_EDITOR)
            if (!File.Exists(DataConfig.indexSaveFilePath))
            {
                File.CreateText(DataConfig.indexSaveFilePath);
            }
#endif
            // load by Player Pref
            //if (!PlayerPrefs.HasKey(DataConfig.indexDataSaveKey))
            //    return;
            //string content = PlayerPrefs.GetString(DataConfig.indexDataSaveKey);

            string content = File.ReadAllText(DataConfig.indexSaveFilePath);

            Debug.Log("Load Index = " + content);
            if (content.Length > 0)
            {
                DataIndexer newData = JsonUtility.FromJson <DataIndexer>(content);
                dataIndexes = newData.dataIndexes;
            }
        }
Exemple #4
0
        public virtual QNModel trainModel(DataIndexer indexer)
        {
            LogLikelihoodFunction objectiveFunction = generateFunction(indexer);

            this.dimension  = objectiveFunction.DomainDimension;
            this.updateInfo = new QNInfo(this, this.m, this.dimension);

            double[] initialPoint = objectiveFunction.InitialPoint;
            double   initialValue = objectiveFunction.valueAt(initialPoint);

            double[] initialGrad = objectiveFunction.gradientAt(initialPoint);

            LineSearchResult lsr = LineSearchResult.getInitialObject(initialValue, initialGrad, initialPoint, 0);

            int z = 0;

            while (true)
            {
                if (verbose)
                {
                    Console.Write(z++);
                }
                double[] direction = null;

                direction = computeDirection(objectiveFunction, lsr);
                lsr       = LineSearch.doLineSearch(objectiveFunction, direction, lsr, verbose);

                updateInfo.updateInfo(lsr);

                if (isConverged(lsr))
                {
                    break;
                }
            }
            return(new QNModel(objectiveFunction, lsr.NextPoint));
        }
Exemple #5
0
 public IndexController(DataIndexer indexer)
 {
     _indexer = indexer;
 }
Exemple #6
0
 private LogLikelihoodFunction generateFunction(DataIndexer indexer)
 {
     return(new LogLikelihoodFunction(indexer));
 }
Exemple #7
0
        // TODO: Need a way to report results and settings back for inclusion in model ...

        public static AbstractModel train(EventStream events, IDictionary <string, string> trainParams,
                                          IDictionary <string, string> reportMap)
        {
            if (!isValid(trainParams))
            {
                throw new System.ArgumentException("trainParams are not valid!");
            }

            if (isSequenceTraining(trainParams))
            {
                throw new System.ArgumentException("sequence training is not supported by this method!");
            }

            string algorithmName = getStringParam(trainParams, ALGORITHM_PARAM, MAXENT_VALUE, reportMap);

            int iterations = getIntParam(trainParams, ITERATIONS_PARAM, ITERATIONS_DEFAULT, reportMap);

            int cutoff = getIntParam(trainParams, CUTOFF_PARAM, CUTOFF_DEFAULT, reportMap);

            bool sortAndMerge;

            if (MAXENT_VALUE.Equals(algorithmName) || MAXENT_QN_VALUE.Equals(algorithmName))
            {
                sortAndMerge = true;
            }
            else if (PERCEPTRON_VALUE.Equals(algorithmName))
            {
                sortAndMerge = false;
            }
            else
            {
                throw new IllegalStateException("Unexpected algorithm name: " + algorithmName);
            }

            HashSumEventStream hses = new HashSumEventStream(events);

            string dataIndexerName = getStringParam(trainParams, DATA_INDEXER_PARAM, DATA_INDEXER_TWO_PASS_VALUE,
                                                    reportMap);

            DataIndexer indexer = null;

            if (DATA_INDEXER_ONE_PASS_VALUE.Equals(dataIndexerName))
            {
                indexer = new OnePassDataIndexer(hses, cutoff, sortAndMerge);
            }
            else if (DATA_INDEXER_TWO_PASS_VALUE.Equals(dataIndexerName))
            {
                indexer = new TwoPassDataIndexer(hses, cutoff, sortAndMerge);
            }
            else
            {
                throw new IllegalStateException("Unexpected data indexer name: " + dataIndexerName);
            }

            AbstractModel model;

            if (MAXENT_VALUE.Equals(algorithmName))
            {
                int threads = getIntParam(trainParams, "Threads", 1, reportMap);

                model = opennlp.maxent.GIS.trainModel(iterations, indexer, true, false, null, 0, threads);
            }
            else if (MAXENT_QN_VALUE.Equals(algorithmName))
            {
                int m          = getIntParam(trainParams, "numOfUpdates", QNTrainer.DEFAULT_M, reportMap);
                int maxFctEval = getIntParam(trainParams, "maxFctEval", QNTrainer.DEFAULT_MAX_FCT_EVAL, reportMap);
                model = (new QNTrainer(m, maxFctEval, true)).trainModel(indexer);
            }
            else if (PERCEPTRON_VALUE.Equals(algorithmName))
            {
                bool useAverage = getBooleanParam(trainParams, "UseAverage", true, reportMap);

                bool useSkippedAveraging = getBooleanParam(trainParams, "UseSkippedAveraging", false, reportMap);

                // overwrite otherwise it might not work
                if (useSkippedAveraging)
                {
                    useAverage = true;
                }

                double stepSizeDecrease = getDoubleParam(trainParams, "StepSizeDecrease", 0, reportMap);

                double tolerance = getDoubleParam(trainParams, "Tolerance", PerceptronTrainer.TOLERANCE_DEFAULT,
                                                  reportMap);

                PerceptronTrainer perceptronTrainer = new PerceptronTrainer();
                perceptronTrainer.SkippedAveraging = useSkippedAveraging;

                if (stepSizeDecrease > 0)
                {
                    perceptronTrainer.StepSizeDecrease = stepSizeDecrease;
                }

                perceptronTrainer.Tolerance = tolerance;

                model = perceptronTrainer.trainModel(iterations, indexer, cutoff, useAverage);
            }
            else
            {
                throw new IllegalStateException("Algorithm not supported: " + algorithmName);
            }

            if (reportMap != null)
            {
                reportMap["Training-Eventhash"] = hses.calculateHashSum().ToString("X"); // 16 Java : i.e. Hex
            }

            return(model);
        }
Exemple #8
0
 public virtual AbstractModel trainModel(int iterations, DataIndexer di, int cutoff)
 {
     return(trainModel(iterations, di, cutoff, true));
 }
Exemple #9
0
        public static GenericDataType createGenericData(PXModel model)
        {
            DataSetType ds = new DataSetType();

            ds.KeyFamilyRef = model.Meta.Matrix.CleanID();

            DataIndexer di = new DataIndexer(model.Meta);

            // Get all table level notes (this includes notes for variables)
            List <AnnotationType> dsAnnotations = new List <AnnotationType>();

            if (model.Meta.Notes != null)
            {
                dsAnnotations.AddRange(model.Meta.Notes.ToSDMXAnnotation());
            }
            foreach (Variable var in model.Meta.Stub)
            {
                if (var.Notes != null)
                {
                    dsAnnotations.AddRange(var.Notes.ToSDMXAnnotation());
                }
            }
            foreach (Variable var in model.Meta.Heading)
            {
                if (var.Notes != null)
                {
                    dsAnnotations.AddRange(var.Notes.ToSDMXAnnotation());
                }
            }

            if (dsAnnotations.Count > 0)
            {
                ds.Annotations = dsAnnotations.ToArray();
            }

            if (model.Meta.ContentVariable == null)
            {
                List <org.sdmx.ValueType> dsAtts = new List <org.sdmx.ValueType>();

                if (model.Meta.ContentInfo != null)
                {
                    // Unit of measure
                    {
                        org.sdmx.ValueType att = new org.sdmx.ValueType();
                        att.concept = "UNIT_MEASURE";
                        att.value   = model.Meta.ContentInfo.Units;
                        dsAtts.Add(att);
                    }
                    // Decimals
                    {
                        org.sdmx.ValueType att = new org.sdmx.ValueType();
                        att.concept = "DECIMALS";
                        att.value   = model.Meta.Decimals.ToString();
                        dsAtts.Add(att);
                    }

                    // Stock/flow/average indicator
                    if (model.Meta.ContentInfo.StockFa != null)
                    {
                        org.sdmx.ValueType att = new org.sdmx.ValueType();
                        att.concept = "SFA_INDICATOR";
                        att.value   = model.Meta.ContentInfo.StockFa;
                        dsAtts.Add(att);
                    }

                    // Seasonal adjustement
                    if (model.Meta.ContentInfo.SeasAdj != null)
                    {
                        org.sdmx.ValueType att = new org.sdmx.ValueType();
                        att.concept = "SEAS_ADJ";
                        att.value   = model.Meta.ContentInfo.SeasAdj;
                        dsAtts.Add(att);
                    }

                    // Daily adjustment
                    if (model.Meta.ContentInfo.DayAdj != null)
                    {
                        org.sdmx.ValueType att = new org.sdmx.ValueType();
                        att.concept = "DAY_ADJ";
                        att.value   = model.Meta.ContentInfo.DayAdj;
                        dsAtts.Add(att);
                    }

                    // Base period
                    if (model.Meta.ContentInfo.Baseperiod != null)
                    {
                        org.sdmx.ValueType att = new org.sdmx.ValueType();
                        att.concept = "BASE_PER";
                        att.value   = model.Meta.ContentInfo.Baseperiod;
                        dsAtts.Add(att);
                    }

                    // Reference period
                    if (model.Meta.ContentInfo.RefPeriod != null)
                    {
                        org.sdmx.ValueType att = new org.sdmx.ValueType();
                        att.concept = "REF_PERIOD";
                        att.value   = model.Meta.ContentInfo.RefPeriod;
                        dsAtts.Add(att);
                    }

                    // Current / fixed prices
                    if (model.Meta.ContentInfo.CFPrices != null)
                    {
                        org.sdmx.ValueType att = new org.sdmx.ValueType();
                        att.concept = "PRICE_BASIS";
                        att.value   = model.Meta.ContentInfo.CFPrices;
                        dsAtts.Add(att);
                    }
                }

                ds.Attributes = dsAtts.ToArray();
            }


            ds.Items = new Object[model.Data.MatrixRowCount];
            for (int i = 0; i < model.Data.MatrixRowCount; i++)
            {
                SeriesType series = new SeriesType();
                series.SeriesKey = new org.sdmx.ValueType[model.Meta.Stub.Count + 1];

                org.sdmx.ValueType key = new org.sdmx.ValueType();
                key.concept = "FREQ";

                switch (model.Meta.Heading[0].TimeScale)
                {
                case TimeScaleType.Annual:
                    key.value = "A";
                    break;

                case TimeScaleType.Halfyear:
                    key.value = "B";
                    break;

                case TimeScaleType.Monthly:
                    key.value = "M";
                    break;

                case TimeScaleType.Quartely:
                    key.value = "Q";
                    break;

                case TimeScaleType.Weekly:
                    key.value = "W";
                    break;

                default:
                    //TODO
                    break;
                }
                series.SeriesKey[0] = key;
                di.SetContext(i, 0);
                // Create annotations based on value notes (not variable notes)
                List <AnnotationType> serAnnotations = new List <AnnotationType>();
                for (int j = 0; j < model.Meta.Stub.Count; j++)
                {
                    key                     = new org.sdmx.ValueType();
                    key.concept             = model.Meta.Stub[j].Name.CleanID();
                    key.value               = model.Meta.Stub[j].Values[di.StubIndecies[j]].Code.CleanID();
                    series.SeriesKey[j + 1] = key;
                    if (model.Meta.Stub[j].Values[di.StubIndecies[j]].Notes != null)
                    {
                        serAnnotations.AddRange(model.Meta.Stub[j].Values[di.StubIndecies[j]].Notes.ToSDMXAnnotation());
                    }
                }
                if (serAnnotations.Count > 0)
                {
                    series.Annotations = serAnnotations.ToArray();
                }

                series.Obs = new ObsType[model.Data.MatrixColumnCount];
                //Added code for reading the cellnotes
                DataFormatter formatter = new DataFormatter(model);
                for (int j = 0; j < model.Data.MatrixColumnCount; j++)
                {
                    string  notes = null;
                    ObsType obs   = new ObsType();

                    // Set observation time
                    obs.Time = model.Meta.Heading[0].Values[j].ToSDMXTime();

                    Boolean missing = PXConstant.ProtectedNullValues.Contains(model.Data.ReadElement(i, j)) || PXConstant.ProtectedValues.Contains(model.Data.ReadElement(i, j));
                    //Create observation status attribute
                    org.sdmx.ValueType status = new org.sdmx.ValueType();
                    status.concept = "OBS_STATUS";

                    obs.Attributes    = new org.sdmx.ValueType[1];
                    obs.Attributes[0] = status;

                    // Set observation value and status code
                    if (!missing)
                    {
                        obs.ObsValue                = new ObsValueType();
                        obs.ObsValue.value          = model.Data.ReadElement(i, j);
                        obs.ObsValue.valueSpecified = true;
                        status.value                = "A";
                    }
                    else
                    {
                        status.value = "M";
                    }


                    // Cell notes
                    formatter.ReadElement(i, j, ref notes);

                    if (notes != null && notes.Length != 0)
                    {
                        AnnotationType  annotation     = new AnnotationType();
                        List <TextType> annotationText = new List <TextType>();
                        TextType        text           = new TextType();
                        text.lang  = "en";
                        text.Value = notes;
                        annotationText.Add(text);
                        annotation.AnnotationText = annotationText.ToArray();
                        obs.Annotations           = new AnnotationType[1];
                        obs.Annotations[0]        = annotation;
                    }


                    series.Obs[j] = obs;
                }

                if (model.Meta.ContentVariable != null)
                {
                    List <org.sdmx.ValueType> serAtts = new List <org.sdmx.ValueType>();

                    if (model.Meta.ContentInfo != null)
                    {
                        // Unit of measure
                        {
                            org.sdmx.ValueType att = new org.sdmx.ValueType();
                            att.concept = "UNIT_MEASURE";
                            int cIndex = model.Meta.Stub.GetIndexByCode(model.Meta.ContentVariable.Code);
                            att.value = model.Meta.ContentVariable.Values[di.StubIndecies[cIndex]].ContentInfo.Units;
                            serAtts.Add(att);
                        }
                        // Stock/flow/average indicator
                        if (model.Meta.ContentInfo.StockFa != null)
                        {
                            org.sdmx.ValueType att = new org.sdmx.ValueType();
                            att.concept = "SFA_INDICATOR";
                            int cIndex = model.Meta.Stub.GetIndexByCode(model.Meta.ContentVariable.Code);
                            att.value = model.Meta.ContentVariable.Values[di.StubIndecies[cIndex]].ContentInfo.StockFa;
                            serAtts.Add(att);
                        }

                        // Seasonal adjustement
                        if (model.Meta.ContentInfo.SeasAdj != null)
                        {
                            org.sdmx.ValueType att = new org.sdmx.ValueType();
                            att.concept = "SEAS_ADJ";
                            int cIndex = model.Meta.Stub.GetIndexByCode(model.Meta.ContentVariable.Code);
                            att.value = model.Meta.ContentVariable.Values[di.StubIndecies[cIndex]].ContentInfo.SeasAdj;
                            serAtts.Add(att);
                        }

                        // Daily adjustment
                        if (model.Meta.ContentInfo.DayAdj != null)
                        {
                            org.sdmx.ValueType att = new org.sdmx.ValueType();
                            att.concept = "DAY_ADJ";
                            int cIndex = model.Meta.Stub.GetIndexByCode(model.Meta.ContentVariable.Code);
                            att.value = model.Meta.ContentVariable.Values[di.StubIndecies[cIndex]].ContentInfo.DayAdj;
                            serAtts.Add(att);
                        }

                        // Base period
                        if (model.Meta.ContentInfo.Baseperiod != null)
                        {
                            org.sdmx.ValueType att = new org.sdmx.ValueType();
                            att.concept = "BASE_PER";
                            int cIndex = model.Meta.Stub.GetIndexByCode(model.Meta.ContentVariable.Code);
                            att.value = model.Meta.ContentVariable.Values[di.StubIndecies[cIndex]].ContentInfo.Baseperiod;
                            serAtts.Add(att);
                        }

                        // Reference period
                        if (model.Meta.ContentInfo.RefPeriod != null)
                        {
                            org.sdmx.ValueType att = new org.sdmx.ValueType();
                            att.concept = "REF_PERIOD";
                            int cIndex = model.Meta.Stub.GetIndexByCode(model.Meta.ContentVariable.Code);
                            att.value = model.Meta.ContentVariable.Values[di.StubIndecies[cIndex]].ContentInfo.RefPeriod;
                            serAtts.Add(att);
                        }

                        // Current / fixed prices
                        if (model.Meta.ContentInfo.CFPrices != null)
                        {
                            org.sdmx.ValueType att = new org.sdmx.ValueType();
                            att.concept = "PRICE_BASIS";
                            int cIndex = model.Meta.Stub.GetIndexByCode(model.Meta.ContentVariable.Code);
                            att.value = model.Meta.ContentVariable.Values[di.StubIndecies[cIndex]].ContentInfo.CFPrices;
                            serAtts.Add(att);
                        }
                    }
                    series.Attributes = serAtts.ToArray();
                }

                ds.Items[i] = series;
            }

            GenericDataType message = new GenericDataType();

            message.DataSet = ds;
            message.Header  = createHeader(model, true);

            return(message);
        }
 public IndexController(DataIndexer indexer)
 {
     this.indexer = indexer;
 }
Exemple #11
0
        /// <summary>
        /// Train a model using the GIS algorithm.
        /// </summary>
        /// <param name="iterations">  The number of GIS iterations to perform. </param>
        /// <param name="di"> The data indexer used to compress events in memory. </param>
        /// <param name="modelPrior"> The prior distribution used to train this model. </param>
        /// <returns> The newly trained model, which can be used immediately or saved
        ///         to disk using an opennlp.maxent.io.GISModelWriter object. </returns>
        public virtual GISModel trainModel(int iterations, DataIndexer di, Prior modelPrior, int cutoff, int threads)
        {
            if (threads <= 0)
            {
                throw new System.ArgumentException("threads must be at least one or greater but is " + threads + "!");
            }

            modelExpects = new MutableContext[threads][];

            /// <summary>
            ///************ Incorporate all of the needed info ***************** </summary>
            display("Incorporating indexed data for training...  \n");
            contexts           = di.Contexts;
            values             = di.Values;
            this.cutoff        = cutoff;
            predicateCounts    = di.PredCounts;
            numTimesEventsSeen = di.NumTimesEventsSeen;
            numUniqueEvents    = contexts.Length;
            this.prior         = modelPrior;
            //printTable(contexts);

            // determine the correction constant and its inverse
            double correctionConstant = 0;

            for (int ci = 0; ci < contexts.Length; ci++)
            {
                if (values == null || values[ci] == null)
                {
                    if (contexts[ci].Length > correctionConstant)
                    {
                        correctionConstant = contexts[ci].Length;
                    }
                }
                else
                {
                    float cl = values[ci][0];
                    for (int vi = 1; vi < values[ci].Length; vi++)
                    {
                        cl += values[ci][vi];
                    }

                    if (cl > correctionConstant)
                    {
                        correctionConstant = cl;
                    }
                }
            }
            display("done.\n");

            outcomeLabels = di.OutcomeLabels;
            outcomeList   = di.OutcomeList;
            numOutcomes   = outcomeLabels.Length;

            predLabels = di.PredLabels;
            prior.setLabels(outcomeLabels, predLabels);
            numPreds = predLabels.Length;

            display("\tNumber of Event Tokens: " + numUniqueEvents + "\n");
            display("\t    Number of Outcomes: " + numOutcomes + "\n");
            display("\t  Number of Predicates: " + numPreds + "\n");

            // set up feature arrays
            float[][] predCount = RectangularArrays.ReturnRectangularFloatArray(numPreds, numOutcomes);
            for (int ti = 0; ti < numUniqueEvents; ti++)
            {
                for (int j = 0; j < contexts[ti].Length; j++)
                {
                    if (values != null && values[ti] != null)
                    {
                        predCount[contexts[ti][j]][outcomeList[ti]] += numTimesEventsSeen[ti] * values[ti][j];
                    }
                    else
                    {
                        predCount[contexts[ti][j]][outcomeList[ti]] += numTimesEventsSeen[ti];
                    }
                }
            }

            //printTable(predCount);
            di = null; // don't need it anymore

            // A fake "observation" to cover features which are not detected in
            // the data.  The default is to assume that we observed "1/10th" of a
            // feature during training.
            double smoothingObservation = _smoothingObservation;

            // Get the observed expectations of the features. Strictly speaking,
            // we should divide the counts by the number of Tokens, but because of
            // the way the model's expectations are approximated in the
            // implementation, this is cancelled out when we compute the next
            // iteration of a parameter, making the extra divisions wasteful.
            parameters = new MutableContext[numPreds];
            for (int i = 0; i < modelExpects.Length; i++)
            {
                modelExpects[i] = new MutableContext[numPreds];
            }
            observedExpects = new MutableContext[numPreds];

            // The model does need the correction constant and the correction feature. The correction constant
            // is only needed during training, and the correction feature is not necessary.
            // For compatibility reasons the model contains form now on a correction constant of 1,
            // and a correction param 0.
            evalParams = new EvalParameters(parameters, 0, 1, numOutcomes);
            int[] activeOutcomes = new int[numOutcomes];
            int[] outcomePattern;
            int[] allOutcomesPattern = new int[numOutcomes];
            for (int oi = 0; oi < numOutcomes; oi++)
            {
                allOutcomesPattern[oi] = oi;
            }
            int numActiveOutcomes = 0;

            for (int pi = 0; pi < numPreds; pi++)
            {
                numActiveOutcomes = 0;
                if (useSimpleSmoothing)
                {
                    numActiveOutcomes = numOutcomes;
                    outcomePattern    = allOutcomesPattern;
                }
                else //determine active outcomes
                {
                    for (int oi = 0; oi < numOutcomes; oi++)
                    {
                        if (predCount[pi][oi] > 0 && predicateCounts[pi] >= cutoff)
                        {
                            activeOutcomes[numActiveOutcomes] = oi;
                            numActiveOutcomes++;
                        }
                    }
                    if (numActiveOutcomes == numOutcomes)
                    {
                        outcomePattern = allOutcomesPattern;
                    }
                    else
                    {
                        outcomePattern = new int[numActiveOutcomes];
                        for (int aoi = 0; aoi < numActiveOutcomes; aoi++)
                        {
                            outcomePattern[aoi] = activeOutcomes[aoi];
                        }
                    }
                }
                parameters[pi] = new MutableContext(outcomePattern, new double[numActiveOutcomes]);
                for (int i = 0; i < modelExpects.Length; i++)
                {
                    modelExpects[i][pi] = new MutableContext(outcomePattern, new double[numActiveOutcomes]);
                }
                observedExpects[pi] = new MutableContext(outcomePattern, new double[numActiveOutcomes]);
                for (int aoi = 0; aoi < numActiveOutcomes; aoi++)
                {
                    int oi = outcomePattern[aoi];
                    parameters[pi].setParameter(aoi, 0.0);
                    foreach (MutableContext[] modelExpect in modelExpects)
                    {
                        modelExpect[pi].setParameter(aoi, 0.0);
                    }
                    if (predCount[pi][oi] > 0)
                    {
                        observedExpects[pi].setParameter(aoi, predCount[pi][oi]);
                    }
                    else if (useSimpleSmoothing)
                    {
                        observedExpects[pi].setParameter(aoi, smoothingObservation);
                    }
                }
            }

            predCount = null; // don't need it anymore

            display("...done.\n");

            /// <summary>
            ///*************** Find the parameters *********************** </summary>
            if (threads == 1)
            {
                display("Computing model parameters ...\n");
            }
            else
            {
                display("Computing model parameters in " + threads + " threads...\n");
            }

            findParameters(iterations, correctionConstant);

            /// <summary>
            ///************* Create and return the model ***************** </summary>
            // To be compatible with old models the correction constant is always 1
            return(new GISModel(parameters, predLabels, outcomeLabels, 1, evalParams.CorrectionParam));
        }
Exemple #12
0
 /// <summary>
 /// Train a model using the GIS algorithm.
 /// </summary>
 /// <param name="iterations">  The number of GIS iterations to perform. </param>
 /// <param name="di"> The data indexer used to compress events in memory. </param>
 /// <returns> The newly trained model, which can be used immediately or saved
 ///         to disk using an opennlp.maxent.io.GISModelWriter object. </returns>
 public virtual GISModel trainModel(int iterations, DataIndexer di, int cutoff)
 {
     return(trainModel(iterations, di, new UniformPrior(), cutoff, 1));
 }