Example #1
0
        static void Main(string[] args)
        {
            GradientEngine.UseEnvironmentFromVariable();
            TensorFlowSetup.Instance.EnsureInitialized();

            // this allows SIREN to oversaturate channels without adding to the loss
            var clampToValidChannelRange = PythonFunctionContainer.Of <Tensor, Tensor>(ClampToValidChannelValueRange);
            var siren = new Sequential(new object[] {
                new GaussianNoise(stddev: 1f / (128 * 1024)),
                new Siren(2, Enumerable.Repeat(256, 5).ToArray()),
                new Dense(units: 4, activation: clampToValidChannelRange),
                new GaussianNoise(stddev: 1f / 128),
            });

            siren.compile(
                // too slow to converge
                //optimizer: new SGD(momentum: 0.5),
                // lowered learning rate to avoid destabilization
                optimizer: new Adam(learning_rate: 0.00032),
                loss: "mse");

            if (args.Length == 0)
            {
                siren.load_weights("sample.weights");
                Render(siren, 1034 * 3, 1536 * 3, "sample6X.png");
                return;
            }

            foreach (string imagePath in args)
            {
                using var original = new Bitmap(imagePath);
                byte[,,] image     = ToBytesHWC(original);
                int height   = image.GetLength(0);
                int width    = image.GetLength(1);
                int channels = image.GetLength(2);
                Debug.Assert(channels == 4);

                var imageSamples = PrepareImage(image);

                var coords = ImageTools.Coord(height, width).ToNumPyArray()
                             .reshape(new[] { width *height, 2 });

                var upscaleCoords = ImageTools.Coord(height * 2, width * 2).ToNumPyArray();

                var improved = ImprovedCallback.Create((sender, eventArgs) => {
                    if (eventArgs.Epoch < 10)
                    {
                        return;
                    }
                    ndarray <float> upscaled = siren.predict(
                        upscaleCoords.reshape(new[] { height *width * 4, 2 }),
                        batch_size: 1024);
                    upscaled         = (ndarray <float>)upscaled.reshape(new[] { height * 2, width * 2, channels });
                    using var bitmap = ToImage(RestoreImage(upscaled));
                    bitmap.Save("sample4X.png", ImageFormat.Png);

                    siren.save_weights("sample.weights");

                    Console.WriteLine();
                    Console.WriteLine("saved!");
                });

                siren.fit(coords, imageSamples, epochs: 100, batchSize: 16 * 1024,
                          shuffleMode: TrainingShuffleMode.Epoch,
                          callbacks: new ICallback[] { improved });
            }
        }
Example #2
0
 static Tensor ClampToValidChannelValueRange(Tensor input)
 => tf.clip_by_value(input,
                     clip_value_min: ImageTools.NormalizeChannelValue(-0.01f),
                     clip_value_max: ImageTools.NormalizeChannelValue(255.01f));