// TODO: refactor with FuseShapesIntoConstants
        public void InferAllShapes(Model model, ref IDictionary <string, TensorShape?> shapesByName, ref IDictionary <string, int?> ranksByName)
        {
            var toRunnableNCHW = new IntermediateToRunnableNCHWPass();

            var knownLayersValue = new Dictionary <string, Tensor>();
            var newKnownLayers   = new HashSet <string>();
            var keepLayers       = new HashSet <string>();

            for (int l = 0; l < model.layers.Count; ++l)
            {
                var layer = model.layers[l];
                if (layer.flags == Layer.Flags.Preserve)
                {
                    keepLayers.Add(layer.name);
                }

                // NN is a directed graph, if we just fused constants + shapes, update following nodes
                // re-evaluate shapes
                FuseInputsIntoLayer(ref layer, knownLayersValue, ranksByName, null);//TODO handle potential folding errors/warnings
                // TODO optimization, pass in index, or add shape
                IRShapeInferenceHelper.RankInference.UpdateKnownTensorRanks(model, ranksByName);
                IRShapeInferenceHelper.ShapeInference.UpdateKnownTensorShapesNCHW(model, ranksByName, ref shapesByName);

                if (ModelOptimizer.IsLayerConstant(layer))
                {
                    knownLayersValue[layer.name] = new Tensor(layer.datasets[0].shape, layer.weights);
                }
                else if (layer.type == Layer.Type.Shape)
                {
                    // assert inputs.Lenght == 1
                    var input = layer.inputs[0];
                    if (shapesByName.ContainsKey(input) && shapesByName[input] != null &&
                        ranksByName.ContainsKey(input) && ranksByName[input] != null
                        )
                    {
                        var shape = shapesByName[input].Value;
                        var rank  = ranksByName[input].Value;
                        knownLayersValue[layer.name] = ShapeToNCHWTensor(shape, rank);
                        newKnownLayers.Add(layer.name);
                        continue;
                    }
                }

                bool allInputsAreKnown = layer.inputs.Length > 0 ? knownLayersValue.ContainsKey(layer.inputs[0]) : false;
                for (int i = 1; i < layer.inputs.Length; i++)
                {
                    allInputsAreKnown &= knownLayersValue.ContainsKey(layer.inputs[i]);
                }

                // if all inputs are known, execute layer
                if (!allInputsAreKnown)
                {
                    continue;
                }

                var layerInputs = new Dictionary <string, Tensor>();
                var opsModel    = new Model();
                opsModel.layout = "iNCHW";
                for (int i = 0; i < layer.inputs.Length; i++)
                {
                    Model.Input input;
                    input.name  = layer.inputs[i];
                    input.shape = shapesByName[input.name].Value.ToArray();
                    input.rank  = ranksByName[input.name].Value;

                    opsModel.inputs.Add(input);
                    layerInputs[input.name] = knownLayersValue[input.name];
                }
                Layer newLayer = new Layer(layer.name.ToString(), layer.activation);
                newLayer.type       = layer.type;
                newLayer.activation = layer.activation;
                newLayer.pad        = layer.pad.ToArray();
                newLayer.stride     = layer.stride.ToArray();
                newLayer.pool       = layer.pool.ToArray();
                newLayer.axis       = layer.axis;
                newLayer.alpha      = layer.alpha;
                newLayer.beta       = layer.beta;
                newLayer.inputs     = layer.inputs.ToArray();
                newLayer.datasets   = layer.datasets;
                newLayer.weights    = layer.weights;
                if (layer.outputs != null)
                {
                    newLayer.outputs = layer.outputs.ToArray();
                }
                if (layer.axes != null)
                {
                    newLayer.axes = layer.axes.ToArray();
                }


                opsModel.layers.Add(newLayer);
                opsModel.outputs.Add(newLayer.name);

                toRunnableNCHW.Run(ref opsModel);

                toRunnableNCHW.Run(ref opsModel);

                // bake
                var useCPUforBaking = WorkerFactory.Device.CPU;
                using (var worker = WorkerFactory.CreateWorker(opsModel, useCPUforBaking))
                {
                    var bakedConstant = worker.Execute(layerInputs).PeekOutput();
                    bakedConstant.TakeOwnership();
                    knownLayersValue[layer.name] = bakedConstant;
                    newKnownLayers.Add(layer.name);
                }
            }

            // clear allocated tensors
            foreach (var l in knownLayersValue)
            {
                l.Value.Dispose();
            }

            // remove unused constants
            var removeUnusedLayersPass = new Cleanup.RemoveUnusedLayersPass();

            removeUnusedLayersPass.Run(ref model);
        }
        private void FuseShapesIntoConstants(ref Model model, IDictionary <string, TensorShape?> shapesByName, IDictionary <string, int?> ranksByName, ref List <Model.ImporterWarning> warnings)
        {
            var toRunnableNCHW = new IntermediateToRunnableNCHWPass();

            var knownLayersValue = new Dictionary <string, Tensor>();
            var newKnownLayers   = new HashSet <string>();
            var keepLayers       = new HashSet <string>();

            for (int l = 0; l < model.layers.Count; ++l)
            {
                var layer = model.layers[l];
                if (layer.flags == Layer.Flags.Preserve)
                {
                    keepLayers.Add(layer.name);
                }

                // NN is a directed graph, if we just fused constants + shapes, update following nodes
                // re-evaluate shapes
                FuseInputsIntoLayer(ref layer, knownLayersValue, ranksByName, warnings);
                // TODO optimization, pass in index, or add shape
                IRShapeInferenceHelper.RankInference.UpdateKnownTensorRanks(model, ranksByName);
                IRShapeInferenceHelper.ShapeInference.UpdateKnownTensorShapesNCHW(model, ranksByName, ref shapesByName);

                if (ModelOptimizer.IsLayerConstant(layer))
                {
                    knownLayersValue[layer.name] = new Tensor(layer.datasets[0].shape, layer.weights);
                }
                else if (layer.type == Layer.Type.Shape)
                {
                    // assert inputs.Lenght == 1
                    var input = layer.inputs[0];
                    if (shapesByName.ContainsKey(input) && shapesByName[input] != null &&
                        ranksByName.ContainsKey(input) && ranksByName[input] != null
                        )
                    {
                        var shape = shapesByName[input].Value;
                        var rank  = ranksByName[input].Value;
                        knownLayersValue[layer.name] = ShapeToNCHWTensor(shape, rank);
                        newKnownLayers.Add(layer.name);
                        continue;
                    }
                }

                bool allInputsAreKnown = layer.inputs.Length > 0 ? knownLayersValue.ContainsKey(layer.inputs[0]) : false;
                for (int i = 1; i < layer.inputs.Length; i++)
                {
                    allInputsAreKnown &= knownLayersValue.ContainsKey(layer.inputs[i]);
                }

                // if all inputs are known, execute layer
                if (!allInputsAreKnown)
                {
                    continue;
                }

                var layerInputs = new Dictionary <string, Tensor>();
                var opsModel    = new Model();
                opsModel.layout = "iNCHW";
                for (int i = 0; i < layer.inputs.Length; i++)
                {
                    Model.Input input;
                    input.name  = layer.inputs[i];
                    input.shape = shapesByName[input.name].Value.ToArray();
                    input.rank  = ranksByName[input.name].Value;

                    opsModel.inputs.Add(input);
                    layerInputs[input.name] = knownLayersValue[input.name];
                }
                Layer newLayer = new Layer(layer.name.ToString(), layer.activation);
                newLayer.type       = layer.type;
                newLayer.activation = layer.activation;
                newLayer.pad        = layer.pad.ToArray();
                newLayer.stride     = layer.stride.ToArray();
                newLayer.pool       = layer.pool.ToArray();
                newLayer.axis       = layer.axis;
                newLayer.alpha      = layer.alpha;
                newLayer.beta       = layer.beta;
                newLayer.inputs     = layer.inputs.ToArray();
                newLayer.datasets   = layer.datasets;
                newLayer.weights    = layer.weights;
                if (layer.outputs != null)
                {
                    newLayer.outputs = layer.outputs.ToArray();
                }
                if (layer.axes != null)
                {
                    newLayer.axes = layer.axes.ToArray();
                }


                opsModel.layers.Add(newLayer);
                opsModel.outputs.Add(newLayer.name);

                toRunnableNCHW.Run(ref opsModel);

                // bake
                var useCPUforBaking = WorkerFactory.Device.CPU;
                using (var worker = WorkerFactory.CreateWorker(opsModel, useCPUforBaking))
                {
                    var bakedConstant = worker.Execute(layerInputs).CopyOutput();
                    knownLayersValue[layer.name] = bakedConstant;
                    newKnownLayers.Add(layer.name);
                }
            }

            // remove new baked layers since we will insert constants for those
            model.layers.RemoveAll(x => newKnownLayers.Contains(x.name) && !keepLayers.Contains(x.name));

            // TODO use ModelBuilder?
            foreach (var l in newKnownLayers)
            {
                if (keepLayers.Contains(l))
                {
                    continue;
                }

                var   name   = l;
                var   tensor = knownLayersValue[name];
                Layer c      = new Layer(name, Layer.Type.Load);

                c.datasets                    = new Layer.DataSet[1];
                c.datasets[0].name            = name;
                c.datasets[0].shape           = tensor.shape;
                c.datasets[0].itemSizeInBytes = 4;
                c.datasets[0].length          = tensor.shape.length;
                c.datasets[0].offset          = 0;

                c.axis = ranksByName[c.name].Value;

                c.weights = new BarracudaArray(tensor.length);
                BarracudaArray.Copy(tensor.ToReadOnlyArray(), c.weights, tensor.length);
                model.layers.Insert(0, c);
            }

            foreach (var l in knownLayersValue)
            {
                l.Value.Dispose();
            }

            // TODO remove?
            // remove unused constants
            var removeUnusedLayersPass = new Cleanup.RemoveUnusedLayersPass();

            removeUnusedLayersPass.Run(ref model);
        }
