コード例 #1
0
ファイル: WorkCycle.cs プロジェクト: dtklinh/CRFTool
        public void Do(SoftwareGraphLearningParameters parameters)
        {
            Build.Do();

            var graphs = new List <IGWGraph <SGLNodeData, SGLEdgeData, SGLGraphData> >();

            #region Graphen erzeugen

            var erdösGraphCreator    = new ErdösGraphCreator();
            var categoryGraphCreator = new CategoryGraphCreator();

            for (int i = 0; i < parameters.NumberOfGraphs; i++)
            {
                var newGraph = erdösGraphCreator.CreateGraph(new ErdösGraphCreationParameter(parameters));
                // Kategoriengraph erzeugen
                categoryGraphCreator.CreateCategoryGraph(newGraph);

                graphs.Add(newGraph);

                //var graph3D = newGraph.Wrap3D();
                //new ShowGraph3D(graph3D).Request();
            }



            #endregion

            #region Beobachtungen erzeugen
            // TODO: Nach welchem Vorgehen Beobachtungen erzeugen?
            // erstmal: für jede Kategorie: zufällig x [0,1): Prob(0) = x, Prob(1) = 1-x

            for (int i = 0; i < parameters.NumberOfGraphs; i++)
            {
                var graph = graphs[i];
                // Beobachtung für graph erzeugen
                foreach (var node in graph.Nodes)
                {
                    var categoryGraph = graph.Data.CategoryGraph;
                    var category      = categoryGraph.Nodes.ToList().Find(catNode => catNode.Data.Category == node.Data.Category);

                    double probability = random.NextDouble();
                    if (probability <= category.Data.ObservationProbabilty)
                    {
                        node.Data.Observation = 0;
                    }
                    else
                    {
                        node.Data.Observation = 1;
                    }
                }
            }

            //information about graph

            /*
             * var graphView = new GraphView();
             * foreach (var graph in graphs)
             * {
             *  graphView.GetGraphInfo(graph);
             *
             * } */



            #endregion

            #region Training

            // zwei Parameter: Korrelationsparameter || Konformitätsparameter
            // zusätzliche parameter (a (negativ), b) (vorgegeben in SoftwareGraphLearningParameters)
            // Zielfunktion setzt sich aus zwei Zielen zusammen:
            // Ziel 1: Homogenität: erstmal: Für jede Kategorie: Score := a * (Math.Abs(0.95 - homogenityRatio))
            // Ziel 2: Unabhängigkeit: erstmal: Für jede Kategorie: Score += b, falls
            // Ziel 3: Korrelation mit lokalen Beobachtungen
            // Math.Sign([Mittelwert Knotenscore] - 0.5) == Math.Sign([Mittelwert Knotenlabeling] - 0.5)
            // Ziel 3: Korrelation



            //in results sind die durchschnittl homogenitäten der graphen für die 10 * 10 datenpunkte gespeichert
            var results = new List <LocalResult>();
            for (int k = 1; k <= 10; k++)
            {
                for (int l = 1; l <= 10; l++)
                {
                    var conformity  = k * 0.1;
                    var correlation = l * 0.1;
                    var isingModel  = new IsingModel(conformity, correlation);
                    //in resultstemp sind homogenitäten für die graphen bei gleicher correlation + conformity
                    var resultsTemp = new List <LocalResult>();
                    int counter     = 0;
                    foreach (var graph in graphs)
                    {
                        var localResult = new LocalResult(parameters.NumberCategories, conformity, correlation);

                        #region homogenities normal

                        // anhand der Observation & Korrelation/Konformität CRF-Scores berechnen
                        // Berechne Scores für nodes und edges
                        isingModel.CreateCRFScore(graph);

                        // Viterbiheuristik starten
                        var request = new SolveInference(graph, null, parameters.NumberLabels);
                        request.RequestInDefaultContext();

                        // sammeln der ergebnisse
                        var resultingLabeling = request.Solution.Labeling;

                        //labeling auf nodes mappen
                        var nodes = graph.Nodes.ToList();
                        foreach (var node in nodes)
                        {
                            node.Data.AssignedLabel = resultingLabeling[node.GraphId];
                        }

                        //homogenität berechnen -> in localResult speichern
                        var categoryGraph = graph.Data.CategoryGraph;
                        //durch jede kategorie gehen
                        foreach (var catNode in categoryGraph.Nodes)
                        {
                            int amountZeroLabeled = 0;
                            //für jeden knoten in aktueller kategorie, anzahl 0 labels zählen
                            foreach (var node in catNode.Data.Nodes)
                            {
                                if (node.Data.AssignedLabel == 0)
                                {
                                    amountZeroLabeled++;
                                }
                            }
                            //homogenität = max(a; 1-a) a = anteil mit 0 gelabelt
                            var homogenityRatio = Math.Max((amountZeroLabeled * 1.0) / catNode.Data.NumberNodes,
                                                           1 - (amountZeroLabeled * 1.0) / catNode.Data.NumberNodes);
                            //homgenität speichern in homogenity array
                            localResult.Homs[catNode.Data.Category] = homogenityRatio;
                        }
                        #endregion

                        #region distinction isolated

                        //distinction isoliert berechnen

                        foreach (var edge in graph.Edges)
                        {   //set score of all inter edges to 0
                            if (edge.Data.Type == EdgeType.Inter)
                            {
                                edge.Data.Scores = new double[2, 2] {
                                    { 0, 0 }, { 0, 0 }
                                };
                            }
                        }
                        // Viterbiheuristik starten
                        request = new SolveInference(graph, null, parameters.NumberLabels);
                        request.RequestInDefaultContext();

                        // sammeln der ergebnisse
                        resultingLabeling = request.Solution.Labeling;

                        //labeling auf nodes mappen
                        nodes = graph.Nodes.ToList();
                        foreach (var node in nodes)
                        {
                            node.Data.LabelTemp = resultingLabeling[node.GraphId];
                        }

                        categoryGraph = graph.Data.CategoryGraph;
                        //durch jede kategorie gehen
                        foreach (var catNode in categoryGraph.Nodes)
                        {
                            int amountDifferentLabeled = 0;
                            //für jede kategorie, anzahl unterschiedlich gelabelter knoten zählen
                            foreach (var node in catNode.Data.Nodes)
                            {
                                if (node.Data.AssignedLabel != node.Data.LabelTemp)
                                {
                                    amountDifferentLabeled++;
                                }
                            }
                            //distinctRatio = anteil der ungleich gelabelten nodes
                            var distinctRatio = (amountDifferentLabeled * 1.0) / catNode.Data.NumberNodes;
                            //distinctRatio in distinction array speichern
                            localResult.Distincts[catNode.Data.Category] = distinctRatio;
                        }

                        #endregion

                        resultsTemp.Add(localResult);

                        if (counter == 0)
                        {
                            graph.SaveAsJSON("exampleGraph_" + k + "_" + l + ".txt");
                            //var graph3D = graph.Wrap3D();
                            //new ShowGraph3D(graph3D).Request();
                        }
                        counter++;
                    }//end of foreach graph in graphs

                    //jetzt durchschnittswerte für alle in resultsTemp berechnen
                    //und diese durchschnitte in results speichern
                    var averageResult = new LocalResult(parameters.NumberCategories, conformity, correlation);

                    //values aufaddieren
                    foreach (var localresult in resultsTemp)
                    {
                        averageResult.AddValues(localresult);
                    }
                    //durchschnitte für kategorien
                    for (int i = 0; i < averageResult.Homs.Length; i++)
                    {
                        averageResult.Homs[i]      /= resultsTemp.Count;
                        averageResult.Distincts[i] /= resultsTemp.Count;
                    }
                    //gesamtdurchschnitte
                    double avgHomogenity  = 0;
                    double avgDistinction = 0;
                    for (int i = 0; i < averageResult.Homs.Length; i++)
                    {
                        avgHomogenity  += averageResult.Homs[i];
                        avgDistinction += averageResult.Distincts[i];
                    }
                    avgHomogenity  /= averageResult.Homs.Length;
                    avgDistinction /= averageResult.Distincts.Length;
                    averageResult.AvgHomogenity  = avgHomogenity;
                    averageResult.AvgDistinction = avgDistinction;
                    averageResult.ResultValue    = avgHomogenity + (1 - avgDistinction);

                    //ausgabe
                    Log.Post("Konformität: " + conformity + " - Korrelation: " + correlation + "   ", LogCategory.Result);
                    Log.Post("Avg Homogenity: " + averageResult.AvgHomogenity, LogCategory.Result);
                    Log.Post("Avg Distinction: " + averageResult.AvgDistinction, LogCategory.Result);
                    Log.Post(Environment.NewLine, LogCategory.Result);

                    results.Add(averageResult);
                }
            }

            /*  Erwartungen:
             *
             *  1) gleiche Werte in allen Communities
             *  2) Homogenität ansteigend in Correlation
             *
             *
             * */

            // Auswertung der Homogenität und Unabhängigkeit

            var bestResult = results.MaxEntry(r => r.ResultValue);

            Log.Post("Exhaustive Search Result:");
            Log.Post("conformity: " + bestResult.Conformity);
            Log.Post("correlation: " + bestResult.Correlation);


            bestResult.SaveAsJSON(@"..\..\bestResult.txt");
            results.SaveAsJSON(@"..\..\results.txt");

            #endregion


            #region OLM


            var con = bestResult.Conformity;
            var cor = bestResult.Correlation;

            //create referenceLabeling for best parameters
            var isingModell = new IsingModel(con, cor);

            foreach (var graph in graphs)
            {
                isingModell.CreateCRFScore(graph);
                var request = new SolveInference(graph, null, parameters.NumberLabels);
                request.RequestInDefaultContext();
                graph.Data.ReferenceLabeling = request.Solution.Labeling;
            }

            var req = new OLMRequest(OLMVariant.Default, graphs);
            req.BasisMerkmale = new BasisMerkmal <ICRFNodeData, ICRFEdgeData, ICRFGraphData>[]
            { new IsingMerkmalNode(), new IsingMerkmalEdge() };
            req.LossFunctionValidation = LossFunction;
            req.MaxIterations          = 100;

            req.RequestInDefaultContext();

            double[] olmWeights = req.Result.ResultingWeights;


            Log.Post("OLM Result: ");
            for (int i = 0; i < olmWeights.Length; i++)
            {
                Log.Post(olmWeights[i] + "");
            }
            #endregion
        }//end of Do Method
