protected override void OnPlanning(GraphPlanContext context) { var graph = context.TFGraph; var input = context.TFOutputs[Input.Connection.From]; var weights = Weights.ToHWIO(); var bias = Bias?.ToNHWC(); var y = graph.Conv2D(input, graph.Const(weights), new long[] { 1, StrideHeight, StrideWidth, 1 }, Padding.ToString().ToUpperInvariant()); if (bias != null) { y = graph.BiasAdd(y, graph.Const(bias)); } context.TFOutputs[Output] = graph.AddActivation(y, FusedActivationFunction); }
protected override void OnPlanning(GraphPlanContext context) { var graph = context.TFGraph; var y = context.TFOutputs[Input.Connection.From]; var weights = Weights.ToHWIO(); if (Input.Dimensions.Length == 4 && Input.Dimensions[2] == 1 && Input.Dimensions[3] == 1) { y = graph.Reshape(y, graph.Const(new[] { Input.Dimensions[0], Input.Dimensions[1] })); } y = graph.MatMul(y, graph.Const(weights)); if (Bias != null) { y = graph.BiasAdd(y, graph.Const(Bias.ToNHWC())); } context.TFOutputs[Output] = graph.AddActivation(y, FusedActivationFunction); }
protected override void OnPlanning(GraphPlanContext context) { var graph = context.TFGraph; var input = context.TFOutputs[Input.Connection.From]; var weights = Weights.ToHWIO(); var bias = Bias.ToNHWC(); TFOutput y; if (PoolType == K210PoolType.LeftTop) { y = KernelWidth == 1 ? input : graph.SpaceToBatchND(input, graph.Const(new[] { 1, 1 }), graph.Const(new[, ] { { 1, 1 }, { 1, 1 } })); y = Conv2dType == K210Conv2dType.Conv2d ? graph.Conv2D(y, graph.Const(weights), new long[] { 1, 2, 2, 1 }, "VALID") : graph.DepthwiseConv2dNative(y, graph.Const(weights), new long[] { 1, 2, 2, 1 }, "VALID"); } else { y = Conv2dType == K210Conv2dType.Conv2d ? graph.Conv2D(input, graph.Const(weights), new long[] { 1, 1, 1, 1 }, "SAME") : graph.DepthwiseConv2dNative(input, graph.Const(weights), new long[] { 1, 1, 1, 1 }, "SAME"); } y = graph.BiasAdd(y, graph.Const(bias)); context.AdditionalTFOutputs[OutputBeforeActivation] = y; y = AddActivation(graph, y, FusedActivationFunction, NonTrivialActivation); switch (PoolType) { case K210PoolType.MaxPool2x2: y = graph.MaxPool(y, new long[] { 1, 2, 2, 1 }, new long[] { 1, 2, 2, 1 }, "VALID"); break; case K210PoolType.AveragePool2x2: y = graph.AvgPool(y, new long[] { 1, 2, 2, 1 }, new long[] { 1, 2, 2, 1 }, "VALID"); break; case K210PoolType.MaxPool4x4: y = graph.MaxPool(y, new long[] { 1, 4, 4, 1 }, new long[] { 1, 4, 4, 1 }, "VALID"); break; case K210PoolType.AveragePool4x4: y = graph.AvgPool(y, new long[] { 1, 4, 4, 1 }, new long[] { 1, 4, 4, 1 }, "VALID"); break; case K210PoolType.MaxPool2x2Stride1: y = graph.MaxPool(y, new long[] { 1, 2, 2, 1 }, new long[] { 1, 1, 1, 1 }, "SAME"); break; case K210PoolType.AveragePool2x2Stride1: y = graph.AvgPool(y, new long[] { 1, 2, 2, 1 }, new long[] { 1, 1, 1, 1 }, "SAME"); break; default: break; } context.TFOutputs[Output] = y; }