Пример #1
0
        private async void ButtonRetrain2_OnClick(object sender, RoutedEventArgs e)
        {
            Options.MainWindow.MicrographDisplayControl.DropBoxNetworks();
            Options.MainWindow.MicrographDisplayControl.DropNoiseNetworks();

            Movie[] Movies = Options.MainWindow.FileDiscoverer.GetImmutableFiles();
            if (Movies.Length == 0)
            {
                Close?.Invoke();
                return;
            }

            TextNewName.Text = Helper.RemoveInvalidChars(TextNewName.Text);
            bool  UseCorpus    = (bool)CheckCorpus.IsChecked;
            bool  TrainMasking = (bool)CheckTrainMask.IsChecked;
            float Diameter     = (float)SliderDiameter.Value;

            string PotentialNewName = TextNewName.Text;

            PanelSettings.Visibility = Visibility.Collapsed;
            PanelTraining.Visibility = Visibility.Visible;

            await Task.Run(async() =>
            {
                GPU.SetDevice(0);

                #region Prepare new examples

                Dispatcher.Invoke(() => TextProgress.Text = "Preparing new examples...");

                Directory.CreateDirectory(Movies[0].DirectoryName + "boxnet2training/");
                NewExamplesPath = Movies[0].DirectoryName + "boxnet2training/" + PotentialNewName + ".tif";

                try
                {
                    PrepareData(NewExamplesPath);
                }
                catch (Exception exc)
                {
                    await Options.MainWindow.ShowMessageAsync("Oopsie", exc.ToString());
                    IsTrainingCanceled = true;

                    return;
                }

                #endregion

                #region Load background and new examples

                Dispatcher.Invoke(() => TextProgress.Text = "Loading examples...");

                int2 DimsLargest = new int2(1);

                List <float[]>[] AllMicrographs = { new List <float[]>(), new List <float[]>() };
                List <float[]>[] AllLabels      = { new List <float[]>(), new List <float[]>() };
                List <float[]>[] AllUncertains  = { new List <float[]>(), new List <float[]>() };
                List <int2>[] AllDims           = { new List <int2>(), new List <int2>() };
                List <float3>[] AllLabelWeights = { new List <float3>(), new List <float3>() };

                string[][] AllPaths = UseCorpus
                                          ? new[]
                {
                    new[] { NewExamplesPath },
                    Directory.EnumerateFiles(System.IO.Path.Combine(Environment.CurrentDirectory, "boxnet2training"), "*.tif").ToArray()
                }
                                          : new[] { new[] { NewExamplesPath } };

                long[] ClassHist = new long[3];

                for (int icorpus = 0; icorpus < AllPaths.Length; icorpus++)
                {
                    foreach (var examplePath in AllPaths[icorpus])
                    {
                        Image ExampleImage = Image.FromFile(examplePath);
                        int N = ExampleImage.Dims.Z / 3;

                        for (int n = 0; n < N; n++)
                        {
                            float[] Mic = ExampleImage.GetHost(Intent.Read)[n * 3 + 0];
                            MathHelper.FitAndSubtractGrid(Mic, new int2(ExampleImage.Dims), new int2(4));

                            AllMicrographs[icorpus].Add(Mic);
                            AllLabels[icorpus].Add(ExampleImage.GetHost(Intent.Read)[n * 3 + 1]);
                            AllUncertains[icorpus].Add(ExampleImage.GetHost(Intent.Read)[n * 3 + 2]);

                            AllDims[icorpus].Add(new int2(ExampleImage.Dims));

                            float[] Labels     = ExampleImage.GetHost(Intent.Read)[n * 3 + 1];
                            float[] Uncertains = ExampleImage.GetHost(Intent.Read)[n * 3 + 2];
                            for (int i = 0; i < Labels.Length; i++)
                            {
                                int Label = (int)Labels[i];
                                if (!TrainMasking && Label == 2)
                                {
                                    Label     = 0;
                                    Labels[i] = 0;
                                }
                                ClassHist[Label]++;
                            }
                        }

                        DimsLargest.X = Math.Max(DimsLargest.X, ExampleImage.Dims.X);
                        DimsLargest.Y = Math.Max(DimsLargest.Y, ExampleImage.Dims.Y);

                        ExampleImage.Dispose();
                    }
                }

                {
                    float[] LabelWeights = { 1f, 1f, 1f };
                    if (ClassHist[1] > 0)
                    {
                        LabelWeights[0] = (float)Math.Pow((float)ClassHist[1] / ClassHist[0], 1 / 3.0);
                        LabelWeights[2] = 1;//(float)Math.Sqrt((float)ClassHist[1] / ClassHist[2]);
                    }
                    else
                    {
                        LabelWeights[0] = (float)Math.Pow((float)ClassHist[2] / ClassHist[0], 1 / 3.0);
                    }

                    for (int icorpus = 0; icorpus < AllPaths.Length; icorpus++)
                    {
                        for (int i = 0; i < AllMicrographs[icorpus].Count; i++)
                        {
                            AllLabelWeights[icorpus].Add(new float3(LabelWeights[0], LabelWeights[1], LabelWeights[2]));
                        }
                    }
                }

                int NNewExamples = AllMicrographs[0].Count;
                int NOldExamples = UseCorpus ? AllMicrographs[1].Count : 0;

                #endregion

                #region Load models

                Dispatcher.Invoke(() => TextProgress.Text = "Loading old BoxNet model...");

                int NThreads = 2;

                Image.FreeDeviceAll();

                Dispatcher.Invoke(() =>
                {
                    Options.MainWindow.MicrographDisplayControl.FreeOnDevice();
                    Options.MainWindow.MicrographDisplayControl.DropBoxNetworks();
                    Options.MainWindow.MicrographDisplayControl.DropNoiseNetworks();
                });


                BoxNet2 NetworkTrain = null;
                try
                {
                    NetworkTrain = new BoxNet2(Options.MainWindow.LocatePickingModel(ModelName), GPU.GetDeviceWithMostMemory(), NThreads, 8, true);
                }
                catch   // It might be an old version of BoxNet that doesn't support batch size != 1
                {
                    NetworkTrain = new BoxNet2(Options.MainWindow.LocatePickingModel(ModelName), GPU.GetDeviceWithMostMemory(), NThreads, 1, true);
                }

                //BoxNet2 NetworkOld = new BoxNet2(Options.MainWindow.LocatePickingModel(ModelName), (GPU.GetDeviceCount() * 2 - 2) % GPU.GetDeviceCount(), NThreads, 1, false);

                #endregion

                #region Training

                Dispatcher.Invoke(() => TextProgress.Text = "Training...");

                int2 DimsAugmented = BoxNet2.BoxDimensionsTrain;
                int Border         = (BoxNet2.BoxDimensionsTrain.X - BoxNet2.BoxDimensionsValidTrain.X) / 2;
                int BatchSize      = NetworkTrain.BatchSize;
                int PlotEveryN     = 10;
                int SmoothN        = 30;

                List <ObservablePoint>[] AccuracyPoints = Helper.ArrayOfFunction(i => new List <ObservablePoint>(), 4);
                Queue <float>[] LastAccuracies          = { new Queue <float>(SmoothN), new Queue <float>(SmoothN) };
                List <float>[] LastBaselines            = { new List <float>(), new List <float>() };

                GPU.SetDevice(0);

                IntPtr d_MaskUncertain;
                {
                    float[] DataUncertain = new float[DimsAugmented.Elements()];
                    for (int y = 0; y < DimsAugmented.Y; y++)
                    {
                        for (int x = 0; x < DimsAugmented.X; x++)
                        {
                            if (x >= Border &&
                                y >= Border &&
                                x < DimsAugmented.X - Border &&
                                y < DimsAugmented.Y - Border)
                            {
                                DataUncertain[y *DimsAugmented.X + x] = 1;
                            }
                            else
                            {
                                DataUncertain[y *DimsAugmented.X + x] = 0.1f;
                            }
                        }
                    }

                    d_MaskUncertain = GPU.MallocDeviceFromHost(DataUncertain, DataUncertain.Length);
                }

                IntPtr[] d_OriData       = Helper.ArrayOfFunction(i => GPU.MallocDevice(DimsLargest.Elements()), NetworkTrain.MaxThreads);
                IntPtr[] d_OriLabels     = Helper.ArrayOfFunction(i => GPU.MallocDevice(DimsLargest.Elements()), NetworkTrain.MaxThreads);
                IntPtr[] d_OriUncertains = Helper.ArrayOfFunction(i => GPU.MallocDevice(DimsLargest.Elements()), NetworkTrain.MaxThreads);

                IntPtr[] d_AugmentedData    = Helper.ArrayOfFunction(i => GPU.MallocDevice(DimsAugmented.Elements() * BatchSize), NetworkTrain.MaxThreads);
                IntPtr[] d_AugmentedLabels  = Helper.ArrayOfFunction(i => GPU.MallocDevice(DimsAugmented.Elements() * BatchSize * 3), NetworkTrain.MaxThreads);
                IntPtr[] d_AugmentedWeights = Helper.ArrayOfFunction(i => GPU.MallocDevice(DimsAugmented.Elements() * BatchSize), NetworkTrain.MaxThreads);

                Stopwatch Watch = new Stopwatch();
                Watch.Start();

                Random[] RG        = Helper.ArrayOfFunction(i => new Random(i), NetworkTrain.MaxThreads);
                RandomNormal[] RGN = Helper.ArrayOfFunction(i => new RandomNormal(i), NetworkTrain.MaxThreads);

                //float[][] h_AugmentedData = Helper.ArrayOfFunction(i => new float[DimsAugmented.Elements()], NetworkTrain.MaxThreads);
                //float[][] h_AugmentedLabels = Helper.ArrayOfFunction(i => new float[DimsAugmented.Elements()], NetworkTrain.MaxThreads);
                //float[][] h_AugmentedWeights = Helper.ArrayOfFunction(i => new float[DimsAugmented.Elements()], NetworkTrain.MaxThreads);
                //float[][] LabelsOneHot = Helper.ArrayOfFunction(i => new float[DimsAugmented.Elements() * 3], NetworkTrain.MaxThreads);

                int NIterations = NNewExamples * 100 * AllMicrographs.Length;

                int NDone = 0;
                Helper.ForCPU(0, NIterations, NetworkTrain.MaxThreads,

                              threadID => GPU.SetDevice(0),

                              (b, threadID) =>
                {
                    int icorpus;
                    lock (Watch)
                        icorpus = NDone % AllPaths.Length;

                    float2[] PositionsGround;

                    for (int ib = 0; ib < BatchSize; ib++)
                    {
                        int ExampleID = RG[threadID].Next(AllMicrographs[icorpus].Count);
                        int2 Dims     = AllDims[icorpus][ExampleID];

                        float2[] Translations = Helper.ArrayOfFunction(x => new float2(RG[threadID].Next(Dims.X - Border * 2) + Border - DimsAugmented.X / 2,
                                                                                       RG[threadID].Next(Dims.Y - Border * 2) + Border - DimsAugmented.Y / 2), 1);

                        float[] Rotations = Helper.ArrayOfFunction(i => (float)(RG[threadID].NextDouble() * Math.PI * 2), 1);
                        float3[] Scales   = Helper.ArrayOfFunction(i => new float3(0.8f + (float)RG[threadID].NextDouble() * 0.4f,
                                                                                   0.8f + (float)RG[threadID].NextDouble() * 0.4f,
                                                                                   (float)(RG[threadID].NextDouble() * Math.PI * 2)), 1);
                        float StdDev = (float)Math.Abs(RGN[threadID].NextSingle(0, 0.3f));

                        float[] DataMicrograph = AllMicrographs[icorpus][ExampleID];
                        float[] DataLabels     = AllLabels[icorpus][ExampleID];
                        float[] DataUncertains = AllUncertains[icorpus][ExampleID];

                        GPU.CopyHostToDevice(DataMicrograph, d_OriData[threadID], Dims.Elements());
                        GPU.CopyHostToDevice(DataLabels, d_OriLabels[threadID], Dims.Elements());
                        GPU.CopyHostToDevice(DataUncertains, d_OriUncertains[threadID], Dims.Elements());

                        //GPU.ValueFill(d_OriUncertains[threadID], Dims.Elements(), 1f);

                        GPU.BoxNet2Augment(d_OriData[threadID],
                                           d_OriLabels[threadID],
                                           d_OriUncertains[threadID],
                                           Dims,
                                           new IntPtr((long)d_AugmentedData[threadID] + ib *DimsAugmented.Elements() * sizeof(float)),
                                           new IntPtr((long)d_AugmentedLabels[threadID] + ib *DimsAugmented.Elements() * 3 * sizeof(float)),
                                           new IntPtr((long)d_AugmentedWeights[threadID] + ib *DimsAugmented.Elements() * sizeof(float)),
                                           DimsAugmented,
                                           AllLabelWeights[icorpus][ExampleID],
                                           Helper.ToInterleaved(Translations),
                                           Rotations,
                                           Helper.ToInterleaved(Scales),
                                           StdDev,
                                           RG[threadID].Next(99999),
                                           (uint)1);
                    }

                    GPU.MultiplySlices(d_AugmentedWeights[threadID],
                                       d_MaskUncertain,
                                       d_AugmentedWeights[threadID],
                                       DimsAugmented.Elements(),
                                       (uint)BatchSize);

                    float LearningRate = 0.00005f;

                    long[][] ResultLabels         = new long[2][];
                    float[][] ResultProbabilities = new float[2][];

                    float Loss = 0;

                    lock (NetworkTrain)
                        Loss = NetworkTrain.Train(d_AugmentedData[threadID],
                                                  d_AugmentedLabels[threadID],
                                                  d_AugmentedWeights[threadID],
                                                  LearningRate,
                                                  threadID,
                                                  out ResultLabels[0],
                                                  out ResultProbabilities[0]);
                    //lock (NetworkOld)
                    //    NetworkOld.Predict(d_AugmentedData[threadID],
                    //                       threadID,
                    //                       out ResultLabels[1],
                    //                       out ResultProbabilities[1]);

                    //float[] AccuracyParticles = new float[2];

                    //for (int i = 0; i < 2; i++)
                    //{
                    //    for (int j = 0; j < ResultLabels[i].Length; j++)
                    //    {
                    //        long Label = ResultLabels[i][j];
                    //        float Prob = ResultProbabilities[i][j * 3 + Label];
                    //        if (Label == 1 && Prob < 0.4f)
                    //            ResultLabels[i][j] = 0;
                    //        else if (Label == 2 && Prob < 0.1f)
                    //            ResultLabels[i][j] = 0;
                    //    }

                    //    float2[] PositionsPicked = GetCentroids(ResultLabels[i], DimsAugmented, Border);

                    //    int FP = 0, FN = 0;

                    //    foreach (var posGround in PositionsGround)
                    //    {
                    //        bool Found = false;
                    //        foreach (var posPicked in PositionsPicked)
                    //        {
                    //            if ((posGround - posPicked).Length() < 5)
                    //            {
                    //                Found = true;
                    //                break;
                    //            }
                    //        }
                    //        if (!Found)
                    //            FN++;
                    //    }
                    //    foreach (var posPicked in PositionsPicked)
                    //    {
                    //        bool Found = false;
                    //        foreach (var posGround in PositionsGround)
                    //        {
                    //            if ((posGround - posPicked).Length() < 5)
                    //            {
                    //                Found = true;
                    //                break;
                    //            }
                    //        }
                    //        if (!Found)
                    //            FP++;
                    //    }

                    //    AccuracyParticles[i] = (float)(PositionsPicked.Length - FP) / (PositionsGround.Length + 0 + FN);

                    //    //if (float.IsNaN(AccuracyParticles[i]))
                    //    //    throw new Exception();
                    //}

                    lock (Watch)
                    {
                        NDone++;

                        //if (!float.IsNaN(AccuracyParticles[0]))
                        {
                            LastAccuracies[icorpus].Enqueue(Loss);
                            if (LastAccuracies[icorpus].Count > SmoothN)
                            {
                                LastAccuracies[icorpus].Dequeue();
                            }
                        }
                        //if (!float.IsNaN(AccuracyParticles[1]))
                        //    LastBaselines[icorpus].Add(AccuracyParticles[1]);

                        if (NDone % PlotEveryN == 0)
                        {
                            for (int iicorpus = 0; iicorpus < AllMicrographs.Length; iicorpus++)
                            {
                                AccuracyPoints[iicorpus * 2 + 0].Add(new ObservablePoint((float)NDone / NIterations * 100,
                                                                                         MathHelper.Mean(LastAccuracies[iicorpus].Where(v => !float.IsNaN(v)))));

                                //AccuracyPoints[iicorpus * 2 + 1].Clear();
                                //AccuracyPoints[iicorpus * 2 + 1].Add(new ObservablePoint(0,
                                //                                                        MathHelper.Mean(LastBaselines[iicorpus].Where(v => !float.IsNaN(v)))));
                                //AccuracyPoints[iicorpus * 2 + 1].Add(new ObservablePoint((float)NDone / NIterations * 100,
                                //                                                        MathHelper.Mean(LastBaselines[iicorpus].Where(v => !float.IsNaN(v)))));
                            }

                            long Elapsed           = Watch.ElapsedMilliseconds;
                            double Estimated       = (double)Elapsed / NDone *NIterations;
                            int Remaining          = (int)(Estimated - Elapsed);
                            TimeSpan SpanRemaining = new TimeSpan(0, 0, 0, 0, Remaining);

                            Dispatcher.InvokeAsync(() =>
                            {
                                SeriesTrainAccuracy.Values = new ChartValues <ObservablePoint>(AccuracyPoints[0]);
                                //SeriesTrainBaseline.Values = new ChartValues<ObservablePoint>(AccuracyPoints[1]);

                                if (UseCorpus)
                                {
                                    SeriesBackgroundAccuracy.Values = new ChartValues <ObservablePoint>(AccuracyPoints[2]);
                                    //SeriesBackgroundBaseline.Values = new ChartValues<ObservablePoint>(AccuracyPoints[3]);
                                }

                                TextProgress.Text = SpanRemaining.ToString((int)SpanRemaining.TotalHours > 0 ? @"hh\:mm\:ss" : @"mm\:ss");
                            });
                        }
                    }
                },

                              null);

                foreach (var ptr in d_OriData)
                {
                    GPU.FreeDevice(ptr);
                }
                foreach (var ptr in d_OriLabels)
                {
                    GPU.FreeDevice(ptr);
                }
                foreach (var ptr in d_OriUncertains)
                {
                    GPU.FreeDevice(ptr);
                }
                foreach (var ptr in d_AugmentedData)
                {
                    GPU.FreeDevice(ptr);
                }
                foreach (var ptr in d_AugmentedLabels)
                {
                    GPU.FreeDevice(ptr);
                }
                foreach (var ptr in d_AugmentedWeights)
                {
                    GPU.FreeDevice(ptr);
                }
                GPU.FreeDevice(d_MaskUncertain);

                #endregion

                if (!IsTrainingCanceled)
                {
                    Dispatcher.Invoke(() => TextProgress.Text = "Saving new BoxNet model...");

                    string BoxNetDir = System.IO.Path.Combine(Environment.CurrentDirectory, "boxnet2models/");
                    Directory.CreateDirectory(BoxNetDir);

                    NetworkTrain.Export(BoxNetDir + PotentialNewName);
                }

                NetworkTrain.Dispose();
                //NetworkOld.Dispose();

                Image.FreeDeviceAll();
                TFHelper.TF_FreeAllMemory();
            });

            if (!IsTrainingCanceled)
            {
                NewName = TextNewName.Text;
            }

            TextProgress.Text = "Done.";

            if (IsTrainingCanceled)
            {
                Close?.Invoke();
            }
            else
            {
                ButtonCancelTraining.Content = "CLOSE";
                ButtonCancelTraining.Click  -= ButtonCancelTraining_OnClick;
                ButtonCancelTraining.Click  += async(a, b) =>
                {
                    Close?.Invoke();

                    if (MainWindow.GlobalOptions.ShowBoxNetReminder)
                    {
                        var DialogResult = await((MainWindow)Application.Current.MainWindow).ShowMessageAsync("Sharing is caring 🙂",
                                                                                                              "BoxNet performs well because of the wealth of training data it\n" +
                                                                                                              "can use. However, it could do even better with the data you just\n" +
                                                                                                              "used for re-training! Would you like to open a website to guide\n" +
                                                                                                              "you through contributing your data to the central repository?\n\n" +
                                                                                                              $"Your training data have been saved to \n{NewExamplesPath.Replace('/', '\\')}.\n\n",
                                                                                                              MessageDialogStyle.AffirmativeAndNegative,
                                                                                                              new MetroDialogSettings()
                        {
                            AffirmativeButtonText = "Yes",
                            NegativeButtonText    = "No",
                            DialogMessageFontSize = 18,
                            DialogTitleFontSize   = 28
                        });
                        if (DialogResult == MessageDialogResult.Affirmative)
                        {
                            string argument = "/select, \"" + NewExamplesPath.Replace('/', '\\') + "\"";
                            Process.Start("explorer.exe", argument);

                            Process.Start("http://www.warpem.com/warp/?page_id=72");
                        }
                    }
                };
            }
        }