コード例 #2
0
ファイル: WorkflowOne.cs プロジェクト: dtklinh/CRFTool
        public void Execute()
        {
            //    - Prerequisites:
            //	  - Graphstructure
            //    - Training Data
            //    - 2 Node Classifications
            TrainingData   = new List <GWGraph <CRFNodeData, CRFEdgeData, CRFGraphData> >();
            EvaluationData = new List <GWGraph <CRFNodeData, CRFEdgeData, CRFGraphData> >();

            // decision wether user wants to train or load pre-trained data
            var requestTraining = new UserDecision("Use Training.", "Load Training Result.");

            requestTraining.Request();

            if (requestTraining.Decision == 0) // use training
            {
                //    -n characteristics for each node
                {
                    var request = new UserInput(UserInputLookFor.Folder);
                    request.DefaultPath = "..\\..\\CRFToolApp\\bin\\Graphs";
                    request.TextForUser = "******";
                    request.Request();
                    GraphDataFolder = request.UserText;

                    foreach (var file in Directory.EnumerateFiles(GraphDataFolder))
                    {
                        var graph = JSONX.LoadFromJSON <GWGraph <CRFNodeData, CRFEdgeData, CRFGraphData> >(file);
                        TrainingData.Add(graph);
                    }
                }

                //   - Step 1:

                //   - discretize characteristics
                #region Use Training Pede
                {
                    // create features
                    CreateFeatures(Dataset, TrainingData);
                    InitCRFScores(TrainingData);

                    var request = new OLMRequest(OLMVariant.Ising, TrainingData);
                    request.BasisMerkmale.AddRange(Dataset.NodeFeatures);
                    request.BasisMerkmale.AddRange(Dataset.EdgeFeatures);

                    request.LossFunctionValidation = OLM.LossRatio;

                    request.Request();

                    // zugehörige Scores erzeugen für jeden Graphen (auch Evaluation)
                    CreateCRFScores(TrainingData, Dataset.NodeFeatures, request.Result.ResultingWeights);

                    // store trained Weights
                    Dataset.NumberIntervals    = NumberIntervals;
                    Dataset.Characteristics    = TrainingData.First().Data.Characteristics.ToArray();
                    Dataset.EdgeCharacteristic = "IsingEdgeCharacteristic";
                    Dataset.Weights            = request.Result.ResultingWeights;
                    Dataset.SaveAsJSON("results.json");
                }

                #endregion
            }
            else
            { // load pre-trained data
                var request = new UserInput();
                request.TextForUser = "******";
                request.Request();
                var file           = request.UserText;
                var trainingResult = JSONX.LoadFromJSON <WorkflowOneDataset>(file);
                Dataset = trainingResult;
            }
            //- Step2:

            // User Choice here
            {
                var request = new UserInput(UserInputLookFor.Folder);
                request.TextForUser = "******";
                request.Request();
                GraphDataFolder = request.UserText;

                foreach (var file in Directory.EnumerateFiles(GraphDataFolder))
                {
                    var graph = JSONX.LoadFromJSON <GWGraph <CRFNodeData, CRFEdgeData, CRFGraphData> >(file);
                    EvaluationData.Add(graph);
                }
            }

            // remove double edges
            {
                var edgesToRemove = new LinkedList <GWEdge <CRFNodeData, CRFEdgeData, CRFGraphData> >();
                foreach (var graph in EvaluationData)
                {
                    foreach (var edge in graph.Edges.ToList())
                    {
                        if (graph.Edges.Any(e => (e != edge) && ((e.Foot == edge.Head && e.Head == edge.Foot && e.GWId.CompareTo(edge.GWId) < 0) || (e.Foot == edge.Foot && e.Head == edge.Head && e.GWId.CompareTo(edge.GWId) < 0))))
                        {
                            edgesToRemove.Add(edge);
                            graph.Edges.Remove(edge);
                        }
                    }
                }
                foreach (var edge in edgesToRemove)
                {
                    edge.Foot.Edges.Remove(edge);
                    edge.Head.Edges.Remove(edge);
                }
            }


            //scores erzeugen
            CreateCRFScores(EvaluationData, Dataset.NodeFeatures, Dataset.Weights);

            //   - Create ROC Curve
            {
            }
            //   - Give Maximum with Viterbi
            {
                foreach (var graph in EvaluationData)
                {
                    var request = new SolveInference(graph, null, 2);
                    request.Request();
                    graph.Data.AssginedLabeling = request.Solution.Labeling;
                }

                //show results in 3D Viewer
                {
                    //var request = new ShowGraphs();
                    //request.Graphs = EvaluationData;
                    //request.Request();
                }
            }
            //   - Give Sample with MCMC
            {
                foreach (var graph in EvaluationData)
                {
                    SoftwareGraphLearningParameters parameters = new SoftwareGraphLearningParameters();
                    parameters.NumberOfGraphs          = 60;
                    parameters.NumberNodes             = 50;
                    parameters.NumberLabels            = 2;
                    parameters.NumberCategories        = 4;
                    parameters.IntraConnectivityDegree = 0.15;
                    parameters.InterConnectivityDegree = 0.01;


                    //sample parameters
                    var samplerParameters = new MHSamplerParameters();
                    var sglGraph          = graph.Convert((nodeData) => new SGLNodeData()
                    {
                    }, (edgeData) => new SGLEdgeData(), (graphData) => new SGLGraphData());

                    samplerParameters.Graph        = sglGraph;
                    samplerParameters.NumberChains = 1;

                    //sampler starten
                    var gibbsSampler = new MHSampler();
                    gibbsSampler.Do(samplerParameters);
                }
            }
        }