Ejemplo n.º 3
0
        void CorrectConstantsForBroadCast(ref Model nhwc)
        {
            List <Layer> correctedConstants = new List <Layer>();

            for (int l = 0; l < nhwc.layers.Count; l++)
            {
                Layer layer = nhwc.layers[l];
                for (int i = 0; i < layer.inputs.Length; i++)
                {
                    var input = layer.inputs[i];

                    if (!ModelAnalyzer.IsLayerBroacastable(layer))
                    {
                        continue;
                    }

                    if (!m_RanksByName.ContainsKey(input) || !m_RanksByName.ContainsKey(layer.name))
                    {
                        continue;
                    }

                    Layer inputLayer;
                    bool  found = ModelAnalyzer.FindLayerByName(nhwc, input, out inputLayer);
                    if (!found)
                    {
                        continue;
                    }

                    if (!ModelOptimizer.IsLayerConstant(inputLayer))
                    {
                        continue;
                    }

                    if (m_RanksByName[input] < 1 || m_RanksByName[input] == m_RanksByName[layer.name])
                    {
                        continue;
                    }
                    if (inputLayer.weights.Length == 1)
                    {
                        continue;
                    }

                    if (m_RanksByName[input] > m_RanksByName[layer.name])
                    {
                        throw new Exception($"constant must be lower rank than input for broadcast to work, TODO add transpose before input");
                    }

                    Layer correctedConstLayer = new Layer("c_" + inputLayer.name + "For_" + layer.name, Layer.Type.Load);

                    // transpose dataset
                    correctedConstLayer.datasets = new Layer.DataSet[1];
                    Array.Copy(inputLayer.datasets, correctedConstLayer.datasets, inputLayer.datasets.Length);
                    correctedConstLayer.datasets[0].name = correctedConstLayer.name;


                    correctedConstLayer.weights = new float[inputLayer.weights.Length];

                    var X = inputLayer.DataSetToTensor(0);

                    int[] permutations = new[] { 0, 1, 2, 3 };
                    var   rank         = m_RanksByName[layer.name].Value;

                    switch (rank)
                    {
                    case 2:
                        // ONNX: 5,7 + 7
                        // Barracuda: 5,_,_,7 + 7,_,_,- => _,_,_,7
                        permutations = new[] { 1, 2, 3, 0 };
                        break;

                    case 3:
                        // ONNX: 5,7,3 + 3
                        // Barracuda: 5,_,3,7 + 3,_,_,_  => _,_,3,_
                        if (m_RanksByName[input] == 1)
                        {
                            permutations = new[] { 1, 2, 0, 3 }
                        }
                        ;

                        // ONNX: 5,7,3 + 7,3
                        // Barracuda: 5,_,3,7 + 7,_,_,3 => _,_,3,7
                        else if (m_RanksByName[input] == 2)
                        {
                            permutations = new[] { 1, 2, 3, 0 }
                        }
                        ;

                        break;

                    case 4:
                        // ONNX: 2,5,7,3 + 3
                        // Barracuda: 2,7,3,5 + 3,_,_,_  => _,_,3,_
                        if (m_RanksByName[input] == 1)
                        {
                            permutations = new[] { 1, 2, 0, 3 }
                        }
                        ;

                        // ONNX: 2,5,7,3 + 7,3
                        // Barracuda: 2,7,3,5 + 7,_,_,3  => _,7,3,_
                        else if (m_RanksByName[input] == 2)
                        {
                            permutations = new[] { 2, 0, 1, 3 }
                        }
                        ;

                        // ONNX: 2,5,7,3 + 5,7,3
                        // Barracuda: 2,7,3,5 + 5,_,3,7  => _,7,3,5
                        else if (m_RanksByName[input] == 3)
                        {
                            permutations = new[] { 1, 3, 2, 0 }
                        }
                        ;
                        break;
                    }

                    if (m_isModelExportedFromNCHW && (m_layersChannelOrder[layer.name] == LayoutTransposeRemovalHelper.ChannelsOrder.NHWC))
                    {
                        switch (rank)
                        {
                        case 2:
                            // ONNX: 5,7 + 7
                            // Barracuda: 5,_,_,7 + 7,_,_,- => _,_,_,7
                            permutations = new[] { 1, 2, 3, 0 };
                            break;

                        case 3:
                            // ONNX: 5,7,3 + 3
                            // Barracuda: 5,_,7,3 + 3,_,_,_  => _,_,_,3
                            if (m_RanksByName[input] == 1)
                            {
                                permutations = new[] { 1, 2, 3, 0 }
                            }
                            ;

                            // ONNX: 5,7,3 + 7,3
                            // Barracuda: 5,_,7,3 + 7,_,_,3 => _,_,7,3
                            else if (m_RanksByName[input] == 2)
                            {
                                permutations = new[] { 2, 3, 0, 1 }
                            }
                            ;

                            break;

                        case 4:
                            // ONNX: 2,5,7,3 + 3
                            // Barracuda: 2,5,7,3 + 3,_,_,_  => _,_,_,3
                            if (m_RanksByName[input] == 1)
                            {
                                permutations = new[] { 1, 2, 3, 0 }
                            }
                            ;

                            // ONNX: 2,5,7,3 + 7,3
                            // Barracuda: 2,5,7,3 + 7,_,_,3  => _,_,7,3,
                            else if (m_RanksByName[input] == 2)
                            {
                                permutations = new[] { 2, 3, 0, 1 }
                            }
                            ;

                            // ONNX: 2,5,7,3 + 5,7,3
                            // Barracuda: 2,5,7,3 + 5,_,7,3  => _,5,7,3
                            else if (m_RanksByName[input] == 3)
                            {
                                permutations = new[] { 1, 0, 2, 3 }
                            }
                            ;
                            break;
                        }
                    }

                    var O = m_Ops.Transpose(X, permutations);
                    correctedConstLayer.ApplyTensorToDataSet(O, 0);

                    correctedConstants.Add(correctedConstLayer);
                    layer.inputs[i] = correctedConstLayer.name;
                }

                nhwc.layers[l] = layer;
            }

            foreach (var l in correctedConstants)
            {
                nhwc.layers.Insert(0, l);
            }
        }
    }
}
Ejemplo n.º 4
0
        public static void FuseConstants(ref Model model)
        {
            var knownLayersValue = new Dictionary <string, Tensor>();
            var newKnownLayers   = new HashSet <string>();
            var keepLayers       = new HashSet <string>();

            for (int l = 0; l < model.layers.Count; ++l)
            {
                var layer = model.layers[l];
                if (layer.flags == Layer.Flags.Preserve)
                {
                    keepLayers.Add(layer.name);
                }

                // NN is a directed graph, if we just fused constants + shapes, update following nodes
                // TODO optimization, pass in index, or add shape
                if (ModelOptimizer.IsLayerConstant(layer))
                {
                    knownLayersValue[layer.name] = new Tensor(layer.datasets[0].shape, layer.weights);
                }

                bool allInputsAreKnown = layer.inputs.Length > 0 ? knownLayersValue.ContainsKey(layer.inputs[0]) : false;
                for (int i = 1; i < layer.inputs.Length; i++)
                {
                    allInputsAreKnown &= knownLayersValue.ContainsKey(layer.inputs[i]);
                }

                // if all inputs are known, execute layer
                if (!allInputsAreKnown)
                {
                    continue;
                }

                var layerInputs = new Dictionary <string, Tensor>();
                var opsModel    = new Model();
                for (int i = 0; i < layer.inputs.Length; i++)
                {
                    Model.Input input;
                    input.name  = layer.inputs[i];
                    input.shape = knownLayersValue[input.name].shape.ToArray();
                    input.rank  = knownLayersValue[input.name].shape.dimensions;

                    opsModel.inputs.Add(input);
                    layerInputs[input.name] = knownLayersValue[input.name];
                }
                opsModel.layers.Add(layer);
                opsModel.outputs.Add(layer.name);

                // bake
                var useCPUforBaking = WorkerFactory.Device.CPU;
                using (var worker = WorkerFactory.CreateWorker(opsModel, useCPUforBaking))
                {
                    // TODO use ModelIR2RunnableNCHWPass
                    var bakedConstant = worker.Execute(layerInputs).PeekOutput();
                    bakedConstant.TakeOwnership();
                    knownLayersValue[layer.name] = bakedConstant;
                    newKnownLayers.Add(layer.name);
                }
            }

            // remove new baked layers since we will insert constants for those
            model.layers.RemoveAll(x => newKnownLayers.Contains(x.name) && !keepLayers.Contains(x.name));

            // TODO use ModelBuilder?
            foreach (var l in newKnownLayers)
            {
                if (keepLayers.Contains(l))
                {
                    continue;
                }

                var   name   = l;
                var   tensor = knownLayersValue[name];
                Layer c      = new Layer(name, Layer.Type.Load);

                c.datasets                    = new Layer.DataSet[1];
                c.datasets[0].name            = name;
                c.datasets[0].shape           = tensor.shape;
                c.datasets[0].itemSizeInBytes = 4;
                c.datasets[0].length          = tensor.shape.length;
                c.datasets[0].offset          = 0;

                c.axis = tensor.shape.dimensions;

                c.weights = new BarracudaArray(tensor.length);
                BarracudaArray.Copy(tensor.ToReadOnlyArray(), c.weights, tensor.length);
                model.layers.Insert(0, c);
            }

            // clear allocated tensors
            foreach (var l in knownLayersValue)
            {
                l.Value.Dispose();
            }

            // remove unused constants
            var removeUnusedLayersPass = new Cleanup.RemoveUnusedLayersPass();

            removeUnusedLayersPass.Run(ref model);
        }