Пример #2
0
        public FlexNet3D(string modelDir, int3 boxDimensions, int gpuID = 0, int nThreads = 1, bool forTraining = true, int batchSize = 128, int bottleneckWidth = 2, int layerWidth = 64, int nlayers = 4)
        {
            BoxDimensions   = boxDimensions;
            ForTraining     = forTraining;
            BatchSize       = batchSize;
            BottleneckWidth = bottleneckWidth;
            NWeights0       = layerWidth;
            NLayers         = nlayers;
            ModelDir        = modelDir;
            MaxThreads      = nThreads;

            TFSessionOptions SessionOptions = TFHelper.CreateOptions();
            TFSession        Dummy          = new TFSession(new TFGraph(), SessionOptions);

            Session = TFHelper.FromSavedModel(SessionOptions, null, ModelDir, new[] { forTraining ? "train" : "serve" }, new TFGraph(), $"/device:GPU:{gpuID}");
            Graph   = Session.Graph;

            NodeInputSource       = Graph["volume_source"][0];
            NodeInputTarget       = Graph["volume_target"][0];
            NodeInputWeightSource = Graph["volume_weight_source"][0];
            NodeInputWeightTarget = Graph["volume_weight_target"][0];
            NodeDropoutRate       = Graph["training_dropout_rate"][0];
            if (forTraining)
            {
                NodeLearningRate      = Graph["training_learning_rate"][0];
                NodeOrthogonalityRate = Graph["training_orthogonality"][0];
                NodeOpTrain           = Graph["train_momentum"][0];
                NodeOutputLoss        = Graph["l2_loss"][0];
                NodeOutputLossKL      = Graph["kl_loss"][0];
                NodeBottleneck        = Graph["bottleneck"][0];
            }

            NodeCode = Graph["volume_code"][0];

            NodeOutputPredicted = Graph["volume_predict"][0];

            NodeWeights0 = Graph["encoder_0/weights_0"][0];
            NodeWeights1 = Graph[$"decoder_{nlayers - 1}/weights_{nlayers - 1}"][0];
            if (forTraining)
            {
                NodeWeights0Assign = Graph["encoder_0/assign_layer0"][0];
                NodeWeights0Input  = Graph["encoder_0/assign_layer0_values"][0];

                NodeWeights1Assign = Graph[$"decoder_{nlayers - 1}/assign_layer0"][0];
                NodeWeights1Input  = Graph[$"decoder_{nlayers - 1}/assign_layer0_values"][0];
            }

            TensorSource = Helper.ArrayOfFunction(i => TFTensor.FromBuffer(new TFShape(BatchSize, (BoxDimensions.X / 2 + 1), BoxDimensions.Y, BoxDimensions.Z, 2),
                                                                           new float[BatchSize * BoxDimensions.ElementsFFT() * 2],
                                                                           0,
                                                                           BatchSize * (int)BoxDimensions.ElementsFFT() * 2),
                                                  nThreads);

            TensorTarget = Helper.ArrayOfFunction(i => TFTensor.FromBuffer(new TFShape(BatchSize, (BoxDimensions.X / 2 + 1), BoxDimensions.Y, BoxDimensions.Z, 2),
                                                                           new float[BatchSize * BoxDimensions.ElementsFFT() * 2],
                                                                           0,
                                                                           BatchSize * (int)BoxDimensions.ElementsFFT() * 2),
                                                  nThreads);

            TensorWeightSource = Helper.ArrayOfFunction(i => TFTensor.FromBuffer(new TFShape(BatchSize, (BoxDimensions.X / 2 + 1), BoxDimensions.Y, BoxDimensions.Z, 1),
                                                                                 new float[BatchSize * BoxDimensions.ElementsFFT()],
                                                                                 0,
                                                                                 BatchSize * (int)BoxDimensions.ElementsFFT()),
                                                        nThreads);

            TensorWeightTarget = Helper.ArrayOfFunction(i => TFTensor.FromBuffer(new TFShape(BatchSize, (BoxDimensions.X / 2 + 1), BoxDimensions.Y, BoxDimensions.Z, 1),
                                                                                 new float[BatchSize * BoxDimensions.ElementsFFT()],
                                                                                 0,
                                                                                 BatchSize * (int)BoxDimensions.ElementsFFT()),
                                                        nThreads);

            TensorCode = Helper.ArrayOfFunction(i => TFTensor.FromBuffer(new TFShape(BatchSize, BottleneckWidth),
                                                                         new float[BatchSize * BottleneckWidth],
                                                                         0,
                                                                         BatchSize * BottleneckWidth),
                                                nThreads);

            TensorLearningRate = Helper.ArrayOfFunction(i => TFTensor.FromBuffer(new TFShape(1),
                                                                                 new float[1],
                                                                                 0,
                                                                                 1),
                                                        nThreads);

            TensorDropoutRate = Helper.ArrayOfFunction(i => TFTensor.FromBuffer(new TFShape(1),
                                                                                new float[1],
                                                                                0,
                                                                                1),
                                                       nThreads);

            TensorOrthogonalityRate = Helper.ArrayOfFunction(i => TFTensor.FromBuffer(new TFShape(1),
                                                                                      new float[1],
                                                                                      0,
                                                                                      1),
                                                             nThreads);

            ResultPredicted  = Helper.ArrayOfFunction(i => new float[BatchSize * BoxDimensions.ElementsFFT() * 2], nThreads);
            ResultBottleneck = Helper.ArrayOfFunction(i => new float[BatchSize * BottleneckWidth], nThreads);
            ResultLoss       = Helper.ArrayOfFunction(i => new float[1], nThreads);
            ResultLossKL     = Helper.ArrayOfFunction(i => new float[1], nThreads);

            RetrievedWeights = new float[boxDimensions.ElementsFFT() * 2 * NWeights0];

            //if (!ForTraining)
            RunnerPrediction = Helper.ArrayOfFunction(i => Session.GetRunner().
                                                      AddInput(NodeCode, TensorCode[i]).
                                                      AddInput(NodeDropoutRate, TensorDropoutRate[i]).
                                                      Fetch(NodeOutputPredicted),
                                                      nThreads);
            //else
            RunnerTraining = Helper.ArrayOfFunction(i => Session.GetRunner().
                                                    AddInput(NodeInputSource, TensorSource[i]).
                                                    AddInput(NodeInputTarget, TensorTarget[i]).
                                                    AddInput(NodeInputWeightSource, TensorWeightSource[i]).
                                                    AddInput(NodeInputWeightTarget, TensorWeightTarget[i]).
                                                    AddInput(NodeDropoutRate, TensorDropoutRate[i]).
                                                    AddInput(NodeLearningRate, TensorLearningRate[i]).
                                                    AddInput(NodeOrthogonalityRate, TensorOrthogonalityRate[i]).
                                                    Fetch(NodeOutputPredicted, NodeOutputLoss, NodeOutputLossKL, NodeBottleneck, NodeOpTrain),
                                                    nThreads);

            RunnerEncode = Helper.ArrayOfFunction(i => Session.GetRunner().
                                                  AddInput(NodeInputSource, TensorSource[i]).
                                                  AddInput(NodeInputWeightSource, TensorWeightSource[i]).
                                                  AddInput(NodeDropoutRate, TensorDropoutRate[i]).
                                                  Fetch(NodeBottleneck),
                                                  nThreads);

            RunnerRetrieveWeights0 = Session.GetRunner().Fetch(NodeWeights0);
            RunnerRetrieveWeights1 = Session.GetRunner().Fetch(NodeWeights1);

            if (ForTraining)
            {
                TensorWeights0 = TFTensor.FromBuffer(new TFShape(NWeights0, BoxDimensions.ElementsFFT() * 2),
                                                     new float[BoxDimensions.ElementsFFT() * 2 * NWeights0],
                                                     0,
                                                     (int)BoxDimensions.ElementsFFT() * 2 * NWeights0);

                RunnerAssignWeights0 = Session.GetRunner().AddInput(NodeWeights0Input, TensorWeights0).
                                       Fetch(NodeWeights0Assign);
                RunnerAssignWeights1 = Session.GetRunner().AddInput(NodeWeights1Input, TensorWeights0).
                                       Fetch(NodeWeights1Assign);
            }

            // Run prediction or training for one batch to claim all the memory needed
            float[] InitDecoded;
            float[] InitBottleneck;
            float[] InitLoss, InitLossKL;
            if (!ForTraining)
            {
                RandomNormal RandN = new RandomNormal(123);
                Predict(Helper.ArrayOfFunction(i => RandN.NextSingle(0, 1), BottleneckWidth * BatchSize),
                        0,
                        out InitDecoded);
            }
            else
            {
                RandomNormal RandN = new RandomNormal();

                Encode(Helper.ArrayOfFunction(i => RandN.NextSingle(0, 1), BatchSize * (int)BoxDimensions.ElementsFFT() * 2),
                       Helper.ArrayOfFunction(i => 1f, BatchSize * (int)BoxDimensions.ElementsFFT()),
                       0,
                       out InitBottleneck);

                Train(Helper.ArrayOfFunction(i => RandN.NextSingle(0, 1), BatchSize * (int)BoxDimensions.ElementsFFT() * 2),
                      Helper.ArrayOfFunction(i => RandN.NextSingle(0, 1), BatchSize * (int)BoxDimensions.ElementsFFT() * 2),
                      Helper.ArrayOfFunction(i => 1f, BatchSize * (int)BoxDimensions.ElementsFFT()),
                      Helper.ArrayOfFunction(i => 1f, BatchSize * (int)BoxDimensions.ElementsFFT()),
                      0.5f,
                      1e-10f,
                      1e-5f,
                      0,
                      out InitDecoded,
                      out InitBottleneck,
                      out InitLoss,
                      out InitLossKL);
            }
        }