コード例 #3
0
        /*
         *  Die mit Herrn Waack besprochene Version des Projektzyklus zum Testen der verschiedenen Trainingsvarianten von OLM
         *
         *
         */
        public void RunCycle(TrainingEvaluationCycleInputParameters inputParameters)
        {
            #region Schritt 0: Vorbereiten der Daten

            // Zwischenspeichern von viel genutzten Variablen zur Übersichtlichkeit:
            var inputGraph = inputParameters.Graph;
            var graphList  = new List <GWGraph <CRFNodeData, CRFEdgeData, CRFGraphData> >();

            // Graphen erzeugen
            for (int i = 0; i < inputParameters.NumberOfGraphInstances; i++)
            {
                var newGraph = inputGraph.Clone(nd => new CRFNodeData()
                {
                    X = nd.Data.X, Y = nd.Data.Y, Z = nd.Data.Z
                }, ed => new CRFEdgeData(), gd => new CRFGraphData());
                graphList.Add(newGraph);
            }

            // Erzeugung der benötigten Objekte:
            seedingMethodPatchCreation = new SeedingMethodPatchCreation(inputParameters.NumberOfSeedsForPatchCreation, inputParameters.MaximumTotalPatchSize);

            #endregion


            #region Schritt 1: Referenzlabelings erzeugen.

            int[][] referenceLabelings = new int[inputParameters.NumberOfGraphInstances][];
            for (int i = 0; i < inputParameters.NumberOfGraphInstances; i++)
            {
                seedingMethodPatchCreation.CreatePatchAndSetAsReferenceLabel(graphList[i]);

                if (i == 0 && GraphVisalization == true)
                {
                    var graph3D = graphList[i].Wrap3D(nd => new Node3DWrap <CRFNodeData>(nd.Data)
                    {
                        ReferenceLabel = nd.Data.ReferenceLabel, X = nd.Data.X, Y = nd.Data.Y, Z = nd.Data.Z
                    }, (ed) => new Edge3DWrap <CRFEdgeData>(ed.Data)
                    {
                        Weight = 1.0
                    });
                    new ShowGraph3D(graph3D).Request();
                }
            }


            #endregion

            #region Schritt 2: Beobachtungen erzeugen (und Scores)

            var createObservationsUnit = new CreateObservationsUnit(inputParameters.TransitionProbabilities);
            var isingModel             = new IsingModel(inputParameters.IsingConformityParameter, inputParameters.IsingCorrelationParameter);
            for (int i = 0; i < inputParameters.NumberOfGraphInstances; i++)
            {
                var graph = graphList[i];
                createObservationsUnit.CreateObservation(graph);
                //graph.Data.Observations = observation;

                // zugehörige Scores erzeugen
                isingModel.CreateCRFScore(graph);

                if (i == 0)
                {
                    var graph3D = graph.Wrap3D();
                    new ShowGraph3D(graph3D).Request();
                }
            }
            #endregion

            #region Schritt 3: Aufteilen der Daten in Evaluation und Training
            // Verhaeltnis: 50 50
            int separation = inputParameters.NumberOfGraphInstances / 2;

            var testGraphs = new List <IGWGraph <ICRFNodeData, ICRFEdgeData, ICRFGraphData> >
                                 (new IGWGraph <ICRFNodeData, ICRFEdgeData, ICRFGraphData> [separation]);
            var evaluationGraphs = new List <GWGraph <CRFNodeData, CRFEdgeData, CRFGraphData> >
                                       (new GWGraph <CRFNodeData, CRFEdgeData, CRFGraphData> [inputParameters.NumberOfGraphInstances - separation]);

            for (int i = 0; i < separation; i++)
            {
                testGraphs[i] = graphList[i];
            }
            int k = 0;
            for (int j = separation; j < inputParameters.NumberOfGraphInstances; j++)
            {
                evaluationGraphs[k++] = graphList[j];
            }

            #endregion

            #region Schritt 4: Die verschiedenen Varianten von OLM trainieren und evaluieren

            // object for evaluation
            var evaluationResults = new Dictionary <OLMVariant, OLMEvaluationResult>();

            foreach (var trainingVariant in inputParameters.TrainingVariantsToTest)
            {
                evaluationResults.Add(trainingVariant, new OLMEvaluationResult());

                #region Schritt 4.1: Training der OLM-Variante
                {
                    var request = new OLMRequest(trainingVariant, testGraphs);
                    request.BasisMerkmale.AddRange(new IsingMerkmalNode(), new IsingMerkmalEdge());
                    //TODO: loss function auslagern
                    request.LossFunctionValidation = (a, b) =>
                    {
                        var loss = 0.0;
                        for (int i = 0; i < a.Length; i++)
                        {
                            loss += a[i] != b[i] ? 1 : 0;
                        }
                        return(loss / a.Length);
                    };

                    request.Request();

                    var olmResult = request.Result;


                    // update Ising parameters in IsingModel
                    isingModel.ConformityParameter  = olmResult.ResultingWeights[0];
                    isingModel.CorrelationParameter = olmResult.ResultingWeights[1];

                    // zugehörige Scores erzeugen für jeden Graphen (auch Evaluation)
                    foreach (var graph in graphList)
                    {
                        isingModel.CreateCRFScore(graph);
                    }
                }
                #endregion

                #region Schritt 4.2: Evaluation der OLM-Variante

                var keys    = new ComputeKeys();
                var results = new OLMEvaluationResult();
                results.ConformityParameter  = isingModel.ConformityParameter;
                results.CorrelationParameter = isingModel.CorrelationParameter;

                // 1) Viterbi-Heuristik starten (request: SolveInference) + zusätzliche Parameter hinzufügen
                for (int graph = 0; graph < evaluationGraphs.Count; graph++)
                {
                    var request2 = new SolveInference(evaluationGraphs[graph], inputParameters.NumberOfLabels,
                                                      inputParameters.BufferSizeViterbi);

                    request2.RequestInDefaultContext();

                    // 2) Ergebnis des request auswerten (request.Solution liefert ein Labeling)
                    int[] predictionLabeling = request2.Solution.Labeling;

                    // 3) Ergebnisse aller Evaluationsgraphen auswerten (TP, TN, FP, FN, MCC) und zwischenspeichern
                    // neues Objekt, damit in Schritt 5 darauf zugegriffen werden kann.
                    var result = keys.computeEvalutionGraphResult(evaluationGraphs[graph], predictionLabeling);
                    // einfügen in Dictionary -> Liste
                    evaluationResults[trainingVariant].GraphResults.Add(result);
                }

                // Berechnen der Average-Werte
                foreach (OLMVariant variant in evaluationResults.Keys)
                {
                    results.ComputeValues(evaluationResults[trainingVariant]);
                }

                // debug output
                Log.Post("Average Values");
                Log.Post("Sensitivity: " + evaluationResults[trainingVariant].AverageSensitivity +
                         "\t Specificy: " + evaluationResults[trainingVariant].AverageSpecificity +
                         "\t MCC: " + evaluationResults[trainingVariant].AverageMCC +
                         //"\t Accuracy: " + evaluationResults[trainingVariant].AverageAccuracy +
                         "\t TotalTP: " + evaluationResults[trainingVariant].TotalTP + "\n");

                #endregion
            }

            #endregion

            #region Schritt 5: Ergebnisse präsentieren und speichern
            // output of the keys
            //outputKeys(evaluation, inputParameters, evaluationGraphs);

            // output of the labels
            //outputLabelingsScores(graphList, inputParameters);


            // TODO: Marlon
            // graphische Ausgabe

            var olmPresentationRequest = new ShowOLMResult(evaluationResults.Values.ToList());
            //foreach (var variant in evaluationResults.Keys)
            //{

            //    //foreach (var graphresult in evaluationResults[variant].GraphResults)
            //    //{
            //    //    //var graph = graphresult.Graph;
            //    //}
            //}
            olmPresentationRequest.Request();
            #endregion
        }