Ejemplo n.º 5
0
        public void Fit(DataIter train, uint epochs = 1, uint batchSize = 32, DataIter validation = null, bool shuffle = false)
        {
            string labelName = "label";

            var label = Symbol.Variable(labelName);

            List <uint> inputShape = new List <uint>();

            inputShape.Add(batchSize);
            inputShape.AddRange(InputShape);

            args["X"]       = new NDArray(new Shape(inputShape.ToArray()));
            args[labelName] = new NDArray(new Shape(batchSize));

            Model.InferArgsMap(mx.Device, args, args);

            var defaultInitializer = new Initializers.GlorotUniform();

            foreach (var arg in args)
            {
                if (ParamInitializers.ContainsKey(arg.Key))
                {
                    ParamInitializers[arg.Key].Generate(arg.Value);
                }
                else
                {
                    defaultInitializer.Generate(arg.Value);
                }
            }

            using (var exec = Model.SimpleBind(mx.Device, args))
            {
                var argNames = Model.ListArguments();

                // Start training
                var sw = new Stopwatch();
                for (var iter = 1; iter <= epochs; iter++)
                {
                    uint samples = 0;
                    train.BatchSize = batchSize;
                    train.Reset();
                    Metric.Reset();
                    TrainMetric.Reset();
                    sw.Restart();

                    while (train.IterNext())
                    {
                        samples += batchSize;
                        var dataBatch = train.Next();

                        // Set data and label
                        dataBatch.Data[0].CopyTo(args["X"]);
                        dataBatch.Label[0].CopyTo(args[labelName]);

                        // Compute gradients
                        exec.Forward(true);
                        exec.Backward();
                        TrainMetric.Update(args[labelName], exec.Output);

                        // Update parameters
                        for (var i = 0; i < argNames.Count; ++i)
                        {
                            if (argNames[i] == "X" || argNames[i] == labelName)
                            {
                                continue;
                            }

                            ModelOptimizer.Update(i, exec.ArgmentArrays[i], exec.GradientArrays[i], null);
                        }
                    }

                    sw.Stop();

                    if (validation != null)
                    {
                        validation.BatchSize = batchSize;
                        validation.Reset();
                        while (validation.IterNext())
                        {
                            var dataBatch = validation.Next();
                            dataBatch.Data[0].CopyTo(args["X"]);
                            dataBatch.Label[0].CopyTo(args[labelName]);
                            NDArray.WaitAll();
                            // Forward pass is enough as no gradient is needed when evaluating
                            exec.Forward(false);
                            Metric.Update(args[labelName], exec.Output);
                        }
                    }

                    var duration = sw.ElapsedMilliseconds == 0 ? 1 : sw.ElapsedMilliseconds;
                    if (validation == null)
                    {
                        Logging.LG($"Epoch: {iter} {Convert.ToInt32(samples * 1000 / duration)} samples/sec Train_Metric: {TrainMetric.Get()}");
                    }
                    else
                    {
                        Logging.LG($"Epoch: {iter} {Convert.ToInt32(samples * 1000 / duration)} samples/sec, Train_Metric: {TrainMetric.Get()}, Val_Metric: {Metric.Get()}");
                    }
                }
            }

            //MXNet.MXNotifyShutdown();
        }