Пример #3
0
        public NoiseNet3D(string modelDir, int3 boxDimensions, int nThreads = 1, int batchSize = 8, bool forTraining = true, int deviceID = 0)
        {
            lock (TFHelper.DeviceSync[deviceID])
            {
                DeviceID      = deviceID;
                BoxDimensions = boxDimensions;
                ForTraining   = forTraining;
                ModelDir      = modelDir;
                MaxThreads    = nThreads;
                BatchSize     = batchSize;

                TFSessionOptions SessionOptions = TFHelper.CreateOptions();
                TFSession        Dummy          = new TFSession(new TFGraph(), SessionOptions);

                Session = TFHelper.FromSavedModel(SessionOptions, null, ModelDir, new[] { forTraining ? "train" : "serve" }, new TFGraph(), $"/device:GPU:{deviceID}");
                Graph   = Session.Graph;

                NodeInputSource = Graph["volume_source"][0];
                if (forTraining)
                {
                    NodeInputTarget  = Graph["volume_target"][0];
                    NodeLearningRate = Graph["training_learning_rate"][0];
                    NodeOpTrain      = Graph["train_momentum"][0];
                    NodeOutputLoss   = Graph["l2_loss"][0];
                }

                NodeOutputPredicted = Graph["volume_predict"][0];

                TensorSource = Helper.ArrayOfFunction(i => TFTensor.FromBuffer(new TFShape(BatchSize, BoxDimensions.X, BoxDimensions.Y, boxDimensions.Z, 1),
                                                                               new float[BatchSize * BoxDimensions.Elements()],
                                                                               0,
                                                                               BatchSize * (int)BoxDimensions.Elements()),
                                                      nThreads);

                if (ForTraining)
                {
                    TensorTarget = Helper.ArrayOfFunction(i => TFTensor.FromBuffer(new TFShape(BatchSize, BoxDimensions.X, BoxDimensions.Y, boxDimensions.Z, 1),
                                                                                   new float[BatchSize * BoxDimensions.Elements()],
                                                                                   0,
                                                                                   BatchSize * (int)BoxDimensions.Elements()),
                                                          nThreads);

                    TensorLearningRate = Helper.ArrayOfFunction(i => TFTensor.FromBuffer(new TFShape(1),
                                                                                         new float[1],
                                                                                         0,
                                                                                         1),
                                                                nThreads);
                }

                ResultPredicted = Helper.ArrayOfFunction(i => new float[BatchSize * BoxDimensions.Elements()], nThreads);
                ResultLoss      = Helper.ArrayOfFunction(i => new float[1], nThreads);

                //if (!ForTraining)
                RunnerPrediction = Helper.ArrayOfFunction(i => Session.GetRunner().
                                                          AddInput(NodeInputSource, TensorSource[i]).
                                                          Fetch(NodeOutputPredicted),
                                                          nThreads);
                if (ForTraining)
                {
                    RunnerTraining = Helper.ArrayOfFunction(i => Session.GetRunner().
                                                            AddInput(NodeInputSource, TensorSource[i]).
                                                            AddInput(NodeInputTarget, TensorTarget[i]).
                                                            AddInput(NodeLearningRate, TensorLearningRate[i]).
                                                            Fetch(NodeOutputPredicted, NodeOutputLoss, NodeOpTrain),
                                                            nThreads);
                }
            }

            // Run prediction or training for one batch to claim all the memory needed
            float[] InitDecoded;
            float[] InitLoss;
            //if (!ForTraining)
            {
                Predict(new float[BoxDimensions.Elements() * BatchSize],
                        0,
                        out InitDecoded);
            }
            if (ForTraining)
            {
                RandomNormal RandN = new RandomNormal();
                Train(Helper.ArrayOfFunction(i => RandN.NextSingle(0, 1), BatchSize * (int)BoxDimensions.Elements()),
                      Helper.ArrayOfFunction(i => RandN.NextSingle(0, 1), BatchSize * (int)BoxDimensions.Elements()),
                      1e-10f,
                      0,
                      out InitDecoded,
                      out InitLoss);
            }
        }
