/// <summary> /// Overrides worker method to evaluate a specified batch into internal buffers. /// </summary> /// <param name="batch"></param> /// <param name="retrieveSupplementalResults"></param> /// <returns></returns> public override IPositionEvaluationBatch EvaluateIntoBuffers(IEncodedPositionBatchFlat batch, bool retrieveSupplementalResults = false) { int bufferLength = 112 * batch.NumPos * 64; float[] flatValues = ArrayPool <float> .Shared.Rent(bufferLength); batch.ValuesFlatFromPlanes(flatValues); PositionEvaluationBatch ret = DoEvaluateBatch(flatValues, batch.NumPos, retrieveSupplementalResults); ArrayPool <float> .Shared.Return(flatValues); return(ret); }
/// <summary> /// Processes the current set of batches by: /// - aggregating them into one big batch /// - evaluating that big batch all at once /// - disaggregating the returned evaluations into sub-batch-results /// </summary> /// <param name="evaluator"></param> /// <param name="retrieveSupplementalResults"></param> internal void ProcessPooledBatch(NNEvaluator evaluator, bool retrieveSupplementalResults) { // Combine together the pending batches. IEncodedPositionBatchFlat fullBatch = null; if (pendingBatches.Count == 1) { // Handle the special and easy case of exactly one batch. fullBatch = pendingBatches[0]; } else { fullBatch = AggregateBatches(); } // Evaluate the big batch IPositionEvaluationBatch fullBatchResult = evaluator.EvaluateIntoBuffers(fullBatch, retrieveSupplementalResults); PositionEvaluationBatch batchDirect = (PositionEvaluationBatch)fullBatchResult; completedBatches = DisaggregateBatches(retrieveSupplementalResults, batchDirect, pendingBatches); }
/// <summary> /// Splits up positions in an aggregated batch back into sub-batches. /// </summary> /// <param name="retrieveSupplementalResults"></param> /// <param name="fullBatchResult"></param> /// <param name="pendingBatches"></param> /// <returns></returns> internal static PositionEvaluationBatch[] DisaggregateBatches(bool retrieveSupplementalResults, PositionEvaluationBatch fullBatchResult, List <IEncodedPositionBatchFlat> pendingBatches) { Span <CompressedPolicyVector> fullPolicyValues = fullBatchResult.Policies.Span; Span <FP16> fullW = fullBatchResult.W.Span; Span <FP16> fullL = fullBatchResult.IsWDL ? fullBatchResult.L.Span : default; Span <FP16> fullM = fullBatchResult.IsWDL ? fullBatchResult.M.Span : default; Span <float> fullValueHeadConv = retrieveSupplementalResults ? fullBatchResult.ValueHeadConv.Span : new float[0]; // Finally, disaggregate the big batch back into a set of individual subbatch results PositionEvaluationBatch[] completedBatches = new PositionEvaluationBatch[pendingBatches.Count]; int subBatchIndex = 0; int nextPosIndex = 0; foreach (EncodedPositionBatchFlat thisBatch in pendingBatches) { if (retrieveSupplementalResults) { throw new NotImplementedException(); } int numPos = thisBatch.NumPos; PositionEvaluationBatch thisResultSubBatch = new PositionEvaluationBatch(fullBatchResult.IsWDL, fullBatchResult.HasM, thisBatch.NumPos, fullPolicyValues.Slice(nextPosIndex, numPos).ToArray(), fullW.Slice(nextPosIndex, numPos).ToArray(), fullBatchResult.IsWDL ? fullL.Slice(nextPosIndex, numPos).ToArray() : null, fullBatchResult.IsWDL ? fullM.Slice(nextPosIndex, numPos).ToArray() : null, retrieveSupplementalResults ? fullValueHeadConv.Slice(nextPosIndex, numPos).ToArray() : null, fullBatchResult.Stats); nextPosIndex += numPos; completedBatches[subBatchIndex++] = thisResultSubBatch; } return(completedBatches); }
/// <summary> /// Implementation of virtual method to actually evaluate the batch. /// </summary> /// <param name="positions"></param> /// <param name="retrieveSupplementalResults"></param> /// <returns></returns> public override IPositionEvaluationBatch EvaluateIntoBuffers(IEncodedPositionBatchFlat positions, bool retrieveSupplementalResults = false) { if (retrieveSupplementalResults) { throw new NotImplementedException(); } if (positions.NumPos <= MinSplitSize) { // Too small to profitably split across multiple devices return(Evaluators[indexPerferredEvalator].EvaluateIntoBuffers(positions, retrieveSupplementalResults)); } else { // TODO: someday we could use the idea already used in LZTrainingPositionServerBatchSlice // and construct custom WFEvaluationBatch which are just using approrpiate Memory slices // Need to create a new constructor for WFEvaluationBatch IPositionEvaluationBatch[] results = new IPositionEvaluationBatch[Evaluators.Length]; List <Task> tasks = new List <Task>(); int[] subBatchSizes = new int[Evaluators.Length]; for (int i = 0; i < Evaluators.Length; i++) { int capI = i; IEncodedPositionBatchFlat thisSubBatch = GetSubBatch(positions, PreferredFractions, capI); subBatchSizes[capI] = thisSubBatch.NumPos; tasks.Add(Task.Run(() => results[capI] = Evaluators[capI].EvaluateIntoBuffers(thisSubBatch, retrieveSupplementalResults))); } Task.WaitAll(tasks.ToArray()); if (UseMergedBatch) { return(new PositionsEvaluationBatchMerged(results, subBatchSizes)); } else { CompressedPolicyVector[] policies = new CompressedPolicyVector[positions.NumPos]; FP16[] w = new FP16[positions.NumPos]; FP16[] l = new FP16[positions.NumPos]; FP16[] m = new FP16[positions.NumPos]; bool isWDL = results[0].IsWDL; bool hasM = results[0].HasM; int nextPosIndex = 0; for (int i = 0; i < Evaluators.Length; i++) { PositionEvaluationBatch resultI = (PositionEvaluationBatch)results[i]; int thisNumPos = resultI.NumPos; resultI.Policies.CopyTo(new Memory <CompressedPolicyVector>(policies).Slice(nextPosIndex, thisNumPos)); resultI.W.CopyTo(new Memory <FP16>(w).Slice(nextPosIndex, thisNumPos)); if (isWDL) { resultI.L.CopyTo(new Memory <FP16>(l).Slice(nextPosIndex, thisNumPos)); resultI.M.CopyTo(new Memory <FP16>(m).Slice(nextPosIndex, thisNumPos)); } nextPosIndex += thisNumPos; } TimingStats stats = new TimingStats(); return(new PositionEvaluationBatch(isWDL, hasM, positions.NumPos, policies, w, l, m, null, stats)); } }