public override void Concat(Volume <float> right, Volume <float> result) { var batchSize = Math.Max(this.Shape.Dimensions[3], right.Shape.Dimensions[3]); if (this.Shape.TotalLength > 1 && right.Shape.TotalLength > 1) { var left = ReShape(new Shape(1, 1, -1, batchSize)); right = right.ReShape(new Shape(1, 1, -1, batchSize)); var elementPerBatch = result.Shape.TotalLength / batchSize; var threshold = left.Shape.Dimensions[2]; for (var n = 0; n < batchSize; n++) { for (var i = 0; i < elementPerBatch; i++) { result.Set(0, 0, i, n, i < threshold ? left.Get(0, 0, i, n) : right.Get(0, 0, i - threshold, n)); } } } else if (this.Shape.TotalLength == 1 && right.Shape.TotalLength > 1) { // Left volume is actually a scalar => broadcast its value right = right.ReShape(new Shape(1, 1, -1, batchSize)); var elementPerBatch = result.Shape.TotalLength / batchSize; var threshold = 1; for (var n = 0; n < batchSize; n++) { for (var i = 0; i < elementPerBatch; i++) { result.Set(0, 0, i, n, i < threshold ? Get(0) : right.Get(0, 0, i - threshold, n)); } } } else { // Right volume is actually a scalar => broadcast its value var left = ReShape(new Shape(1, 1, -1, batchSize)); var elementPerBatch = result.Shape.TotalLength / batchSize; var threshold = left.Shape.Dimensions[2]; for (var n = 0; n < batchSize; n++) { for (var i = 0; i < elementPerBatch; i++) { result.Set(0, 0, i, n, i < threshold ? left.Get(0, 0, i, n) : right.Get(0)); } } } }
public override void SoftmaxGradient(Volume <float> outputGradient, Volume <float> inputGradient) { var batchSize = this.Shape.Dimensions[3]; var outputReshape = ReShape(-1, batchSize); var outputGradientReshape = outputGradient.ReShape(-1, batchSize); var inputGradientReshape = inputGradient.ReShape(-1, batchSize); var firstDim = outputReshape.Shape.Dimensions[0]; for (var b = 0; b < batchSize; b++) { var classIndex = -1; for (var i = 0; i < firstDim; i++) { var yi = outputGradientReshape.Get(i, b); if (yi == 1.0f) { classIndex = i; } } var pj = outputReshape.Get(classIndex, b); // input gradient: // pi(1 - pi) if i = class index // -pipj if i != class index for (var i = 0; i < firstDim; i++) { var pi = outputReshape.Get(i, b); if (i == classIndex) { inputGradientReshape.Set(i, b, pj * (1.0f - pj)); } else { inputGradientReshape.Set(i, b, -pj * pi); } } } }
public override void DoSum(Volume <float> result) { var batchSize = this.Shape.DimensionCount > 1 ? this.Shape.GetDimension(-1) : 1; var inputReshape = ReShape(-1, batchSize); var n = inputReshape.Shape.GetDimension(0); if (batchSize > 1 && result.Shape.DimensionCount > 1 && result.Shape.GetDimension(3) == 1) { var resultReshape = result.ReShape(-1, 1); for (var j = 0; j < n; j++) { var sum = 0.0f; // Sum over batch for (var i = 0; i < batchSize; i++) { var d = inputReshape.Get(j, i); sum += d; } resultReshape.Set(j, 0, sum); } } else { for (var i = 0; i < batchSize; i++) { var sum = 0.0f; for (var j = 0; j < n; j++) { var d = inputReshape.Get(j, i); sum += d; } result.Set(new[] { i }, sum); } } }