コード例 #4
0
ファイル: OLMBase.cs プロジェクト: dtklinh/CRFTool
        public void Do(int weights, IEnumerable <IGWGraph <NodeData, EdgeData, GraphData> > graphs, int maxIterations, OLMRequest olmrequest)
        {
            Weights = weights;
            int    validationQuantils = 2;
            double quantilratio       = 1.0 / validationQuantils;
            var    quantiledGraphs    = new List <IGWGraph <NodeData, EdgeData, GraphData> > [validationQuantils];

            for (int i = 0; i < validationQuantils; i++)
            {
                quantiledGraphs[i] = new List <IGWGraph <NodeData, EdgeData, GraphData> >();
            }

            //divide graphs in training / validation
            foreach (var graph in graphs)
            {
                var quantil = random.Next(validationQuantils);
                quantiledGraphs[quantil].Add(graph);
            }

            for (int quantilIteration = 0; quantilIteration < validationQuantils; quantilIteration++)
            {
                TrainingGraphs   = new List <IGWGraph <NodeData, EdgeData, GraphData> >();
                ValidationGraphs = new List <IGWGraph <NodeData, EdgeData, GraphData> >();


                //CoreResidues = 0;
                for (int quantil = 0; quantil < validationQuantils; quantil++)
                {
                    if (quantil == quantilIteration)
                    {
                        ValidationGraphs.AddRange(quantiledGraphs[quantil]);
                    }
                    else
                    {
                        TrainingGraphs.AddRange(quantiledGraphs[quantil]);
                    }
                }

                //foreach (var graph in ValidationGraphs)
                //{
                //    CoreResidues += graph.Data.CoreResidues;
                //}

                Iteration     = 0;
                MaxIterations = maxIterations;

                SetStartingWeights();

                this.WeightObservationUnit.Init(weightCurrent);
                var lossOpt     = double.MaxValue;
                var lossOptOld  = 0.0;
                var lossCurrent = 0.0;

                OLMTracker = new OLMTracking(weights, new int[] { 1, 3, 5, 8, 12, 20, 50 }, weightCurrent, Name + "_q" + quantilIteration + "_OLMTracking.txt");

                var interfaceValid    = 0;
                var noninterfaceValid = 0;

                foreach (var graph in ValidationGraphs)
                {
                    interfaceValid    += graph.Data.ReferenceLabeling.Sum();
                    noninterfaceValid += graph.Nodes.Count() - graph.Data.ReferenceLabeling.Sum();
                }

                var sitesValid = interfaceValid + noninterfaceValid;

                while (!CheckCancelCriteria())
                {
                    Iteration++;

                    var oldWVector = weightCurrent.ToArray();
                    weightCurrent = DoIteration(TrainingGraphs, weightCurrent, Iteration);

                    ResultingWeights = weightCurrent;

                    //for (int i = 1; i < 20; i++)
                    //{
                    //    for (int k = 1; k < 20; k++)
                    //    {
                    //        weightCurrent[0] = Math.Pow(-1.0, i) * ((int)(i / 2)) * 0.1;
                    //        weightCurrent[1] = Math.Pow(-1.0, k) * ((int)(k / 2)) * 0.1;

                    tp          = 0; tn = 0; fp = 0; fn = 0;
                    lossCurrent = 0.0;
                    foreach (var graph in ValidationGraphs)
                    {
                        SetWeightsCRF(weightCurrent, graph);

                        var request = new SolveInference(graph as IGWGraph <ICRFNodeData, ICRFEdgeData, ICRFGraphData>, null, Labels, BufferSizeInference);
                        request.RequestInDefaultContext();

                        var prediction = request.Solution.Labeling;
                        lossCurrent += LossFunctionValidation(graph.Data.ReferenceLabeling, prediction);

                        TrackResults(graph.Data.ReferenceLabeling, prediction);
                    }
                    WriterResults();
                    lossCurrent /= sitesValid;

                    if (lossCurrent < lossOpt)
                    {
                        lossOptOld = lossOpt;
                        lossOpt    = lossCurrent;
                        weightOpt  = weightCurrent;
                    }

                    OLMTracker.Track(weightCurrent, lossCurrent);
                    var iterationResult = new OLMIterationResult(weightCurrent.ToArray(), lossCurrent);
                    olmrequest.Result.ResultsHistory.IterationResultHistory.Add(iterationResult);

                    //    }
                    //}
                }
            }

            OLMTracker.WriteWeights();

            //return weightOpt;
        }
