Exemplo n.º 1
0
        static ndarray <float> PrepareImage(byte[,,] image)
        {
            int height   = image.GetLength(0);
            int width    = image.GetLength(1);
            int channels = image.GetLength(2);

            var normalized = SirenTests.NormalizeChannelValue(image.ToNumPyArray());
            var flattened  = normalized.reshape(new[] { height *width, channels }).astype(np.float32_fn);

            return((ndarray <float>)flattened);
        }
Exemplo n.º 2
0
        static void Render(Model siren, int width, int height, string path)
        {
            var             renderCoords = SirenTests.Coord(height, width).ToNumPyArray();
            ndarray <float> renderBytes  = siren.predict(
                renderCoords.reshape(new[] { height *width, 2 }),
                batch_size: 1024);
            const int channels = 4;

            renderBytes      = (ndarray <float>)renderBytes.reshape(new[] { height, width, channels });
            using var bitmap = ToImage(RestoreImage(renderBytes));
            bitmap.Save(path, ImageFormat.Png);
        }
Exemplo n.º 3
0
        static void Main(string[] args)
        {
            GradientEngine.UseEnvironmentFromVariable();
            TensorFlowSetup.Instance.EnsureInitialized();

            // this allows SIREN to oversaturate channels without adding to the loss
            var clampToValidChannelRange = PythonFunctionContainer.Of <Tensor, Tensor>(ClampToValidChannelValueRange);
            var siren = new Sequential(new object[] {
                new GaussianNoise(stddev: 1f / (128 * 1024)),
                new Siren(2, Enumerable.Repeat(256, 5).ToArray()),
                new Dense(units: 4, activation: clampToValidChannelRange),
                new GaussianNoise(stddev: 1f / 128),
            });

            siren.compile(
                // too slow to converge
                //optimizer: new SGD(momentum: 0.5),
                // lowered learning rate to avoid destabilization
                optimizer: new Adam(learning_rate: 0.00032),
                loss: new MeanSquaredError());

            if (args.Length == 0)
            {
                siren.load_weights("sample.weights");
                Render(siren, 1034 * 3, 1536 * 3, "sample6X.png");
                return;
            }

            foreach (string imagePath in args)
            {
                using var original = new Bitmap(imagePath);
                byte[,,] image     = ToBytesHWC(original);
                int height   = image.GetLength(0);
                int width    = image.GetLength(1);
                int channels = image.GetLength(2);
                Debug.Assert(channels == 4);

                var imageSamples = PrepareImage(image);

                var coords = SirenTests.Coord(height, width).ToNumPyArray()
                             .reshape(new[] { width *height, 2 });

                var upscaleCoords = SirenTests.Coord(height * 2, width * 2).ToNumPyArray();

                var improved = new ImprovedCallback();
                improved.OnLossImproved += (sender, eventArgs) => {
                    if (eventArgs.Epoch < 10)
                    {
                        return;
                    }
                    ndarray <float> upscaled = siren.predict(
                        upscaleCoords.reshape(new[] { height *width * 4, 2 }),
                        batch_size: 1024);
                    upscaled         = (ndarray <float>)upscaled.reshape(new[] { height * 2, width * 2, channels });
                    using var bitmap = ToImage(RestoreImage(upscaled));
                    bitmap.Save("sample4X.png", ImageFormat.Png);

                    siren.save_weights("sample.weights");

                    Console.WriteLine();
                    Console.WriteLine("saved!");
                };

                siren.fit(coords, imageSamples, epochs: 100, batchSize: 64, stepsPerEpoch: 200,
                          shuffleMode: TrainingShuffleMode.Batch,
                          callbacks: new ICallback[] { improved });
            }
        }
Exemplo n.º 4
0
 static Tensor ClampToValidChannelValueRange(Tensor input)
 => tf.clip_by_value(input,
                     clip_value_min: SirenTests.NormalizeChannelValue(-0.01f),
                     clip_value_max: SirenTests.NormalizeChannelValue(255.01f));