Пример #4
0
        public CubeNet(string modelDir, int deviceID = 0, int nThreads = 1, int batchSize = 1, int nClasses = 2, bool forTraining = false)
        {
            lock (TFHelper.DeviceSync[deviceID])
            {
                DeviceID    = deviceID;
                ForTraining = forTraining;
                ModelDir    = modelDir;
                MaxThreads  = nThreads;
                BatchSize   = batchSize;
                NClasses    = nClasses;

                TFSessionOptions SessionOptions = TFHelper.CreateOptions();
                TFSession        Dummy          = new TFSession(new TFGraph(), SessionOptions);

                Session = TFHelper.FromSavedModel(SessionOptions, null, ModelDir, new[] { forTraining ? "train" : "serve" }, new TFGraph(), $"/device:GPU:{deviceID}");
                Graph   = Session.Graph;

                if (forTraining)
                {
                    NodeInputMicTile = Graph["images"][0];
                    NodeInputLabels  = Graph["image_classes"][0];
                    NodeInputWeights = Graph["image_weights"][0];
                    NodeLearningRate = Graph["training_learning_rate"][0];
                    NodeOpTrain      = Graph["train_momentum"][0];

                    NodeOutputLoss = Graph["cross_entropy"][0];
                }
                else
                {
                    NodeInputMicTilePredict = Graph["images_predict"][0];
                }

                NodeOutputArgMax  = Graph["argmax_tensor"][0];
                NodeOutputSoftMax = Graph["softmax_tensor"][0];

                if (forTraining)
                {
                    TensorMicTile = Helper.ArrayOfFunction(i => TFTensor.FromBuffer(new TFShape(BatchSize, BoxDimensionsTrain.X, BoxDimensionsTrain.Y, BoxDimensionsTrain.Z, 1),
                                                                                    new float[BatchSize * BoxDimensionsTrain.Elements()],
                                                                                    0,
                                                                                    BatchSize * (int)BoxDimensionsTrain.Elements()),
                                                           nThreads);

                    TensorTrainingLabels = Helper.ArrayOfFunction(i => TFTensor.FromBuffer(new TFShape(BatchSize, BoxDimensionsTrain.X, BoxDimensionsTrain.Y, BoxDimensionsTrain.Z, NClasses),
                                                                                           new float[BatchSize * BoxDimensionsTrain.Elements() * NClasses],
                                                                                           0,
                                                                                           BatchSize * (int)BoxDimensionsTrain.Elements() * NClasses),
                                                                  nThreads);

                    TensorTrainingWeights = Helper.ArrayOfFunction(i => TFTensor.FromBuffer(new TFShape(BatchSize, BoxDimensionsTrain.X, BoxDimensionsTrain.Y, BoxDimensionsTrain.Z, 1),
                                                                                            new float[BatchSize * BoxDimensionsTrain.Elements()],
                                                                                            0,
                                                                                            BatchSize * (int)BoxDimensionsTrain.Elements()),
                                                                   nThreads);

                    TensorLearningRate = Helper.ArrayOfFunction(i => TFTensor.FromBuffer(new TFShape(1),
                                                                                         new float[1],
                                                                                         0,
                                                                                         1),
                                                                nThreads);
                }
                else
                {
                    TensorMicTilePredict = Helper.ArrayOfFunction(i => TFTensor.FromBuffer(new TFShape(BatchSize, BoxDimensionsPredict.X, BoxDimensionsPredict.Y, BoxDimensionsPredict.Z, 1),
                                                                                           new float[BatchSize * BoxDimensionsPredict.Elements()],
                                                                                           0,
                                                                                           BatchSize * (int)BoxDimensionsPredict.Elements()),
                                                                  nThreads);
                }

                if (forTraining)
                {
                    ResultArgMax  = Helper.ArrayOfFunction(i => new long[BatchSize * (int)BoxDimensionsTrain.Elements()], nThreads);
                    ResultSoftMax = Helper.ArrayOfFunction(i => new float[BatchSize * (int)BoxDimensionsTrain.Elements() * NClasses], nThreads);
                    ResultLoss    = Helper.ArrayOfFunction(i => new float[BatchSize], nThreads);
                }
                else
                {
                    ResultArgMax  = Helper.ArrayOfFunction(i => new long[BatchSize * (int)BoxDimensionsPredict.Elements()], nThreads);
                    ResultSoftMax = Helper.ArrayOfFunction(i => new float[BatchSize * (int)BoxDimensionsPredict.Elements() * NClasses], nThreads);
                }

                if (!ForTraining)
                {
                    RunnerPrediction = Helper.ArrayOfFunction(i => Session.GetRunner().
                                                              AddInput(NodeInputMicTilePredict, TensorMicTilePredict[i]).
                                                              Fetch(NodeOutputArgMax, NodeOutputSoftMax),
                                                              nThreads);
                }
                if (ForTraining)
                {
                    RunnerTraining = Helper.ArrayOfFunction(i => Session.GetRunner().
                                                            AddInput(NodeInputMicTile, TensorMicTile[i]).
                                                            AddInput(NodeInputLabels, TensorTrainingLabels[i]).
                                                            AddInput(NodeInputWeights, TensorTrainingWeights[i]).
                                                            AddInput(NodeLearningRate, TensorLearningRate[i]).
                                                            Fetch(NodeOpTrain, NodeOutputArgMax, NodeOutputSoftMax, NodeOutputLoss),
                                                            nThreads);
                }
            }

            // Run prediction or training for one batch to claim all the memory needed
            long[]  InitArgMax;
            float[] InitProb;
            if (!ForTraining)
            {
                Predict(new float[BoxDimensionsPredict.Elements() * BatchSize],
                        0,
                        out InitArgMax,
                        out InitProb);
            }
            if (ForTraining)
            {
                RandomNormal RandN = new RandomNormal();
                Train(Helper.ArrayOfFunction(i => RandN.NextSingle(0, 1), BatchSize * (int)BoxDimensionsTrain.Elements()),
                      Helper.ArrayOfConstant(0.0f, BatchSize * (int)BoxDimensionsTrain.Elements() * NClasses),
                      Helper.ArrayOfConstant(0.0f, BatchSize * (int)BoxDimensionsTrain.Elements()),
                      1e-6f,
                      0,
                      out InitArgMax,
                      out InitProb);
            }
        }
