/// <summary>
        /// Checks that the shape of the rank 2 observation input placeholder is the same as the corresponding sensor.
        /// </summary>
        /// <param name="tensorProxy">The tensor that is expected by the model</param>
        /// <param name="sensor">The sensor that produces the visual observation.</param>
        /// <returns>
        /// If the Check failed, returns a string containing information about why the
        /// check failed. If the check passed, returns null.
        /// </returns>
        static FailedCheck CheckRankTwoObsShape(
            TensorProxy tensorProxy, ISensor sensor)
        {
            var shape  = sensor.GetObservationSpec().Shape;
            var dim1Bp = shape[0];
            var dim2Bp = shape[1];
            var dim1T  = tensorProxy.Channels;
            var dim2T  = tensorProxy.Width;
            var dim3T  = tensorProxy.Height;

            if ((dim1Bp != dim1T) || (dim2Bp != dim2T))
            {
                var proxyDimStr = $"[?x{dim1T}x{dim2T}]";
                if (dim3T > 1)
                {
                    proxyDimStr = $"[?x{dim3T}x{dim2T}x{dim1T}]";
                }
                return(FailedCheck.Warning($"An Observation of the model does not match. " +
                                           $"Received TensorProxy of shape [?x{dim1Bp}x{dim2Bp}] but " +
                                           $"was expecting {proxyDimStr} for the {sensor.GetName()} Sensor."
                                           ));
            }
            return(null);
        }
示例#2
0
        public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable <AgentInfoSensorsPair> infos)
        {
            TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
            var vecObsSizeT = tensorProxy.shape[tensorProxy.shape.Length - 1];
            var agentIndex  = 0;

            foreach (var info in infos)
            {
                if (info.agentInfo.done)
                {
                    // If the agent is done, we might have a stale reference to the sensors
                    // e.g. a dependent object might have been disposed.
                    // To avoid this, just fill observation with zeroes instead of calling sensor.Write.
                    TensorUtils.FillTensorBatch(tensorProxy, agentIndex, 0.0f);
                }
                else
                {
                    var tensorOffset = 0;
                    // Write each sensor consecutively to the tensor
                    foreach (var sensorIndex in m_SensorIndices)
                    {
                        var sensor = info.sensors[sensorIndex];
                        m_ObservationWriter.SetTarget(tensorProxy, agentIndex, tensorOffset);
                        var numWritten = sensor.Write(m_ObservationWriter);
                        tensorOffset += numWritten;
                    }
                    Debug.AssertFormat(
                        tensorOffset == vecObsSizeT,
                        "mismatch between vector observation size ({0}) and number of observations written ({1})",
                        vecObsSizeT, tensorOffset
                        );
                }

                agentIndex++;
            }
        }
示例#3
0
 public void Generate(TensorProxy tensorProxy, int batchSize, IList <AgentInfoSensorsPair> infos)
 {
     tensorProxy.data?.Dispose();
     tensorProxy.data    = m_Allocator.Alloc(new TensorShape(1, 1));
     tensorProxy.data[0] = batchSize;
 }
示例#4
0
 public void Generate(TensorProxy tensorProxy, int batchSize, IList <AgentInfoSensorsPair> infos)
 {
     TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
     TensorUtils.FillTensorWithRandomNormal(tensorProxy, m_RandomNormal);
 }
示例#5
0
 public void Generate(TensorProxy tensorProxy, int batchSize, IList <AgentInfoSensorsPair> infos)
 {
     TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
 }
示例#6
0
        public void Apply(TensorProxy tensorProxy, IEnumerable <int> actionIds, Dictionary <int, ActionBuffers> lastActions)
        {
            //var tensorDataProbabilities = tensorProxy.Data as float[,];
            var idActionPairList   = actionIds as List <int> ?? actionIds.ToList();
            var batchSize          = idActionPairList.Count;
            var actionValues       = new float[batchSize, m_ActionSize.Length];
            var startActionIndices = Utilities.CumSum(m_ActionSize);

            for (var actionIndex = 0; actionIndex < m_ActionSize.Length; actionIndex++)
            {
                var nBranchAction = m_ActionSize[actionIndex];
                var actionProbs   = new TensorProxy()
                {
                    valueType = TensorProxy.TensorType.FloatingPoint,
                    shape     = new long[] { batchSize, nBranchAction },
                    data      = m_Allocator.Alloc(new TensorShape(batchSize, nBranchAction))
                };

                for (var batchIndex = 0; batchIndex < batchSize; batchIndex++)
                {
                    for (var branchActionIndex = 0;
                         branchActionIndex < nBranchAction;
                         branchActionIndex++)
                    {
                        actionProbs.data[batchIndex, branchActionIndex] =
                            tensorProxy.data[batchIndex, startActionIndices[actionIndex] + branchActionIndex];
                    }
                }

                var outputTensor = new TensorProxy()
                {
                    valueType = TensorProxy.TensorType.FloatingPoint,
                    shape     = new long[] { batchSize, 1 },
                    data      = m_Allocator.Alloc(new TensorShape(batchSize, 1))
                };

                Eval(actionProbs, outputTensor, m_Multinomial);

                for (var ii = 0; ii < batchSize; ii++)
                {
                    actionValues[ii, actionIndex] = outputTensor.data[ii, 0];
                }
                actionProbs.data.Dispose();
                outputTensor.data.Dispose();
            }
            var agentIndex = 0;

            foreach (int agentId in actionIds)
            {
                if (lastActions.ContainsKey(agentId))
                {
                    var actionBuffer = lastActions[agentId];
                    if (actionBuffer.IsEmpty())
                    {
                        actionBuffer         = new ActionBuffers(m_ActionSpec);
                        lastActions[agentId] = actionBuffer;
                    }
                    var discreteBuffer = actionBuffer.DiscreteActions;
                    for (var j = 0; j < m_ActionSize.Length; j++)
                    {
                        discreteBuffer[j] = (int)actionValues[agentIndex, j];
                    }
                }
                agentIndex++;
            }
        }