/// <summary> /// Returns a CompressedPolicyVector array with the policy vectors /// extracted from all the positions in this batch. /// </summary> /// <param name="numPos"></param> /// <param name="topK"></param> /// <param name="indices"></param> /// <param name="probabilities"></param> /// <param name="probType"></param> /// <returns></returns> static CompressedPolicyVector[] ExtractPoliciesTopK(int numPos, int topK, Span <int> indices, Span <float> probabilities, PolicyType probType) { if (probType == PolicyType.LogProbabilities) { throw new NotImplementedException(); for (int i = 0; i < indices.Length; i++) { probabilities[i] = MathF.Exp(probabilities[i]); } } if (indices == null && probabilities == null) { return(null); } if (probabilities.Length != indices.Length) { throw new ArgumentException("Indices and probabilties expected to be same length"); } CompressedPolicyVector[] retPolicies = new CompressedPolicyVector[numPos]; int offset = 0; for (int i = 0; i < numPos; i++) { CompressedPolicyVector.Initialize(ref retPolicies[i], indices.Slice(offset, topK), probabilities.Slice(offset, topK)); offset += topK; } return(retPolicies); }
static CompressedPolicyVector[] ExtractPoliciesBufferFlat(int numPos, float[] policyProbs, PolicyType probType, bool alreadySorted) { // TODO: possibly needs work. // Do we handle WDL correctly? Do we flip the moves if we are black (using positions) ? if (policyProbs == null) { return(null); } if (policyProbs.Length != EncodedPolicyVector.POLICY_VECTOR_LENGTH * numPos) { throw new ArgumentException("Wrong policy size"); } CompressedPolicyVector[] retPolicies = new CompressedPolicyVector[numPos]; if (policyProbs.Length != EncodedPolicyVector.POLICY_VECTOR_LENGTH * numPos) { throw new ArgumentException("Wrong policy size"); } float[] buffer = new float[EncodedPolicyVector.POLICY_VECTOR_LENGTH]; for (int i = 0; i < numPos; i++) { int startIndex = EncodedPolicyVector.POLICY_VECTOR_LENGTH * i; if (probType == PolicyType.Probabilities) { Array.Copy(policyProbs, startIndex, buffer, 0, EncodedPolicyVector.POLICY_VECTOR_LENGTH); } else { // Avoid overflow by subtracting off max float max = 0.0f; for (int j = 0; j < EncodedPolicyVector.POLICY_VECTOR_LENGTH; j++) { float val = policyProbs[startIndex + j]; if (val > max) { max = val; } } for (int j = 0; j < EncodedPolicyVector.POLICY_VECTOR_LENGTH; j++) { buffer[j] = (float)Math.Exp(policyProbs[startIndex + j] - max); // TODO: make faster } } double acc = 0; for (int j = 0; j < EncodedPolicyVector.POLICY_VECTOR_LENGTH; j++) { acc += buffer[j]; } if (acc == 0.0) { throw new Exception("Sum of unnormalized probabilities was zero."); } // As performance optimization, only adjust if significantly different from 1.0 const float MAX_DEVIATION = 0.001f; if (acc < 1.0f - MAX_DEVIATION || acc > 1.0f + MAX_DEVIATION) { for (int j = 0; j < EncodedPolicyVector.POLICY_VECTOR_LENGTH; j++) { buffer[j] = (float)(buffer[j] / acc); } } CompressedPolicyVector.Initialize(ref retPolicies[i], buffer, alreadySorted); } return(retPolicies); }