public static void FuseLinear(Model model, HashSet <string> keepLayers = null) { // outputs and memories can be queried by the user, make sure they are not removed var preserve = new HashSet <string>( model.memories.Select(mem => mem.input).Concat( model.memories.Select(mem => mem.output)).Concat( model.outputs)); var constantLayers = new Dictionary <string, Layer>(); foreach (var l in model.layers) { if (IsLayerConstant(l)) { constantLayers[l.name] = l; } } // pack constants into layer database PackConstants(model, constantLayers); var remap = new Dictionary <string, string>(); var mergedLayers = new HashSet <Layer>(); for (int l = 0; l < model.layers.Count; ++l) { var layer = model.layers[l]; bool isLayerLinear = LinearLayerFusing.IsLayerLinear(layer, constantLayers); bool isLayerPreserved = preserve.Contains(layer.name); bool layerHasActivation = IsLayerFusedActivation(layer); if (!isLayerLinear) { continue; } // if layer has an activation, we fuse it, but treat it as non linear for future children if (!layerHasActivation) { remap[layer.name] = layer.name; } // Multi input nodes can only fuse constants and same inputs // only merge constants. @TODO: fuse equal input nodes var nonLinearInputs = layer.inputs.Where(x => !remap.ContainsKey(x) && !constantLayers.ContainsKey(x)).ToList(); var linearInputs = layer.inputs.Where(x => remap.ContainsKey(x)).ToList(); // merge layer with one linearInput and eventual constants if (nonLinearInputs.Count > 0 || linearInputs.Count > 1) { continue; } var input = linearInputs[0]; // input is a linear layer, fuse it int inputLayerIndex = model.layers.FindIndex(x => x.name == remap[input]); Layer inputLayer = model.layers[inputLayerIndex]; if (!AreLayersFusable(inputLayer, layer)) { continue; } // convention: layer will be fused into inputLayer // => fused layer will have the same inputs as inputLayer Layer fusedLayer = FuseConsecutiveLayers(inputLayer, layer); if (LayerComplextity(fusedLayer) > LayerComplextity(inputLayer) + LayerComplextity(layer)) { continue; } if (layerHasActivation) { fusedLayer.activation = layer.activation; } bool hasNoSkipConnection = (model.GetDownStreamLayersCount(input) == 1); // if input has more than 1 child, we can't override input with fused result // same if input is preserved if (!hasNoSkipConnection || preserve.Contains(input)) { fusedLayer.name = layer.name; model.layers[l] = fusedLayer; continue; } // preserve layer if output/memory if (isLayerPreserved) { // cannot merge layer into input: // remove input, no need to remap as inputs == input.inputs fusedLayer.name = layer.name; mergedLayers.Add(inputLayer); model.layers[l] = fusedLayer; } else { // merge layer into input // remove current and remap input names mergedLayers.Add(layer); remap[layer.name] = fusedLayer.name; model.layers[inputLayerIndex] = fusedLayer; } } // remove merged layers model.layers.RemoveAll(x => mergedLayers.Contains(x)); // update remapped inputs for (int l = 0; l < model.layers.Count; ++l) { Layer layer = model.layers[l]; for (int i = 0; i < layer.inputs.Length; ++i) { var input = layer.inputs[i]; if (remap.ContainsKey(input)) { model.layers[l].inputs[i] = remap[input]; } } } // unpack constants UnpackConstants(model); // remove unused constants foreach (var l in model.layers) { foreach (var i in l.inputs) { if (constantLayers.ContainsKey(i)) { constantLayers.Remove(i); } } } model.layers.RemoveAll(x => constantLayers.ContainsKey(x.name) && !preserve.Contains(x.name) && (keepLayers == null ? true : !keepLayers.Contains(x.name))); }
static Layer FuseConsecutiveLayers(Layer previous, Layer current) { return(linearLayerFuser.FuseLayers(previous, current)); }
static bool AreLayersFusable(Layer l0, Layer l1) { // can't fuse if input has a fused activation or if fusing code not implemented return(!IsLayerFusedActivation(l0) && linearLayerFuser.AreLayersFusable(l0, l1)); }
static bool IsLayerFusedActivation(Layer layer) { return(layer.activation != Layer.Activation.None); }
static long LayerComplextity(Layer l) { return(m_LayerComplexity.LayerComplextity(l)); }
public static bool IsLayerConstant(Layer layer) { return(layer.type == Layer.Type.Load); }
public override void PrepareStorage(Layer forLayer) { base.PrepareStorage(forLayer); m_LayerRequiresStorage = m_LayersWithStorage.Contains(forLayer); }
public void LayerExecutionStarted(Layer layer) { Assert.IsNotNull(CurrentModelExecutionReport); CurrentModelExecutionReport.LayerExecutionStarted(layer); }
public void TakeMemorySnapshot(IOps ops, IVars vars, string context, Layer layer) { MemorySnapshotsReport.TakeMemorySnapshot(ops, vars, context, layer); }
internal void LayerExecutionStarted(Layer layer) { Assert.IsNull(CurrentLayerExecutionReport); CurrentLayerExecutionReport = new LayerExecutionReport(layer); }