/// <summary> /// 'softmax' operation runs the softmax on each item in the btm and places the results in the top. /// </summary> /// <param name="btm">Specifies the input data.</param> /// <param name="top">Specifies the output data.</param> /// <returns>The top blob is returned.</returns> public Blob <T> softmax(Blob <T> btm, Blob <T> top) { string strMarker = marker; top.ReshapeLike(btm); int nOuterNum = btm.count(0, m_nAxis); int nInnerNum = btm.count(m_nAxis + 1); int nChannels = top.shape(m_nAxis); int nCount = btm.count(); work.ReshapeLike(top); m_cuda.copy(nCount, btm.gpu_data, top.mutable_gpu_data); // We need to subtract the max to avoid numerical issues, compute the exp // and then normalize. // compute max. m_cuda.channel_max(nOuterNum * nInnerNum, nOuterNum, nChannels, nInnerNum, top.gpu_data, work.mutable_gpu_data); // subtract m_cuda.channel_sub(nCount, nOuterNum, nChannels, nInnerNum, work.gpu_data, top.mutable_gpu_data); // exponentiate m_cuda.exp(nCount, top.gpu_data, top.mutable_gpu_data); // Sum after exp m_cuda.channel_sum(nOuterNum * nInnerNum, nOuterNum, nChannels, nInnerNum, top.gpu_data, work.mutable_gpu_data); // divide m_cuda.channel_div(nCount, nOuterNum, nChannels, nInnerNum, work.gpu_data, top.mutable_gpu_data); if (m_bNeedsBackprop) { Action backward = () => { work.ReshapeLike(top); m_cuda.copy(nCount, top.gpu_diff, work.mutable_gpu_diff); // Compute inner1d(top_diff, top_data) and subtract them from the bottom diff. m_cuda.channel_dot(nOuterNum * nInnerNum, nOuterNum, nChannels, nInnerNum, top.gpu_diff, top.gpu_data, work.mutable_gpu_data); m_cuda.channel_sub(nCount, nOuterNum, nChannels, nInnerNum, work.gpu_data, work.mutable_gpu_diff); // elementwise multiplication m_cuda.mul(nCount, work.gpu_diff, top.gpu_data, work.mutable_gpu_diff); apply(work, btm); if (m_bClipGradients) { clip_gradient(btm); } if (m_bCheckForNans) { check_nan(btm); } if (m_bAddDebug) { add_debug(strMarker + " - softmax", btm, top); } }; m_rgBackprop.Add(new Tuple <string, Action>(m_strMarker, backward)); } return(top); }