public static void reconstruct_picture(Network net, float[] features, Image recon, Image update, float rate, float momentum, float lambda, int smoothSize, int iters) { int iter = 0; for (iter = 0; iter < iters; ++iter) { Image delta = new Image(recon.W, recon.H, recon.C); NetworkState state = new NetworkState(); state.Input = (float[])recon.Data.Clone(); state.Delta = (float[])delta.Data.Clone(); state.Truth = new float[Network.get_network_output_size(net)]; Array.Copy(features, 0, state.Truth, 0, state.Truth.Length); Network.forward_network_gpu(net, state); Network.backward_network_gpu(net, state); Array.Copy(state.Delta, delta.Data, delta.W * delta.H * delta.C); Blas.Axpy_cpu(recon.W * recon.H * recon.C, 1, delta.Data, update.Data); Smooth(recon, update, lambda, smoothSize); Blas.Axpy_cpu(recon.W * recon.H * recon.C, rate, update.Data, recon.Data); Blas.Scal_cpu(recon.W * recon.H * recon.C, momentum, update.Data, 1); LoadArgs.constrain_image(recon); } }
private static void optimize_picture(Network net, Image orig, int maxLayer, float scale, float rate, float thresh, bool norm) { net.N = maxLayer + 1; int dx = Utils.Rand.Next() % 16 - 8; int dy = Utils.Rand.Next() % 16 - 8; bool flip = Utils.Rand.Next() % 2 != 0; Image crop = LoadArgs.crop_image(orig, dx, dy, orig.W, orig.H); Image im = LoadArgs.resize_image(crop, (int)(orig.W * scale), (int)(orig.H * scale)); if (flip) { LoadArgs.flip_image(im); } Network.resize_network(net, im.W, im.H); Layer last = net.Layers[net.N - 1]; Image delta = new Image(im.W, im.H, im.C); NetworkState state = new NetworkState(); state.Input = (float[])im.Data.Clone(); state.Delta = (float[])im.Data.Clone(); Network.forward_network_gpu(net, state); Blas.copy_ongpu(last.Outputs, last.OutputGpu, last.DeltaGpu); Array.Copy(last.DeltaGpu, last.Delta, last.Outputs); calculate_loss(last.Delta, last.Delta, last.Outputs, thresh); Array.Copy(last.Delta, last.DeltaGpu, last.Outputs); Network.backward_network_gpu(net, state); Array.Copy(state.Delta, delta.Data, im.W * im.H * im.C); if (flip) { LoadArgs.flip_image(delta); } Image resized = LoadArgs.resize_image(delta, orig.W, orig.H); Image outi = LoadArgs.crop_image(resized, -dx, -dy, orig.W, orig.H); if (norm) { Utils.normalize_array(outi.Data, outi.W * outi.H * outi.C); } Blas.Axpy_cpu(orig.W * orig.H * orig.C, rate, outi.Data, orig.Data); LoadArgs.constrain_image(orig); }