static void Render(Model siren, int width, int height, string path) { var renderCoords = ImageTools.Coord(height, width).ToNumPyArray(); ndarray <float> renderBytes = siren.predict( renderCoords.reshape(new[] { height *width, 2 }), batch_size: 1024); const int channels = 4; renderBytes = (ndarray <float>)renderBytes.reshape(new[] { height, width, channels }); using var bitmap = ToImage(RestoreImage(renderBytes)); bitmap.Save(path, ImageFormat.Png); }
static void Main(string[] args) { GradientEngine.UseEnvironmentFromVariable(); TensorFlowSetup.Instance.EnsureInitialized(); // this allows SIREN to oversaturate channels without adding to the loss var clampToValidChannelRange = PythonFunctionContainer.Of <Tensor, Tensor>(ClampToValidChannelValueRange); var siren = new Sequential(new object[] { new GaussianNoise(stddev: 1f / (128 * 1024)), new Siren(2, Enumerable.Repeat(256, 5).ToArray()), new Dense(units: 4, activation: clampToValidChannelRange), new GaussianNoise(stddev: 1f / 128), }); siren.compile( // too slow to converge //optimizer: new SGD(momentum: 0.5), // lowered learning rate to avoid destabilization optimizer: new Adam(learning_rate: 0.00032), loss: "mse"); if (args.Length == 0) { siren.load_weights("sample.weights"); Render(siren, 1034 * 3, 1536 * 3, "sample6X.png"); return; } foreach (string imagePath in args) { using var original = new Bitmap(imagePath); byte[,,] image = ToBytesHWC(original); int height = image.GetLength(0); int width = image.GetLength(1); int channels = image.GetLength(2); Debug.Assert(channels == 4); var imageSamples = PrepareImage(image); var coords = ImageTools.Coord(height, width).ToNumPyArray() .reshape(new[] { width *height, 2 }); var upscaleCoords = ImageTools.Coord(height * 2, width * 2).ToNumPyArray(); var improved = ImprovedCallback.Create((sender, eventArgs) => { if (eventArgs.Epoch < 10) { return; } ndarray <float> upscaled = siren.predict( upscaleCoords.reshape(new[] { height *width * 4, 2 }), batch_size: 1024); upscaled = (ndarray <float>)upscaled.reshape(new[] { height * 2, width * 2, channels }); using var bitmap = ToImage(RestoreImage(upscaled)); bitmap.Save("sample4X.png", ImageFormat.Png); siren.save_weights("sample.weights"); Console.WriteLine(); Console.WriteLine("saved!"); }); siren.fit(coords, imageSamples, epochs: 100, batchSize: 16 * 1024, shuffleMode: TrainingShuffleMode.Epoch, callbacks: new ICallback[] { improved }); } }