コード例 #5
0
        public void TestOlm()
        {
            var graph = CreateTestGraphCRF();

            var request = new OLMRequest(OLMVariant.Default, graph.ToIEnumerable());
        }
コード例 #6
0
        /*
         *  Die mit Herrn Waack besprochene Version des Projektzyklus zum Testen der verschiedenen Trainingsvarianten von OLM
         *
         *
         */
        public void RunCycle(TrainingEvaluationCycleInputParameters inputParameters)
        {
            #region Schritt 1: Vorbereiten der Daten

            var graphList         = inputParameters.Graphs;
            int numberOfLabels    = inputParameters.NumberOfLabels;
            int numberOfIntervals = inputParameters.NumberOfIntervals;

            #endregion

            #region Schritt 2: Beobachtungen erzeugen (und Scores)

            // var createObservationsUnit = new CreateObservationsUnit(inputParameters.Threshold);
            var createObservationsUnit = new CreateObservationsUnit(inputParameters.TransitionProbabilities);

            if (UseIsingModel)
            {
                Log.Post("Ising-Model");
            }
            else
            {
                Log.Post("Potts-Model with " + inputParameters.NumberOfIntervals + " Intervals");
            }

            var isingModel = new IsingModel(inputParameters.IsingConformityParameter, inputParameters.IsingCorrelationParameter);
            //var pottsModel = new PottsModel(inputParameters.PottsConformityParameters, inputParameters.IsingCorrelationParameter,
            //    inputParameters.AmplifierControlParameter, inputParameters.NumberOfLabels);
            var pottsModel = new PottsModelComplex(inputParameters.PottsConformityParameters, inputParameters.PottsCorrelationParameters,
                                                   inputParameters.AmplifierControlParameter, inputParameters.NumberOfLabels);

            for (int i = 0; i < inputParameters.NumberOfGraphInstances; i++)
            {
                var graph = graphList[i];
                createObservationsUnit.CreateObservation(graph);
                //createObservationsUnit.CreateObservationThresholding(graph);

                // zugehörige Scores erzeugen
                if (UseIsingModel)
                {
                    isingModel.CreateCRFScore(graph);
                }

                else
                {
                    pottsModel.InitCRFScore(graph);
                }

                if (i == 0 && GraphVisualization == true)
                {
                    var graph3D = graph.Wrap3D();
                    new ShowGraph3D(graph3D).Request();
                }
            }
            #endregion

            #region Schritt 3: Aufteilen der Daten in Evaluation und Training
            // Verhaeltnis: 80 20
            int separation = inputParameters.NumberOfGraphInstances - inputParameters.NumberOfGraphInstances / 5;
            // Verhältnis Leave-one-out
            //int separation = inputParameters.NumberOfGraphInstances - 1;

            var trainingGraphs = new List <IGWGraph <ICRFNodeData, ICRFEdgeData, ICRFGraphData> >
                                     (new IGWGraph <ICRFNodeData, ICRFEdgeData, ICRFGraphData> [separation]);
            var evaluationGraphs = new List <GWGraph <CRFNodeData, CRFEdgeData, CRFGraphData> >
                                       (new GWGraph <CRFNodeData, CRFEdgeData, CRFGraphData> [inputParameters.NumberOfGraphInstances - separation]);
            var randomizedGraphList = graphList.RandomizeOrder().ToList();

            for (int i = 0; i < separation; i++)
            {
                trainingGraphs[i] = randomizedGraphList[i];
                //trainingGraphs[i] = graphList[i];
            }
            int k = 0;
            for (int j = separation; j < inputParameters.NumberOfGraphInstances; j++, k++)
            {
                evaluationGraphs[k] = randomizedGraphList[j];
                //evaluationGraphs[i] = graphList[i];
            }

            Log.Post("Evaluation Graph ID: " + evaluationGraphs[0].Id);
            #endregion

            #region Schritt 4: Die verschiedenen Varianten von OLM trainieren und evaluieren

            // object for evaluation
            var evaluationResults = new Dictionary <OLMVariant, OLMEvaluationResult>();

            foreach (var trainingVariant in inputParameters.TrainingVariantsToTest)
            {
                evaluationResults.Add(trainingVariant, new OLMEvaluationResult());

                #region Schritt 4.1: Training der OLM-Variante
                {
                    var request = new OLMRequest(trainingVariant, trainingGraphs);
                    if (UseIsingModel)
                    {
                        request.BasisMerkmale.AddRange(new IsingMerkmalNode(), new IsingMerkmalEdge());
                    }
                    else
                    {
                        request.BasisMerkmale.AddRange(pottsModel.AddNodeFeatures(graphList, numberOfIntervals));
                        //request.BasisMerkmale.Add(new IsingMerkmalEdge());
                        request.BasisMerkmale.AddRange(pottsModel.AddEdgeFeatures(graphList, numberOfIntervals));
                    }

                    // loss function
                    request.LossFunctionIteration  = OLM.OLM.LossRatio;
                    request.LossFunctionValidation = OLM.OLM.LossRatio;

                    // execute training methods by calling OLMManager -> OLMBase
                    request.Request();

                    var olmResult = request.Result;

                    // update parameters in PottsModel
                    if (UseIsingModel)
                    {
                        isingModel.ConformityParameter  = olmResult.ResultingWeights[0];
                        isingModel.CorrelationParameter = olmResult.ResultingWeights[1];
                    }
                    else
                    {
                        int i = 0;
                        for (i = 0; i < pottsModel.ConformityParameter.Length; i++)
                        {
                            pottsModel.ConformityParameter[i] = olmResult.ResultingWeights[i];
                        }
                        //pottsModel.CorrelationParameter = olmResult.ResultingWeights[numberOfIntervals * 2];
                        for (int j = 0; j < pottsModel.CorrelationParameter.Length; j++)
                        {
                            pottsModel.CorrelationParameter[j] = olmResult.ResultingWeights[i++];
                        }
                    }

                    // zugehörige Scores erzeugen für jeden Graphen (auch Evaluation)
                    foreach (var graph in graphList)
                    {
                        if (UseIsingModel)
                        {
                            isingModel.CreateCRFScore(graph);
                        }
                        else
                        {
                            pottsModel.CreateCRFScore(graph, request.BasisMerkmale);
                        }
                    }
                }
                #endregion

                #region Schritt 4.2: Evaluation der OLM-Variante

                var keys    = new ComputeKeys();
                var results = new OLMEvaluationResult();
                if (UseIsingModel)
                {
                    results = new OLMEvaluationResult
                    {
                        ConformityParameter  = isingModel.ConformityParameter,
                        CorrelationParameter = isingModel.CorrelationParameter
                    };
                }
                else
                {
                    results = new OLMEvaluationResult
                    {
                        ConformityParameters = pottsModel.ConformityParameter,
                        //  CorrelationParameter = pottsModel.CorrelationParameter
                        CorrelationParameters = pottsModel.CorrelationParameter
                    };
                }

                if (UseIsingModel)
                {
                    Log.Post("Conformity: " + results.ConformityParameter + "\t Correlation: " + results.CorrelationParameter);
                }
                else
                {
                    for (int i = 0; i < results.ConformityParameters.Length; i++)
                    {
                        Log.Post("Conformity " + i + ": " + results.ConformityParameters[i] + "\t");
                    }
                    Log.Post("Correlation: " + results.CorrelationParameter);
                }

                // 1) Viterbi-Heuristik starten (request: SolveInference) + zusätzliche Parameter hinzufügen
                for (int graph = 0; graph < evaluationGraphs.Count; graph++)
                {
                    var request2 = new SolveInference(evaluationGraphs[graph], inputParameters.NumberOfLabels,
                                                      inputParameters.BufferSizeViterbi);

                    request2.RequestInDefaultContext();

                    // 2) Ergebnis des request auswerten (request.Solution liefert ein Labeling)
                    int[] predictionLabeling = request2.Solution.Labeling;

                    // 3) Ergebnisse aller Evaluationsgraphen auswerten (TP, TN, FP, FN, MCC) und zwischenspeichern
                    // neues Objekt, damit in Schritt 5 darauf zugegriffen werden kann.
                    var result = keys.computeEvalutionGraphResult(evaluationGraphs[graph], predictionLabeling);
                    // einfügen in Dictionary -> Liste
                    evaluationResults[trainingVariant].GraphResults.Add(result);
                }

                // Berechnen der Average-Werte
                foreach (OLMVariant variant in evaluationResults.Keys)
                {
                    results.ComputeValues(evaluationResults[trainingVariant]);
                }

                // debug output
                Log.Post("Average Values");
                Log.Post("Sensitivity: " + evaluationResults[trainingVariant].AverageSensitivity +
                         "\t Specificy: " + evaluationResults[trainingVariant].AverageSpecificity +
                         "\t MCC: " + evaluationResults[trainingVariant].AverageMCC +
                         //"\t Accuracy: " + evaluationResults[trainingVariant].AverageAccuracy +
                         "\t TotalTP: " + evaluationResults[trainingVariant].TotalTP +
                         "\t TotalFP: " + evaluationResults[trainingVariant].TotalFP +
                         "\t TotalTN: " + evaluationResults[trainingVariant].TotalTN +
                         "\t TotalFN: " + evaluationResults[trainingVariant].TotalFN);

                #endregion
            }

            #endregion
        }