Пример #5
0
        static void Main(string[] args)
        {
            #region Command line options

            Options Options = new Options();
            string  WorkingDirectory;

            string ProgramFolder = System.Reflection.Assembly.GetEntryAssembly().Location;
            ProgramFolder = ProgramFolder.Substring(0, Math.Max(ProgramFolder.LastIndexOf('\\'), ProgramFolder.LastIndexOf('/')) + 1);

            if (!Debugger.IsAttached)
            {
                Parser.Default.ParseArguments <Options>(args).WithParsed <Options>(opts => Options = opts);
                WorkingDirectory = Environment.CurrentDirectory + "/";
            }
            else
            {
                Options.Observation1Path  = @"E:\sara_debug\TS_for_M\reconstruction\odd";
                Options.Observation2Path  = @"E:\sara_debug\TS_for_M\reconstruction\even";
                Options.DenoiseSeparately = false;
                Options.MaskPath          = @"";
                Options.OldModelName      = "";
                Options.DontFlatten       = true;
                Options.Overflatten       = 1.0f;
                Options.PixelSize         = 15f;
                Options.Upsample          = 1.0f;
                Options.Lowpass           = -1;
                Options.KeepDimensions    = true;
                Options.MaskOutput        = false;
                Options.NIterations       = 600;
                Options.BatchSize         = 4;
                Options.GPUNetwork        = 0;
                Options.GPUPreprocess     = 1;
                WorkingDirectory          = @"G:\localsharpen\";
            }

            if (!Options.DontFlatten && Options.PixelSize < 0)
            {
                throw new Exception("Flattening requested, but pixel size not specified.");
            }

            if (Options.BatchSize != 4)
            {
                if (Options.BatchSize < 1)
                {
                    throw new Exception("Batch size must be at least 1.");
                }

                Options.NIterations = Options.NIterations * 4 / Options.BatchSize;
                Console.WriteLine($"Adjusting the number of iterations to {Options.NIterations} to match different batch size.\n");
            }

            #endregion

            GPU.SetDevice(Options.GPUPreprocess);

            #region Mask

            Console.Write("Loading mask... ");

            Image Mask        = null;
            int3  BoundingBox = new int3(-1);
            if (!string.IsNullOrEmpty(Options.MaskPath))
            {
                Mask = Image.FromFile(Options.MaskPath);

                Mask.TransformValues((x, y, z, v) =>
                {
                    if (v > 1e-3f)
                    {
                        BoundingBox.X = Math.Max(BoundingBox.X, Math.Abs(x - Mask.Dims.X / 2) * 2);
                        BoundingBox.Y = Math.Max(BoundingBox.Y, Math.Abs(y - Mask.Dims.Y / 2) * 2);
                        BoundingBox.Z = Math.Max(BoundingBox.Z, Math.Abs(z - Mask.Dims.Z / 2) * 2);
                    }

                    return(v);
                });

                if (BoundingBox.X < 2)
                {
                    throw new Exception("Mask does not seem to contain any non-zero values.");
                }

                BoundingBox += 64;

                BoundingBox.X = Math.Min(BoundingBox.X, Mask.Dims.X);
                BoundingBox.Y = Math.Min(BoundingBox.Y, Mask.Dims.Y);
                BoundingBox.Z = Math.Min(BoundingBox.Z, Mask.Dims.Z);
            }

            Console.WriteLine("done.\n");

            #endregion

            #region Load and prepare data

            Console.WriteLine("Preparing data:");

            List <Image>  Maps1                   = new List <Image>();
            List <Image>  Maps2                   = new List <Image>();
            List <Image>  MapsForDenoising        = new List <Image>();
            List <Image>  MapsForDenoising2       = new List <Image>();
            List <string> NamesForDenoising       = new List <string>();
            List <int3>   DimensionsForDenoising  = new List <int3>();
            List <int3>   OriginalBoxForDenoising = new List <int3>();
            List <float2> MeanStdForDenoising     = new List <float2>();
            List <float>  PixelSizeForDenoising   = new List <float>();

            foreach (var file in Directory.EnumerateFiles(Options.Observation1Path, "*.mrc"))
            {
                string   MapName   = Helper.PathToName(file);
                string[] Map2Paths = Directory.EnumerateFiles(Options.Observation2Path, MapName + ".mrc").ToArray();
                if (Map2Paths == null || Map2Paths.Length == 0)
                {
                    continue;
                }

                Console.Write($"Preparing {MapName}... ");

                Image Map1 = Image.FromFile(file);
                Image Map2 = Image.FromFile(Map2Paths.First());

                float MapPixelSize = Map1.PixelSize / (Options.KeepDimensions ? 1 : Options.Upsample);

                if (!Options.DontFlatten)
                {
                    Image Average = Map1.GetCopy();
                    Average.Add(Map2);

                    if (Mask != null)
                    {
                        Average.Multiply(Mask);
                    }

                    float[] Spectrum = Average.AsAmplitudes1D(true, 1, (Average.Dims.X + Average.Dims.Y + Average.Dims.Z) / 6);
                    Average.Dispose();

                    int   i10A   = (int)(Options.PixelSize * 2 / 10 * Spectrum.Length);
                    float Amp10A = Spectrum[i10A];

                    for (int i = 0; i < Spectrum.Length; i++)
                    {
                        Spectrum[i] = i < i10A ? 1 : (float)Math.Pow(Amp10A / Spectrum[i], Options.Overflatten);
                    }

                    Image Map1Flat = Map1.AsSpectrumMultiplied(true, Spectrum);
                    Map1.Dispose();
                    Map1 = Map1Flat;
                    Map1.FreeDevice();

                    Image Map2Flat = Map2.AsSpectrumMultiplied(true, Spectrum);
                    Map2.Dispose();
                    Map2 = Map2Flat;
                    Map2.FreeDevice();
                }

                if (Options.Lowpass > 0)
                {
                    Map1.Bandpass(0, Options.PixelSize * 2 / Options.Lowpass, true, 0.01f);
                    Map2.Bandpass(0, Options.PixelSize * 2 / Options.Lowpass, true, 0.01f);
                }

                //{
                //    int NShells = Map1.Dims.X / 2;
                //    float[] ResInv = Helper.ArrayOfFunction(i => Math.Min((int)(0.45 * NShells), i) / (Map1.Dims.X * MapPixelSize), NShells);
                //    float[] FilterSharpen = new float[NShells];
                //    for (int i = 0; i < NShells; i++)
                //        FilterSharpen[i] = (float)Math.Exp(100 / 4 * ResInv[i] * ResInv[i]);

                //    Image Map1Sharp = FSC.ApplyRamp(Map1, FilterSharpen);
                //    Map1.Dispose();
                //    Map1 = Map1Sharp;

                //    Image Map2Sharp = FSC.ApplyRamp(Map2, FilterSharpen);
                //    Map2.Dispose();
                //    Map2 = Map2Sharp;
                //}

                OriginalBoxForDenoising.Add(Map1.Dims);

                if (BoundingBox.X > 0)
                {
                    Image Map1Cropped = Map1.AsPadded(BoundingBox);
                    Map1.Dispose();
                    Map1 = Map1Cropped;
                    Map1.FreeDevice();

                    Image Map2Cropped = Map2.AsPadded(BoundingBox);
                    Map2.Dispose();
                    Map2 = Map2Cropped;
                    Map2.FreeDevice();
                }

                DimensionsForDenoising.Add(Map1.Dims);

                if (Options.Upsample != 1f)
                {
                    Image Map1Scaled = Map1.AsScaled(Map1.Dims * Options.Upsample / 2 * 2);
                    Map1.Dispose();
                    Map1 = Map1Scaled;
                    Map1.FreeDevice();

                    Image Map2Scaled = Map2.AsScaled(Map2.Dims * Options.Upsample / 2 * 2);
                    Map2.Dispose();
                    Map2 = Map2Scaled;
                    Map2.FreeDevice();
                }

                float2 MeanStd = MathHelper.MeanAndStd(Helper.Combine(Map1.GetHostContinuousCopy(), Map2.GetHostContinuousCopy()));
                MeanStdForDenoising.Add(MeanStd);

                Map1.TransformValues(v => (v - MeanStd.X) / MeanStd.Y);
                Map2.TransformValues(v => (v - MeanStd.X) / MeanStd.Y);

                Image ForDenoising  = Map1.GetCopy();
                Image ForDenoising2 = Options.DenoiseSeparately ? Map2.GetCopy() : null;

                GPU.PrefilterForCubic(Map1.GetDevice(Intent.ReadWrite), Map1.Dims);
                Map1.FreeDevice();
                Maps1.Add(Map1);

                if (!Options.DenoiseSeparately)
                {
                    ForDenoising.Add(Map2);
                    ForDenoising.Multiply(0.5f);
                }

                GPU.PrefilterForCubic(Map2.GetDevice(Intent.ReadWrite), Map2.Dims);
                Map2.FreeDevice();
                Maps2.Add(Map2);

                ForDenoising.FreeDevice();
                MapsForDenoising.Add(ForDenoising);
                NamesForDenoising.Add(MapName);

                PixelSizeForDenoising.Add(MapPixelSize);

                if (Options.DenoiseSeparately)
                {
                    ForDenoising2.FreeDevice();
                    MapsForDenoising2.Add(ForDenoising2);
                }

                Console.WriteLine(" Done.");
            }

            Mask?.FreeDevice();

            if (Maps1.Count == 0)
            {
                throw new Exception("No maps were found. Please make sure the paths are correct and the names are consistent between the two observations.");
            }

            Console.WriteLine("");

            #endregion

            NoiseNet3D TrainModel       = null;
            string     NameTrainedModel = Options.OldModelName;
            int        Dim = 64;

            if (string.IsNullOrEmpty(Options.OldModelName))
            {
                #region Load model

                Console.WriteLine("Loading model, " + GPU.GetFreeMemory(Options.GPUNetwork) + " MB free.");
                TrainModel = new NoiseNet3D(ProgramFolder + "noisenet3dmodel", new int3(Dim), 1, Options.BatchSize, true, Options.GPUNetwork);
                Console.WriteLine("Loaded model, " + GPU.GetFreeMemory(Options.GPUNetwork) + " MB remaining.\n");

                #endregion

                GPU.SetDevice(Options.GPUPreprocess);

                #region Training

                Random Rand = new Random(123);

                int NMaps         = Maps1.Count;
                int NMapsPerBatch = Math.Min(128, NMaps);
                int MapSamples    = Options.BatchSize;

                Image[] ExtractedSource = Helper.ArrayOfFunction(i => new Image(new int3(Dim, Dim, Dim * MapSamples)), NMapsPerBatch);
                Image[] ExtractedTarget = Helper.ArrayOfFunction(i => new Image(new int3(Dim, Dim, Dim * MapSamples)), NMapsPerBatch);

                for (int iter = 0; iter < Options.NIterations; iter++)
                {
                    int[] ShuffledMapIDs = Helper.RandomSubset(Helper.ArrayOfSequence(0, NMaps, 1), NMapsPerBatch, Rand.Next(9999));

                    for (int m = 0; m < NMapsPerBatch; m++)
                    {
                        int MapID = ShuffledMapIDs[m];

                        Image Map1 = Maps1[MapID];
                        Image Map2 = Maps2[MapID];

                        int3 DimsMap = Map1.Dims;

                        int3 Margin = new int3((int)(Dim / 2 * 1.5f));
                        //Margin.Z = 0;
                        float3[] Position = Helper.ArrayOfFunction(i => new float3((float)Rand.NextDouble() * (DimsMap.X - Margin.X * 2) + Margin.X,
                                                                                   (float)Rand.NextDouble() * (DimsMap.Y - Margin.Y * 2) + Margin.Y,
                                                                                   (float)Rand.NextDouble() * (DimsMap.Z - Margin.Z * 2) + Margin.Z), MapSamples);

                        float3[] Angle = Helper.ArrayOfFunction(i => new float3((float)Rand.NextDouble() * 360,
                                                                                (float)Rand.NextDouble() * 360,
                                                                                (float)Rand.NextDouble() * 360) * Helper.ToRad, MapSamples);

                        {
                            ulong[] Texture = new ulong[1], TextureArray = new ulong[1];
                            GPU.CreateTexture3D(Map1.GetDevice(Intent.Read), Map1.Dims, Texture, TextureArray, true);
                            //Map1.FreeDevice();

                            GPU.Rotate3DExtractAt(Texture[0],
                                                  Map1.Dims,
                                                  ExtractedSource[m].GetDevice(Intent.Write),
                                                  new int3(Dim),
                                                  Helper.ToInterleaved(Angle),
                                                  Helper.ToInterleaved(Position),
                                                  (uint)MapSamples);

                            //ExtractedSource[MapID].WriteMRC("d_extractedsource.mrc", true);

                            GPU.DestroyTexture(Texture[0], TextureArray[0]);
                        }

                        {
                            ulong[] Texture = new ulong[1], TextureArray = new ulong[1];
                            GPU.CreateTexture3D(Map2.GetDevice(Intent.Read), Map2.Dims, Texture, TextureArray, true);
                            //Map2.FreeDevice();

                            GPU.Rotate3DExtractAt(Texture[0],
                                                  Map2.Dims,
                                                  ExtractedTarget[m].GetDevice(Intent.Write),
                                                  new int3(Dim),
                                                  Helper.ToInterleaved(Angle),
                                                  Helper.ToInterleaved(Position),
                                                  (uint)MapSamples);

                            //ExtractedTarget.WriteMRC("d_extractedtarget.mrc", true);

                            GPU.DestroyTexture(Texture[0], TextureArray[0]);
                        }

                        //Map1.FreeDevice();
                        //Map2.FreeDevice();
                    }

                    float[] PredictedData = null, Loss = null;

                    {
                        float CurrentLearningRate = 0.0001f * (float)Math.Pow(10, -iter / (float)Options.NIterations * 2);

                        for (int m = 0; m < ShuffledMapIDs.Length; m++)
                        {
                            int MapID = m;

                            bool Twist = Rand.Next(2) == 0;

                            if (Twist)
                            {
                                TrainModel.Train(ExtractedSource[MapID].GetDevice(Intent.Read),
                                                 ExtractedTarget[MapID].GetDevice(Intent.Read),
                                                 CurrentLearningRate,
                                                 0,
                                                 out PredictedData,
                                                 out Loss);
                            }
                            else
                            {
                                TrainModel.Train(ExtractedTarget[MapID].GetDevice(Intent.Read),
                                                 ExtractedSource[MapID].GetDevice(Intent.Read),
                                                 CurrentLearningRate,
                                                 0,
                                                 out PredictedData,
                                                 out Loss);
                            }
                        }
                    }

                    ClearCurrentConsoleLine();
                    Console.Write($"{iter + 1}/{Options.NIterations}");
                }

                NameTrainedModel = "noisenet3d_64_" + DateTime.Now.ToString("yyyyMMdd_HHmmss");
                TrainModel.Export(NameTrainedModel);
                TrainModel.Dispose();

                TFHelper.TF_FreeAllMemory();

                Console.WriteLine("\nDone training!\n");

                #endregion
            }

            #region Denoise

            Console.WriteLine("Loading trained model, " + GPU.GetFreeMemory(Options.GPUNetwork) + " MB free.");
            TrainModel = new NoiseNet3D(NameTrainedModel, new int3(Dim), 1, Options.BatchSize, false, Options.GPUNetwork);
            //TrainModel = new NoiseNet3D(@"H:\denoise_refine\noisenet3d_64_20180808_010023", new int3(Dim), 1, Options.BatchSize, false, Options.GPUNetwork);
            Console.WriteLine("Loaded trained model, " + GPU.GetFreeMemory(Options.GPUNetwork) + " MB remaining.\n");

            //Directory.Delete(NameTrainedModel, true);

            Directory.CreateDirectory("denoised");

            GPU.SetDevice(Options.GPUPreprocess);

            for (int imap = 0; imap < MapsForDenoising.Count; imap++)
            {
                Console.Write($"Denoising {NamesForDenoising[imap]}... ");

                Image Map1 = MapsForDenoising[imap];
                NoiseNet3D.Denoise(Map1, new NoiseNet3D[] { TrainModel });

                float2 MeanStd = MeanStdForDenoising[imap];

                Map1.TransformValues(v => v * MeanStd.Y);

                if (Options.KeepDimensions)
                {
                    if (DimensionsForDenoising[imap] != Map1.Dims)
                    {
                        Image Scaled = Map1.AsScaled(DimensionsForDenoising[imap]);
                        Map1.Dispose();
                        Map1 = Scaled;
                    }
                    if (OriginalBoxForDenoising[imap] != Map1.Dims)
                    {
                        Image Padded = Map1.AsPadded(OriginalBoxForDenoising[imap]);
                        Map1.Dispose();
                        Map1 = Padded;
                    }
                }
                Map1.PixelSize = PixelSizeForDenoising[imap];

                Map1.TransformValues(v => v + MeanStd.X);

                if (Options.Lowpass > 0)
                {
                    Map1.Bandpass(0, Map1.PixelSize * 2 / Options.Lowpass, true, 0.01f);
                }

                if (Options.KeepDimensions && Options.MaskOutput)
                {
                    Map1.Multiply(Mask);
                }

                string SavePath1 = "denoised/" + NamesForDenoising[imap] + (Options.DenoiseSeparately ? "_1" : "") + ".mrc";
                Map1.WriteMRC(SavePath1, true);
                Map1.Dispose();

                Console.WriteLine("Done. Saved to " + SavePath1);

                if (Options.DenoiseSeparately)
                {
                    Console.Write($"Denoising {NamesForDenoising[imap]} (2nd observation)... ");

                    Image Map2 = MapsForDenoising2[imap];
                    NoiseNet3D.Denoise(Map2, new NoiseNet3D[] { TrainModel });

                    Map2.TransformValues(v => v * MeanStd.Y);

                    if (Options.KeepDimensions)
                    {
                        if (DimensionsForDenoising[imap] != Map2.Dims)
                        {
                            Image Scaled = Map2.AsScaled(DimensionsForDenoising[imap]);
                            Map2.Dispose();
                            Map2 = Scaled;
                        }
                        if (OriginalBoxForDenoising[imap] != Map2.Dims)
                        {
                            Image Padded = Map2.AsPadded(OriginalBoxForDenoising[imap]);
                            Map2.Dispose();
                            Map2 = Padded;
                        }
                    }
                    Map2.PixelSize = PixelSizeForDenoising[imap];

                    Map2.TransformValues(v => v + MeanStd.X);

                    if (Options.Lowpass > 0)
                    {
                        Map2.Bandpass(0, Map2.PixelSize * 2 / Options.Lowpass, true, 0.01f);
                    }

                    if (Options.KeepDimensions && Options.MaskOutput)
                    {
                        Map2.Multiply(Mask);
                    }

                    string SavePath2 = "denoised/" + NamesForDenoising[imap] + "_2" + ".mrc";
                    Map2.WriteMRC(SavePath2, true);
                    Map2.Dispose();

                    Console.WriteLine("Done. Saved to " + SavePath2);
                }
            }

            Console.WriteLine("\nAll done!");

            #endregion
        }
