protected override IGraphData _Backpropagate(INode fromNode, IGraphData errorSignal, IContext context, IReadOnlyList <INode> parents) { var tensor = errorSignal.GetMatrix(). ReshapeAs4DTensor(_newHeight, _newWidth, _source._filter.ColumnCount); var padding = _source._padding; // calculate the weight and bias updates using (var update = _im2Col.TransposeThisAndMultiply(tensor)) { var weightUpdate = update.CombineDepthSlices(); var biasUpdate = tensor.ColumnSums(); context.LearningContext.StoreUpdate(_source, weightUpdate, err => _source.Update(err, context.LearningContext)); context.LearningContext.StoreUpdate(_source, biasUpdate, bu => _UpdateBias(bu, context.LearningContext)); } if (_source._shouldBackpropagate) { var filters = _source._filter; var inputDepth = _source._inputDepth; var filterWidth = _source._filterWidth; var filterHeight = _source._filterHeight; var xStride = _source._xStride; var yStride = _source._yStride; var outputRows = _inputHeight + padding * 2; var outputColumns = _inputWidth + padding * 2; var outputDepth = _inputDepth; var reverseIm2Col = tensor.ReverseIm2Col(filters, outputRows, outputColumns, outputDepth, filterWidth, filterHeight, xStride, yStride); var delta = reverseIm2Col; if (padding > 0) { var delta2 = delta.RemovePadding(padding); delta.Dispose(); delta = delta2; } return(new Tensor4DGraphData(delta.ReshapeAsMatrix(), _inputHeight, _inputWidth, inputDepth)); } return(errorSignal); }
protected override IGraphData _Backpropagate(INode fromNode, IGraphData errorSignal, IContext context, IReadOnlyList <INode> parents) { var lap = context.LinearAlgebraProvider; var tensor = errorSignal.GetMatrix().ConvertTo4DTensor(_newHeight, _newWidth, _source._filter.ColumnCount); var padding = _source._padding; // calculate the weight and bias updates var weightUpdate = _im2Col.TransposeThisAndMultiply(tensor).CombineDepthSlices(); var biasUpdate = tensor.ColumnSums(); context.LearningContext.StoreUpdate(_source, weightUpdate, err => _source.Update(err, context.LearningContext)); context.LearningContext.StoreUpdate(_source, biasUpdate, bu => _UpdateBias(bu, context.LearningContext)); if (_source._shouldBackpropagate) { var filters = _source._filter; var inputDepth = _source._inputDepth; var filterWidth = _source._filterWidth; var filterHeight = _source._filterHeight; var stride = _source._stride; var filterList = new List <IReadOnlyList <IVector> >(); for (var i = 0; i < filters.ColumnCount; i++) { filterList.Add(filters.Column(i).Split(inputDepth).Select(v => v.Rotate(v.Count / filterWidth)).ToList()); } using (var reverseIm2Col = tensor.ReverseIm2Col(filterList, _inputHeight, _inputWidth, inputDepth, padding, filterWidth, filterHeight, stride)) { var delta = reverseIm2Col.ConvertTo4DTensor(_inputHeight + padding * 2, _inputWidth + padding * 2); if (padding > 0) { delta = delta.RemovePadding(padding); } return(new Tensor4DGraphData(delta.ConvertToMatrix(), _inputHeight, _inputWidth, inputDepth)); } } return(errorSignal); }