Beispiel #1
0
        /// <summary>
        /// Overrides worker method to evaluate a specified batch into internal buffers.
        /// </summary>
        /// <param name="batch"></param>
        /// <param name="retrieveSupplementalResults"></param>
        /// <returns></returns>
        public override IPositionEvaluationBatch EvaluateIntoBuffers(IEncodedPositionBatchFlat batch, bool retrieveSupplementalResults = false)
        {
            int bufferLength = 112 * batch.NumPos * 64;

            float[] flatValues = ArrayPool <float> .Shared.Rent(bufferLength);

            batch.ValuesFlatFromPlanes(flatValues);
            PositionEvaluationBatch ret = DoEvaluateBatch(flatValues, batch.NumPos, retrieveSupplementalResults);

            ArrayPool <float> .Shared.Return(flatValues);

            return(ret);
        }
Beispiel #2
0
        /// <summary>
        /// Processes the current set of batches by:
        ///   - aggregating them into one big batch
        ///   - evaluating that big batch all at once
        ///   - disaggregating the returned evaluations into sub-batch-results
        /// </summary>
        /// <param name="evaluator"></param>
        /// <param name="retrieveSupplementalResults"></param>
        internal void ProcessPooledBatch(NNEvaluator evaluator, bool retrieveSupplementalResults)
        {
            // Combine together the pending batches.
            IEncodedPositionBatchFlat fullBatch = null;

            if (pendingBatches.Count == 1)
            {
                // Handle the special and easy case of exactly one batch.
                fullBatch = pendingBatches[0];
            }
            else
            {
                fullBatch = AggregateBatches();
            }

            // Evaluate the big batch
            IPositionEvaluationBatch fullBatchResult = evaluator.EvaluateIntoBuffers(fullBatch, retrieveSupplementalResults);
            PositionEvaluationBatch  batchDirect     = (PositionEvaluationBatch)fullBatchResult;

            completedBatches = DisaggregateBatches(retrieveSupplementalResults, batchDirect, pendingBatches);
        }
Beispiel #3
0
        /// <summary>
        /// Splits up positions in an aggregated batch back into sub-batches.
        /// </summary>
        /// <param name="retrieveSupplementalResults"></param>
        /// <param name="fullBatchResult"></param>
        /// <param name="pendingBatches"></param>
        /// <returns></returns>
        internal static PositionEvaluationBatch[] DisaggregateBatches(bool retrieveSupplementalResults,
                                                                      PositionEvaluationBatch fullBatchResult,
                                                                      List <IEncodedPositionBatchFlat> pendingBatches)
        {
            Span <CompressedPolicyVector> fullPolicyValues = fullBatchResult.Policies.Span;
            Span <FP16> fullW = fullBatchResult.W.Span;

            Span <FP16>  fullL             = fullBatchResult.IsWDL ? fullBatchResult.L.Span : default;
            Span <FP16>  fullM             = fullBatchResult.IsWDL ? fullBatchResult.M.Span : default;
            Span <float> fullValueHeadConv = retrieveSupplementalResults ? fullBatchResult.ValueHeadConv.Span : new float[0];

            // Finally, disaggregate the big batch back into a set of individual subbatch results
            PositionEvaluationBatch[] completedBatches = new PositionEvaluationBatch[pendingBatches.Count];

            int subBatchIndex = 0;
            int nextPosIndex  = 0;

            foreach (EncodedPositionBatchFlat thisBatch in pendingBatches)
            {
                if (retrieveSupplementalResults)
                {
                    throw new NotImplementedException();
                }

                int numPos = thisBatch.NumPos;
                PositionEvaluationBatch thisResultSubBatch =
                    new PositionEvaluationBatch(fullBatchResult.IsWDL, fullBatchResult.HasM, thisBatch.NumPos,
                                                fullPolicyValues.Slice(nextPosIndex, numPos).ToArray(),
                                                fullW.Slice(nextPosIndex, numPos).ToArray(),
                                                fullBatchResult.IsWDL ? fullL.Slice(nextPosIndex, numPos).ToArray() : null,
                                                fullBatchResult.IsWDL ? fullM.Slice(nextPosIndex, numPos).ToArray() : null,
                                                retrieveSupplementalResults ? fullValueHeadConv.Slice(nextPosIndex, numPos).ToArray() : null,
                                                fullBatchResult.Stats);

                nextPosIndex += numPos;
                completedBatches[subBatchIndex++] = thisResultSubBatch;
            }
            return(completedBatches);
        }
Beispiel #4
0
        /// <summary>
        /// Implementation of virtual method to actually evaluate the batch.
        /// </summary>
        /// <param name="positions"></param>
        /// <param name="retrieveSupplementalResults"></param>
        /// <returns></returns>
        public override IPositionEvaluationBatch EvaluateIntoBuffers(IEncodedPositionBatchFlat positions, bool retrieveSupplementalResults = false)
        {
            if (retrieveSupplementalResults)
            {
                throw new NotImplementedException();
            }

            if (positions.NumPos <= MinSplitSize)
            {
                // Too small to profitably split across multiple devices
                return(Evaluators[indexPerferredEvalator].EvaluateIntoBuffers(positions, retrieveSupplementalResults));
            }
            else
            {
                // TODO: someday we could use the idea already used in LZTrainingPositionServerBatchSlice
                //       and construct custom WFEvaluationBatch which are just using approrpiate Memory slices
                //       Need to create a new constructor for WFEvaluationBatch
                IPositionEvaluationBatch[] results = new IPositionEvaluationBatch[Evaluators.Length];

                List <Task> tasks         = new List <Task>();
                int[]       subBatchSizes = new int[Evaluators.Length];
                for (int i = 0; i < Evaluators.Length; i++)
                {
                    int capI = i;
                    IEncodedPositionBatchFlat thisSubBatch = GetSubBatch(positions, PreferredFractions, capI);
                    subBatchSizes[capI] = thisSubBatch.NumPos;
                    tasks.Add(Task.Run(() => results[capI] = Evaluators[capI].EvaluateIntoBuffers(thisSubBatch, retrieveSupplementalResults)));
                }
                Task.WaitAll(tasks.ToArray());

                if (UseMergedBatch)
                {
                    return(new PositionsEvaluationBatchMerged(results, subBatchSizes));
                }
                else
                {
                    CompressedPolicyVector[] policies = new CompressedPolicyVector[positions.NumPos];
                    FP16[] w = new FP16[positions.NumPos];
                    FP16[] l = new FP16[positions.NumPos];
                    FP16[] m = new FP16[positions.NumPos];

                    bool isWDL = results[0].IsWDL;
                    bool hasM  = results[0].HasM;

                    int nextPosIndex = 0;
                    for (int i = 0; i < Evaluators.Length; i++)
                    {
                        PositionEvaluationBatch resultI = (PositionEvaluationBatch)results[i];
                        int thisNumPos = resultI.NumPos;

                        resultI.Policies.CopyTo(new Memory <CompressedPolicyVector>(policies).Slice(nextPosIndex, thisNumPos));
                        resultI.W.CopyTo(new Memory <FP16>(w).Slice(nextPosIndex, thisNumPos));

                        if (isWDL)
                        {
                            resultI.L.CopyTo(new Memory <FP16>(l).Slice(nextPosIndex, thisNumPos));
                            resultI.M.CopyTo(new Memory <FP16>(m).Slice(nextPosIndex, thisNumPos));
                        }

                        nextPosIndex += thisNumPos;
                    }

                    TimingStats stats = new TimingStats();
                    return(new PositionEvaluationBatch(isWDL, hasM, positions.NumPos, policies, w, l, m, null, stats));
                }
            }