Пример #6
0
        public BoxNet(string modelDir, int gpuID = 0, int nThreads = 1, int batchSize = 128, bool forTraining = false)
        {
            ForTraining = forTraining;
            BatchSize   = batchSize;
            ModelDir    = modelDir;
            MaxThreads  = nThreads;

            TFSessionOptions SessionOptions = TFHelper.CreateOptions();
            TFSession        Dummy          = new TFSession(new TFGraph(), SessionOptions);

            Session = TFHelper.FromSavedModel(SessionOptions, null, ModelDir, new[] { forTraining ? "train" : "serve" }, new TFGraph(), $"/device:GPU:{gpuID}");
            Graph   = Session.Graph;

            NodeInputMicTile = Graph["mic_tiles"][0];
            if (forTraining)
            {
                NodeInputLabels  = Graph["training_labels"][0];
                NodeLearningRate = Graph["training_learning_rate"][0];
                NodeOpTrain      = Graph["train_momentum"][0];
            }

            NodeOutputArgMax  = Graph["ArgMax"][0];
            NodeOutputSoftMax = Graph["softmax_tensor"][0];

            TensorMicTile = Helper.ArrayOfFunction(i => TFTensor.FromBuffer(new TFShape(BatchSize, 1, BoxDimensions.Y, BoxDimensions.X),
                                                                            new float[BatchSize * BoxDimensions.Elements()],
                                                                            0,
                                                                            BatchSize * (int)BoxDimensions.Elements()),
                                                   nThreads);

            TensorTrainingLabels = Helper.ArrayOfFunction(i => TFTensor.FromBuffer(new TFShape(BatchSize, 2),
                                                                                   new float[BatchSize * 2],
                                                                                   0,
                                                                                   BatchSize * 2),
                                                          nThreads);

            TensorLearningRate = Helper.ArrayOfFunction(i => new TFTensor(0.0f),
                                                        nThreads);

            ResultArgMax  = Helper.ArrayOfFunction(i => new long[BatchSize], nThreads);
            ResultSoftMax = Helper.ArrayOfFunction(i => new float[BatchSize * 2], nThreads);

            if (!ForTraining)
            {
                RunnerPrediction = Helper.ArrayOfFunction(i => Session.GetRunner().
                                                          AddInput(NodeInputMicTile, TensorMicTile[i]).
                                                          Fetch(NodeOutputArgMax, NodeOutputSoftMax),
                                                          nThreads);
            }
            else
            {
                RunnerTraining = Helper.ArrayOfFunction(i => Session.GetRunner().
                                                        AddInput(NodeInputMicTile, TensorMicTile[i]).
                                                        AddInput(NodeInputLabels, TensorTrainingLabels[i]).
                                                        AddInput(NodeLearningRate, TensorLearningRate[i]).
                                                        Fetch(NodeOutputArgMax, NodeOutputSoftMax, NodeOpTrain),
                                                        nThreads);
            }

            // Run prediction or training for one batch to claim all the memory needed
            long[]  InitArgMax;
            float[] InitProb;
            if (!ForTraining)
            {
                Predict(new float[BoxDimensions.Elements() * BatchSize],
                        0,
                        out InitArgMax,
                        out InitProb);
            }
            else
            {
                RandomNormal RandN = new RandomNormal();
                Train(Helper.ArrayOfFunction(i => RandN.NextSingle(0, 1), BatchSize * (int)BoxDimensions.Elements()),
                      Helper.Combine(Helper.ArrayOfFunction(i => new[] { 1.0f, 0.0f }, 128)),
                      1e-6f,
                      0,
                      out InitArgMax,
                      out InitProb);
            }
        }