public void MNISTTwoHiddenLayerNetworkGPUTest()
        {
            // Parameters
            var learningRate = 0.1f;
            var epochs       = 5;
            var numGPUs      = 4;

            var mnist = new Mnist();

            mnist.ReadDataSets("/tmp");
            int batchSize  = 400;
            int numBatches = mnist.TrainImages.Length / batchSize;

            using (var graph = new TFGraph())
            {
                var X = graph.Placeholder(TFDataType.Float, new TFShape(-1, 784));
                var Y = graph.Placeholder(TFDataType.Float, new TFShape(-1, 10));

                var Xs = graph.Split(graph.Const(0), X, numGPUs);
                var Ys = graph.Split(graph.Const(0), Y, numGPUs);

                var        sgd                   = new SGD(graph, learningRate, 0.9f);
                TFOutput[] costs                 = new TFOutput[numGPUs];
                TFOutput[] accuracys             = new TFOutput[numGPUs];
                var        variablesAndGradients = new Dictionary <Variable, List <TFOutput> >();
                for (int i = 0; i < numGPUs; i++)
                {
                    using (var device = graph.WithDevice("/GPU:" + i))
                    {
                        (costs[i], _, accuracys[i]) = CreateNetwork(graph, Xs[i], Ys[i], variablesAndGradients);
                        foreach (var gv in sgd.ComputeGradient(costs[i], colocateGradientsWithOps: true))
                        {
                            if (!variablesAndGradients.ContainsKey(gv.variable))
                            {
                                variablesAndGradients[gv.variable] = new List <TFOutput>();
                            }
                            variablesAndGradients[gv.variable].Add(gv.gradient);
                        }
                    }
                }
                var cost     = graph.ReduceMean(graph.Stack(costs));
                var accuracy = graph.ReduceMean(graph.Stack(accuracys));

                var gradientsAndVariables = new (TFOutput gradient, Variable variable)[variablesAndGradients.Count];