Ejemplo n.º 6
0
        int[] GetPermutationForBroadcast(int targetRank, int rank, bool isNHWC = false)
        {
            int[] permutations = new[] { 0, 1, 2, 3 };

            if (rank == 0 || targetRank == 1)
            {
                return(permutations);
            }

            switch (targetRank)
            {
            case 2:
                // ONNX: 5,7 + 7
                // Barracuda: 5,_,_,7 + 7,_,_,- => _,_,_,7
                permutations = new[] { 1, 2, 3, 0 };
                break;

            case 3:
                // ONNX: 5,7,3 + 3
                // Barracuda: 5,_,3,7 + 3,_,_,_  => _,_,3,_
                if (rank == 1)
                {
                    permutations = new[] { 1, 2, 0, 3 }
                }
                ;

                // ONNX: 5,7,3 + 7,3
                // Barracuda: 5,_,3,7 + 7,_,_,3 => _,_,3,7
                else if (rank == 2)
                {
                    permutations = new[] { 1, 2, 3, 0 }
                }
                ;

                break;

            case 4:
                // ONNX: 2,5,7,3 + 3
                // Barracuda: 2,7,3,5 + 3,_,_,_  => _,_,3,_
                if (rank == 1)
                {
                    permutations = new[] { 1, 2, 0, 3 }
                }
                ;

                // ONNX: 2,5,7,3 + 7,3
                // Barracuda: 2,7,3,5 + 7,_,_,3  => _,7,3,_
                else if (rank == 2)
                {
                    permutations = new[] { 1, 0, 3, 2 }
                }
                ;

                // ONNX: 2,5,7,3 + 5,7,3
                // Barracuda: 2,7,3,5 + 5,_,3,7  => _,7,3,5
                else if (rank == 3)
                {
                    permutations = new[] { 1, 3, 2, 0 }
                }
                ;
                break;
            }

            if (isNHWC)
            {
                switch (targetRank)
                {
                case 2:
                    // ONNX: 5,7 + 7
                    // Barracuda: 5,_,_,7 + 7,_,_,- => _,_,_,7
                    permutations = new[] { 1, 2, 3, 0 };
                    break;

                case 3:
                    // ONNX: 5,7,3 + 3
                    // Barracuda: 5,_,7,3 + 3,_,_,_  => _,_,_,3
                    if (rank == 1)
                    {
                        permutations = new[] { 1, 2, 3, 0 }
                    }
                    ;

                    // ONNX: 5,7,3 + 7,3
                    // Barracuda: 5,_,7,3 + 7,_,_,3 => _,_,7,3
                    else if (rank == 2)
                    {
                        permutations = new[] { 1, 2, 0, 3 }
                    }
                    ;

                    break;

                case 4:
                    // ONNX: 2,5,7,3 + 3
                    // Barracuda: 2,5,7,3 + 3,_,_,_  => _,_,_,3
                    if (rank == 1)
                    {
                        permutations = new[] { 1, 2, 3, 0 }
                    }
                    ;

                    // ONNX: 2,5,7,3 + 7,3
                    // Barracuda: 2,5,7,3 + 7,_,_,3  => _,_,7,3,
                    else if (rank == 2)
                    {
                        permutations = new[] { 1, 2, 0, 3 }
                    }
                    ;

                    // ONNX: 2,5,7,3 + 5,7,3
                    // Barracuda: 2,5,7,3 + 5,_,7,3  => _,5,7,3
                    else if (rank == 3)
                    {
                        permutations = new[] { 1, 0, 2, 3 }
                    }
                    ;
                    break;
                }
            }
            return(permutations);
        }

        void CorrectConstantsForBroadCast(ref Model nhwc)
        {
            List <Layer> correctedConstants = new List <Layer>();

            for (int l = 0; l < nhwc.layers.Count; l++)
            {
                Layer layer = nhwc.layers[l];
                for (int i = 0; i < layer.inputs.Length; i++)
                {
                    var input = layer.inputs[i];

                    if (!ModelAnalyzer.IsLayerBroacastable(layer))
                    {
                        continue;
                    }

                    if (!m_RanksByName.ContainsKey(input) || !m_RanksByName.ContainsKey(layer.name))
                    {
                        continue;
                    }

                    Layer inputLayer;
                    bool  found = ModelAnalyzer.FindLayerByName(nhwc, input, out inputLayer);
                    if (!found)
                    {
                        continue;
                    }

                    if (!ModelOptimizer.IsLayerConstant(inputLayer))
                    {
                        continue;
                    }

                    if (m_RanksByName[input] < 1 || m_RanksByName[input] == m_RanksByName[layer.name])
                    {
                        continue;
                    }
                    if (inputLayer.weights.Length == 1)
                    {
                        continue;
                    }

                    if (m_RanksByName[input] > m_RanksByName[layer.name])
                    {
                        throw new Exception($"constant must be lower rank than input for broadcast to work, TODO add transpose before input");
                    }

                    Layer correctedConstLayer = new Layer("c_" + inputLayer.name + "For_" + layer.name, Layer.Type.Load);

                    // transpose dataset
                    correctedConstLayer.datasets = new Layer.DataSet[1];
                    Array.Copy(inputLayer.datasets, correctedConstLayer.datasets, inputLayer.datasets.Length);
                    correctedConstLayer.datasets[0].name = correctedConstLayer.name;


                    correctedConstLayer.weights = new BarracudaArray(inputLayer.weights.Length);

                    var X = inputLayer.DataSetToTensor(0);

                    var rank = m_RanksByName[layer.name].Value;

                    var   inputRank    = m_RanksByName[input].Value;
                    int[] permutations = GetPermutationForBroadcast(rank, inputRank, (m_isModelExportedFromNHWC && (m_layersChannelOrder[layer.name] == LayoutTransposeRemovalHelper.ChannelsOrder.NHWC)));

                    var O = m_Ops.Transpose(X, permutations);
                    correctedConstLayer.ApplyTensorToDataSet(O, 0);
                    O.Dispose();
                    X.Dispose();

                    correctedConstants.Add(correctedConstLayer);
                    layer.inputs[i] = correctedConstLayer.name;
                }

                nhwc.layers[l] = layer;
            }

            foreach (var l in correctedConstants)
            {
                nhwc.layers.Insert(0, l);
            }
        }

        void CorrectDynamicInputsForBroadCast(ref Model nhwc)
        {
            // insert transposes before broadcastalbe ops
            for (int l = 0; l < nhwc.layers.Count; l++)
            {
                Layer layer = nhwc.layers[l];
                if (!ModelAnalyzer.IsLayerBroacastable(layer))
                {
                    continue;
                }

                if (!m_RanksByName.ContainsKey(layer.name) || m_RanksByName[layer.name] == null)
                {
                    continue;
                }

                int maxRank = m_RanksByName[layer.name].Value;
                if (maxRank <= 1)
                {
                    continue;
                }

                for (int i = 0; i < layer.inputs.Length; i++)
                {
                    string input = layer.inputs[i];

                    if (!m_RanksByName.ContainsKey(input) || m_RanksByName[input] == null)
                    {
                        continue;
                    }

                    int inputRank = m_RanksByName[input].Value;

                    if (inputRank < 1 || inputRank == maxRank)
                    {
                        continue;
                    }

                    int[] permutations = GetPermutationForBroadcast(maxRank, inputRank, (m_isModelExportedFromNHWC && (m_layersChannelOrder[layer.name] == LayoutTransposeRemovalHelper.ChannelsOrder.NHWC)));

                    Layer transpose = new Layer("transpose_forbroadcast_" + layer.name + "_" + input, Layer.Type.Transpose);
                    transpose.inputs = new[] { input };
                    transpose.pool   = permutations;

                    nhwc.layers[l].inputs[i] = transpose.name;
                    nhwc.layers.Insert(l, transpose);
                    l += 1;
                }
            }
        }
    }
}
Ejemplo n.º 7
0
        public void Fit(DataIter train, uint epochs = 1, uint batchSize = 32, DataIter validation = null, bool shuffle = false)
        {
            var    args      = new SortedDictionary <string, NDArray>();
            string labelName = "label";
            var    label     = Symbol.Variable(labelName);

            args["X"]       = new NDArray(new Shape(batchSize, (uint)InputShape[0]));
            args[labelName] = new NDArray(new Shape(batchSize, (uint)OutputShape.Size));

            CompiledModel.InferArgsMap(GlobalParam.Device, args, args);

            var initializer = new SiaDNN.Initializers.GlorotUniform();

            foreach (var arg in args)
            {
                initializer.Operator(arg.Key, arg.Value);
            }

            ModelOptimizer.SetParam("rescale_grad", 1.0 / batchSize);

            using (var exec = CompiledModel.SimpleBind(GlobalParam.Device, args))
            {
                var argNames = CompiledModel.ListArguments();

                // Start training
                var sw = new Stopwatch();
                for (var iter = 0; iter < epochs; ++iter)
                {
                    uint samples = 0;
                    train.BatchSize = batchSize;
                    train.Reset();

                    sw.Restart();

                    while (train.Next())
                    {
                        samples += batchSize;
                        var dataBatch = train.GetDataBatch();
                        // Set data and label
                        dataBatch.Data.CopyTo(args["X"]);
                        dataBatch.Label.CopyTo(args[labelName]);

                        // Compute gradients
                        exec.Forward(true);
                        exec.Backward();
                        // Update parameters
                        for (var i = 0; i < argNames.Count; ++i)
                        {
                            if (argNames[i] == "X" || argNames[i] == labelName)
                            {
                                continue;
                            }

                            ModelOptimizer.Update(i, exec.ArgmentArrays[i], exec.GradientArrays[i]);
                        }

                        Metric.Update(dataBatch.Label, exec.Outputs[0]);
                    }

                    sw.Stop();

                    if (validation != null)
                    {
                        validation.BatchSize = batchSize;
                        validation.Reset();
                        while (validation.Next())
                        {
                            var dataBatch = validation.GetDataBatch();
                            dataBatch.Data.CopyTo(args["X"]);
                            dataBatch.Label.CopyTo(args[labelName]);
                            // Forward pass is enough as no gradient is needed when evaluating
                            exec.Forward(false);
                            Metric.Update(dataBatch.Label, exec.Outputs[0]);
                        }
                    }


                    var duration = sw.ElapsedMilliseconds / 1000.0;
                    if (validation == null)
                    {
                        Logging.LG($"Epoch: {iter} {samples / duration} samples/sec Train_Metric: {Metric.Get()}");
                    }
                    else
                    {
                        Logging.LG($"Epoch: {iter} {samples / duration} samples/sec, Train_Metric: {Metric.Get()},  Val_Metric: {Metric.Get()}");
                    }
                }
            }

            MXNet.MXNotifyShutdown();
        }