Example #1
0
        public override void Concat(Volume <float> right, Volume <float> result)
        {
            var batchSize = Math.Max(this.Shape.Dimensions[3], right.Shape.Dimensions[3]);

            if (this.Shape.TotalLength > 1 && right.Shape.TotalLength > 1)
            {
                var left = ReShape(new Shape(1, 1, -1, batchSize));
                right = right.ReShape(new Shape(1, 1, -1, batchSize));

                var elementPerBatch = result.Shape.TotalLength / batchSize;
                var threshold       = left.Shape.Dimensions[2];

                for (var n = 0; n < batchSize; n++)
                {
                    for (var i = 0; i < elementPerBatch; i++)
                    {
                        result.Set(0, 0, i, n, i < threshold ? left.Get(0, 0, i, n) : right.Get(0, 0, i - threshold, n));
                    }
                }
            }
            else if (this.Shape.TotalLength == 1 && right.Shape.TotalLength > 1)
            {
                // Left volume is actually a scalar => broadcast its value

                right = right.ReShape(new Shape(1, 1, -1, batchSize));
                var elementPerBatch = result.Shape.TotalLength / batchSize;
                var threshold       = 1;

                for (var n = 0; n < batchSize; n++)
                {
                    for (var i = 0; i < elementPerBatch; i++)
                    {
                        result.Set(0, 0, i, n, i < threshold ? Get(0) : right.Get(0, 0, i - threshold, n));
                    }
                }
            }
            else
            {
                // Right volume is actually a scalar => broadcast its value

                var left            = ReShape(new Shape(1, 1, -1, batchSize));
                var elementPerBatch = result.Shape.TotalLength / batchSize;
                var threshold       = left.Shape.Dimensions[2];

                for (var n = 0; n < batchSize; n++)
                {
                    for (var i = 0; i < elementPerBatch; i++)
                    {
                        result.Set(0, 0, i, n, i < threshold ? left.Get(0, 0, i, n) : right.Get(0));
                    }
                }
            }
        }
Example #2
0
        public override void SoftmaxGradient(Volume <float> outputGradient, Volume <float> inputGradient)
        {
            var batchSize = this.Shape.Dimensions[3];

            var outputReshape         = ReShape(-1, batchSize);
            var outputGradientReshape = outputGradient.ReShape(-1, batchSize);
            var inputGradientReshape  = inputGradient.ReShape(-1, batchSize);

            var firstDim = outputReshape.Shape.Dimensions[0];

            for (var b = 0; b < batchSize; b++)
            {
                var classIndex = -1;

                for (var i = 0; i < firstDim; i++)
                {
                    var yi = outputGradientReshape.Get(i, b);

                    if (yi == 1.0f)
                    {
                        classIndex = i;
                    }
                }

                var pj = outputReshape.Get(classIndex, b);

                // input gradient:
                // pi(1 - pi) if i = class index
                // -pipj if i != class index
                for (var i = 0; i < firstDim; i++)
                {
                    var pi = outputReshape.Get(i, b);

                    if (i == classIndex)
                    {
                        inputGradientReshape.Set(i, b, pj * (1.0f - pj));
                    }
                    else
                    {
                        inputGradientReshape.Set(i, b, -pj * pi);
                    }
                }
            }
        }
Example #3
0
        public override void DoSum(Volume <float> result)
        {
            var batchSize    = this.Shape.DimensionCount > 1 ? this.Shape.GetDimension(-1) : 1;
            var inputReshape = ReShape(-1, batchSize);

            var n = inputReshape.Shape.GetDimension(0);

            if (batchSize > 1 && result.Shape.DimensionCount > 1 && result.Shape.GetDimension(3) == 1)
            {
                var resultReshape = result.ReShape(-1, 1);

                for (var j = 0; j < n; j++)
                {
                    var sum = 0.0f;

                    // Sum over batch
                    for (var i = 0; i < batchSize; i++)
                    {
                        var d = inputReshape.Get(j, i);
                        sum += d;
                    }

                    resultReshape.Set(j, 0, sum);
                }
            }
            else
            {
                for (var i = 0; i < batchSize; i++)
                {
                    var sum = 0.0f;

                    for (var j = 0; j < n; j++)
                    {
                        var d = inputReshape.Get(j, i);
                        sum += d;
                    }

                    result.Set(new[] { i }, sum);
                }
            }
        }