コード例 #1
0
        public static void FuseLinear(Model model, HashSet <string> keepLayers = null)
        {
            // outputs and memories can be queried by the user, make sure they are not removed
            var preserve = new HashSet <string>(
                model.memories.Select(mem => mem.input).Concat(
                    model.memories.Select(mem => mem.output)).Concat(
                    model.outputs));

            var constantLayers = new Dictionary <string, Layer>();

            foreach (var l in model.layers)
            {
                if (IsLayerConstant(l))
                {
                    constantLayers[l.name] = l;
                }
            }

            // pack constants into layer database
            PackConstants(model, constantLayers);

            var remap        = new Dictionary <string, string>();
            var mergedLayers = new HashSet <Layer>();

            for (int l = 0; l < model.layers.Count; ++l)
            {
                var layer = model.layers[l];

                bool isLayerLinear      = LinearLayerFusing.IsLayerLinear(layer, constantLayers);
                bool isLayerPreserved   = preserve.Contains(layer.name);
                bool layerHasActivation = IsLayerFusedActivation(layer);

                if (!isLayerLinear)
                {
                    continue;
                }

                // if layer has an activation, we fuse it, but treat it as non linear for future children
                if (!layerHasActivation)
                {
                    remap[layer.name] = layer.name;
                }

                // Multi input nodes can only fuse constants and same inputs
                // only merge constants. @TODO: fuse equal input nodes
                var nonLinearInputs = layer.inputs.Where(x => !remap.ContainsKey(x) && !constantLayers.ContainsKey(x)).ToList();
                var linearInputs    = layer.inputs.Where(x => remap.ContainsKey(x)).ToList();

                // merge layer with one linearInput and eventual constants
                if (nonLinearInputs.Count > 0 || linearInputs.Count > 1)
                {
                    continue;
                }

                var input = linearInputs[0];

                // input is a linear layer, fuse it
                int   inputLayerIndex = model.layers.FindIndex(x => x.name == remap[input]);
                Layer inputLayer      = model.layers[inputLayerIndex];

                if (!AreLayersFusable(inputLayer, layer))
                {
                    continue;
                }

                // convention: layer will be fused into inputLayer
                // => fused layer will have the same inputs as inputLayer
                Layer fusedLayer = FuseConsecutiveLayers(inputLayer, layer);

                if (LayerComplextity(fusedLayer) > LayerComplextity(inputLayer) + LayerComplextity(layer))
                {
                    continue;
                }

                if (layerHasActivation)
                {
                    fusedLayer.activation = layer.activation;
                }

                bool hasNoSkipConnection = (model.GetDownStreamLayersCount(input) == 1);
                //  if input has more than 1 child, we can't override input with fused result
                //  same if input is preserved
                if (!hasNoSkipConnection || preserve.Contains(input))
                {
                    fusedLayer.name = layer.name;
                    model.layers[l] = fusedLayer;
                    continue;
                }

                // preserve layer if output/memory
                if (isLayerPreserved)
                {
                    // cannot merge layer into input:
                    // remove input, no need to remap as inputs == input.inputs
                    fusedLayer.name = layer.name;
                    mergedLayers.Add(inputLayer);
                    model.layers[l] = fusedLayer;
                }
                else
                {
                    // merge layer into input
                    // remove current and remap input names
                    mergedLayers.Add(layer);
                    remap[layer.name]             = fusedLayer.name;
                    model.layers[inputLayerIndex] = fusedLayer;
                }
            }

            // remove merged layers
            model.layers.RemoveAll(x => mergedLayers.Contains(x));

            // update remapped inputs
            for (int l = 0; l < model.layers.Count; ++l)
            {
                Layer layer = model.layers[l];
                for (int i = 0; i < layer.inputs.Length; ++i)
                {
                    var input = layer.inputs[i];
                    if (remap.ContainsKey(input))
                    {
                        model.layers[l].inputs[i] = remap[input];
                    }
                }
            }

            // unpack constants
            UnpackConstants(model);

            // remove unused constants
            foreach (var l in model.layers)
            {
                foreach (var i in l.inputs)
                {
                    if (constantLayers.ContainsKey(i))
                    {
                        constantLayers.Remove(i);
                    }
                }
            }
            model.layers.RemoveAll(x => constantLayers.ContainsKey(x.name) &&
                                   !preserve.Contains(x.name) &&
                                   (keepLayers == null ? true : !keepLayers.Contains(x.name)));
        }
コード例 #2
0
 static Layer FuseConsecutiveLayers(Layer previous, Layer current)
 {
     return(linearLayerFuser.FuseLayers(previous, current));
 }
コード例 #3
0
 static bool AreLayersFusable(Layer l0, Layer l1)
 {
     // can't fuse if input has a fused activation or if fusing code not implemented
     return(!IsLayerFusedActivation(l0) && linearLayerFuser.AreLayersFusable(l0, l1));
 }
コード例 #4
0
 static bool IsLayerFusedActivation(Layer layer)
 {
     return(layer.activation != Layer.Activation.None);
 }
コード例 #5
0
 static long LayerComplextity(Layer l)
 {
     return(m_LayerComplexity.LayerComplextity(l));
 }
コード例 #6
0
 public static bool IsLayerConstant(Layer layer)
 {
     return(layer.type == Layer.Type.Load);
 }
コード例 #7
0
 public override void PrepareStorage(Layer forLayer)
 {
     base.PrepareStorage(forLayer);
     m_LayerRequiresStorage = m_LayersWithStorage.Contains(forLayer);
 }
コード例 #8
0
 public void LayerExecutionStarted(Layer layer)
 {
     Assert.IsNotNull(CurrentModelExecutionReport);
     CurrentModelExecutionReport.LayerExecutionStarted(layer);
 }
コード例 #9
0
 public void TakeMemorySnapshot(IOps ops, IVars vars, string context, Layer layer)
 {
     MemorySnapshotsReport.TakeMemorySnapshot(ops, vars, context, layer);
 }
コード例 #10
0
 internal void LayerExecutionStarted(Layer layer)
 {
     Assert.IsNull(CurrentLayerExecutionReport);
     CurrentLayerExecutionReport = new LayerExecutionReport(layer);
 }